1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "fmt"
12 "io"
13 "log"
14 "net"
15 "net/http"
16 "net/http/internal/ascii"
17 "net/textproto"
18 "net/url"
19 "strings"
20 "sync"
21 "time"
22
23 "golang.org/x/net/http/httpguts"
24 )
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42 type ReverseProxy struct {
43
44
45
46
47
48
49 Director func(*http.Request)
50
51
52
53 Transport http.RoundTripper
54
55
56
57
58
59
60
61
62
63
64
65 FlushInterval time.Duration
66
67
68
69
70 ErrorLog *log.Logger
71
72
73
74
75 BufferPool BufferPool
76
77
78
79
80
81
82
83
84
85
86 ModifyResponse func(*http.Response) error
87
88
89
90
91
92
93 ErrorHandler func(http.ResponseWriter, *http.Request, error)
94 }
95
96
97
98 type BufferPool interface {
99 Get() []byte
100 Put([]byte)
101 }
102
103 func singleJoiningSlash(a, b string) string {
104 aslash := strings.HasSuffix(a, "/")
105 bslash := strings.HasPrefix(b, "/")
106 switch {
107 case aslash && bslash:
108 return a + b[1:]
109 case !aslash && !bslash:
110 return a + "/" + b
111 }
112 return a + b
113 }
114
115 func joinURLPath(a, b *url.URL) (path, rawpath string) {
116 if a.RawPath == "" && b.RawPath == "" {
117 return singleJoiningSlash(a.Path, b.Path), ""
118 }
119
120
121 apath := a.EscapedPath()
122 bpath := b.EscapedPath()
123
124 aslash := strings.HasSuffix(apath, "/")
125 bslash := strings.HasPrefix(bpath, "/")
126
127 switch {
128 case aslash && bslash:
129 return a.Path + b.Path[1:], apath + bpath[1:]
130 case !aslash && !bslash:
131 return a.Path + "/" + b.Path, apath + "/" + bpath
132 }
133 return a.Path + b.Path, apath + bpath
134 }
135
136
137
138
139
140
141
142
143 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
144 targetQuery := target.RawQuery
145 director := func(req *http.Request) {
146 req.URL.Scheme = target.Scheme
147 req.URL.Host = target.Host
148 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
149 if targetQuery == "" || req.URL.RawQuery == "" {
150 req.URL.RawQuery = targetQuery + req.URL.RawQuery
151 } else {
152 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
153 }
154 if _, ok := req.Header["User-Agent"]; !ok {
155
156 req.Header.Set("User-Agent", "")
157 }
158 }
159 return &ReverseProxy{Director: director}
160 }
161
162 func copyHeader(dst, src http.Header) {
163 for k, vv := range src {
164 for _, v := range vv {
165 dst.Add(k, v)
166 }
167 }
168 }
169
170
171
172
173
174
175 var hopHeaders = []string{
176 "Connection",
177 "Proxy-Connection",
178 "Keep-Alive",
179 "Proxy-Authenticate",
180 "Proxy-Authorization",
181 "Te",
182 "Trailer",
183 "Transfer-Encoding",
184 "Upgrade",
185 }
186
187 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
188 p.logf("http: proxy error: %v", err)
189 rw.WriteHeader(http.StatusBadGateway)
190 }
191
192 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
193 if p.ErrorHandler != nil {
194 return p.ErrorHandler
195 }
196 return p.defaultErrorHandler
197 }
198
199
200
201 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
202 if p.ModifyResponse == nil {
203 return true
204 }
205 if err := p.ModifyResponse(res); err != nil {
206 res.Body.Close()
207 p.getErrorHandler()(rw, req, err)
208 return false
209 }
210 return true
211 }
212
213 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
214 transport := p.Transport
215 if transport == nil {
216 transport = http.DefaultTransport
217 }
218
219 ctx := req.Context()
220 if cn, ok := rw.(http.CloseNotifier); ok {
221 var cancel context.CancelFunc
222 ctx, cancel = context.WithCancel(ctx)
223 defer cancel()
224 notifyChan := cn.CloseNotify()
225 go func() {
226 select {
227 case <-notifyChan:
228 cancel()
229 case <-ctx.Done():
230 }
231 }()
232 }
233
234 outreq := req.Clone(ctx)
235 if req.ContentLength == 0 {
236 outreq.Body = nil
237 }
238 if outreq.Body != nil {
239
240
241
242
243
244
245 defer outreq.Body.Close()
246 }
247 if outreq.Header == nil {
248 outreq.Header = make(http.Header)
249 }
250
251 p.Director(outreq)
252 outreq.Close = false
253
254 reqUpType := upgradeType(outreq.Header)
255 if !ascii.IsPrint(reqUpType) {
256 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
257 return
258 }
259 removeConnectionHeaders(outreq.Header)
260
261
262
263
264 for _, h := range hopHeaders {
265 outreq.Header.Del(h)
266 }
267
268
269
270
271
272
273 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
274 outreq.Header.Set("Te", "trailers")
275 }
276
277
278
279 if reqUpType != "" {
280 outreq.Header.Set("Connection", "Upgrade")
281 outreq.Header.Set("Upgrade", reqUpType)
282 }
283
284 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
285
286
287
288 prior, ok := outreq.Header["X-Forwarded-For"]
289 omit := ok && prior == nil
290 if len(prior) > 0 {
291 clientIP = strings.Join(prior, ", ") + ", " + clientIP
292 }
293 if !omit {
294 outreq.Header.Set("X-Forwarded-For", clientIP)
295 }
296 }
297
298 res, err := transport.RoundTrip(outreq)
299 if err != nil {
300 p.getErrorHandler()(rw, outreq, err)
301 return
302 }
303
304
305 if res.StatusCode == http.StatusSwitchingProtocols {
306 if !p.modifyResponse(rw, res, outreq) {
307 return
308 }
309 p.handleUpgradeResponse(rw, outreq, res)
310 return
311 }
312
313 removeConnectionHeaders(res.Header)
314
315 for _, h := range hopHeaders {
316 res.Header.Del(h)
317 }
318
319 if !p.modifyResponse(rw, res, outreq) {
320 return
321 }
322
323 copyHeader(rw.Header(), res.Header)
324
325
326
327 announcedTrailers := len(res.Trailer)
328 if announcedTrailers > 0 {
329 trailerKeys := make([]string, 0, len(res.Trailer))
330 for k := range res.Trailer {
331 trailerKeys = append(trailerKeys, k)
332 }
333 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
334 }
335
336 rw.WriteHeader(res.StatusCode)
337
338 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
339 if err != nil {
340 defer res.Body.Close()
341
342
343
344 if !shouldPanicOnCopyError(req) {
345 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
346 return
347 }
348 panic(http.ErrAbortHandler)
349 }
350 res.Body.Close()
351
352 if len(res.Trailer) > 0 {
353
354
355
356 if fl, ok := rw.(http.Flusher); ok {
357 fl.Flush()
358 }
359 }
360
361 if len(res.Trailer) == announcedTrailers {
362 copyHeader(rw.Header(), res.Trailer)
363 return
364 }
365
366 for k, vv := range res.Trailer {
367 k = http.TrailerPrefix + k
368 for _, v := range vv {
369 rw.Header().Add(k, v)
370 }
371 }
372 }
373
374 var inOurTests bool
375
376
377
378
379
380
381 func shouldPanicOnCopyError(req *http.Request) bool {
382 if inOurTests {
383
384 return true
385 }
386 if req.Context().Value(http.ServerContextKey) != nil {
387
388
389 return true
390 }
391
392
393 return false
394 }
395
396
397
398 func removeConnectionHeaders(h http.Header) {
399 for _, f := range h["Connection"] {
400 for _, sf := range strings.Split(f, ",") {
401 if sf = textproto.TrimString(sf); sf != "" {
402 h.Del(sf)
403 }
404 }
405 }
406 }
407
408
409
410 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
411 resCT := res.Header.Get("Content-Type")
412
413
414
415 if resCT == "text/event-stream" {
416 return -1
417 }
418
419
420 if res.ContentLength == -1 {
421 return -1
422 }
423
424 return p.FlushInterval
425 }
426
427 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
428 if flushInterval != 0 {
429 if wf, ok := dst.(writeFlusher); ok {
430 mlw := &maxLatencyWriter{
431 dst: wf,
432 latency: flushInterval,
433 }
434 defer mlw.stop()
435
436
437 mlw.flushPending = true
438 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
439
440 dst = mlw
441 }
442 }
443
444 var buf []byte
445 if p.BufferPool != nil {
446 buf = p.BufferPool.Get()
447 defer p.BufferPool.Put(buf)
448 }
449 _, err := p.copyBuffer(dst, src, buf)
450 return err
451 }
452
453
454
455 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
456 if len(buf) == 0 {
457 buf = make([]byte, 32*1024)
458 }
459 var written int64
460 for {
461 nr, rerr := src.Read(buf)
462 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
463 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
464 }
465 if nr > 0 {
466 nw, werr := dst.Write(buf[:nr])
467 if nw > 0 {
468 written += int64(nw)
469 }
470 if werr != nil {
471 return written, werr
472 }
473 if nr != nw {
474 return written, io.ErrShortWrite
475 }
476 }
477 if rerr != nil {
478 if rerr == io.EOF {
479 rerr = nil
480 }
481 return written, rerr
482 }
483 }
484 }
485
486 func (p *ReverseProxy) logf(format string, args ...interface{}) {
487 if p.ErrorLog != nil {
488 p.ErrorLog.Printf(format, args...)
489 } else {
490 log.Printf(format, args...)
491 }
492 }
493
494 type writeFlusher interface {
495 io.Writer
496 http.Flusher
497 }
498
499 type maxLatencyWriter struct {
500 dst writeFlusher
501 latency time.Duration
502
503 mu sync.Mutex
504 t *time.Timer
505 flushPending bool
506 }
507
508 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
509 m.mu.Lock()
510 defer m.mu.Unlock()
511 n, err = m.dst.Write(p)
512 if m.latency < 0 {
513 m.dst.Flush()
514 return
515 }
516 if m.flushPending {
517 return
518 }
519 if m.t == nil {
520 m.t = time.AfterFunc(m.latency, m.delayedFlush)
521 } else {
522 m.t.Reset(m.latency)
523 }
524 m.flushPending = true
525 return
526 }
527
528 func (m *maxLatencyWriter) delayedFlush() {
529 m.mu.Lock()
530 defer m.mu.Unlock()
531 if !m.flushPending {
532 return
533 }
534 m.dst.Flush()
535 m.flushPending = false
536 }
537
538 func (m *maxLatencyWriter) stop() {
539 m.mu.Lock()
540 defer m.mu.Unlock()
541 m.flushPending = false
542 if m.t != nil {
543 m.t.Stop()
544 }
545 }
546
547 func upgradeType(h http.Header) string {
548 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
549 return ""
550 }
551 return h.Get("Upgrade")
552 }
553
554 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
555 reqUpType := upgradeType(req.Header)
556 resUpType := upgradeType(res.Header)
557 if !ascii.IsPrint(resUpType) {
558 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
559 }
560 if !ascii.EqualFold(reqUpType, resUpType) {
561 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
562 return
563 }
564
565 hj, ok := rw.(http.Hijacker)
566 if !ok {
567 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
568 return
569 }
570 backConn, ok := res.Body.(io.ReadWriteCloser)
571 if !ok {
572 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
573 return
574 }
575
576 backConnCloseCh := make(chan bool)
577 go func() {
578
579
580 select {
581 case <-req.Context().Done():
582 case <-backConnCloseCh:
583 }
584 backConn.Close()
585 }()
586
587 defer close(backConnCloseCh)
588
589 conn, brw, err := hj.Hijack()
590 if err != nil {
591 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
592 return
593 }
594 defer conn.Close()
595
596 copyHeader(rw.Header(), res.Header)
597
598 res.Header = rw.Header()
599 res.Body = nil
600 if err := res.Write(brw); err != nil {
601 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
602 return
603 }
604 if err := brw.Flush(); err != nil {
605 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
606 return
607 }
608 errc := make(chan error, 1)
609 spc := switchProtocolCopier{user: conn, backend: backConn}
610 go spc.copyToBackend(errc)
611 go spc.copyFromBackend(errc)
612 <-errc
613 return
614 }
615
616
617
618 type switchProtocolCopier struct {
619 user, backend io.ReadWriter
620 }
621
622 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
623 _, err := io.Copy(c.user, c.backend)
624 errc <- err
625 }
626
627 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
628 _, err := io.Copy(c.backend, c.user)
629 errc <- err
630 }
631
View as plain text