...

Source file src/net/http/httputil/reverseproxy_test.go

Documentation: net/http/httputil

		 1  // Copyright 2011 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  // Reverse proxy tests.
		 6  
		 7  package httputil
		 8  
		 9  import (
		10  	"bufio"
		11  	"bytes"
		12  	"context"
		13  	"errors"
		14  	"fmt"
		15  	"io"
		16  	"log"
		17  	"net/http"
		18  	"net/http/httptest"
		19  	"net/http/internal/ascii"
		20  	"net/url"
		21  	"os"
		22  	"reflect"
		23  	"sort"
		24  	"strconv"
		25  	"strings"
		26  	"sync"
		27  	"testing"
		28  	"time"
		29  )
		30  
		31  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
		32  
		33  func init() {
		34  	inOurTests = true
		35  	hopHeaders = append(hopHeaders, fakeHopHeader)
		36  }
		37  
		38  func TestReverseProxy(t *testing.T) {
		39  	const backendResponse = "I am the backend"
		40  	const backendStatus = 404
		41  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		42  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
		43  			c, _, _ := w.(http.Hijacker).Hijack()
		44  			c.Close()
		45  			return
		46  		}
		47  		if len(r.TransferEncoding) > 0 {
		48  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
		49  		}
		50  		if r.Header.Get("X-Forwarded-For") == "" {
		51  			t.Errorf("didn't get X-Forwarded-For header")
		52  		}
		53  		if c := r.Header.Get("Connection"); c != "" {
		54  			t.Errorf("handler got Connection header value %q", c)
		55  		}
		56  		if c := r.Header.Get("Te"); c != "trailers" {
		57  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
		58  		}
		59  		if c := r.Header.Get("Upgrade"); c != "" {
		60  			t.Errorf("handler got Upgrade header value %q", c)
		61  		}
		62  		if c := r.Header.Get("Proxy-Connection"); c != "" {
		63  			t.Errorf("handler got Proxy-Connection header value %q", c)
		64  		}
		65  		if g, e := r.Host, "some-name"; g != e {
		66  			t.Errorf("backend got Host header %q, want %q", g, e)
		67  		}
		68  		w.Header().Set("Trailers", "not a special header field name")
		69  		w.Header().Set("Trailer", "X-Trailer")
		70  		w.Header().Set("X-Foo", "bar")
		71  		w.Header().Set("Upgrade", "foo")
		72  		w.Header().Set(fakeHopHeader, "foo")
		73  		w.Header().Add("X-Multi-Value", "foo")
		74  		w.Header().Add("X-Multi-Value", "bar")
		75  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
		76  		w.WriteHeader(backendStatus)
		77  		w.Write([]byte(backendResponse))
		78  		w.Header().Set("X-Trailer", "trailer_value")
		79  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
		80  	}))
		81  	defer backend.Close()
		82  	backendURL, err := url.Parse(backend.URL)
		83  	if err != nil {
		84  		t.Fatal(err)
		85  	}
		86  	proxyHandler := NewSingleHostReverseProxy(backendURL)
		87  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
		88  	frontend := httptest.NewServer(proxyHandler)
		89  	defer frontend.Close()
		90  	frontendClient := frontend.Client()
		91  
		92  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
		93  	getReq.Host = "some-name"
		94  	getReq.Header.Set("Connection", "close, TE")
		95  	getReq.Header.Add("Te", "foo")
		96  	getReq.Header.Add("Te", "bar, trailers")
		97  	getReq.Header.Set("Proxy-Connection", "should be deleted")
		98  	getReq.Header.Set("Upgrade", "foo")
		99  	getReq.Close = true
	 100  	res, err := frontendClient.Do(getReq)
	 101  	if err != nil {
	 102  		t.Fatalf("Get: %v", err)
	 103  	}
	 104  	if g, e := res.StatusCode, backendStatus; g != e {
	 105  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
	 106  	}
	 107  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
	 108  		t.Errorf("got X-Foo %q; expected %q", g, e)
	 109  	}
	 110  	if c := res.Header.Get(fakeHopHeader); c != "" {
	 111  		t.Errorf("got %s header value %q", fakeHopHeader, c)
	 112  	}
	 113  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
	 114  		t.Errorf("header Trailers = %q; want %q", g, e)
	 115  	}
	 116  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
	 117  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
	 118  	}
	 119  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
	 120  		t.Fatalf("got %d SetCookies, want %d", g, e)
	 121  	}
	 122  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
	 123  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
	 124  	}
	 125  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
	 126  		t.Errorf("unexpected cookie %q", cookie.Name)
	 127  	}
	 128  	bodyBytes, _ := io.ReadAll(res.Body)
	 129  	if g, e := string(bodyBytes), backendResponse; g != e {
	 130  		t.Errorf("got body %q; expected %q", g, e)
	 131  	}
	 132  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
	 133  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
	 134  	}
	 135  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
	 136  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
	 137  	}
	 138  
	 139  	// Test that a backend failing to be reached or one which doesn't return
	 140  	// a response results in a StatusBadGateway.
	 141  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
	 142  	getReq.Close = true
	 143  	res, err = frontendClient.Do(getReq)
	 144  	if err != nil {
	 145  		t.Fatal(err)
	 146  	}
	 147  	res.Body.Close()
	 148  	if res.StatusCode != http.StatusBadGateway {
	 149  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
	 150  	}
	 151  
	 152  }
	 153  
	 154  // Issue 16875: remove any proxied headers mentioned in the "Connection"
	 155  // header value.
	 156  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
	 157  	const fakeConnectionToken = "X-Fake-Connection-Token"
	 158  	const backendResponse = "I am the backend"
	 159  
	 160  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
	 161  	// in the Request's Connection header.
	 162  	const someConnHeader = "X-Some-Conn-Header"
	 163  
	 164  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 165  		if c := r.Header.Get("Connection"); c != "" {
	 166  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
	 167  		}
	 168  		if c := r.Header.Get(fakeConnectionToken); c != "" {
	 169  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
	 170  		}
	 171  		if c := r.Header.Get(someConnHeader); c != "" {
	 172  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
	 173  		}
	 174  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
	 175  		w.Header().Add("Connection", someConnHeader)
	 176  		w.Header().Set(someConnHeader, "should be deleted")
	 177  		w.Header().Set(fakeConnectionToken, "should be deleted")
	 178  		io.WriteString(w, backendResponse)
	 179  	}))
	 180  	defer backend.Close()
	 181  	backendURL, err := url.Parse(backend.URL)
	 182  	if err != nil {
	 183  		t.Fatal(err)
	 184  	}
	 185  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 186  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 187  		proxyHandler.ServeHTTP(w, r)
	 188  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
	 189  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
	 190  		}
	 191  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
	 192  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
	 193  		}
	 194  		c := r.Header["Connection"]
	 195  		var cf []string
	 196  		for _, f := range c {
	 197  			for _, sf := range strings.Split(f, ",") {
	 198  				if sf = strings.TrimSpace(sf); sf != "" {
	 199  					cf = append(cf, sf)
	 200  				}
	 201  			}
	 202  		}
	 203  		sort.Strings(cf)
	 204  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
	 205  		sort.Strings(expectedValues)
	 206  		if !reflect.DeepEqual(cf, expectedValues) {
	 207  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
	 208  		}
	 209  	}))
	 210  	defer frontend.Close()
	 211  
	 212  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
	 213  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
	 214  	getReq.Header.Add("Connection", someConnHeader)
	 215  	getReq.Header.Set(someConnHeader, "should be deleted")
	 216  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
	 217  	res, err := frontend.Client().Do(getReq)
	 218  	if err != nil {
	 219  		t.Fatalf("Get: %v", err)
	 220  	}
	 221  	defer res.Body.Close()
	 222  	bodyBytes, err := io.ReadAll(res.Body)
	 223  	if err != nil {
	 224  		t.Fatalf("reading body: %v", err)
	 225  	}
	 226  	if got, want := string(bodyBytes), backendResponse; got != want {
	 227  		t.Errorf("got body %q; want %q", got, want)
	 228  	}
	 229  	if c := res.Header.Get("Connection"); c != "" {
	 230  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
	 231  	}
	 232  	if c := res.Header.Get(someConnHeader); c != "" {
	 233  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
	 234  	}
	 235  	if c := res.Header.Get(fakeConnectionToken); c != "" {
	 236  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
	 237  	}
	 238  }
	 239  
	 240  func TestReverseProxyStripEmptyConnection(t *testing.T) {
	 241  	// See Issue 46313.
	 242  	const backendResponse = "I am the backend"
	 243  
	 244  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
	 245  	// in the Request's Connection header.
	 246  	const someConnHeader = "X-Some-Conn-Header"
	 247  
	 248  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 249  		if c := r.Header.Values("Connection"); len(c) != 0 {
	 250  			t.Errorf("handler got header %q = %v; want empty", "Connection", c)
	 251  		}
	 252  		if c := r.Header.Get(someConnHeader); c != "" {
	 253  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
	 254  		}
	 255  		w.Header().Add("Connection", "")
	 256  		w.Header().Add("Connection", someConnHeader)
	 257  		w.Header().Set(someConnHeader, "should be deleted")
	 258  		io.WriteString(w, backendResponse)
	 259  	}))
	 260  	defer backend.Close()
	 261  	backendURL, err := url.Parse(backend.URL)
	 262  	if err != nil {
	 263  		t.Fatal(err)
	 264  	}
	 265  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 266  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 267  		proxyHandler.ServeHTTP(w, r)
	 268  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
	 269  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
	 270  		}
	 271  	}))
	 272  	defer frontend.Close()
	 273  
	 274  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
	 275  	getReq.Header.Add("Connection", "")
	 276  	getReq.Header.Add("Connection", someConnHeader)
	 277  	getReq.Header.Set(someConnHeader, "should be deleted")
	 278  	res, err := frontend.Client().Do(getReq)
	 279  	if err != nil {
	 280  		t.Fatalf("Get: %v", err)
	 281  	}
	 282  	defer res.Body.Close()
	 283  	bodyBytes, err := io.ReadAll(res.Body)
	 284  	if err != nil {
	 285  		t.Fatalf("reading body: %v", err)
	 286  	}
	 287  	if got, want := string(bodyBytes), backendResponse; got != want {
	 288  		t.Errorf("got body %q; want %q", got, want)
	 289  	}
	 290  	if c := res.Header.Get("Connection"); c != "" {
	 291  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
	 292  	}
	 293  	if c := res.Header.Get(someConnHeader); c != "" {
	 294  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
	 295  	}
	 296  }
	 297  
	 298  func TestXForwardedFor(t *testing.T) {
	 299  	const prevForwardedFor = "client ip"
	 300  	const backendResponse = "I am the backend"
	 301  	const backendStatus = 404
	 302  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 303  		if r.Header.Get("X-Forwarded-For") == "" {
	 304  			t.Errorf("didn't get X-Forwarded-For header")
	 305  		}
	 306  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
	 307  			t.Errorf("X-Forwarded-For didn't contain prior data")
	 308  		}
	 309  		w.WriteHeader(backendStatus)
	 310  		w.Write([]byte(backendResponse))
	 311  	}))
	 312  	defer backend.Close()
	 313  	backendURL, err := url.Parse(backend.URL)
	 314  	if err != nil {
	 315  		t.Fatal(err)
	 316  	}
	 317  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 318  	frontend := httptest.NewServer(proxyHandler)
	 319  	defer frontend.Close()
	 320  
	 321  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
	 322  	getReq.Host = "some-name"
	 323  	getReq.Header.Set("Connection", "close")
	 324  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
	 325  	getReq.Close = true
	 326  	res, err := frontend.Client().Do(getReq)
	 327  	if err != nil {
	 328  		t.Fatalf("Get: %v", err)
	 329  	}
	 330  	if g, e := res.StatusCode, backendStatus; g != e {
	 331  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
	 332  	}
	 333  	bodyBytes, _ := io.ReadAll(res.Body)
	 334  	if g, e := string(bodyBytes), backendResponse; g != e {
	 335  		t.Errorf("got body %q; expected %q", g, e)
	 336  	}
	 337  }
	 338  
	 339  // Issue 38079: don't append to X-Forwarded-For if it's present but nil
	 340  func TestXForwardedFor_Omit(t *testing.T) {
	 341  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 342  		if v := r.Header.Get("X-Forwarded-For"); v != "" {
	 343  			t.Errorf("got X-Forwarded-For header: %q", v)
	 344  		}
	 345  		w.Write([]byte("hi"))
	 346  	}))
	 347  	defer backend.Close()
	 348  	backendURL, err := url.Parse(backend.URL)
	 349  	if err != nil {
	 350  		t.Fatal(err)
	 351  	}
	 352  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 353  	frontend := httptest.NewServer(proxyHandler)
	 354  	defer frontend.Close()
	 355  
	 356  	oldDirector := proxyHandler.Director
	 357  	proxyHandler.Director = func(r *http.Request) {
	 358  		r.Header["X-Forwarded-For"] = nil
	 359  		oldDirector(r)
	 360  	}
	 361  
	 362  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
	 363  	getReq.Host = "some-name"
	 364  	getReq.Close = true
	 365  	res, err := frontend.Client().Do(getReq)
	 366  	if err != nil {
	 367  		t.Fatalf("Get: %v", err)
	 368  	}
	 369  	res.Body.Close()
	 370  }
	 371  
	 372  var proxyQueryTests = []struct {
	 373  	baseSuffix string // suffix to add to backend URL
	 374  	reqSuffix	string // suffix to add to frontend's request URL
	 375  	want			 string // what backend should see for final request URL (without ?)
	 376  }{
	 377  	{"", "", ""},
	 378  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
	 379  	{"", "?us=er", "us=er"},
	 380  	{"?sta=tic", "", "sta=tic"},
	 381  }
	 382  
	 383  func TestReverseProxyQuery(t *testing.T) {
	 384  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 385  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
	 386  		w.Write([]byte("hi"))
	 387  	}))
	 388  	defer backend.Close()
	 389  
	 390  	for i, tt := range proxyQueryTests {
	 391  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
	 392  		if err != nil {
	 393  			t.Fatal(err)
	 394  		}
	 395  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
	 396  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
	 397  		req.Close = true
	 398  		res, err := frontend.Client().Do(req)
	 399  		if err != nil {
	 400  			t.Fatalf("%d. Get: %v", i, err)
	 401  		}
	 402  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
	 403  			t.Errorf("%d. got query %q; expected %q", i, g, e)
	 404  		}
	 405  		res.Body.Close()
	 406  		frontend.Close()
	 407  	}
	 408  }
	 409  
	 410  func TestReverseProxyFlushInterval(t *testing.T) {
	 411  	const expected = "hi"
	 412  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 413  		w.Write([]byte(expected))
	 414  	}))
	 415  	defer backend.Close()
	 416  
	 417  	backendURL, err := url.Parse(backend.URL)
	 418  	if err != nil {
	 419  		t.Fatal(err)
	 420  	}
	 421  
	 422  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 423  	proxyHandler.FlushInterval = time.Microsecond
	 424  
	 425  	frontend := httptest.NewServer(proxyHandler)
	 426  	defer frontend.Close()
	 427  
	 428  	req, _ := http.NewRequest("GET", frontend.URL, nil)
	 429  	req.Close = true
	 430  	res, err := frontend.Client().Do(req)
	 431  	if err != nil {
	 432  		t.Fatalf("Get: %v", err)
	 433  	}
	 434  	defer res.Body.Close()
	 435  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
	 436  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
	 437  	}
	 438  }
	 439  
	 440  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
	 441  	const expected = "hi"
	 442  	stopCh := make(chan struct{})
	 443  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 444  		w.Header().Add("MyHeader", expected)
	 445  		w.WriteHeader(200)
	 446  		w.(http.Flusher).Flush()
	 447  		<-stopCh
	 448  	}))
	 449  	defer backend.Close()
	 450  	defer close(stopCh)
	 451  
	 452  	backendURL, err := url.Parse(backend.URL)
	 453  	if err != nil {
	 454  		t.Fatal(err)
	 455  	}
	 456  
	 457  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 458  	proxyHandler.FlushInterval = time.Microsecond
	 459  
	 460  	frontend := httptest.NewServer(proxyHandler)
	 461  	defer frontend.Close()
	 462  
	 463  	req, _ := http.NewRequest("GET", frontend.URL, nil)
	 464  	req.Close = true
	 465  
	 466  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
	 467  	defer cancel()
	 468  	req = req.WithContext(ctx)
	 469  
	 470  	res, err := frontend.Client().Do(req)
	 471  	if err != nil {
	 472  		t.Fatalf("Get: %v", err)
	 473  	}
	 474  	defer res.Body.Close()
	 475  
	 476  	if res.Header.Get("MyHeader") != expected {
	 477  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
	 478  	}
	 479  }
	 480  
	 481  func TestReverseProxyCancellation(t *testing.T) {
	 482  	const backendResponse = "I am the backend"
	 483  
	 484  	reqInFlight := make(chan struct{})
	 485  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 486  		close(reqInFlight) // cause the client to cancel its request
	 487  
	 488  		select {
	 489  		case <-time.After(10 * time.Second):
	 490  			// Note: this should only happen in broken implementations, and the
	 491  			// closenotify case should be instantaneous.
	 492  			t.Error("Handler never saw CloseNotify")
	 493  			return
	 494  		case <-w.(http.CloseNotifier).CloseNotify():
	 495  		}
	 496  
	 497  		w.WriteHeader(http.StatusOK)
	 498  		w.Write([]byte(backendResponse))
	 499  	}))
	 500  
	 501  	defer backend.Close()
	 502  
	 503  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
	 504  
	 505  	backendURL, err := url.Parse(backend.URL)
	 506  	if err != nil {
	 507  		t.Fatal(err)
	 508  	}
	 509  
	 510  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 511  
	 512  	// Discards errors of the form:
	 513  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
	 514  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
	 515  
	 516  	frontend := httptest.NewServer(proxyHandler)
	 517  	defer frontend.Close()
	 518  	frontendClient := frontend.Client()
	 519  
	 520  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
	 521  	go func() {
	 522  		<-reqInFlight
	 523  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
	 524  	}()
	 525  	res, err := frontendClient.Do(getReq)
	 526  	if res != nil {
	 527  		t.Errorf("got response %v; want nil", res.Status)
	 528  	}
	 529  	if err == nil {
	 530  		// This should be an error like:
	 531  		// Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
	 532  		//		use of closed network connection
	 533  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
	 534  	}
	 535  }
	 536  
	 537  func req(t *testing.T, v string) *http.Request {
	 538  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
	 539  	if err != nil {
	 540  		t.Fatal(err)
	 541  	}
	 542  	return req
	 543  }
	 544  
	 545  // Issue 12344
	 546  func TestNilBody(t *testing.T) {
	 547  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 548  		w.Write([]byte("hi"))
	 549  	}))
	 550  	defer backend.Close()
	 551  
	 552  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
	 553  		backURL, _ := url.Parse(backend.URL)
	 554  		rp := NewSingleHostReverseProxy(backURL)
	 555  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
	 556  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
	 557  		rp.ServeHTTP(w, r)
	 558  	}))
	 559  	defer frontend.Close()
	 560  
	 561  	res, err := http.Get(frontend.URL)
	 562  	if err != nil {
	 563  		t.Fatal(err)
	 564  	}
	 565  	defer res.Body.Close()
	 566  	slurp, err := io.ReadAll(res.Body)
	 567  	if err != nil {
	 568  		t.Fatal(err)
	 569  	}
	 570  	if string(slurp) != "hi" {
	 571  		t.Errorf("Got %q; want %q", slurp, "hi")
	 572  	}
	 573  }
	 574  
	 575  // Issue 15524
	 576  func TestUserAgentHeader(t *testing.T) {
	 577  	const explicitUA = "explicit UA"
	 578  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 579  		if r.URL.Path == "/noua" {
	 580  			if c := r.Header.Get("User-Agent"); c != "" {
	 581  				t.Errorf("handler got non-empty User-Agent header %q", c)
	 582  			}
	 583  			return
	 584  		}
	 585  		if c := r.Header.Get("User-Agent"); c != explicitUA {
	 586  			t.Errorf("handler got unexpected User-Agent header %q", c)
	 587  		}
	 588  	}))
	 589  	defer backend.Close()
	 590  	backendURL, err := url.Parse(backend.URL)
	 591  	if err != nil {
	 592  		t.Fatal(err)
	 593  	}
	 594  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 595  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
	 596  	frontend := httptest.NewServer(proxyHandler)
	 597  	defer frontend.Close()
	 598  	frontendClient := frontend.Client()
	 599  
	 600  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
	 601  	getReq.Header.Set("User-Agent", explicitUA)
	 602  	getReq.Close = true
	 603  	res, err := frontendClient.Do(getReq)
	 604  	if err != nil {
	 605  		t.Fatalf("Get: %v", err)
	 606  	}
	 607  	res.Body.Close()
	 608  
	 609  	getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
	 610  	getReq.Header.Set("User-Agent", "")
	 611  	getReq.Close = true
	 612  	res, err = frontendClient.Do(getReq)
	 613  	if err != nil {
	 614  		t.Fatalf("Get: %v", err)
	 615  	}
	 616  	res.Body.Close()
	 617  }
	 618  
	 619  type bufferPool struct {
	 620  	get func() []byte
	 621  	put func([]byte)
	 622  }
	 623  
	 624  func (bp bufferPool) Get() []byte	{ return bp.get() }
	 625  func (bp bufferPool) Put(v []byte) { bp.put(v) }
	 626  
	 627  func TestReverseProxyGetPutBuffer(t *testing.T) {
	 628  	const msg = "hi"
	 629  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 630  		io.WriteString(w, msg)
	 631  	}))
	 632  	defer backend.Close()
	 633  
	 634  	backendURL, err := url.Parse(backend.URL)
	 635  	if err != nil {
	 636  		t.Fatal(err)
	 637  	}
	 638  
	 639  	var (
	 640  		mu	sync.Mutex
	 641  		log []string
	 642  	)
	 643  	addLog := func(event string) {
	 644  		mu.Lock()
	 645  		defer mu.Unlock()
	 646  		log = append(log, event)
	 647  	}
	 648  	rp := NewSingleHostReverseProxy(backendURL)
	 649  	const size = 1234
	 650  	rp.BufferPool = bufferPool{
	 651  		get: func() []byte {
	 652  			addLog("getBuf")
	 653  			return make([]byte, size)
	 654  		},
	 655  		put: func(p []byte) {
	 656  			addLog("putBuf-" + strconv.Itoa(len(p)))
	 657  		},
	 658  	}
	 659  	frontend := httptest.NewServer(rp)
	 660  	defer frontend.Close()
	 661  
	 662  	req, _ := http.NewRequest("GET", frontend.URL, nil)
	 663  	req.Close = true
	 664  	res, err := frontend.Client().Do(req)
	 665  	if err != nil {
	 666  		t.Fatalf("Get: %v", err)
	 667  	}
	 668  	slurp, err := io.ReadAll(res.Body)
	 669  	res.Body.Close()
	 670  	if err != nil {
	 671  		t.Fatalf("reading body: %v", err)
	 672  	}
	 673  	if string(slurp) != msg {
	 674  		t.Errorf("msg = %q; want %q", slurp, msg)
	 675  	}
	 676  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
	 677  	mu.Lock()
	 678  	defer mu.Unlock()
	 679  	if !reflect.DeepEqual(log, wantLog) {
	 680  		t.Errorf("Log events = %q; want %q", log, wantLog)
	 681  	}
	 682  }
	 683  
	 684  func TestReverseProxy_Post(t *testing.T) {
	 685  	const backendResponse = "I am the backend"
	 686  	const backendStatus = 200
	 687  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
	 688  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 689  		slurp, err := io.ReadAll(r.Body)
	 690  		if err != nil {
	 691  			t.Errorf("Backend body read = %v", err)
	 692  		}
	 693  		if len(slurp) != len(requestBody) {
	 694  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
	 695  		}
	 696  		if !bytes.Equal(slurp, requestBody) {
	 697  			t.Error("Backend read wrong request body.") // 1MB; omitting details
	 698  		}
	 699  		w.Write([]byte(backendResponse))
	 700  	}))
	 701  	defer backend.Close()
	 702  	backendURL, err := url.Parse(backend.URL)
	 703  	if err != nil {
	 704  		t.Fatal(err)
	 705  	}
	 706  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 707  	frontend := httptest.NewServer(proxyHandler)
	 708  	defer frontend.Close()
	 709  
	 710  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
	 711  	res, err := frontend.Client().Do(postReq)
	 712  	if err != nil {
	 713  		t.Fatalf("Do: %v", err)
	 714  	}
	 715  	if g, e := res.StatusCode, backendStatus; g != e {
	 716  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
	 717  	}
	 718  	bodyBytes, _ := io.ReadAll(res.Body)
	 719  	if g, e := string(bodyBytes), backendResponse; g != e {
	 720  		t.Errorf("got body %q; expected %q", g, e)
	 721  	}
	 722  }
	 723  
	 724  type RoundTripperFunc func(*http.Request) (*http.Response, error)
	 725  
	 726  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
	 727  	return fn(req)
	 728  }
	 729  
	 730  // Issue 16036: send a Request with a nil Body when possible
	 731  func TestReverseProxy_NilBody(t *testing.T) {
	 732  	backendURL, _ := url.Parse("http://fake.tld/")
	 733  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 734  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
	 735  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
	 736  		if req.Body != nil {
	 737  			t.Error("Body != nil; want a nil Body")
	 738  		}
	 739  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
	 740  	})
	 741  	frontend := httptest.NewServer(proxyHandler)
	 742  	defer frontend.Close()
	 743  
	 744  	res, err := frontend.Client().Get(frontend.URL)
	 745  	if err != nil {
	 746  		t.Fatal(err)
	 747  	}
	 748  	defer res.Body.Close()
	 749  	if res.StatusCode != 502 {
	 750  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
	 751  	}
	 752  }
	 753  
	 754  // Issue 33142: always allocate the request headers
	 755  func TestReverseProxy_AllocatedHeader(t *testing.T) {
	 756  	proxyHandler := new(ReverseProxy)
	 757  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
	 758  	proxyHandler.Director = func(*http.Request) {}		 // noop
	 759  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
	 760  		if req.Header == nil {
	 761  			t.Error("Header == nil; want a non-nil Header")
	 762  		}
	 763  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
	 764  	})
	 765  
	 766  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
	 767  		Method:		 "GET",
	 768  		URL:				&url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
	 769  		Proto:			"HTTP/1.0",
	 770  		ProtoMajor: 1,
	 771  	})
	 772  }
	 773  
	 774  // Issue 14237. Test ModifyResponse and that an error from it
	 775  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
	 776  func TestReverseProxyModifyResponse(t *testing.T) {
	 777  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 778  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
	 779  	}))
	 780  	defer backendServer.Close()
	 781  
	 782  	rpURL, _ := url.Parse(backendServer.URL)
	 783  	rproxy := NewSingleHostReverseProxy(rpURL)
	 784  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
	 785  	rproxy.ModifyResponse = func(resp *http.Response) error {
	 786  		if resp.Header.Get("X-Hit-Mod") != "true" {
	 787  			return fmt.Errorf("tried to by-pass proxy")
	 788  		}
	 789  		return nil
	 790  	}
	 791  
	 792  	frontendProxy := httptest.NewServer(rproxy)
	 793  	defer frontendProxy.Close()
	 794  
	 795  	tests := []struct {
	 796  		url			string
	 797  		wantCode int
	 798  	}{
	 799  		{frontendProxy.URL + "/mod", http.StatusOK},
	 800  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
	 801  	}
	 802  
	 803  	for i, tt := range tests {
	 804  		resp, err := http.Get(tt.url)
	 805  		if err != nil {
	 806  			t.Fatalf("failed to reach proxy: %v", err)
	 807  		}
	 808  		if g, e := resp.StatusCode, tt.wantCode; g != e {
	 809  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
	 810  		}
	 811  		resp.Body.Close()
	 812  	}
	 813  }
	 814  
	 815  type failingRoundTripper struct{}
	 816  
	 817  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
	 818  	return nil, errors.New("some error")
	 819  }
	 820  
	 821  type staticResponseRoundTripper struct{ res *http.Response }
	 822  
	 823  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
	 824  	return rt.res, nil
	 825  }
	 826  
	 827  func TestReverseProxyErrorHandler(t *testing.T) {
	 828  	tests := []struct {
	 829  		name					 string
	 830  		wantCode			 int
	 831  		errorHandler	 func(http.ResponseWriter, *http.Request, error)
	 832  		transport			http.RoundTripper // defaults to failingRoundTripper
	 833  		modifyResponse func(*http.Response) error
	 834  	}{
	 835  		{
	 836  			name:		 "default",
	 837  			wantCode: http.StatusBadGateway,
	 838  		},
	 839  		{
	 840  			name:				 "errorhandler",
	 841  			wantCode:		 http.StatusTeapot,
	 842  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
	 843  		},
	 844  		{
	 845  			name: "modifyresponse_noerr",
	 846  			transport: staticResponseRoundTripper{
	 847  				&http.Response{StatusCode: 345, Body: http.NoBody},
	 848  			},
	 849  			modifyResponse: func(res *http.Response) error {
	 850  				res.StatusCode++
	 851  				return nil
	 852  			},
	 853  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
	 854  			wantCode:		 346,
	 855  		},
	 856  		{
	 857  			name: "modifyresponse_err",
	 858  			transport: staticResponseRoundTripper{
	 859  				&http.Response{StatusCode: 345, Body: http.NoBody},
	 860  			},
	 861  			modifyResponse: func(res *http.Response) error {
	 862  				res.StatusCode++
	 863  				return errors.New("some error to trigger errorHandler")
	 864  			},
	 865  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
	 866  			wantCode:		 http.StatusTeapot,
	 867  		},
	 868  	}
	 869  
	 870  	for _, tt := range tests {
	 871  		t.Run(tt.name, func(t *testing.T) {
	 872  			target := &url.URL{
	 873  				Scheme: "http",
	 874  				Host:	 "dummy.tld",
	 875  				Path:	 "/",
	 876  			}
	 877  			rproxy := NewSingleHostReverseProxy(target)
	 878  			rproxy.Transport = tt.transport
	 879  			rproxy.ModifyResponse = tt.modifyResponse
	 880  			if rproxy.Transport == nil {
	 881  				rproxy.Transport = failingRoundTripper{}
	 882  			}
	 883  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
	 884  			if tt.errorHandler != nil {
	 885  				rproxy.ErrorHandler = tt.errorHandler
	 886  			}
	 887  			frontendProxy := httptest.NewServer(rproxy)
	 888  			defer frontendProxy.Close()
	 889  
	 890  			resp, err := http.Get(frontendProxy.URL + "/test")
	 891  			if err != nil {
	 892  				t.Fatalf("failed to reach proxy: %v", err)
	 893  			}
	 894  			if g, e := resp.StatusCode, tt.wantCode; g != e {
	 895  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
	 896  			}
	 897  			resp.Body.Close()
	 898  		})
	 899  	}
	 900  }
	 901  
	 902  // Issue 16659: log errors from short read
	 903  func TestReverseProxy_CopyBuffer(t *testing.T) {
	 904  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 905  		out := "this call was relayed by the reverse proxy"
	 906  		// Coerce a wrong content length to induce io.UnexpectedEOF
	 907  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
	 908  		fmt.Fprintln(w, out)
	 909  	}))
	 910  	defer backendServer.Close()
	 911  
	 912  	rpURL, err := url.Parse(backendServer.URL)
	 913  	if err != nil {
	 914  		t.Fatal(err)
	 915  	}
	 916  
	 917  	var proxyLog bytes.Buffer
	 918  	rproxy := NewSingleHostReverseProxy(rpURL)
	 919  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
	 920  	donec := make(chan bool, 1)
	 921  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 922  		defer func() { donec <- true }()
	 923  		rproxy.ServeHTTP(w, r)
	 924  	}))
	 925  	defer frontendProxy.Close()
	 926  
	 927  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
	 928  		t.Fatalf("want non-nil error")
	 929  	}
	 930  	// The race detector complains about the proxyLog usage in logf in copyBuffer
	 931  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
	 932  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
	 933  	// continue after Get.
	 934  	<-donec
	 935  
	 936  	expected := []string{
	 937  		"EOF",
	 938  		"read",
	 939  	}
	 940  	for _, phrase := range expected {
	 941  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
	 942  			t.Errorf("expected log to contain phrase %q", phrase)
	 943  		}
	 944  	}
	 945  }
	 946  
	 947  type staticTransport struct {
	 948  	res *http.Response
	 949  }
	 950  
	 951  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
	 952  	return t.res, nil
	 953  }
	 954  
	 955  func BenchmarkServeHTTP(b *testing.B) {
	 956  	res := &http.Response{
	 957  		StatusCode: 200,
	 958  		Body:			 io.NopCloser(strings.NewReader("")),
	 959  	}
	 960  	proxy := &ReverseProxy{
	 961  		Director:	func(*http.Request) {},
	 962  		Transport: &staticTransport{res},
	 963  	}
	 964  
	 965  	w := httptest.NewRecorder()
	 966  	r := httptest.NewRequest("GET", "/", nil)
	 967  
	 968  	b.ReportAllocs()
	 969  	for i := 0; i < b.N; i++ {
	 970  		proxy.ServeHTTP(w, r)
	 971  	}
	 972  }
	 973  
	 974  func TestServeHTTPDeepCopy(t *testing.T) {
	 975  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 976  		w.Write([]byte("Hello Gopher!"))
	 977  	}))
	 978  	defer backend.Close()
	 979  	backendURL, err := url.Parse(backend.URL)
	 980  	if err != nil {
	 981  		t.Fatal(err)
	 982  	}
	 983  
	 984  	type result struct {
	 985  		before, after string
	 986  	}
	 987  
	 988  	resultChan := make(chan result, 1)
	 989  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	 990  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	 991  		before := r.URL.String()
	 992  		proxyHandler.ServeHTTP(w, r)
	 993  		after := r.URL.String()
	 994  		resultChan <- result{before: before, after: after}
	 995  	}))
	 996  	defer frontend.Close()
	 997  
	 998  	want := result{before: "/", after: "/"}
	 999  
	1000  	res, err := frontend.Client().Get(frontend.URL)
	1001  	if err != nil {
	1002  		t.Fatalf("Do: %v", err)
	1003  	}
	1004  	res.Body.Close()
	1005  
	1006  	got := <-resultChan
	1007  	if got != want {
	1008  		t.Errorf("got = %+v; want = %+v", got, want)
	1009  	}
	1010  }
	1011  
	1012  // Issue 18327: verify we always do a deep copy of the Request.Header map
	1013  // before any mutations.
	1014  func TestClonesRequestHeaders(t *testing.T) {
	1015  	log.SetOutput(io.Discard)
	1016  	defer log.SetOutput(os.Stderr)
	1017  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
	1018  	req.RemoteAddr = "1.2.3.4:56789"
	1019  	rp := &ReverseProxy{
	1020  		Director: func(req *http.Request) {
	1021  			req.Header.Set("From-Director", "1")
	1022  		},
	1023  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
	1024  			if v := req.Header.Get("From-Director"); v != "1" {
	1025  				t.Errorf("From-Directory value = %q; want 1", v)
	1026  			}
	1027  			return nil, io.EOF
	1028  		}),
	1029  	}
	1030  	rp.ServeHTTP(httptest.NewRecorder(), req)
	1031  
	1032  	if req.Header.Get("From-Director") == "1" {
	1033  		t.Error("Director header mutation modified caller's request")
	1034  	}
	1035  	if req.Header.Get("X-Forwarded-For") != "" {
	1036  		t.Error("X-Forward-For header mutation modified caller's request")
	1037  	}
	1038  
	1039  }
	1040  
	1041  type roundTripperFunc func(req *http.Request) (*http.Response, error)
	1042  
	1043  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
	1044  	return fn(req)
	1045  }
	1046  
	1047  func TestModifyResponseClosesBody(t *testing.T) {
	1048  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
	1049  	req.RemoteAddr = "1.2.3.4:56789"
	1050  	closeCheck := new(checkCloser)
	1051  	logBuf := new(bytes.Buffer)
	1052  	outErr := errors.New("ModifyResponse error")
	1053  	rp := &ReverseProxy{
	1054  		Director: func(req *http.Request) {},
	1055  		Transport: &staticTransport{&http.Response{
	1056  			StatusCode: 200,
	1057  			Body:			 closeCheck,
	1058  		}},
	1059  		ErrorLog: log.New(logBuf, "", 0),
	1060  		ModifyResponse: func(*http.Response) error {
	1061  			return outErr
	1062  		},
	1063  	}
	1064  	rec := httptest.NewRecorder()
	1065  	rp.ServeHTTP(rec, req)
	1066  	res := rec.Result()
	1067  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
	1068  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
	1069  	}
	1070  	if !closeCheck.closed {
	1071  		t.Errorf("body should have been closed")
	1072  	}
	1073  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
	1074  		t.Errorf("ErrorLog %q does not contain %q", g, e)
	1075  	}
	1076  }
	1077  
	1078  type checkCloser struct {
	1079  	closed bool
	1080  }
	1081  
	1082  func (cc *checkCloser) Close() error {
	1083  	cc.closed = true
	1084  	return nil
	1085  }
	1086  
	1087  func (cc *checkCloser) Read(b []byte) (int, error) {
	1088  	return len(b), nil
	1089  }
	1090  
	1091  // Issue 23643: panic on body copy error
	1092  func TestReverseProxy_PanicBodyError(t *testing.T) {
	1093  	log.SetOutput(io.Discard)
	1094  	defer log.SetOutput(os.Stderr)
	1095  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	1096  		out := "this call was relayed by the reverse proxy"
	1097  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
	1098  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
	1099  		fmt.Fprintln(w, out)
	1100  	}))
	1101  	defer backendServer.Close()
	1102  
	1103  	rpURL, err := url.Parse(backendServer.URL)
	1104  	if err != nil {
	1105  		t.Fatal(err)
	1106  	}
	1107  
	1108  	rproxy := NewSingleHostReverseProxy(rpURL)
	1109  
	1110  	// Ensure that the handler panics when the body read encounters an
	1111  	// io.ErrUnexpectedEOF
	1112  	defer func() {
	1113  		err := recover()
	1114  		if err == nil {
	1115  			t.Fatal("handler should have panicked")
	1116  		}
	1117  		if err != http.ErrAbortHandler {
	1118  			t.Fatal("expected ErrAbortHandler, got", err)
	1119  		}
	1120  	}()
	1121  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
	1122  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
	1123  }
	1124  
	1125  // Issue #46866: panic without closing incoming request body causes a panic
	1126  func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
	1127  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	1128  		out := "this call was relayed by the reverse proxy"
	1129  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
	1130  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
	1131  		fmt.Fprintln(w, out)
	1132  	}))
	1133  	defer backend.Close()
	1134  	backendURL, err := url.Parse(backend.URL)
	1135  	if err != nil {
	1136  		t.Fatal(err)
	1137  	}
	1138  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	1139  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
	1140  	frontend := httptest.NewServer(proxyHandler)
	1141  	defer frontend.Close()
	1142  	frontendClient := frontend.Client()
	1143  
	1144  	var wg sync.WaitGroup
	1145  	for i := 0; i < 2; i++ {
	1146  		wg.Add(1)
	1147  		go func() {
	1148  			defer wg.Done()
	1149  			for j := 0; j < 10; j++ {
	1150  				const reqLen = 6 * 1024 * 1024
	1151  				req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
	1152  				req.ContentLength = reqLen
	1153  				resp, _ := frontendClient.Transport.RoundTrip(req)
	1154  				if resp != nil {
	1155  					io.Copy(io.Discard, resp.Body)
	1156  					resp.Body.Close()
	1157  				}
	1158  			}
	1159  		}()
	1160  	}
	1161  	wg.Wait()
	1162  }
	1163  
	1164  func TestSelectFlushInterval(t *testing.T) {
	1165  	tests := []struct {
	1166  		name string
	1167  		p		*ReverseProxy
	1168  		res	*http.Response
	1169  		want time.Duration
	1170  	}{
	1171  		{
	1172  			name: "default",
	1173  			res:	&http.Response{},
	1174  			p:		&ReverseProxy{FlushInterval: 123},
	1175  			want: 123,
	1176  		},
	1177  		{
	1178  			name: "server-sent events overrides non-zero",
	1179  			res: &http.Response{
	1180  				Header: http.Header{
	1181  					"Content-Type": {"text/event-stream"},
	1182  				},
	1183  			},
	1184  			p:		&ReverseProxy{FlushInterval: 123},
	1185  			want: -1,
	1186  		},
	1187  		{
	1188  			name: "server-sent events overrides zero",
	1189  			res: &http.Response{
	1190  				Header: http.Header{
	1191  					"Content-Type": {"text/event-stream"},
	1192  				},
	1193  			},
	1194  			p:		&ReverseProxy{FlushInterval: 0},
	1195  			want: -1,
	1196  		},
	1197  		{
	1198  			name: "Content-Length: -1, overrides non-zero",
	1199  			res: &http.Response{
	1200  				ContentLength: -1,
	1201  			},
	1202  			p:		&ReverseProxy{FlushInterval: 123},
	1203  			want: -1,
	1204  		},
	1205  		{
	1206  			name: "Content-Length: -1, overrides zero",
	1207  			res: &http.Response{
	1208  				ContentLength: -1,
	1209  			},
	1210  			p:		&ReverseProxy{FlushInterval: 0},
	1211  			want: -1,
	1212  		},
	1213  	}
	1214  	for _, tt := range tests {
	1215  		t.Run(tt.name, func(t *testing.T) {
	1216  			got := tt.p.flushInterval(tt.res)
	1217  			if got != tt.want {
	1218  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
	1219  			}
	1220  		})
	1221  	}
	1222  }
	1223  
	1224  func TestReverseProxyWebSocket(t *testing.T) {
	1225  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	1226  		if upgradeType(r.Header) != "websocket" {
	1227  			t.Error("unexpected backend request")
	1228  			http.Error(w, "unexpected request", 400)
	1229  			return
	1230  		}
	1231  		c, _, err := w.(http.Hijacker).Hijack()
	1232  		if err != nil {
	1233  			t.Error(err)
	1234  			return
	1235  		}
	1236  		defer c.Close()
	1237  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
	1238  		bs := bufio.NewScanner(c)
	1239  		if !bs.Scan() {
	1240  			t.Errorf("backend failed to read line from client: %v", bs.Err())
	1241  			return
	1242  		}
	1243  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
	1244  	}))
	1245  	defer backendServer.Close()
	1246  
	1247  	backURL, _ := url.Parse(backendServer.URL)
	1248  	rproxy := NewSingleHostReverseProxy(backURL)
	1249  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
	1250  	rproxy.ModifyResponse = func(res *http.Response) error {
	1251  		res.Header.Add("X-Modified", "true")
	1252  		return nil
	1253  	}
	1254  
	1255  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
	1256  		rw.Header().Set("X-Header", "X-Value")
	1257  		rproxy.ServeHTTP(rw, req)
	1258  		if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
	1259  			t.Errorf("response writer X-Modified header = %q; want %q", got, want)
	1260  		}
	1261  	})
	1262  
	1263  	frontendProxy := httptest.NewServer(handler)
	1264  	defer frontendProxy.Close()
	1265  
	1266  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
	1267  	req.Header.Set("Connection", "Upgrade")
	1268  	req.Header.Set("Upgrade", "websocket")
	1269  
	1270  	c := frontendProxy.Client()
	1271  	res, err := c.Do(req)
	1272  	if err != nil {
	1273  		t.Fatal(err)
	1274  	}
	1275  	if res.StatusCode != 101 {
	1276  		t.Fatalf("status = %v; want 101", res.Status)
	1277  	}
	1278  
	1279  	got := res.Header.Get("X-Header")
	1280  	want := "X-Value"
	1281  	if got != want {
	1282  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
	1283  	}
	1284  
	1285  	if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
	1286  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
	1287  	}
	1288  	rwc, ok := res.Body.(io.ReadWriteCloser)
	1289  	if !ok {
	1290  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
	1291  	}
	1292  	defer rwc.Close()
	1293  
	1294  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
	1295  		t.Errorf("response X-Modified header = %q; want %q", got, want)
	1296  	}
	1297  
	1298  	io.WriteString(rwc, "Hello\n")
	1299  	bs := bufio.NewScanner(rwc)
	1300  	if !bs.Scan() {
	1301  		t.Fatalf("Scan: %v", bs.Err())
	1302  	}
	1303  	got = bs.Text()
	1304  	want = `backend got "Hello"`
	1305  	if got != want {
	1306  		t.Errorf("got %#q, want %#q", got, want)
	1307  	}
	1308  }
	1309  
	1310  func TestReverseProxyWebSocketCancellation(t *testing.T) {
	1311  	n := 5
	1312  	triggerCancelCh := make(chan bool, n)
	1313  	nthResponse := func(i int) string {
	1314  		return fmt.Sprintf("backend response #%d\n", i)
	1315  	}
	1316  	terminalMsg := "final message"
	1317  
	1318  	cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	1319  		if g, ws := upgradeType(r.Header), "websocket"; g != ws {
	1320  			t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
	1321  			http.Error(w, "Unexpected request", 400)
	1322  			return
	1323  		}
	1324  		conn, bufrw, err := w.(http.Hijacker).Hijack()
	1325  		if err != nil {
	1326  			t.Error(err)
	1327  			return
	1328  		}
	1329  		defer conn.Close()
	1330  
	1331  		upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
	1332  		if _, err := io.WriteString(conn, upgradeMsg); err != nil {
	1333  			t.Error(err)
	1334  			return
	1335  		}
	1336  		if _, _, err := bufrw.ReadLine(); err != nil {
	1337  			t.Errorf("Failed to read line from client: %v", err)
	1338  			return
	1339  		}
	1340  
	1341  		for i := 0; i < n; i++ {
	1342  			if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
	1343  				select {
	1344  				case <-triggerCancelCh:
	1345  				default:
	1346  					t.Errorf("Writing response #%d failed: %v", i, err)
	1347  				}
	1348  				return
	1349  			}
	1350  			bufrw.Flush()
	1351  			time.Sleep(time.Second)
	1352  		}
	1353  		if _, err := bufrw.WriteString(terminalMsg); err != nil {
	1354  			select {
	1355  			case <-triggerCancelCh:
	1356  			default:
	1357  				t.Errorf("Failed to write terminal message: %v", err)
	1358  			}
	1359  		}
	1360  		bufrw.Flush()
	1361  	}))
	1362  	defer cst.Close()
	1363  
	1364  	backendURL, _ := url.Parse(cst.URL)
	1365  	rproxy := NewSingleHostReverseProxy(backendURL)
	1366  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
	1367  	rproxy.ModifyResponse = func(res *http.Response) error {
	1368  		res.Header.Add("X-Modified", "true")
	1369  		return nil
	1370  	}
	1371  
	1372  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
	1373  		rw.Header().Set("X-Header", "X-Value")
	1374  		ctx, cancel := context.WithCancel(req.Context())
	1375  		go func() {
	1376  			<-triggerCancelCh
	1377  			cancel()
	1378  		}()
	1379  		rproxy.ServeHTTP(rw, req.WithContext(ctx))
	1380  	})
	1381  
	1382  	frontendProxy := httptest.NewServer(handler)
	1383  	defer frontendProxy.Close()
	1384  
	1385  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
	1386  	req.Header.Set("Connection", "Upgrade")
	1387  	req.Header.Set("Upgrade", "websocket")
	1388  
	1389  	res, err := frontendProxy.Client().Do(req)
	1390  	if err != nil {
	1391  		t.Fatalf("Dialing to frontend proxy: %v", err)
	1392  	}
	1393  	defer res.Body.Close()
	1394  	if g, w := res.StatusCode, 101; g != w {
	1395  		t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
	1396  	}
	1397  
	1398  	if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
	1399  		t.Errorf("X-Header mismatch\n\tgot:	%q\n\twant: %q", g, w)
	1400  	}
	1401  
	1402  	if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
	1403  		t.Fatalf("Upgrade header mismatch\n\tgot:	%q\n\twant: %q", g, w)
	1404  	}
	1405  
	1406  	rwc, ok := res.Body.(io.ReadWriteCloser)
	1407  	if !ok {
	1408  		t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
	1409  	}
	1410  
	1411  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
	1412  		t.Errorf("response X-Modified header = %q; want %q", got, want)
	1413  	}
	1414  
	1415  	if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
	1416  		t.Fatalf("Failed to write first message: %v", err)
	1417  	}
	1418  
	1419  	// Read loop.
	1420  
	1421  	br := bufio.NewReader(rwc)
	1422  	for {
	1423  		line, err := br.ReadString('\n')
	1424  		switch {
	1425  		case line == terminalMsg: // this case before "err == io.EOF"
	1426  			t.Fatalf("The websocket request was not canceled, unfortunately!")
	1427  
	1428  		case err == io.EOF:
	1429  			return
	1430  
	1431  		case err != nil:
	1432  			t.Fatalf("Unexpected error: %v", err)
	1433  
	1434  		case line == nthResponse(0): // We've gotten the first response back
	1435  			// Let's trigger a cancel.
	1436  			close(triggerCancelCh)
	1437  		}
	1438  	}
	1439  }
	1440  
	1441  func TestUnannouncedTrailer(t *testing.T) {
	1442  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
	1443  		w.WriteHeader(http.StatusOK)
	1444  		w.(http.Flusher).Flush()
	1445  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
	1446  	}))
	1447  	defer backend.Close()
	1448  	backendURL, err := url.Parse(backend.URL)
	1449  	if err != nil {
	1450  		t.Fatal(err)
	1451  	}
	1452  	proxyHandler := NewSingleHostReverseProxy(backendURL)
	1453  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
	1454  	frontend := httptest.NewServer(proxyHandler)
	1455  	defer frontend.Close()
	1456  	frontendClient := frontend.Client()
	1457  
	1458  	res, err := frontendClient.Get(frontend.URL)
	1459  	if err != nil {
	1460  		t.Fatalf("Get: %v", err)
	1461  	}
	1462  
	1463  	io.ReadAll(res.Body)
	1464  
	1465  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
	1466  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
	1467  	}
	1468  
	1469  }
	1470  
	1471  func TestSingleJoinSlash(t *testing.T) {
	1472  	tests := []struct {
	1473  		slasha	 string
	1474  		slashb	 string
	1475  		expected string
	1476  	}{
	1477  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
	1478  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
	1479  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
	1480  		{"https://www.google.com", "", "https://www.google.com/"},
	1481  		{"", "favicon.ico", "/favicon.ico"},
	1482  	}
	1483  	for _, tt := range tests {
	1484  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
	1485  			t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
	1486  				tt.slasha,
	1487  				tt.slashb,
	1488  				tt.expected,
	1489  				got)
	1490  		}
	1491  	}
	1492  }
	1493  
	1494  func TestJoinURLPath(t *testing.T) {
	1495  	tests := []struct {
	1496  		a				*url.URL
	1497  		b				*url.URL
	1498  		wantPath string
	1499  		wantRaw	string
	1500  	}{
	1501  		{&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
	1502  		{&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
	1503  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
	1504  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
	1505  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
	1506  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
	1507  	}
	1508  
	1509  	for _, tt := range tests {
	1510  		p, rp := joinURLPath(tt.a, tt.b)
	1511  		if p != tt.wantPath || rp != tt.wantRaw {
	1512  			t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
	1513  				tt.a.Path, tt.a.RawPath,
	1514  				tt.b.Path, tt.b.RawPath,
	1515  				tt.wantPath, tt.wantRaw,
	1516  				p, rp)
	1517  		}
	1518  	}
	1519  }
	1520  

View as plain text