Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "encoding/json"
17 "errors"
18 "fmt"
19 "internal/testenv"
20 "io"
21 "log"
22 "math/rand"
23 "net"
24 . "net/http"
25 "net/http/httptest"
26 "net/http/httputil"
27 "net/http/internal"
28 "net/http/internal/testcert"
29 "net/url"
30 "os"
31 "os/exec"
32 "path/filepath"
33 "reflect"
34 "regexp"
35 "runtime"
36 "runtime/debug"
37 "strconv"
38 "strings"
39 "sync"
40 "sync/atomic"
41 "syscall"
42 "testing"
43 "time"
44 )
45
46 type dummyAddr string
47 type oneConnListener struct {
48 conn net.Conn
49 }
50
51 func (l *oneConnListener) Accept() (c net.Conn, err error) {
52 c = l.conn
53 if c == nil {
54 err = io.EOF
55 return
56 }
57 err = nil
58 l.conn = nil
59 return
60 }
61
62 func (l *oneConnListener) Close() error {
63 return nil
64 }
65
66 func (l *oneConnListener) Addr() net.Addr {
67 return dummyAddr("test-address")
68 }
69
70 func (a dummyAddr) Network() string {
71 return string(a)
72 }
73
74 func (a dummyAddr) String() string {
75 return string(a)
76 }
77
78 type noopConn struct{}
79
80 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
81 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
82 func (noopConn) SetDeadline(t time.Time) error { return nil }
83 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
84 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
85
86 type rwTestConn struct {
87 io.Reader
88 io.Writer
89 noopConn
90
91 closeFunc func() error
92 closec chan bool
93 }
94
95 func (c *rwTestConn) Close() error {
96 if c.closeFunc != nil {
97 return c.closeFunc()
98 }
99 select {
100 case c.closec <- true:
101 default:
102 }
103 return nil
104 }
105
106 type testConn struct {
107 readMu sync.Mutex
108 readBuf bytes.Buffer
109 writeBuf bytes.Buffer
110 closec chan bool
111 noopConn
112 }
113
114 func (c *testConn) Read(b []byte) (int, error) {
115 c.readMu.Lock()
116 defer c.readMu.Unlock()
117 return c.readBuf.Read(b)
118 }
119
120 func (c *testConn) Write(b []byte) (int, error) {
121 return c.writeBuf.Write(b)
122 }
123
124 func (c *testConn) Close() error {
125 select {
126 case c.closec <- true:
127 default:
128 }
129 return nil
130 }
131
132
133
134 func reqBytes(req string) []byte {
135 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
136 }
137
138 type handlerTest struct {
139 logbuf bytes.Buffer
140 handler Handler
141 }
142
143 func newHandlerTest(h Handler) handlerTest {
144 return handlerTest{handler: h}
145 }
146
147 func (ht *handlerTest) rawResponse(req string) string {
148 reqb := reqBytes(req)
149 var output bytes.Buffer
150 conn := &rwTestConn{
151 Reader: bytes.NewReader(reqb),
152 Writer: &output,
153 closec: make(chan bool, 1),
154 }
155 ln := &oneConnListener{conn: conn}
156 srv := &Server{
157 ErrorLog: log.New(&ht.logbuf, "", 0),
158 Handler: ht.handler,
159 }
160 go srv.Serve(ln)
161 <-conn.closec
162 return output.String()
163 }
164
165 func TestConsumingBodyOnNextConn(t *testing.T) {
166 t.Parallel()
167 defer afterTest(t)
168 conn := new(testConn)
169 for i := 0; i < 2; i++ {
170 conn.readBuf.Write([]byte(
171 "POST / HTTP/1.1\r\n" +
172 "Host: test\r\n" +
173 "Content-Length: 11\r\n" +
174 "\r\n" +
175 "foo=1&bar=1"))
176 }
177
178 reqNum := 0
179 ch := make(chan *Request)
180 servech := make(chan error)
181 listener := &oneConnListener{conn}
182 handler := func(res ResponseWriter, req *Request) {
183 reqNum++
184 ch <- req
185 }
186
187 go func() {
188 servech <- Serve(listener, HandlerFunc(handler))
189 }()
190
191 var req *Request
192 req = <-ch
193 if req == nil {
194 t.Fatal("Got nil first request.")
195 }
196 if req.Method != "POST" {
197 t.Errorf("For request #1's method, got %q; expected %q",
198 req.Method, "POST")
199 }
200
201 req = <-ch
202 if req == nil {
203 t.Fatal("Got nil first request.")
204 }
205 if req.Method != "POST" {
206 t.Errorf("For request #2's method, got %q; expected %q",
207 req.Method, "POST")
208 }
209
210 if serveerr := <-servech; serveerr != io.EOF {
211 t.Errorf("Serve returned %q; expected EOF", serveerr)
212 }
213 }
214
215 type stringHandler string
216
217 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
218 w.Header().Set("Result", string(s))
219 }
220
221 var handlers = []struct {
222 pattern string
223 msg string
224 }{
225 {"/", "Default"},
226 {"/someDir/", "someDir"},
227 {"/#/", "hash"},
228 {"someHost.com/someDir/", "someHost.com/someDir"},
229 }
230
231 var vtests = []struct {
232 url string
233 expected string
234 }{
235 {"http://localhost/someDir/apage", "someDir"},
236 {"http://localhost/%23/apage", "hash"},
237 {"http://localhost/otherDir/apage", "Default"},
238 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
239 {"http://otherHost.com/someDir/apage", "someDir"},
240 {"http://otherHost.com/aDir/apage", "Default"},
241
242 {"http://localhost/someDir", "/someDir/"},
243 {"http://localhost/%23", "/%23/"},
244 {"http://someHost.com/someDir", "/someDir/"},
245 }
246
247 func TestHostHandlers(t *testing.T) {
248 setParallel(t)
249 defer afterTest(t)
250 mux := NewServeMux()
251 for _, h := range handlers {
252 mux.Handle(h.pattern, stringHandler(h.msg))
253 }
254 ts := httptest.NewServer(mux)
255 defer ts.Close()
256
257 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
258 if err != nil {
259 t.Fatal(err)
260 }
261 defer conn.Close()
262 cc := httputil.NewClientConn(conn, nil)
263 for _, vt := range vtests {
264 var r *Response
265 var req Request
266 if req.URL, err = url.Parse(vt.url); err != nil {
267 t.Errorf("cannot parse url: %v", err)
268 continue
269 }
270 if err := cc.Write(&req); err != nil {
271 t.Errorf("writing request: %v", err)
272 continue
273 }
274 r, err := cc.Read(&req)
275 if err != nil {
276 t.Errorf("reading response: %v", err)
277 continue
278 }
279 switch r.StatusCode {
280 case StatusOK:
281 s := r.Header.Get("Result")
282 if s != vt.expected {
283 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
284 }
285 case StatusMovedPermanently:
286 s := r.Header.Get("Location")
287 if s != vt.expected {
288 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
289 }
290 default:
291 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
292 }
293 }
294 }
295
296 var serveMuxRegister = []struct {
297 pattern string
298 h Handler
299 }{
300 {"/dir/", serve(200)},
301 {"/search", serve(201)},
302 {"codesearch.google.com/search", serve(202)},
303 {"codesearch.google.com/", serve(203)},
304 {"example.com/", HandlerFunc(checkQueryStringHandler)},
305 }
306
307
308 func serve(code int) HandlerFunc {
309 return func(w ResponseWriter, r *Request) {
310 w.WriteHeader(code)
311 }
312 }
313
314
315
316
317 func checkQueryStringHandler(w ResponseWriter, r *Request) {
318 u := *r.URL
319 u.Scheme = "http"
320 u.Host = r.Host
321 u.RawQuery = ""
322 if "http://"+r.URL.RawQuery == u.String() {
323 w.WriteHeader(200)
324 } else {
325 w.WriteHeader(500)
326 }
327 }
328
329 var serveMuxTests = []struct {
330 method string
331 host string
332 path string
333 code int
334 pattern string
335 }{
336 {"GET", "google.com", "/", 404, ""},
337 {"GET", "google.com", "/dir", 301, "/dir/"},
338 {"GET", "google.com", "/dir/", 200, "/dir/"},
339 {"GET", "google.com", "/dir/file", 200, "/dir/"},
340 {"GET", "google.com", "/search", 201, "/search"},
341 {"GET", "google.com", "/search/", 404, ""},
342 {"GET", "google.com", "/search/foo", 404, ""},
343 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
344 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
345 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
346 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
347 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
348 {"GET", "images.google.com", "/search", 201, "/search"},
349 {"GET", "images.google.com", "/search/", 404, ""},
350 {"GET", "images.google.com", "/search/foo", 404, ""},
351 {"GET", "google.com", "/../search", 301, "/search"},
352 {"GET", "google.com", "/dir/..", 301, ""},
353 {"GET", "google.com", "/dir/..", 301, ""},
354 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
355
356
357
358 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
359 {"CONNECT", "google.com", "/../search", 404, ""},
360 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
361 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
362 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
363 }
364
365 func TestServeMuxHandler(t *testing.T) {
366 setParallel(t)
367 mux := NewServeMux()
368 for _, e := range serveMuxRegister {
369 mux.Handle(e.pattern, e.h)
370 }
371
372 for _, tt := range serveMuxTests {
373 r := &Request{
374 Method: tt.method,
375 Host: tt.host,
376 URL: &url.URL{
377 Path: tt.path,
378 },
379 }
380 h, pattern := mux.Handler(r)
381 rr := httptest.NewRecorder()
382 h.ServeHTTP(rr, r)
383 if pattern != tt.pattern || rr.Code != tt.code {
384 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
385 }
386 }
387 }
388
389
390 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
391 setParallel(t)
392 defer func() {
393 if err := recover(); err == nil {
394 t.Error("expected call to mux.HandleFunc to panic")
395 }
396 }()
397 mux := NewServeMux()
398 mux.HandleFunc("/", nil)
399 }
400
401 var serveMuxTests2 = []struct {
402 method string
403 host string
404 url string
405 code int
406 redirOk bool
407 }{
408 {"GET", "google.com", "/", 404, false},
409 {"GET", "example.com", "/test/?example.com/test/", 200, false},
410 {"GET", "example.com", "test/?example.com/test/", 200, true},
411 }
412
413
414
415 func TestServeMuxHandlerRedirects(t *testing.T) {
416 setParallel(t)
417 mux := NewServeMux()
418 for _, e := range serveMuxRegister {
419 mux.Handle(e.pattern, e.h)
420 }
421
422 for _, tt := range serveMuxTests2 {
423 tries := 1
424 turl := tt.url
425 for {
426 u, e := url.Parse(turl)
427 if e != nil {
428 t.Fatal(e)
429 }
430 r := &Request{
431 Method: tt.method,
432 Host: tt.host,
433 URL: u,
434 }
435 h, _ := mux.Handler(r)
436 rr := httptest.NewRecorder()
437 h.ServeHTTP(rr, r)
438 if rr.Code != 301 {
439 if rr.Code != tt.code {
440 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
441 }
442 break
443 }
444 if !tt.redirOk {
445 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
446 break
447 }
448 turl = rr.HeaderMap.Get("Location")
449 tries--
450 }
451 if tries < 0 {
452 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
453 }
454 }
455 }
456
457
458 func TestMuxRedirectLeadingSlashes(t *testing.T) {
459 setParallel(t)
460 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
461 for _, path := range paths {
462 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
463 if err != nil {
464 t.Errorf("%s", err)
465 }
466 mux := NewServeMux()
467 resp := httptest.NewRecorder()
468
469 mux.ServeHTTP(resp, req)
470
471 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
472 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
473 return
474 }
475
476 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
477 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
478 return
479 }
480 }
481 }
482
483
484
485
486
487 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
488 setParallel(t)
489 defer afterTest(t)
490
491 writeBackQuery := func(w ResponseWriter, r *Request) {
492 fmt.Fprintf(w, "%s", r.URL.RawQuery)
493 }
494
495 mux := NewServeMux()
496 mux.HandleFunc("/testOne", writeBackQuery)
497 mux.HandleFunc("/testTwo/", writeBackQuery)
498 mux.HandleFunc("/testThree", writeBackQuery)
499 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
500 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
501 })
502
503 ts := httptest.NewServer(mux)
504 defer ts.Close()
505
506 tests := [...]struct {
507 path string
508 method string
509 want string
510 statusOk bool
511 }{
512 0: {"/testOne?this=that", "GET", "this=that", true},
513 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
514 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
515 3: {"/testTwo?", "GET", "", true},
516 4: {"/testThree?foo", "GET", "foo", true},
517 5: {"/testThree/?foo", "GET", "foo:bar", true},
518 6: {"/testThree?foo", "CONNECT", "foo", true},
519 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
520
521
522 8: {"/testOne/foo/..?foo", "GET", "foo", true},
523 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
524 }
525
526 for i, tt := range tests {
527 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
528 res, err := ts.Client().Do(req)
529 if err != nil {
530 continue
531 }
532 slurp, _ := io.ReadAll(res.Body)
533 res.Body.Close()
534 if !tt.statusOk {
535 if got, want := res.StatusCode, 404; got != want {
536 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
537 }
538 }
539 if got, want := string(slurp), tt.want; got != want {
540 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
541 }
542 }
543 }
544
545 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
546 setParallel(t)
547 defer afterTest(t)
548
549 mux := NewServeMux()
550 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
551 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
552 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
553 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
554 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
555 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
556
557 tests := []struct {
558 method string
559 url string
560 code int
561 loc string
562 want string
563 }{
564 {"GET", "http://example.com/", 404, "", ""},
565 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
566 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
567 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
568 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
569 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
570 {"CONNECT", "http://example.com/", 404, "", ""},
571 {"CONNECT", "http://example.com:3000/", 404, "", ""},
572 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
573 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
574 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
575 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
576 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
577 }
578
579 ts := httptest.NewServer(mux)
580 defer ts.Close()
581
582 for i, tt := range tests {
583 req, _ := NewRequest(tt.method, tt.url, nil)
584 w := httptest.NewRecorder()
585 mux.ServeHTTP(w, req)
586
587 if got, want := w.Code, tt.code; got != want {
588 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
589 }
590
591 if tt.code == 301 {
592 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
593 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
594 }
595 } else {
596 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
597 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
598 }
599 }
600 }
601 }
602
603 func TestShouldRedirectConcurrency(t *testing.T) {
604 setParallel(t)
605 defer afterTest(t)
606
607 mux := NewServeMux()
608 ts := httptest.NewServer(mux)
609 defer ts.Close()
610 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
611 }
612
613 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
614 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
615 func benchmarkServeMux(b *testing.B, runHandler bool) {
616 type test struct {
617 path string
618 code int
619 req *Request
620 }
621
622
623 var tests []test
624 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
625 for _, e := range endpoints {
626 for i := 200; i < 230; i++ {
627 p := fmt.Sprintf("/%s/%d/", e, i)
628 tests = append(tests, test{
629 path: p,
630 code: i,
631 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
632 })
633 }
634 }
635 mux := NewServeMux()
636 for _, tt := range tests {
637 mux.Handle(tt.path, serve(tt.code))
638 }
639
640 rw := httptest.NewRecorder()
641 b.ReportAllocs()
642 b.ResetTimer()
643 for i := 0; i < b.N; i++ {
644 for _, tt := range tests {
645 *rw = httptest.ResponseRecorder{}
646 h, pattern := mux.Handler(tt.req)
647 if runHandler {
648 h.ServeHTTP(rw, tt.req)
649 if pattern != tt.path || rw.Code != tt.code {
650 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
651 }
652 }
653 }
654 }
655 }
656
657 func TestServerTimeouts(t *testing.T) {
658 setParallel(t)
659 defer afterTest(t)
660
661 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
662 for i, timeout := range tries {
663 err := testServerTimeouts(timeout)
664 if err == nil {
665 return
666 }
667 t.Logf("failed at %v: %v", timeout, err)
668 if i != len(tries)-1 {
669 t.Logf("retrying at %v ...", tries[i+1])
670 }
671 }
672 t.Fatal("all attempts failed")
673 }
674
675 func testServerTimeouts(timeout time.Duration) error {
676 reqNum := 0
677 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
678 reqNum++
679 fmt.Fprintf(res, "req=%d", reqNum)
680 }))
681 ts.Config.ReadTimeout = timeout
682 ts.Config.WriteTimeout = timeout
683 ts.Start()
684 defer ts.Close()
685
686
687 c := ts.Client()
688 r, err := c.Get(ts.URL)
689 if err != nil {
690 return fmt.Errorf("http Get #1: %v", err)
691 }
692 got, err := io.ReadAll(r.Body)
693 expected := "req=1"
694 if string(got) != expected || err != nil {
695 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
696 string(got), err, expected)
697 }
698
699
700 t1 := time.Now()
701 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
702 if err != nil {
703 return fmt.Errorf("Dial: %v", err)
704 }
705 buf := make([]byte, 1)
706 n, err := conn.Read(buf)
707 conn.Close()
708 latency := time.Since(t1)
709 if n != 0 || err != io.EOF {
710 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
711 }
712 minLatency := timeout / 5 * 4
713 if latency < minLatency {
714 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
715 }
716
717
718
719
720 r, err = c.Get(ts.URL)
721 if err != nil {
722 return fmt.Errorf("http Get #2: %v", err)
723 }
724 got, err = io.ReadAll(r.Body)
725 r.Body.Close()
726 expected = "req=2"
727 if string(got) != expected || err != nil {
728 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
729 }
730
731 if !testing.Short() {
732 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
733 if err != nil {
734 return fmt.Errorf("long Dial: %v", err)
735 }
736 defer conn.Close()
737 go io.Copy(io.Discard, conn)
738 for i := 0; i < 5; i++ {
739 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
740 if err != nil {
741 return fmt.Errorf("on write %d: %v", i, err)
742 }
743 time.Sleep(timeout / 2)
744 }
745 }
746 return nil
747 }
748
749
750 func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) {
751 if testing.Short() {
752 t.Skip("skipping in short mode")
753 }
754 setParallel(t)
755 defer afterTest(t)
756 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {}))
757 ts.Config.WriteTimeout = 250 * time.Millisecond
758 ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
759 ts.StartTLS()
760 defer ts.Close()
761
762 c := ts.Client()
763 if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
764 t.Fatal(err)
765 }
766
767 for i := 1; i <= 3; i++ {
768 req, err := NewRequest("GET", ts.URL, nil)
769 if err != nil {
770 t.Fatal(err)
771 }
772
773
774 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
775 defer cancel()
776 req = req.WithContext(ctx)
777
778 r, err := c.Do(req)
779 if ctx.Err() == context.DeadlineExceeded {
780 t.Fatalf("http2 Get #%d response timed out", i)
781 }
782 if err != nil {
783 t.Fatalf("http2 Get #%d: %v", i, err)
784 }
785 r.Body.Close()
786 if r.ProtoMajor != 2 {
787 t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto)
788 }
789 time.Sleep(ts.Config.WriteTimeout / 2)
790 }
791 }
792
793
794
795 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
796 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
797 for i, timeout := range tries {
798 err := testFunc(timeout)
799 if err == nil {
800 return
801 }
802 t.Logf("failed at %v: %v", timeout, err)
803 if i != len(tries)-1 {
804 t.Logf("retrying at %v ...", tries[i+1])
805 }
806 }
807 t.Fatal("all attempts failed")
808 }
809
810
811 func TestHTTP2WriteDeadlineEnforcedPerStream(t *testing.T) {
812 if testing.Short() {
813 t.Skip("skipping in short mode")
814 }
815 setParallel(t)
816 defer afterTest(t)
817 tryTimeouts(t, testHTTP2WriteDeadlineEnforcedPerStream)
818 }
819
820 func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error {
821 reqNum := 0
822 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
823 reqNum++
824 if reqNum == 1 {
825 return
826 }
827 time.Sleep(timeout)
828 }))
829 ts.Config.WriteTimeout = timeout / 2
830 ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
831 ts.StartTLS()
832 defer ts.Close()
833
834 c := ts.Client()
835 if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
836 return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err)
837 }
838
839 req, err := NewRequest("GET", ts.URL, nil)
840 if err != nil {
841 return fmt.Errorf("NewRequest: %v", err)
842 }
843 r, err := c.Do(req)
844 if err != nil {
845 return fmt.Errorf("http2 Get #1: %v", err)
846 }
847 r.Body.Close()
848 if r.ProtoMajor != 2 {
849 return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto)
850 }
851
852 req, err = NewRequest("GET", ts.URL, nil)
853 if err != nil {
854 return fmt.Errorf("NewRequest: %v", err)
855 }
856 r, err = c.Do(req)
857 if err == nil {
858 r.Body.Close()
859 if r.ProtoMajor != 2 {
860 return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto)
861 }
862 return fmt.Errorf("http2 Get #2 expected error, got nil")
863 }
864 expected := "stream ID 3; INTERNAL_ERROR"
865 if !strings.Contains(err.Error(), expected) {
866 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
867 }
868 return nil
869 }
870
871
872 func TestHTTP2NoWriteDeadline(t *testing.T) {
873 if testing.Short() {
874 t.Skip("skipping in short mode")
875 }
876 setParallel(t)
877 defer afterTest(t)
878 tryTimeouts(t, testHTTP2NoWriteDeadline)
879 }
880
881 func testHTTP2NoWriteDeadline(timeout time.Duration) error {
882 reqNum := 0
883 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
884 reqNum++
885 if reqNum == 1 {
886 return
887 }
888 time.Sleep(timeout)
889 }))
890 ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
891 ts.StartTLS()
892 defer ts.Close()
893
894 c := ts.Client()
895 if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
896 return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err)
897 }
898
899 for i := 0; i < 2; i++ {
900 req, err := NewRequest("GET", ts.URL, nil)
901 if err != nil {
902 return fmt.Errorf("NewRequest: %v", err)
903 }
904 r, err := c.Do(req)
905 if err != nil {
906 return fmt.Errorf("http2 Get #%d: %v", i, err)
907 }
908 r.Body.Close()
909 if r.ProtoMajor != 2 {
910 return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto)
911 }
912 }
913 return nil
914 }
915
916
917
918
919 func TestOnlyWriteTimeout(t *testing.T) {
920 setParallel(t)
921 defer afterTest(t)
922 var (
923 mu sync.RWMutex
924 conn net.Conn
925 )
926 var afterTimeoutErrc = make(chan error, 1)
927 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) {
928 buf := make([]byte, 512<<10)
929 _, err := w.Write(buf)
930 if err != nil {
931 t.Errorf("handler Write error: %v", err)
932 return
933 }
934 mu.RLock()
935 defer mu.RUnlock()
936 if conn == nil {
937 t.Error("no established connection found")
938 return
939 }
940 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
941 _, err = w.Write(buf)
942 afterTimeoutErrc <- err
943 }))
944 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
945 ts.Start()
946 defer ts.Close()
947
948 c := ts.Client()
949
950 errc := make(chan error, 1)
951 go func() {
952 res, err := c.Get(ts.URL)
953 if err != nil {
954 errc <- err
955 return
956 }
957 _, err = io.Copy(io.Discard, res.Body)
958 res.Body.Close()
959 errc <- err
960 }()
961 select {
962 case err := <-errc:
963 if err == nil {
964 t.Errorf("expected an error from Get request")
965 }
966 case <-time.After(10 * time.Second):
967 t.Fatal("timeout waiting for Get error")
968 }
969 if err := <-afterTimeoutErrc; err == nil {
970 t.Error("expected write error after timeout")
971 }
972 }
973
974
975 type trackLastConnListener struct {
976 net.Listener
977
978 mu *sync.RWMutex
979 last *net.Conn
980 }
981
982 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
983 c, err = l.Listener.Accept()
984 if err == nil {
985 l.mu.Lock()
986 *l.last = c
987 l.mu.Unlock()
988 }
989 return
990 }
991
992
993 func TestIdentityResponse(t *testing.T) {
994 setParallel(t)
995 defer afterTest(t)
996 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
997 rw.Header().Set("Content-Length", "3")
998 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
999 switch {
1000 case req.FormValue("overwrite") == "1":
1001 _, err := rw.Write([]byte("foo TOO LONG"))
1002 if err != ErrContentLength {
1003 t.Errorf("expected ErrContentLength; got %v", err)
1004 }
1005 case req.FormValue("underwrite") == "1":
1006 rw.Header().Set("Content-Length", "500")
1007 rw.Write([]byte("too short"))
1008 default:
1009 rw.Write([]byte("foo"))
1010 }
1011 })
1012
1013 ts := httptest.NewServer(handler)
1014 defer ts.Close()
1015
1016 c := ts.Client()
1017
1018
1019
1020
1021
1022 for _, te := range []string{"", "identity"} {
1023 url := ts.URL + "/?te=" + te
1024 res, err := c.Get(url)
1025 if err != nil {
1026 t.Fatalf("error with Get of %s: %v", url, err)
1027 }
1028 if cl, expected := res.ContentLength, int64(3); cl != expected {
1029 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1030 }
1031 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1032 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1033 }
1034 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1035 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1036 url, expected, tl, res.TransferEncoding)
1037 }
1038 res.Body.Close()
1039 }
1040
1041
1042 url := ts.URL + "/?overwrite=1"
1043 res, err := c.Get(url)
1044 if err != nil {
1045 t.Fatalf("error with Get of %s: %v", url, err)
1046 }
1047 res.Body.Close()
1048
1049
1050
1051 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1052 if err != nil {
1053 t.Fatalf("error dialing: %v", err)
1054 }
1055 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1056 if err != nil {
1057 t.Fatalf("error writing: %v", err)
1058 }
1059
1060
1061 got, _ := io.ReadAll(conn)
1062 expectedSuffix := "\r\n\r\ntoo short"
1063 if !strings.HasSuffix(string(got), expectedSuffix) {
1064 t.Errorf("Expected output to end with %q; got response body %q",
1065 expectedSuffix, string(got))
1066 }
1067 }
1068
1069 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1070 setParallel(t)
1071 defer afterTest(t)
1072 s := httptest.NewServer(h)
1073 defer s.Close()
1074
1075 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1076 if err != nil {
1077 t.Fatal("dial error:", err)
1078 }
1079 defer conn.Close()
1080
1081 _, err = fmt.Fprint(conn, req)
1082 if err != nil {
1083 t.Fatal("print error:", err)
1084 }
1085
1086 r := bufio.NewReader(conn)
1087 res, err := ReadResponse(r, &Request{Method: "GET"})
1088 if err != nil {
1089 t.Fatal("ReadResponse error:", err)
1090 }
1091
1092 didReadAll := make(chan bool, 1)
1093 go func() {
1094 select {
1095 case <-time.After(5 * time.Second):
1096 t.Error("body not closed after 5s")
1097 return
1098 case <-didReadAll:
1099 }
1100 }()
1101
1102 _, err = io.ReadAll(r)
1103 if err != nil {
1104 t.Fatal("read error:", err)
1105 }
1106 didReadAll <- true
1107
1108 if !res.Close {
1109 t.Errorf("Response.Close = false; want true")
1110 }
1111 }
1112
1113 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1114 setParallel(t)
1115 defer afterTest(t)
1116 ts := httptest.NewServer(handler)
1117 defer ts.Close()
1118 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1119 if err != nil {
1120 t.Fatal(err)
1121 }
1122 defer conn.Close()
1123 br := bufio.NewReader(conn)
1124 for i := 0; i < 2; i++ {
1125 if _, err := io.WriteString(conn, req); err != nil {
1126 t.Fatal(err)
1127 }
1128 res, err := ReadResponse(br, nil)
1129 if err != nil {
1130 t.Fatalf("res %d: %v", i+1, err)
1131 }
1132 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1133 t.Fatalf("res %d body copy: %v", i+1, err)
1134 }
1135 res.Body.Close()
1136 }
1137 }
1138
1139
1140 func TestServeHTTP10Close(t *testing.T) {
1141 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1142 ServeFile(w, r, "testdata/file")
1143 }))
1144 }
1145
1146
1147 func TestClientCanClose(t *testing.T) {
1148 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1149
1150 }))
1151 }
1152
1153
1154
1155 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1156 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1157 w.Header().Set("Connection", "close")
1158 }))
1159 }
1160
1161 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1162 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1163 w.Header().Set("Connection", "close")
1164 }))
1165 }
1166
1167 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1168 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1169
1170
1171 }))
1172 }
1173
1174 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1175 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1176
1177
1178 func TestHTTP10KeepAlive204Response(t *testing.T) {
1179 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1180 }
1181
1182 func TestHTTP11KeepAlive204Response(t *testing.T) {
1183 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1184 }
1185
1186 func TestHTTP10KeepAlive304Response(t *testing.T) {
1187 testTCPConnectionStaysOpen(t,
1188 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1189 HandlerFunc(send304))
1190 }
1191
1192
1193 func TestKeepAliveFinalChunkWithEOF(t *testing.T) {
1194 setParallel(t)
1195 defer afterTest(t)
1196 cst := newClientServerTest(t, false , HandlerFunc(func(w ResponseWriter, r *Request) {
1197 w.(Flusher).Flush()
1198 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1199 }))
1200 defer cst.close()
1201 type data struct {
1202 Addr string
1203 }
1204 var addrs [2]data
1205 for i := range addrs {
1206 res, err := cst.c.Get(cst.ts.URL)
1207 if err != nil {
1208 t.Fatal(err)
1209 }
1210 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1211 t.Fatal(err)
1212 }
1213 if addrs[i].Addr == "" {
1214 t.Fatal("no address")
1215 }
1216 res.Body.Close()
1217 }
1218 if addrs[0] != addrs[1] {
1219 t.Fatalf("connection not reused")
1220 }
1221 }
1222
1223 func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) }
1224 func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) }
1225
1226 func testSetsRemoteAddr(t *testing.T, h2 bool) {
1227 setParallel(t)
1228 defer afterTest(t)
1229 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1230 fmt.Fprintf(w, "%s", r.RemoteAddr)
1231 }))
1232 defer cst.close()
1233
1234 res, err := cst.c.Get(cst.ts.URL)
1235 if err != nil {
1236 t.Fatalf("Get error: %v", err)
1237 }
1238 body, err := io.ReadAll(res.Body)
1239 if err != nil {
1240 t.Fatalf("ReadAll error: %v", err)
1241 }
1242 ip := string(body)
1243 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1244 t.Fatalf("Expected local addr; got %q", ip)
1245 }
1246 }
1247
1248 type blockingRemoteAddrListener struct {
1249 net.Listener
1250 conns chan<- net.Conn
1251 }
1252
1253 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1254 c, err := l.Listener.Accept()
1255 if err != nil {
1256 return nil, err
1257 }
1258 brac := &blockingRemoteAddrConn{
1259 Conn: c,
1260 addrs: make(chan net.Addr, 1),
1261 }
1262 l.conns <- brac
1263 return brac, nil
1264 }
1265
1266 type blockingRemoteAddrConn struct {
1267 net.Conn
1268 addrs chan net.Addr
1269 }
1270
1271 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1272 return <-c.addrs
1273 }
1274
1275
1276 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1277 defer afterTest(t)
1278 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1279 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1280 }))
1281 conns := make(chan net.Conn)
1282 ts.Listener = &blockingRemoteAddrListener{
1283 Listener: ts.Listener,
1284 conns: conns,
1285 }
1286 ts.Start()
1287 defer ts.Close()
1288
1289 c := ts.Client()
1290 c.Timeout = time.Second
1291
1292 c.Transport.(*Transport).DisableKeepAlives = true
1293
1294 fetch := func(num int, response chan<- string) {
1295 resp, err := c.Get(ts.URL)
1296 if err != nil {
1297 t.Errorf("Request %d: %v", num, err)
1298 response <- ""
1299 return
1300 }
1301 defer resp.Body.Close()
1302 body, err := io.ReadAll(resp.Body)
1303 if err != nil {
1304 t.Errorf("Request %d: %v", num, err)
1305 response <- ""
1306 return
1307 }
1308 response <- string(body)
1309 }
1310
1311
1312 response1c := make(chan string, 1)
1313 go fetch(1, response1c)
1314
1315
1316 conn1 := <-conns
1317
1318
1319 response2c := make(chan string, 1)
1320 go fetch(2, response2c)
1321 var conn2 net.Conn
1322
1323 select {
1324 case conn2 = <-conns:
1325 case <-time.After(time.Second):
1326 t.Fatal("Second Accept didn't happen")
1327 }
1328
1329
1330 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1331 IP: net.ParseIP("12.12.12.12"), Port: 12}
1332
1333
1334 response2 := <-response2c
1335 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1336 t.Fatalf("response 2 addr = %q; want %q", g, e)
1337 }
1338
1339
1340 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1341 IP: net.ParseIP("21.21.21.21"), Port: 21}
1342
1343
1344 response1 := <-response1c
1345 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1346 t.Fatalf("response 1 addr = %q; want %q", g, e)
1347 }
1348 }
1349
1350
1351
1352 func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) }
1353 func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) }
1354
1355 func testHeadResponses(t *testing.T, h2 bool) {
1356 setParallel(t)
1357 defer afterTest(t)
1358 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1359 _, err := w.Write([]byte("<html>"))
1360 if err != nil {
1361 t.Errorf("ResponseWriter.Write: %v", err)
1362 }
1363
1364
1365 _, err = io.Copy(w, strings.NewReader("789a"))
1366 if err != nil {
1367 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1368 }
1369 }))
1370 defer cst.close()
1371 res, err := cst.c.Head(cst.ts.URL)
1372 if err != nil {
1373 t.Error(err)
1374 }
1375 if len(res.TransferEncoding) > 0 {
1376 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1377 }
1378 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1379 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1380 }
1381 if v := res.ContentLength; v != 10 {
1382 t.Errorf("Content-Length: %d; want 10", v)
1383 }
1384 body, err := io.ReadAll(res.Body)
1385 if err != nil {
1386 t.Error(err)
1387 }
1388 if len(body) > 0 {
1389 t.Errorf("got unexpected body %q", string(body))
1390 }
1391 }
1392
1393 func TestTLSHandshakeTimeout(t *testing.T) {
1394 setParallel(t)
1395 defer afterTest(t)
1396 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
1397 errc := make(chanWriter, 10)
1398 ts.Config.ReadTimeout = 250 * time.Millisecond
1399 ts.Config.ErrorLog = log.New(errc, "", 0)
1400 ts.StartTLS()
1401 defer ts.Close()
1402 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1403 if err != nil {
1404 t.Fatalf("Dial: %v", err)
1405 }
1406 defer conn.Close()
1407
1408 var buf [1]byte
1409 n, err := conn.Read(buf[:])
1410 if err == nil || n != 0 {
1411 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1412 }
1413
1414 select {
1415 case v := <-errc:
1416 if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1417 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1418 }
1419 case <-time.After(5 * time.Second):
1420 t.Errorf("timeout waiting for logged error")
1421 }
1422 }
1423
1424 func TestTLSServer(t *testing.T) {
1425 setParallel(t)
1426 defer afterTest(t)
1427 ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1428 if r.TLS != nil {
1429 w.Header().Set("X-TLS-Set", "true")
1430 if r.TLS.HandshakeComplete {
1431 w.Header().Set("X-TLS-HandshakeComplete", "true")
1432 }
1433 }
1434 }))
1435 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1436 defer ts.Close()
1437
1438
1439
1440
1441
1442
1443 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1444 if err != nil {
1445 t.Fatalf("Dial: %v", err)
1446 }
1447 defer idleConn.Close()
1448
1449 if !strings.HasPrefix(ts.URL, "https://") {
1450 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1451 return
1452 }
1453 client := ts.Client()
1454 res, err := client.Get(ts.URL)
1455 if err != nil {
1456 t.Error(err)
1457 return
1458 }
1459 if res == nil {
1460 t.Errorf("got nil Response")
1461 return
1462 }
1463 defer res.Body.Close()
1464 if res.Header.Get("X-TLS-Set") != "true" {
1465 t.Errorf("expected X-TLS-Set response header")
1466 return
1467 }
1468 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1469 t.Errorf("expected X-TLS-HandshakeComplete header")
1470 }
1471 }
1472
1473 func TestServeTLS(t *testing.T) {
1474 CondSkipHTTP2(t)
1475
1476 defer afterTest(t)
1477 defer SetTestHookServerServe(nil)
1478
1479 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1480 if err != nil {
1481 t.Fatal(err)
1482 }
1483 tlsConf := &tls.Config{
1484 Certificates: []tls.Certificate{cert},
1485 }
1486
1487 ln := newLocalListener(t)
1488 defer ln.Close()
1489 addr := ln.Addr().String()
1490
1491 serving := make(chan bool, 1)
1492 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1493 serving <- true
1494 })
1495 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1496 s := &Server{
1497 Addr: addr,
1498 TLSConfig: tlsConf,
1499 Handler: handler,
1500 }
1501 errc := make(chan error, 1)
1502 go func() { errc <- s.ServeTLS(ln, "", "") }()
1503 select {
1504 case err := <-errc:
1505 t.Fatalf("ServeTLS: %v", err)
1506 case <-serving:
1507 case <-time.After(5 * time.Second):
1508 t.Fatal("timeout")
1509 }
1510
1511 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1512 InsecureSkipVerify: true,
1513 NextProtos: []string{"h2", "http/1.1"},
1514 })
1515 if err != nil {
1516 t.Fatal(err)
1517 }
1518 defer c.Close()
1519 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1520 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1521 }
1522 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1523 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1524 }
1525 }
1526
1527
1528 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1529 setParallel(t)
1530 defer afterTest(t)
1531 ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1532 t.Error("unexpected HTTPS request")
1533 }))
1534 var errBuf bytes.Buffer
1535 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1536 defer ts.Close()
1537 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1538 if err != nil {
1539 t.Fatal(err)
1540 }
1541 defer conn.Close()
1542 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1543 slurp, err := io.ReadAll(conn)
1544 if err != nil {
1545 t.Fatal(err)
1546 }
1547 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1548 if !strings.HasPrefix(string(slurp), wantPrefix) {
1549 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1550 }
1551 }
1552
1553
1554 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1555 testAutomaticHTTP2_Serve(t, nil, true)
1556 }
1557
1558 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1559 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1560 }
1561
1562 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1563 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1564 }
1565
1566 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1567 setParallel(t)
1568 defer afterTest(t)
1569 ln := newLocalListener(t)
1570 ln.Close()
1571 var s Server
1572 s.TLSConfig = tlsConf
1573 if err := s.Serve(ln); err == nil {
1574 t.Fatal("expected an error")
1575 }
1576 gotH2 := s.TLSNextProto["h2"] != nil
1577 if gotH2 != wantH2 {
1578 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1579 }
1580 }
1581
1582 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1583 setParallel(t)
1584 defer afterTest(t)
1585 ln := newLocalListener(t)
1586 ln.Close()
1587 var s Server
1588
1589
1590 s.TLSConfig = &tls.Config{
1591 NextProtos: []string{"h2"},
1592 }
1593 if err := s.Serve(ln); err == nil {
1594 t.Fatal("expected an error")
1595 }
1596 on := s.TLSNextProto["h2"] != nil
1597 if !on {
1598 t.Errorf("http2 wasn't automatically enabled")
1599 }
1600 }
1601
1602 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1603 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1604 if err != nil {
1605 t.Fatal(err)
1606 }
1607 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1608 Certificates: []tls.Certificate{cert},
1609 })
1610 }
1611
1612 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1613 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1614 if err != nil {
1615 t.Fatal(err)
1616 }
1617 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1618 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1619 return &cert, nil
1620 },
1621 })
1622 }
1623
1624 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1625 CondSkipHTTP2(t)
1626
1627 defer afterTest(t)
1628 defer SetTestHookServerServe(nil)
1629 var ok bool
1630 var s *Server
1631 const maxTries = 5
1632 var ln net.Listener
1633 Try:
1634 for try := 0; try < maxTries; try++ {
1635 ln = newLocalListener(t)
1636 addr := ln.Addr().String()
1637 ln.Close()
1638 t.Logf("Got %v", addr)
1639 lnc := make(chan net.Listener, 1)
1640 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1641 lnc <- ln
1642 })
1643 s = &Server{
1644 Addr: addr,
1645 TLSConfig: tlsConf,
1646 }
1647 errc := make(chan error, 1)
1648 go func() { errc <- s.ListenAndServeTLS("", "") }()
1649 select {
1650 case err := <-errc:
1651 t.Logf("On try #%v: %v", try+1, err)
1652 continue
1653 case ln = <-lnc:
1654 ok = true
1655 t.Logf("Listening on %v", ln.Addr().String())
1656 break Try
1657 }
1658 }
1659 if !ok {
1660 t.Fatalf("Failed to start up after %d tries", maxTries)
1661 }
1662 defer ln.Close()
1663 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1664 InsecureSkipVerify: true,
1665 NextProtos: []string{"h2", "http/1.1"},
1666 })
1667 if err != nil {
1668 t.Fatal(err)
1669 }
1670 defer c.Close()
1671 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1672 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1673 }
1674 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1675 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1676 }
1677 }
1678
1679 type serverExpectTest struct {
1680 contentLength int
1681 chunked bool
1682 expectation string
1683 readBody bool
1684 expectedResponse string
1685 }
1686
1687 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1688 return serverExpectTest{
1689 contentLength: contentLength,
1690 expectation: expectation,
1691 readBody: readBody,
1692 expectedResponse: expectedResponse,
1693 }
1694 }
1695
1696 var serverExpectTests = []serverExpectTest{
1697
1698 expectTest(100, "100-continue", true, "100 Continue"),
1699 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1700
1701
1702 expectTest(100, "", true, "200 OK"),
1703
1704
1705
1706 expectTest(100, "100-continue", false, "401 Unauthorized"),
1707
1708 expectTest(100, "", false, "401 Unauthorized"),
1709
1710
1711 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1712
1713
1714 expectTest(0, "100-continue", true, "200 OK"),
1715
1716 expectTest(0, "100-continue", false, "401 Unauthorized"),
1717
1718 {
1719 expectation: "100-continue",
1720 readBody: true,
1721 chunked: true,
1722 expectedResponse: "100 Continue",
1723 },
1724 }
1725
1726
1727
1728
1729 func TestServerExpect(t *testing.T) {
1730 setParallel(t)
1731 defer afterTest(t)
1732 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
1733
1734
1735
1736 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1737 io.ReadAll(r.Body)
1738 w.Write([]byte("Hi"))
1739 } else {
1740 w.WriteHeader(StatusUnauthorized)
1741 }
1742 }))
1743 defer ts.Close()
1744
1745 runTest := func(test serverExpectTest) {
1746 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1747 if err != nil {
1748 t.Fatalf("Dial: %v", err)
1749 }
1750 defer conn.Close()
1751
1752
1753
1754 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1755
1756 go func() {
1757 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
1758 if test.chunked {
1759 contentLen = "Transfer-Encoding: chunked"
1760 }
1761 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
1762 "Connection: close\r\n"+
1763 "%s\r\n"+
1764 "Expect: %s\r\nHost: foo\r\n\r\n",
1765 test.readBody, contentLen, test.expectation)
1766 if err != nil {
1767 t.Errorf("On test %#v, error writing request headers: %v", test, err)
1768 return
1769 }
1770 if writeBody {
1771 var targ io.WriteCloser = struct {
1772 io.Writer
1773 io.Closer
1774 }{
1775 conn,
1776 io.NopCloser(nil),
1777 }
1778 if test.chunked {
1779 targ = httputil.NewChunkedWriter(conn)
1780 }
1781 body := strings.Repeat("A", test.contentLength)
1782 _, err = fmt.Fprint(targ, body)
1783 if err == nil {
1784 err = targ.Close()
1785 }
1786 if err != nil {
1787 if !test.readBody {
1788
1789
1790 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
1791 return
1792 }
1793 t.Errorf("On test %#v, error writing request body: %v", test, err)
1794 }
1795 }
1796 }()
1797 bufr := bufio.NewReader(conn)
1798 line, err := bufr.ReadString('\n')
1799 if err != nil {
1800 if writeBody && !test.readBody {
1801
1802
1803
1804
1805
1806 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
1807 return
1808 }
1809 t.Fatalf("On test %#v, ReadString: %v", test, err)
1810 }
1811 if !strings.Contains(line, test.expectedResponse) {
1812 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
1813 }
1814 }
1815
1816 for _, test := range serverExpectTests {
1817 runTest(test)
1818 }
1819 }
1820
1821
1822
1823 func TestServerUnreadRequestBodyLittle(t *testing.T) {
1824 setParallel(t)
1825 defer afterTest(t)
1826 conn := new(testConn)
1827 body := strings.Repeat("x", 100<<10)
1828 conn.readBuf.Write([]byte(fmt.Sprintf(
1829 "POST / HTTP/1.1\r\n"+
1830 "Host: test\r\n"+
1831 "Content-Length: %d\r\n"+
1832 "\r\n", len(body))))
1833 conn.readBuf.Write([]byte(body))
1834
1835 done := make(chan bool)
1836
1837 readBufLen := func() int {
1838 conn.readMu.Lock()
1839 defer conn.readMu.Unlock()
1840 return conn.readBuf.Len()
1841 }
1842
1843 ls := &oneConnListener{conn}
1844 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
1845 defer close(done)
1846 if bufLen := readBufLen(); bufLen < len(body)/2 {
1847 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
1848 }
1849 rw.WriteHeader(200)
1850 rw.(Flusher).Flush()
1851 if g, e := readBufLen(), 0; g != e {
1852 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
1853 }
1854 if c := rw.Header().Get("Connection"); c != "" {
1855 t.Errorf(`Connection header = %q; want ""`, c)
1856 }
1857 }))
1858 <-done
1859 }
1860
1861
1862
1863
1864 func TestServerUnreadRequestBodyLarge(t *testing.T) {
1865 setParallel(t)
1866 if testing.Short() && testenv.Builder() == "" {
1867 t.Log("skipping in short mode")
1868 }
1869 conn := new(testConn)
1870 body := strings.Repeat("x", 1<<20)
1871 conn.readBuf.Write([]byte(fmt.Sprintf(
1872 "POST / HTTP/1.1\r\n"+
1873 "Host: test\r\n"+
1874 "Content-Length: %d\r\n"+
1875 "\r\n", len(body))))
1876 conn.readBuf.Write([]byte(body))
1877 conn.closec = make(chan bool, 1)
1878
1879 ls := &oneConnListener{conn}
1880 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
1881 if conn.readBuf.Len() < len(body)/2 {
1882 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
1883 }
1884 rw.WriteHeader(200)
1885 rw.(Flusher).Flush()
1886 if conn.readBuf.Len() < len(body)/2 {
1887 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
1888 }
1889 }))
1890 <-conn.closec
1891
1892 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
1893 t.Errorf("Expected a Connection: close header; got response: %s", res)
1894 }
1895 }
1896
1897 type handlerBodyCloseTest struct {
1898 bodySize int
1899 bodyChunked bool
1900 reqConnClose bool
1901
1902 wantEOFSearch bool
1903 wantNextReq bool
1904 }
1905
1906 func (t handlerBodyCloseTest) connectionHeader() string {
1907 if t.reqConnClose {
1908 return "Connection: close\r\n"
1909 }
1910 return ""
1911 }
1912
1913 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
1914
1915
1916 0: {
1917 bodySize: 20 << 10,
1918 bodyChunked: false,
1919 reqConnClose: false,
1920 wantEOFSearch: true,
1921 wantNextReq: true,
1922 },
1923
1924
1925
1926 1: {
1927 bodySize: 20 << 10,
1928 bodyChunked: true,
1929 reqConnClose: false,
1930 wantEOFSearch: true,
1931 wantNextReq: true,
1932 },
1933
1934
1935
1936
1937 2: {
1938 bodySize: 20 << 10,
1939 bodyChunked: false,
1940 reqConnClose: true,
1941 wantEOFSearch: false,
1942 wantNextReq: false,
1943 },
1944
1945
1946
1947
1948
1949
1950 3: {
1951 bodySize: 20 << 10,
1952 bodyChunked: true,
1953 reqConnClose: true,
1954 wantEOFSearch: true,
1955 wantNextReq: false,
1956 },
1957
1958
1959 4: {
1960 bodySize: 1 << 20,
1961 bodyChunked: false,
1962 reqConnClose: false,
1963 wantEOFSearch: false,
1964 wantNextReq: false,
1965 },
1966
1967
1968 5: {
1969 bodySize: 1 << 20,
1970 bodyChunked: true,
1971 reqConnClose: false,
1972 wantEOFSearch: true,
1973 wantNextReq: false,
1974 },
1975
1976
1977
1978
1979 6: {
1980 bodySize: 1 << 20,
1981 bodyChunked: true,
1982 reqConnClose: true,
1983 wantEOFSearch: true,
1984 wantNextReq: false,
1985 },
1986
1987
1988
1989 7: {
1990 bodySize: 1 << 20,
1991 bodyChunked: false,
1992 reqConnClose: true,
1993 wantEOFSearch: false,
1994 wantNextReq: false,
1995 },
1996 }
1997
1998 func TestHandlerBodyClose(t *testing.T) {
1999 setParallel(t)
2000 if testing.Short() && testenv.Builder() == "" {
2001 t.Skip("skipping in -short mode")
2002 }
2003 for i, tt := range handlerBodyCloseTests {
2004 testHandlerBodyClose(t, i, tt)
2005 }
2006 }
2007
2008 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2009 conn := new(testConn)
2010 body := strings.Repeat("x", tt.bodySize)
2011 if tt.bodyChunked {
2012 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2013 "Host: test\r\n" +
2014 tt.connectionHeader() +
2015 "Transfer-Encoding: chunked\r\n" +
2016 "\r\n")
2017 cw := internal.NewChunkedWriter(&conn.readBuf)
2018 io.WriteString(cw, body)
2019 cw.Close()
2020 conn.readBuf.WriteString("\r\n")
2021 } else {
2022 conn.readBuf.Write([]byte(fmt.Sprintf(
2023 "POST / HTTP/1.1\r\n"+
2024 "Host: test\r\n"+
2025 tt.connectionHeader()+
2026 "Content-Length: %d\r\n"+
2027 "\r\n", len(body))))
2028 conn.readBuf.Write([]byte(body))
2029 }
2030 if !tt.reqConnClose {
2031 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2032 }
2033 conn.closec = make(chan bool, 1)
2034
2035 readBufLen := func() int {
2036 conn.readMu.Lock()
2037 defer conn.readMu.Unlock()
2038 return conn.readBuf.Len()
2039 }
2040
2041 ls := &oneConnListener{conn}
2042 var numReqs int
2043 var size0, size1 int
2044 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2045 numReqs++
2046 if numReqs == 1 {
2047 size0 = readBufLen()
2048 req.Body.Close()
2049 size1 = readBufLen()
2050 }
2051 }))
2052 <-conn.closec
2053 if numReqs < 1 || numReqs > 2 {
2054 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2055 }
2056 didSearch := size0 != size1
2057 if didSearch != tt.wantEOFSearch {
2058 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2059 }
2060 if tt.wantNextReq && numReqs != 2 {
2061 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2062 }
2063 }
2064
2065
2066
2067 type testHandlerBodyConsumer struct {
2068 name string
2069 f func(io.ReadCloser)
2070 }
2071
2072 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2073 {"nil", func(io.ReadCloser) {}},
2074 {"close", func(r io.ReadCloser) { r.Close() }},
2075 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2076 }
2077
2078 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2079 setParallel(t)
2080 defer afterTest(t)
2081 for _, handler := range testHandlerBodyConsumers {
2082 conn := new(testConn)
2083 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2084 "Host: test\r\n" +
2085 "Transfer-Encoding: chunked\r\n" +
2086 "\r\n" +
2087 "hax\r\n" +
2088 "GET /secret HTTP/1.1\r\n" +
2089 "Host: test\r\n" +
2090 "\r\n")
2091
2092 conn.closec = make(chan bool, 1)
2093 ls := &oneConnListener{conn}
2094 var numReqs int
2095 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2096 numReqs++
2097 if strings.Contains(req.URL.Path, "secret") {
2098 t.Error("Request for /secret encountered, should not have happened.")
2099 }
2100 handler.f(req.Body)
2101 }))
2102 <-conn.closec
2103 if numReqs != 1 {
2104 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2105 }
2106 }
2107 }
2108
2109 func TestInvalidTrailerClosesConnection(t *testing.T) {
2110 setParallel(t)
2111 defer afterTest(t)
2112 for _, handler := range testHandlerBodyConsumers {
2113 conn := new(testConn)
2114 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2115 "Host: test\r\n" +
2116 "Trailer: hack\r\n" +
2117 "Transfer-Encoding: chunked\r\n" +
2118 "\r\n" +
2119 "3\r\n" +
2120 "hax\r\n" +
2121 "0\r\n" +
2122 "I'm not a valid trailer\r\n" +
2123 "GET /secret HTTP/1.1\r\n" +
2124 "Host: test\r\n" +
2125 "\r\n")
2126
2127 conn.closec = make(chan bool, 1)
2128 ln := &oneConnListener{conn}
2129 var numReqs int
2130 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2131 numReqs++
2132 if strings.Contains(req.URL.Path, "secret") {
2133 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2134 }
2135 handler.f(req.Body)
2136 }))
2137 <-conn.closec
2138 if numReqs != 1 {
2139 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2140 }
2141 }
2142 }
2143
2144
2145
2146
2147 type slowTestConn struct {
2148
2149 script []interface{}
2150 closec chan bool
2151
2152 mu sync.Mutex
2153 rd, wd time.Time
2154 noopConn
2155 }
2156
2157 func (c *slowTestConn) SetDeadline(t time.Time) error {
2158 c.SetReadDeadline(t)
2159 c.SetWriteDeadline(t)
2160 return nil
2161 }
2162
2163 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2164 c.mu.Lock()
2165 defer c.mu.Unlock()
2166 c.rd = t
2167 return nil
2168 }
2169
2170 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2171 c.mu.Lock()
2172 defer c.mu.Unlock()
2173 c.wd = t
2174 return nil
2175 }
2176
2177 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2178 c.mu.Lock()
2179 defer c.mu.Unlock()
2180 restart:
2181 if !c.rd.IsZero() && time.Now().After(c.rd) {
2182 return 0, syscall.ETIMEDOUT
2183 }
2184 if len(c.script) == 0 {
2185 return 0, io.EOF
2186 }
2187
2188 switch cue := c.script[0].(type) {
2189 case time.Duration:
2190 if !c.rd.IsZero() {
2191
2192
2193 if remaining := time.Until(c.rd); remaining < cue {
2194 c.script[0] = cue - remaining
2195 time.Sleep(remaining)
2196 return 0, syscall.ETIMEDOUT
2197 }
2198 }
2199 c.script = c.script[1:]
2200 time.Sleep(cue)
2201 goto restart
2202
2203 case string:
2204 n = copy(b, cue)
2205
2206 if len(cue) > n {
2207 c.script[0] = cue[n:]
2208 } else {
2209 c.script = c.script[1:]
2210 }
2211
2212 default:
2213 panic("unknown cue in slowTestConn script")
2214 }
2215
2216 return
2217 }
2218
2219 func (c *slowTestConn) Close() error {
2220 select {
2221 case c.closec <- true:
2222 default:
2223 }
2224 return nil
2225 }
2226
2227 func (c *slowTestConn) Write(b []byte) (int, error) {
2228 if !c.wd.IsZero() && time.Now().After(c.wd) {
2229 return 0, syscall.ETIMEDOUT
2230 }
2231 return len(b), nil
2232 }
2233
2234 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2235 if testing.Short() {
2236 t.Skip("skipping in -short mode")
2237 }
2238 defer afterTest(t)
2239 for _, handler := range testHandlerBodyConsumers {
2240 conn := &slowTestConn{
2241 script: []interface{}{
2242 "POST /public HTTP/1.1\r\n" +
2243 "Host: test\r\n" +
2244 "Content-Length: 10000\r\n" +
2245 "\r\n",
2246 "foo bar baz",
2247 600 * time.Millisecond,
2248 "GET /secret HTTP/1.1\r\n" +
2249 "Host: test\r\n" +
2250 "\r\n",
2251 },
2252 closec: make(chan bool, 1),
2253 }
2254 ls := &oneConnListener{conn}
2255
2256 var numReqs int
2257 s := Server{
2258 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2259 numReqs++
2260 if strings.Contains(req.URL.Path, "secret") {
2261 t.Error("Request for /secret encountered, should not have happened.")
2262 }
2263 handler.f(req.Body)
2264 }),
2265 ReadTimeout: 400 * time.Millisecond,
2266 }
2267 go s.Serve(ls)
2268 <-conn.closec
2269
2270 if numReqs != 1 {
2271 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2272 }
2273 }
2274 }
2275
2276 func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
2277 func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
2278 func testTimeoutHandler(t *testing.T, h2 bool) {
2279 setParallel(t)
2280 defer afterTest(t)
2281 sendHi := make(chan bool, 1)
2282 writeErrors := make(chan error, 1)
2283 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2284 <-sendHi
2285 _, werr := w.Write([]byte("hi"))
2286 writeErrors <- werr
2287 })
2288 timeout := make(chan time.Time, 1)
2289 cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout))
2290 defer cst.close()
2291
2292
2293 sendHi <- true
2294 res, err := cst.c.Get(cst.ts.URL)
2295 if err != nil {
2296 t.Error(err)
2297 }
2298 if g, e := res.StatusCode, StatusOK; g != e {
2299 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2300 }
2301 body, _ := io.ReadAll(res.Body)
2302 if g, e := string(body), "hi"; g != e {
2303 t.Errorf("got body %q; expected %q", g, e)
2304 }
2305 if g := <-writeErrors; g != nil {
2306 t.Errorf("got unexpected Write error on first request: %v", g)
2307 }
2308
2309
2310 timeout <- time.Time{}
2311 res, err = cst.c.Get(cst.ts.URL)
2312 if err != nil {
2313 t.Error(err)
2314 }
2315 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2316 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2317 }
2318 body, _ = io.ReadAll(res.Body)
2319 if !strings.Contains(string(body), "<title>Timeout</title>") {
2320 t.Errorf("expected timeout body; got %q", string(body))
2321 }
2322 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2323 t.Errorf("response content-type = %q; want %q", g, w)
2324 }
2325
2326
2327
2328 sendHi <- true
2329 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2330 t.Errorf("expected Write error of %v; got %v", e, g)
2331 }
2332 }
2333
2334
2335 func TestTimeoutHandlerRace(t *testing.T) {
2336 setParallel(t)
2337 defer afterTest(t)
2338
2339 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2340 ms, _ := strconv.Atoi(r.URL.Path[1:])
2341 if ms == 0 {
2342 ms = 1
2343 }
2344 for i := 0; i < ms; i++ {
2345 w.Write([]byte("hi"))
2346 time.Sleep(time.Millisecond)
2347 }
2348 })
2349
2350 ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
2351 defer ts.Close()
2352
2353 c := ts.Client()
2354
2355 var wg sync.WaitGroup
2356 gate := make(chan bool, 10)
2357 n := 50
2358 if testing.Short() {
2359 n = 10
2360 gate = make(chan bool, 3)
2361 }
2362 for i := 0; i < n; i++ {
2363 gate <- true
2364 wg.Add(1)
2365 go func() {
2366 defer wg.Done()
2367 defer func() { <-gate }()
2368 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2369 if err == nil {
2370 io.Copy(io.Discard, res.Body)
2371 res.Body.Close()
2372 }
2373 }()
2374 }
2375 wg.Wait()
2376 }
2377
2378
2379
2380 func TestTimeoutHandlerRaceHeader(t *testing.T) {
2381 setParallel(t)
2382 defer afterTest(t)
2383
2384 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2385 w.WriteHeader(204)
2386 })
2387
2388 ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, ""))
2389 defer ts.Close()
2390
2391 var wg sync.WaitGroup
2392 gate := make(chan bool, 50)
2393 n := 500
2394 if testing.Short() {
2395 n = 10
2396 }
2397
2398 c := ts.Client()
2399 for i := 0; i < n; i++ {
2400 gate <- true
2401 wg.Add(1)
2402 go func() {
2403 defer wg.Done()
2404 defer func() { <-gate }()
2405 res, err := c.Get(ts.URL)
2406 if err != nil {
2407
2408
2409 t.Log(err)
2410 return
2411 }
2412 defer res.Body.Close()
2413 io.Copy(io.Discard, res.Body)
2414 }()
2415 }
2416 wg.Wait()
2417 }
2418
2419
2420 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
2421 setParallel(t)
2422 defer afterTest(t)
2423 sendHi := make(chan bool, 1)
2424 writeErrors := make(chan error, 1)
2425 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2426 w.Header().Set("Content-Type", "text/plain")
2427 <-sendHi
2428 _, werr := w.Write([]byte("hi"))
2429 writeErrors <- werr
2430 })
2431 timeout := make(chan time.Time, 1)
2432 cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
2433 defer cst.close()
2434
2435
2436 sendHi <- true
2437 res, err := cst.c.Get(cst.ts.URL)
2438 if err != nil {
2439 t.Error(err)
2440 }
2441 if g, e := res.StatusCode, StatusOK; g != e {
2442 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2443 }
2444 body, _ := io.ReadAll(res.Body)
2445 if g, e := string(body), "hi"; g != e {
2446 t.Errorf("got body %q; expected %q", g, e)
2447 }
2448 if g := <-writeErrors; g != nil {
2449 t.Errorf("got unexpected Write error on first request: %v", g)
2450 }
2451
2452
2453 timeout <- time.Time{}
2454 res, err = cst.c.Get(cst.ts.URL)
2455 if err != nil {
2456 t.Error(err)
2457 }
2458 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2459 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2460 }
2461 body, _ = io.ReadAll(res.Body)
2462 if !strings.Contains(string(body), "<title>Timeout</title>") {
2463 t.Errorf("expected timeout body; got %q", string(body))
2464 }
2465
2466
2467
2468 sendHi <- true
2469 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2470 t.Errorf("expected Write error of %v; got %v", e, g)
2471 }
2472 }
2473
2474
2475 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2476 if testing.Short() {
2477 t.Skip("skipping sleeping test in -short mode")
2478 }
2479 defer afterTest(t)
2480 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2481 w.WriteHeader(StatusNoContent)
2482 }
2483 timeout := 300 * time.Millisecond
2484 ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
2485 defer ts.Close()
2486
2487 c := ts.Client()
2488
2489
2490
2491
2492 time.Sleep(2 * timeout)
2493 res, err := c.Get(ts.URL)
2494 if err != nil {
2495 t.Fatal(err)
2496 }
2497 defer res.Body.Close()
2498 if res.StatusCode != StatusNoContent {
2499 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2500 }
2501 }
2502
2503
2504 func TestTimeoutHandlerEmptyResponse(t *testing.T) {
2505 setParallel(t)
2506 defer afterTest(t)
2507 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2508
2509 }
2510 timeout := 300 * time.Millisecond
2511 ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
2512 defer ts.Close()
2513
2514 c := ts.Client()
2515
2516 res, err := c.Get(ts.URL)
2517 if err != nil {
2518 t.Fatal(err)
2519 }
2520 defer res.Body.Close()
2521 if res.StatusCode != StatusOK {
2522 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2523 }
2524 }
2525
2526
2527 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2528 wrapper := func(h Handler) Handler {
2529 return TimeoutHandler(h, time.Second, "")
2530 }
2531 testHandlerPanic(t, false, false, wrapper, "intentional death for testing")
2532 }
2533
2534 func TestRedirectBadPath(t *testing.T) {
2535
2536
2537 rr := httptest.NewRecorder()
2538 req := &Request{
2539 Method: "GET",
2540 URL: &url.URL{
2541 Scheme: "http",
2542 Path: "not-empty-but-no-leading-slash",
2543 },
2544 }
2545 Redirect(rr, req, "", 304)
2546 if rr.Code != 304 {
2547 t.Errorf("Code = %d; want 304", rr.Code)
2548 }
2549 }
2550
2551
2552 func TestRedirect(t *testing.T) {
2553 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2554
2555 var tests = []struct {
2556 in string
2557 want string
2558 }{
2559
2560 {"http://foobar.com/baz", "http://foobar.com/baz"},
2561
2562 {"https://foobar.com/baz", "https://foobar.com/baz"},
2563
2564 {"test://foobar.com/baz", "test://foobar.com/baz"},
2565
2566 {"//foobar.com/baz", "//foobar.com/baz"},
2567
2568 {"/foobar.com/baz", "/foobar.com/baz"},
2569
2570 {"foobar.com/baz", "/qux/foobar.com/baz"},
2571
2572 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2573
2574 {"///foobar.com/baz", "/foobar.com/baz"},
2575
2576
2577 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2578 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2579 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2580
2581 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2582 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2583 }
2584
2585 for _, tt := range tests {
2586 rec := httptest.NewRecorder()
2587 Redirect(rec, req, tt.in, 302)
2588 if got, want := rec.Code, 302; got != want {
2589 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2590 }
2591 if got := rec.Header().Get("Location"); got != tt.want {
2592 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2593 }
2594 }
2595 }
2596
2597
2598
2599 func TestRedirectContentTypeAndBody(t *testing.T) {
2600 type ctHeader struct {
2601 Values []string
2602 }
2603
2604 var tests = []struct {
2605 method string
2606 ct *ctHeader
2607 wantCT string
2608 wantBody string
2609 }{
2610 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2611 {MethodHead, nil, "text/html; charset=utf-8", ""},
2612 {MethodPost, nil, "", ""},
2613 {MethodDelete, nil, "", ""},
2614 {"foo", nil, "", ""},
2615 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2616 {MethodGet, &ctHeader{[]string{}}, "", ""},
2617 {MethodGet, &ctHeader{nil}, "", ""},
2618 }
2619 for _, tt := range tests {
2620 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2621 rec := httptest.NewRecorder()
2622 if tt.ct != nil {
2623 rec.Header()["Content-Type"] = tt.ct.Values
2624 }
2625 Redirect(rec, req, "/foo", 302)
2626 if got, want := rec.Code, 302; got != want {
2627 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2628 }
2629 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2630 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2631 }
2632 resp := rec.Result()
2633 body, err := io.ReadAll(resp.Body)
2634 if err != nil {
2635 t.Fatal(err)
2636 }
2637 if got, want := string(body), tt.wantBody; got != want {
2638 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2639 }
2640 }
2641 }
2642
2643
2644
2645
2646
2647
2648
2649 func TestZeroLengthPostAndResponse_h1(t *testing.T) {
2650 testZeroLengthPostAndResponse(t, h1Mode)
2651 }
2652 func TestZeroLengthPostAndResponse_h2(t *testing.T) {
2653 testZeroLengthPostAndResponse(t, h2Mode)
2654 }
2655
2656 func testZeroLengthPostAndResponse(t *testing.T, h2 bool) {
2657 setParallel(t)
2658 defer afterTest(t)
2659 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) {
2660 all, err := io.ReadAll(r.Body)
2661 if err != nil {
2662 t.Fatalf("handler ReadAll: %v", err)
2663 }
2664 if len(all) != 0 {
2665 t.Errorf("handler got %d bytes; expected 0", len(all))
2666 }
2667 rw.Header().Set("Content-Length", "0")
2668 }))
2669 defer cst.close()
2670
2671 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2672 if err != nil {
2673 t.Fatal(err)
2674 }
2675 req.ContentLength = 0
2676
2677 var resp [5]*Response
2678 for i := range resp {
2679 resp[i], err = cst.c.Do(req)
2680 if err != nil {
2681 t.Fatalf("client post #%d: %v", i, err)
2682 }
2683 }
2684
2685 for i := range resp {
2686 all, err := io.ReadAll(resp[i].Body)
2687 if err != nil {
2688 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2689 }
2690 if len(all) != 0 {
2691 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2692 }
2693 }
2694 }
2695
2696 func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil, nil) }
2697 func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil, nil) }
2698
2699 func TestHandlerPanic_h1(t *testing.T) {
2700 testHandlerPanic(t, false, h1Mode, nil, "intentional death for testing")
2701 }
2702 func TestHandlerPanic_h2(t *testing.T) {
2703 testHandlerPanic(t, false, h2Mode, nil, "intentional death for testing")
2704 }
2705
2706 func TestHandlerPanicWithHijack(t *testing.T) {
2707
2708 testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing")
2709 }
2710
2711 func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue interface{}) {
2712 defer afterTest(t)
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729 pr, pw := io.Pipe()
2730 log.SetOutput(pw)
2731 defer log.SetOutput(os.Stderr)
2732 defer pw.Close()
2733
2734 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
2735 if withHijack {
2736 rwc, _, err := w.(Hijacker).Hijack()
2737 if err != nil {
2738 t.Logf("unexpected error: %v", err)
2739 }
2740 defer rwc.Close()
2741 }
2742 panic(panicValue)
2743 })
2744 if wrapper != nil {
2745 handler = wrapper(handler)
2746 }
2747 cst := newClientServerTest(t, h2, handler)
2748 defer cst.close()
2749
2750
2751
2752
2753 done := make(chan bool, 1)
2754 go func() {
2755 buf := make([]byte, 4<<10)
2756 _, err := pr.Read(buf)
2757 pr.Close()
2758 if err != nil && err != io.EOF {
2759 t.Error(err)
2760 }
2761 done <- true
2762 }()
2763
2764 _, err := cst.c.Get(cst.ts.URL)
2765 if err == nil {
2766 t.Logf("expected an error")
2767 }
2768
2769 if panicValue == nil {
2770 return
2771 }
2772
2773 select {
2774 case <-done:
2775 return
2776 case <-time.After(5 * time.Second):
2777 t.Fatal("expected server handler to log an error")
2778 }
2779 }
2780
2781 type terrorWriter struct{ t *testing.T }
2782
2783 func (w terrorWriter) Write(p []byte) (int, error) {
2784 w.t.Errorf("%s", p)
2785 return len(p), nil
2786 }
2787
2788
2789
2790 func TestServerWriteHijackZeroBytes(t *testing.T) {
2791 defer afterTest(t)
2792 done := make(chan struct{})
2793 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
2794 defer close(done)
2795 w.(Flusher).Flush()
2796 conn, _, err := w.(Hijacker).Hijack()
2797 if err != nil {
2798 t.Errorf("Hijack: %v", err)
2799 return
2800 }
2801 defer conn.Close()
2802 _, err = w.Write(nil)
2803 if err != ErrHijacked {
2804 t.Errorf("Write error = %v; want ErrHijacked", err)
2805 }
2806 }))
2807 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
2808 ts.Start()
2809 defer ts.Close()
2810
2811 c := ts.Client()
2812 res, err := c.Get(ts.URL)
2813 if err != nil {
2814 t.Fatal(err)
2815 }
2816 res.Body.Close()
2817 select {
2818 case <-done:
2819 case <-time.After(5 * time.Second):
2820 t.Fatal("timeout")
2821 }
2822 }
2823
2824 func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") }
2825 func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") }
2826 func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") }
2827 func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") }
2828
2829 func testServerNoHeader(t *testing.T, h2 bool, header string) {
2830 setParallel(t)
2831 defer afterTest(t)
2832 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
2833 w.Header()[header] = nil
2834 io.WriteString(w, "<html>foo</html>")
2835 }))
2836 defer cst.close()
2837 res, err := cst.c.Get(cst.ts.URL)
2838 if err != nil {
2839 t.Fatal(err)
2840 }
2841 res.Body.Close()
2842 if got, ok := res.Header[header]; ok {
2843 t.Fatalf("Expected no %s header; got %q", header, got)
2844 }
2845 }
2846
2847 func TestStripPrefix(t *testing.T) {
2848 setParallel(t)
2849 defer afterTest(t)
2850 h := HandlerFunc(func(w ResponseWriter, r *Request) {
2851 w.Header().Set("X-Path", r.URL.Path)
2852 w.Header().Set("X-RawPath", r.URL.RawPath)
2853 })
2854 ts := httptest.NewServer(StripPrefix("/foo/bar", h))
2855 defer ts.Close()
2856
2857 c := ts.Client()
2858
2859 cases := []struct {
2860 reqPath string
2861 path string
2862 rawPath string
2863 }{
2864 {"/foo/bar/qux", "/qux", ""},
2865 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
2866 {"/foo%2Fbar/qux", "", ""},
2867 {"/bar", "", ""},
2868 }
2869 for _, tc := range cases {
2870 t.Run(tc.reqPath, func(t *testing.T) {
2871 res, err := c.Get(ts.URL + tc.reqPath)
2872 if err != nil {
2873 t.Fatal(err)
2874 }
2875 res.Body.Close()
2876 if tc.path == "" {
2877 if res.StatusCode != StatusNotFound {
2878 t.Errorf("got %q, want 404 Not Found", res.Status)
2879 }
2880 return
2881 }
2882 if res.StatusCode != StatusOK {
2883 t.Fatalf("got %q, want 200 OK", res.Status)
2884 }
2885 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
2886 t.Errorf("got Path %q, want %q", g, w)
2887 }
2888 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
2889 t.Errorf("got RawPath %q, want %q", g, w)
2890 }
2891 })
2892 }
2893 }
2894
2895
2896 func TestStripPrefixNotModifyRequest(t *testing.T) {
2897 h := StripPrefix("/foo", NotFoundHandler())
2898 req := httptest.NewRequest("GET", "/foo/bar", nil)
2899 h.ServeHTTP(httptest.NewRecorder(), req)
2900 if req.URL.Path != "/foo/bar" {
2901 t.Errorf("StripPrefix should not modify the provided Request, but it did")
2902 }
2903 }
2904
2905 func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) }
2906 func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) }
2907 func testRequestLimit(t *testing.T, h2 bool) {
2908 setParallel(t)
2909 defer afterTest(t)
2910 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
2911 t.Fatalf("didn't expect to get request in Handler")
2912 }), optQuietLog)
2913 defer cst.close()
2914 req, _ := NewRequest("GET", cst.ts.URL, nil)
2915 var bytesPerHeader = len("header12345: val12345\r\n")
2916 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
2917 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
2918 }
2919 res, err := cst.c.Do(req)
2920 if res != nil {
2921 defer res.Body.Close()
2922 }
2923 if h2 {
2924
2925
2926
2927
2928 if err == nil && res.StatusCode != 431 {
2929 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
2930 }
2931 } else {
2932
2933
2934
2935
2936 if err != nil {
2937 t.Fatalf("Do: %v", err)
2938 }
2939 if res.StatusCode != 431 {
2940 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
2941 }
2942 }
2943 }
2944
2945 type neverEnding byte
2946
2947 func (b neverEnding) Read(p []byte) (n int, err error) {
2948 for i := range p {
2949 p[i] = byte(b)
2950 }
2951 return len(p), nil
2952 }
2953
2954 type countReader struct {
2955 r io.Reader
2956 n *int64
2957 }
2958
2959 func (cr countReader) Read(p []byte) (n int, err error) {
2960 n, err = cr.r.Read(p)
2961 atomic.AddInt64(cr.n, int64(n))
2962 return
2963 }
2964
2965 func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) }
2966 func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) }
2967 func testRequestBodyLimit(t *testing.T, h2 bool) {
2968 setParallel(t)
2969 defer afterTest(t)
2970 const limit = 1 << 20
2971 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
2972 r.Body = MaxBytesReader(w, r.Body, limit)
2973 n, err := io.Copy(io.Discard, r.Body)
2974 if err == nil {
2975 t.Errorf("expected error from io.Copy")
2976 }
2977 if n != limit {
2978 t.Errorf("io.Copy = %d, want %d", n, limit)
2979 }
2980 }))
2981 defer cst.close()
2982
2983 nWritten := new(int64)
2984 req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995 _, _ = cst.c.Do(req)
2996
2997 if atomic.LoadInt64(nWritten) > limit*100 {
2998 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
2999 limit, nWritten)
3000 }
3001 }
3002
3003
3004
3005 func TestClientWriteShutdown(t *testing.T) {
3006 if runtime.GOOS == "plan9" {
3007 t.Skip("skipping test; see https://golang.org/issue/17906")
3008 }
3009 defer afterTest(t)
3010 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
3011 defer ts.Close()
3012 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3013 if err != nil {
3014 t.Fatalf("Dial: %v", err)
3015 }
3016 err = conn.(*net.TCPConn).CloseWrite()
3017 if err != nil {
3018 t.Fatalf("CloseWrite: %v", err)
3019 }
3020 donec := make(chan bool)
3021 go func() {
3022 defer close(donec)
3023 bs, err := io.ReadAll(conn)
3024 if err != nil {
3025 t.Errorf("ReadAll: %v", err)
3026 }
3027 got := string(bs)
3028 if got != "" {
3029 t.Errorf("read %q from server; want nothing", got)
3030 }
3031 }()
3032 select {
3033 case <-donec:
3034 case <-time.After(10 * time.Second):
3035 t.Fatalf("timeout")
3036 }
3037 }
3038
3039
3040
3041 func TestServerBufferedChunking(t *testing.T) {
3042 conn := new(testConn)
3043 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3044 conn.closec = make(chan bool, 1)
3045 ls := &oneConnListener{conn}
3046 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3047 rw.(Flusher).Flush()
3048 rw.Write([]byte{'x'})
3049 rw.Write([]byte{'y'})
3050 rw.Write([]byte{'z'})
3051 }))
3052 <-conn.closec
3053 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3054 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3055 conn.writeBuf.Bytes())
3056 }
3057 }
3058
3059
3060
3061
3062
3063 func TestServerGracefulClose(t *testing.T) {
3064 setParallel(t)
3065 defer afterTest(t)
3066 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3067 Error(w, "bye", StatusUnauthorized)
3068 }))
3069 defer ts.Close()
3070
3071 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3072 if err != nil {
3073 t.Fatal(err)
3074 }
3075 defer conn.Close()
3076 const bodySize = 5 << 20
3077 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3078 for i := 0; i < bodySize; i++ {
3079 req = append(req, 'x')
3080 }
3081 writeErr := make(chan error)
3082 go func() {
3083 _, err := conn.Write(req)
3084 writeErr <- err
3085 }()
3086 br := bufio.NewReader(conn)
3087 lineNum := 0
3088 for {
3089 line, err := br.ReadString('\n')
3090 if err == io.EOF {
3091 break
3092 }
3093 if err != nil {
3094 t.Fatalf("ReadLine: %v", err)
3095 }
3096 lineNum++
3097 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3098 t.Errorf("Response line = %q; want a 401", line)
3099 }
3100 }
3101
3102
3103
3104 <-writeErr
3105 }
3106
3107 func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) }
3108 func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) }
3109 func testCaseSensitiveMethod(t *testing.T, h2 bool) {
3110 defer afterTest(t)
3111 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
3112 if r.Method != "get" {
3113 t.Errorf(`Got method %q; want "get"`, r.Method)
3114 }
3115 }))
3116 defer cst.close()
3117 req, _ := NewRequest("get", cst.ts.URL, nil)
3118 res, err := cst.c.Do(req)
3119 if err != nil {
3120 t.Error(err)
3121 return
3122 }
3123
3124 res.Body.Close()
3125 }
3126
3127
3128
3129
3130
3131 func TestContentLengthZero(t *testing.T) {
3132 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {}))
3133 defer ts.Close()
3134
3135 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3136 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3137 if err != nil {
3138 t.Fatalf("error dialing: %v", err)
3139 }
3140 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3141 if err != nil {
3142 t.Fatalf("error writing: %v", err)
3143 }
3144 req, _ := NewRequest("GET", "/", nil)
3145 res, err := ReadResponse(bufio.NewReader(conn), req)
3146 if err != nil {
3147 t.Fatalf("error reading response: %v", err)
3148 }
3149 if te := res.TransferEncoding; len(te) > 0 {
3150 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3151 }
3152 if cl := res.ContentLength; cl != 0 {
3153 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3154 }
3155 conn.Close()
3156 }
3157 }
3158
3159 func TestCloseNotifier(t *testing.T) {
3160 defer afterTest(t)
3161 gotReq := make(chan bool, 1)
3162 sawClose := make(chan bool, 1)
3163 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
3164 gotReq <- true
3165 cc := rw.(CloseNotifier).CloseNotify()
3166 <-cc
3167 sawClose <- true
3168 }))
3169 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3170 if err != nil {
3171 t.Fatalf("error dialing: %v", err)
3172 }
3173 diec := make(chan bool)
3174 go func() {
3175 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3176 if err != nil {
3177 t.Error(err)
3178 return
3179 }
3180 <-diec
3181 conn.Close()
3182 }()
3183 For:
3184 for {
3185 select {
3186 case <-gotReq:
3187 diec <- true
3188 case <-sawClose:
3189 break For
3190 case <-time.After(5 * time.Second):
3191 t.Fatal("timeout")
3192 }
3193 }
3194 ts.Close()
3195 }
3196
3197
3198
3199
3200
3201 func TestCloseNotifierPipelined(t *testing.T) {
3202 setParallel(t)
3203 defer afterTest(t)
3204 gotReq := make(chan bool, 2)
3205 sawClose := make(chan bool, 2)
3206 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
3207 gotReq <- true
3208 cc := rw.(CloseNotifier).CloseNotify()
3209 select {
3210 case <-cc:
3211 t.Error("unexpected CloseNotify")
3212 case <-time.After(100 * time.Millisecond):
3213 }
3214 sawClose <- true
3215 }))
3216 defer ts.Close()
3217 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3218 if err != nil {
3219 t.Fatalf("error dialing: %v", err)
3220 }
3221 diec := make(chan bool, 1)
3222 defer close(diec)
3223 go func() {
3224 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3225 _, err = io.WriteString(conn, req+req)
3226 if err != nil {
3227 t.Error(err)
3228 return
3229 }
3230 <-diec
3231 conn.Close()
3232 }()
3233 reqs := 0
3234 closes := 0
3235 for {
3236 select {
3237 case <-gotReq:
3238 reqs++
3239 if reqs > 2 {
3240 t.Fatal("too many requests")
3241 }
3242 case <-sawClose:
3243 closes++
3244 if closes > 1 {
3245 return
3246 }
3247 case <-time.After(5 * time.Second):
3248 ts.CloseClientConnections()
3249 t.Fatal("timeout")
3250 }
3251 }
3252 }
3253
3254 func TestCloseNotifierChanLeak(t *testing.T) {
3255 defer afterTest(t)
3256 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3257 for i := 0; i < 20; i++ {
3258 var output bytes.Buffer
3259 conn := &rwTestConn{
3260 Reader: bytes.NewReader(req),
3261 Writer: &output,
3262 closec: make(chan bool, 1),
3263 }
3264 ln := &oneConnListener{conn: conn}
3265 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3266
3267
3268
3269 _ = rw.(CloseNotifier).CloseNotify()
3270 })
3271 go Serve(ln, handler)
3272 <-conn.closec
3273 }
3274 }
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285 func TestHijackAfterCloseNotifier(t *testing.T) {
3286 defer afterTest(t)
3287 script := make(chan string, 2)
3288 script <- "closenotify"
3289 script <- "hijack"
3290 close(script)
3291 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3292 plan := <-script
3293 switch plan {
3294 default:
3295 panic("bogus plan; too many requests")
3296 case "closenotify":
3297 w.(CloseNotifier).CloseNotify()
3298 w.Header().Set("X-Addr", r.RemoteAddr)
3299 case "hijack":
3300 c, _, err := w.(Hijacker).Hijack()
3301 if err != nil {
3302 t.Errorf("Hijack in Handler: %v", err)
3303 return
3304 }
3305 if _, ok := c.(*net.TCPConn); !ok {
3306
3307
3308 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3309 }
3310 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3311 c.Close()
3312 return
3313 }
3314 }))
3315 defer ts.Close()
3316 res1, err := Get(ts.URL)
3317 if err != nil {
3318 log.Fatal(err)
3319 }
3320 res2, err := Get(ts.URL)
3321 if err != nil {
3322 log.Fatal(err)
3323 }
3324 addr1 := res1.Header.Get("X-Addr")
3325 addr2 := res2.Header.Get("X-Addr")
3326 if addr1 == "" || addr1 != addr2 {
3327 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3328 }
3329 }
3330
3331 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3332 setParallel(t)
3333 defer afterTest(t)
3334 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3335 bodyOkay := make(chan bool, 1)
3336 gotCloseNotify := make(chan bool, 1)
3337 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
3338 defer close(bodyOkay)
3339
3340 reqBody := r.Body
3341 r.Body = nil
3342
3343 gone := w.(CloseNotifier).CloseNotify()
3344 slurp, err := io.ReadAll(reqBody)
3345 if err != nil {
3346 t.Errorf("Body read: %v", err)
3347 return
3348 }
3349 if len(slurp) != len(requestBody) {
3350 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3351 return
3352 }
3353 if !bytes.Equal(slurp, requestBody) {
3354 t.Error("Backend read wrong request body.")
3355 return
3356 }
3357 bodyOkay <- true
3358 select {
3359 case <-gone:
3360 gotCloseNotify <- true
3361 case <-time.After(5 * time.Second):
3362 gotCloseNotify <- false
3363 }
3364 }))
3365 defer ts.Close()
3366
3367 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3368 if err != nil {
3369 t.Fatal(err)
3370 }
3371 defer conn.Close()
3372
3373 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3374 len(requestBody), requestBody)
3375 if !<-bodyOkay {
3376
3377 return
3378 }
3379 conn.Close()
3380 if !<-gotCloseNotify {
3381 t.Error("timeout waiting for CloseNotify")
3382 }
3383 }
3384
3385 func TestOptions(t *testing.T) {
3386 uric := make(chan string, 2)
3387 mux := NewServeMux()
3388 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3389 uric <- r.RequestURI
3390 })
3391 ts := httptest.NewServer(mux)
3392 defer ts.Close()
3393
3394 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3395 if err != nil {
3396 t.Fatal(err)
3397 }
3398 defer conn.Close()
3399
3400
3401 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3402 if err != nil {
3403 t.Fatal(err)
3404 }
3405 br := bufio.NewReader(conn)
3406 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3407 if err != nil {
3408 t.Fatal(err)
3409 }
3410 if res.StatusCode != 200 {
3411 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3412 }
3413
3414
3415 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3416 if err != nil {
3417 t.Fatal(err)
3418 }
3419 res, err = ReadResponse(br, &Request{Method: "GET"})
3420 if err != nil {
3421 t.Fatal(err)
3422 }
3423 if res.StatusCode != 400 {
3424 t.Errorf("Got non-400 response to GET *: %#v", res)
3425 }
3426
3427 res, err = Get(ts.URL + "/second")
3428 if err != nil {
3429 t.Fatal(err)
3430 }
3431 res.Body.Close()
3432 if got := <-uric; got != "/second" {
3433 t.Errorf("Handler saw request for %q; want /second", got)
3434 }
3435 }
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446 func TestHeaderToWire(t *testing.T) {
3447 tests := []struct {
3448 name string
3449 handler func(ResponseWriter, *Request)
3450 check func(got, logs string) error
3451 }{
3452 {
3453 name: "write without Header",
3454 handler: func(rw ResponseWriter, r *Request) {
3455 rw.Write([]byte("hello world"))
3456 },
3457 check: func(got, logs string) error {
3458 if !strings.Contains(got, "Content-Length:") {
3459 return errors.New("no content-length")
3460 }
3461 if !strings.Contains(got, "Content-Type: text/plain") {
3462 return errors.New("no content-type")
3463 }
3464 return nil
3465 },
3466 },
3467 {
3468 name: "Header mutation before write",
3469 handler: func(rw ResponseWriter, r *Request) {
3470 h := rw.Header()
3471 h.Set("Content-Type", "some/type")
3472 rw.Write([]byte("hello world"))
3473 h.Set("Too-Late", "bogus")
3474 },
3475 check: func(got, logs string) error {
3476 if !strings.Contains(got, "Content-Length:") {
3477 return errors.New("no content-length")
3478 }
3479 if !strings.Contains(got, "Content-Type: some/type") {
3480 return errors.New("wrong content-type")
3481 }
3482 if strings.Contains(got, "Too-Late") {
3483 return errors.New("don't want too-late header")
3484 }
3485 return nil
3486 },
3487 },
3488 {
3489 name: "write then useless Header mutation",
3490 handler: func(rw ResponseWriter, r *Request) {
3491 rw.Write([]byte("hello world"))
3492 rw.Header().Set("Too-Late", "Write already wrote headers")
3493 },
3494 check: func(got, logs string) error {
3495 if strings.Contains(got, "Too-Late") {
3496 return errors.New("header appeared from after WriteHeader")
3497 }
3498 return nil
3499 },
3500 },
3501 {
3502 name: "flush then write",
3503 handler: func(rw ResponseWriter, r *Request) {
3504 rw.(Flusher).Flush()
3505 rw.Write([]byte("post-flush"))
3506 rw.Header().Set("Too-Late", "Write already wrote headers")
3507 },
3508 check: func(got, logs string) error {
3509 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3510 return errors.New("not chunked")
3511 }
3512 if strings.Contains(got, "Too-Late") {
3513 return errors.New("header appeared from after WriteHeader")
3514 }
3515 return nil
3516 },
3517 },
3518 {
3519 name: "header then flush",
3520 handler: func(rw ResponseWriter, r *Request) {
3521 rw.Header().Set("Content-Type", "some/type")
3522 rw.(Flusher).Flush()
3523 rw.Write([]byte("post-flush"))
3524 rw.Header().Set("Too-Late", "Write already wrote headers")
3525 },
3526 check: func(got, logs string) error {
3527 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3528 return errors.New("not chunked")
3529 }
3530 if strings.Contains(got, "Too-Late") {
3531 return errors.New("header appeared from after WriteHeader")
3532 }
3533 if !strings.Contains(got, "Content-Type: some/type") {
3534 return errors.New("wrong content-type")
3535 }
3536 return nil
3537 },
3538 },
3539 {
3540 name: "sniff-on-first-write content-type",
3541 handler: func(rw ResponseWriter, r *Request) {
3542 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3543 rw.Header().Set("Content-Type", "x/wrong")
3544 },
3545 check: func(got, logs string) error {
3546 if !strings.Contains(got, "Content-Type: text/html") {
3547 return errors.New("wrong content-type; want html")
3548 }
3549 return nil
3550 },
3551 },
3552 {
3553 name: "explicit content-type wins",
3554 handler: func(rw ResponseWriter, r *Request) {
3555 rw.Header().Set("Content-Type", "some/type")
3556 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3557 },
3558 check: func(got, logs string) error {
3559 if !strings.Contains(got, "Content-Type: some/type") {
3560 return errors.New("wrong content-type; want html")
3561 }
3562 return nil
3563 },
3564 },
3565 {
3566 name: "empty handler",
3567 handler: func(rw ResponseWriter, r *Request) {
3568 },
3569 check: func(got, logs string) error {
3570 if !strings.Contains(got, "Content-Length: 0") {
3571 return errors.New("want 0 content-length")
3572 }
3573 return nil
3574 },
3575 },
3576 {
3577 name: "only Header, no write",
3578 handler: func(rw ResponseWriter, r *Request) {
3579 rw.Header().Set("Some-Header", "some-value")
3580 },
3581 check: func(got, logs string) error {
3582 if !strings.Contains(got, "Some-Header") {
3583 return errors.New("didn't get header")
3584 }
3585 return nil
3586 },
3587 },
3588 {
3589 name: "WriteHeader call",
3590 handler: func(rw ResponseWriter, r *Request) {
3591 rw.WriteHeader(404)
3592 rw.Header().Set("Too-Late", "some-value")
3593 },
3594 check: func(got, logs string) error {
3595 if !strings.Contains(got, "404") {
3596 return errors.New("wrong status")
3597 }
3598 if strings.Contains(got, "Too-Late") {
3599 return errors.New("shouldn't have seen Too-Late")
3600 }
3601 return nil
3602 },
3603 },
3604 }
3605 for _, tc := range tests {
3606 ht := newHandlerTest(HandlerFunc(tc.handler))
3607 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3608 logs := ht.logbuf.String()
3609 if err := tc.check(got, logs); err != nil {
3610 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3611 }
3612 }
3613 }
3614
3615 type errorListener struct {
3616 errs []error
3617 }
3618
3619 func (l *errorListener) Accept() (c net.Conn, err error) {
3620 if len(l.errs) == 0 {
3621 return nil, io.EOF
3622 }
3623 err = l.errs[0]
3624 l.errs = l.errs[1:]
3625 return
3626 }
3627
3628 func (l *errorListener) Close() error {
3629 return nil
3630 }
3631
3632 func (l *errorListener) Addr() net.Addr {
3633 return dummyAddr("test-address")
3634 }
3635
3636 func TestAcceptMaxFds(t *testing.T) {
3637 setParallel(t)
3638
3639 ln := &errorListener{[]error{
3640 &net.OpError{
3641 Op: "accept",
3642 Err: syscall.EMFILE,
3643 }}}
3644 server := &Server{
3645 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3646 ErrorLog: log.New(io.Discard, "", 0),
3647 }
3648 err := server.Serve(ln)
3649 if err != io.EOF {
3650 t.Errorf("got error %v, want EOF", err)
3651 }
3652 }
3653
3654 func TestWriteAfterHijack(t *testing.T) {
3655 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3656 var buf bytes.Buffer
3657 wrotec := make(chan bool, 1)
3658 conn := &rwTestConn{
3659 Reader: bytes.NewReader(req),
3660 Writer: &buf,
3661 closec: make(chan bool, 1),
3662 }
3663 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3664 conn, bufrw, err := rw.(Hijacker).Hijack()
3665 if err != nil {
3666 t.Error(err)
3667 return
3668 }
3669 go func() {
3670 bufrw.Write([]byte("[hijack-to-bufw]"))
3671 bufrw.Flush()
3672 conn.Write([]byte("[hijack-to-conn]"))
3673 conn.Close()
3674 wrotec <- true
3675 }()
3676 })
3677 ln := &oneConnListener{conn: conn}
3678 go Serve(ln, handler)
3679 <-conn.closec
3680 <-wrotec
3681 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
3682 t.Errorf("wrote %q; want %q", g, w)
3683 }
3684 }
3685
3686 func TestDoubleHijack(t *testing.T) {
3687 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3688 var buf bytes.Buffer
3689 conn := &rwTestConn{
3690 Reader: bytes.NewReader(req),
3691 Writer: &buf,
3692 closec: make(chan bool, 1),
3693 }
3694 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3695 conn, _, err := rw.(Hijacker).Hijack()
3696 if err != nil {
3697 t.Error(err)
3698 return
3699 }
3700 _, _, err = rw.(Hijacker).Hijack()
3701 if err == nil {
3702 t.Errorf("got err = nil; want err != nil")
3703 }
3704 conn.Close()
3705 })
3706 ln := &oneConnListener{conn: conn}
3707 go Serve(ln, handler)
3708 <-conn.closec
3709 }
3710
3711
3712
3713
3714
3715
3716
3717 func TestHTTP10ConnectionHeader(t *testing.T) {
3718 defer afterTest(t)
3719
3720 mux := NewServeMux()
3721 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
3722 ts := httptest.NewServer(mux)
3723 defer ts.Close()
3724
3725
3726 tests := []struct {
3727 req string
3728 expect []string
3729 }{
3730 {
3731 req: "GET / HTTP/1.0\r\n\r\n",
3732 expect: nil,
3733 },
3734 {
3735 req: "OPTIONS * HTTP/1.0\r\n\r\n",
3736 expect: nil,
3737 },
3738 {
3739 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
3740 expect: []string{"keep-alive"},
3741 },
3742 }
3743
3744 for _, tt := range tests {
3745 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3746 if err != nil {
3747 t.Fatal("dial err:", err)
3748 }
3749
3750 _, err = fmt.Fprint(conn, tt.req)
3751 if err != nil {
3752 t.Fatal("conn write err:", err)
3753 }
3754
3755 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
3756 if err != nil {
3757 t.Fatal("ReadResponse err:", err)
3758 }
3759 conn.Close()
3760 resp.Body.Close()
3761
3762 got := resp.Header["Connection"]
3763 if !reflect.DeepEqual(got, tt.expect) {
3764 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
3765 }
3766 }
3767 }
3768
3769
3770 func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) }
3771 func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) }
3772 func testServerReaderFromOrder(t *testing.T, h2 bool) {
3773 setParallel(t)
3774 defer afterTest(t)
3775 pr, pw := io.Pipe()
3776 const size = 3 << 20
3777 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
3778 rw.Header().Set("Content-Type", "text/plain")
3779 done := make(chan bool)
3780 go func() {
3781 io.Copy(rw, pr)
3782 close(done)
3783 }()
3784 time.Sleep(25 * time.Millisecond)
3785 n, err := io.Copy(io.Discard, req.Body)
3786 if err != nil {
3787 t.Errorf("handler Copy: %v", err)
3788 return
3789 }
3790 if n != size {
3791 t.Errorf("handler Copy = %d; want %d", n, size)
3792 }
3793 pw.Write([]byte("hi"))
3794 pw.Close()
3795 <-done
3796 }))
3797 defer cst.close()
3798
3799 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
3800 if err != nil {
3801 t.Fatal(err)
3802 }
3803 res, err := cst.c.Do(req)
3804 if err != nil {
3805 t.Fatal(err)
3806 }
3807 all, err := io.ReadAll(res.Body)
3808 if err != nil {
3809 t.Fatal(err)
3810 }
3811 res.Body.Close()
3812 if string(all) != "hi" {
3813 t.Errorf("Body = %q; want hi", all)
3814 }
3815 }
3816
3817
3818 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
3819 for _, code := range []int{StatusNotModified, StatusNoContent, StatusContinue} {
3820 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
3821 if r.URL.Path == "/header" {
3822 w.Header().Set("Content-Length", "123")
3823 }
3824 w.WriteHeader(code)
3825 if r.URL.Path == "/more" {
3826 w.Write([]byte("stuff"))
3827 }
3828 }))
3829 for _, req := range []string{
3830 "GET / HTTP/1.0",
3831 "GET /header HTTP/1.0",
3832 "GET /more HTTP/1.0",
3833 "GET / HTTP/1.1\nHost: foo",
3834 "GET /header HTTP/1.1\nHost: foo",
3835 "GET /more HTTP/1.1\nHost: foo",
3836 } {
3837 got := ht.rawResponse(req)
3838 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
3839 if !strings.Contains(got, wantStatus) {
3840 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
3841 } else if strings.Contains(got, "Content-Length") {
3842 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
3843 } else if strings.Contains(got, "stuff") {
3844 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
3845 }
3846 }
3847 }
3848 }
3849
3850 func TestContentTypeOkayOn204(t *testing.T) {
3851 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
3852 w.Header().Set("Content-Length", "123")
3853 w.Header().Set("Content-Type", "foo/bar")
3854 w.WriteHeader(204)
3855 }))
3856 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
3857 if !strings.Contains(got, "Content-Type: foo/bar") {
3858 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
3859 }
3860 if strings.Contains(got, "Content-Length: 123") {
3861 t.Errorf("Response = %q; don't want a Content-Length", got)
3862 }
3863 }
3864
3865
3866
3867
3868
3869
3870
3871 func TestTransportAndServerSharedBodyRace_h1(t *testing.T) {
3872 testTransportAndServerSharedBodyRace(t, h1Mode)
3873 }
3874 func TestTransportAndServerSharedBodyRace_h2(t *testing.T) {
3875 testTransportAndServerSharedBodyRace(t, h2Mode)
3876 }
3877 func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
3878 setParallel(t)
3879 defer afterTest(t)
3880
3881 const bodySize = 1 << 20
3882
3883
3884
3885
3886
3887 errorf := func(format string, args ...interface{}) {
3888 v := fmt.Sprintf(format, args...)
3889 println(v)
3890 t.Error(v)
3891 }
3892
3893 unblockBackend := make(chan bool)
3894 backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
3895 gone := rw.(CloseNotifier).CloseNotify()
3896 didCopy := make(chan interface{})
3897 go func() {
3898 n, err := io.CopyN(rw, req.Body, bodySize)
3899 didCopy <- []interface{}{n, err}
3900 }()
3901 isGone := false
3902 Loop:
3903 for {
3904 select {
3905 case <-didCopy:
3906 break Loop
3907 case <-gone:
3908 isGone = true
3909 case <-time.After(time.Second):
3910 println("1 second passes in backend, proxygone=", isGone)
3911 }
3912 }
3913 <-unblockBackend
3914 }))
3915 var quitTimer *time.Timer
3916 defer func() { quitTimer.Stop() }()
3917 defer backend.close()
3918
3919 backendRespc := make(chan *Response, 1)
3920 var proxy *clientServerTest
3921 proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
3922 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
3923 req2.ContentLength = bodySize
3924 cancel := make(chan struct{})
3925 req2.Cancel = cancel
3926
3927 bresp, err := proxy.c.Do(req2)
3928 if err != nil {
3929 errorf("Proxy outbound request: %v", err)
3930 return
3931 }
3932 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
3933 if err != nil {
3934 errorf("Proxy copy error: %v", err)
3935 return
3936 }
3937 backendRespc <- bresp
3938
3939
3940
3941 if h2 {
3942 close(cancel)
3943 } else {
3944 proxy.c.Transport.(*Transport).CancelRequest(req2)
3945 }
3946 rw.Write([]byte("OK"))
3947 }))
3948 defer proxy.close()
3949 defer func() {
3950
3951
3952
3953
3954
3955 quitTimer = time.AfterFunc(7*time.Second, func() {
3956 debug.SetTraceback("ALL")
3957 stacks := make([]byte, 1<<20)
3958 stacks = stacks[:runtime.Stack(stacks, true)]
3959 fmt.Fprintf(os.Stderr, "%s", stacks)
3960 log.Fatalf("Timeout.")
3961 })
3962 }()
3963
3964 defer close(unblockBackend)
3965 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
3966 res, err := proxy.c.Do(req)
3967 if err != nil {
3968 t.Fatalf("Original request: %v", err)
3969 }
3970
3971
3972 res.Body.Close()
3973 select {
3974 case res := <-backendRespc:
3975 res.Body.Close()
3976 default:
3977
3978 }
3979 }
3980
3981
3982
3983
3984 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
3985 if testing.Short() {
3986 t.Skip("skipping in -short mode")
3987 }
3988 defer afterTest(t)
3989
3990 readErrCh := make(chan error, 1)
3991 errCh := make(chan error, 2)
3992
3993 server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
3994 go func(body io.Reader) {
3995 _, err := body.Read(make([]byte, 100))
3996 readErrCh <- err
3997 }(req.Body)
3998 time.Sleep(500 * time.Millisecond)
3999 }))
4000 defer server.Close()
4001
4002 closeConn := make(chan bool)
4003 defer close(closeConn)
4004 go func() {
4005 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4006 if err != nil {
4007 errCh <- err
4008 return
4009 }
4010 defer conn.Close()
4011 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4012 if err != nil {
4013 errCh <- err
4014 return
4015 }
4016
4017
4018 <-closeConn
4019 }()
4020 select {
4021 case err := <-readErrCh:
4022 if err == nil {
4023 t.Error("Read was nil. Expected error.")
4024 }
4025 case err := <-errCh:
4026 t.Error(err)
4027 case <-time.After(5 * time.Second):
4028 t.Error("timeout")
4029 }
4030 }
4031
4032
4033 func TestResponseWriterWriteString(t *testing.T) {
4034 okc := make(chan bool, 1)
4035 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4036 _, ok := w.(io.StringWriter)
4037 okc <- ok
4038 }))
4039 ht.rawResponse("GET / HTTP/1.0")
4040 select {
4041 case ok := <-okc:
4042 if !ok {
4043 t.Error("ResponseWriter did not implement io.StringWriter")
4044 }
4045 default:
4046 t.Error("handler was never called")
4047 }
4048 }
4049
4050 func TestAppendTime(t *testing.T) {
4051 var b [len(TimeFormat)]byte
4052 t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
4053 res := ExportAppendTime(b[:0], t1)
4054 t2, err := ParseTime(string(res))
4055 if err != nil {
4056 t.Fatalf("Error parsing time: %s", err)
4057 }
4058 if !t1.Equal(t2) {
4059 t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
4060 }
4061 }
4062
4063 func TestServerConnState(t *testing.T) {
4064 setParallel(t)
4065 defer afterTest(t)
4066 handler := map[string]func(w ResponseWriter, r *Request){
4067 "/": func(w ResponseWriter, r *Request) {
4068 fmt.Fprintf(w, "Hello.")
4069 },
4070 "/close": func(w ResponseWriter, r *Request) {
4071 w.Header().Set("Connection", "close")
4072 fmt.Fprintf(w, "Hello.")
4073 },
4074 "/hijack": func(w ResponseWriter, r *Request) {
4075 c, _, _ := w.(Hijacker).Hijack()
4076 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4077 c.Close()
4078 },
4079 "/hijack-panic": func(w ResponseWriter, r *Request) {
4080 c, _, _ := w.(Hijacker).Hijack()
4081 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4082 c.Close()
4083 panic("intentional panic")
4084 },
4085 }
4086
4087
4088 type stateLog struct {
4089 active net.Conn
4090 got []ConnState
4091 want []ConnState
4092 complete chan<- struct{}
4093 }
4094 activeLog := make(chan *stateLog, 1)
4095
4096
4097
4098
4099 wantLog := func(doRequests func(), want ...ConnState) {
4100 t.Helper()
4101 complete := make(chan struct{})
4102 activeLog <- &stateLog{want: want, complete: complete}
4103
4104 doRequests()
4105
4106 stateDelay := 5 * time.Second
4107 if deadline, ok := t.Deadline(); ok {
4108
4109
4110
4111
4112 const arbitraryCleanupMargin = 1 * time.Second
4113 stateDelay = time.Until(deadline) - arbitraryCleanupMargin
4114 }
4115 timer := time.NewTimer(stateDelay)
4116 select {
4117 case <-timer.C:
4118 t.Errorf("Timed out after %v waiting for connection to change state.", stateDelay)
4119 case <-complete:
4120 timer.Stop()
4121 }
4122 sl := <-activeLog
4123 if !reflect.DeepEqual(sl.got, sl.want) {
4124 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4125 }
4126
4127
4128
4129 }
4130
4131 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4132 handler[r.URL.Path](w, r)
4133 }))
4134 defer func() {
4135 activeLog <- &stateLog{}
4136 ts.Close()
4137 }()
4138
4139 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4140 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4141 if c == nil {
4142 t.Errorf("nil conn seen in state %s", state)
4143 return
4144 }
4145 sl := <-activeLog
4146 if sl.active == nil && state == StateNew {
4147 sl.active = c
4148 } else if sl.active != c {
4149 t.Errorf("unexpected conn in state %s", state)
4150 activeLog <- sl
4151 return
4152 }
4153 sl.got = append(sl.got, state)
4154 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) {
4155 close(sl.complete)
4156 sl.complete = nil
4157 }
4158 activeLog <- sl
4159 }
4160
4161 ts.Start()
4162 c := ts.Client()
4163
4164 mustGet := func(url string, headers ...string) {
4165 t.Helper()
4166 req, err := NewRequest("GET", url, nil)
4167 if err != nil {
4168 t.Fatal(err)
4169 }
4170 for len(headers) > 0 {
4171 req.Header.Add(headers[0], headers[1])
4172 headers = headers[2:]
4173 }
4174 res, err := c.Do(req)
4175 if err != nil {
4176 t.Errorf("Error fetching %s: %v", url, err)
4177 return
4178 }
4179 _, err = io.ReadAll(res.Body)
4180 defer res.Body.Close()
4181 if err != nil {
4182 t.Errorf("Error reading %s: %v", url, err)
4183 }
4184 }
4185
4186 wantLog(func() {
4187 mustGet(ts.URL + "/")
4188 mustGet(ts.URL + "/close")
4189 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4190
4191 wantLog(func() {
4192 mustGet(ts.URL + "/")
4193 mustGet(ts.URL+"/", "Connection", "close")
4194 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4195
4196 wantLog(func() {
4197 mustGet(ts.URL + "/hijack")
4198 }, StateNew, StateActive, StateHijacked)
4199
4200 wantLog(func() {
4201 mustGet(ts.URL + "/hijack-panic")
4202 }, StateNew, StateActive, StateHijacked)
4203
4204 wantLog(func() {
4205 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4206 if err != nil {
4207 t.Fatal(err)
4208 }
4209 c.Close()
4210 }, StateNew, StateClosed)
4211
4212 wantLog(func() {
4213 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4214 if err != nil {
4215 t.Fatal(err)
4216 }
4217 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4218 t.Fatal(err)
4219 }
4220 c.Read(make([]byte, 1))
4221 c.Close()
4222 }, StateNew, StateActive, StateClosed)
4223
4224 wantLog(func() {
4225 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4226 if err != nil {
4227 t.Fatal(err)
4228 }
4229 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4230 t.Fatal(err)
4231 }
4232 res, err := ReadResponse(bufio.NewReader(c), nil)
4233 if err != nil {
4234 t.Fatal(err)
4235 }
4236 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4237 t.Fatal(err)
4238 }
4239 c.Close()
4240 }, StateNew, StateActive, StateIdle, StateClosed)
4241 }
4242
4243 func TestServerKeepAlivesEnabled(t *testing.T) {
4244 defer afterTest(t)
4245 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
4246 ts.Config.SetKeepAlivesEnabled(false)
4247 ts.Start()
4248 defer ts.Close()
4249 res, err := Get(ts.URL)
4250 if err != nil {
4251 t.Fatal(err)
4252 }
4253 defer res.Body.Close()
4254 if !res.Close {
4255 t.Errorf("Body.Close == false; want true")
4256 }
4257 }
4258
4259
4260 func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) }
4261 func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) }
4262 func testServerEmptyBodyRace(t *testing.T, h2 bool) {
4263 setParallel(t)
4264 defer afterTest(t)
4265 var n int32
4266 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
4267 atomic.AddInt32(&n, 1)
4268 }), optQuietLog)
4269 defer cst.close()
4270 var wg sync.WaitGroup
4271 const reqs = 20
4272 for i := 0; i < reqs; i++ {
4273 wg.Add(1)
4274 go func() {
4275 defer wg.Done()
4276 res, err := cst.c.Get(cst.ts.URL)
4277 if err != nil {
4278
4279
4280 time.Sleep(10 * time.Millisecond)
4281 res, err = cst.c.Get(cst.ts.URL)
4282 if err != nil {
4283 t.Error(err)
4284 return
4285 }
4286 }
4287 defer res.Body.Close()
4288 _, err = io.Copy(io.Discard, res.Body)
4289 if err != nil {
4290 t.Error(err)
4291 return
4292 }
4293 }()
4294 }
4295 wg.Wait()
4296 if got := atomic.LoadInt32(&n); got != reqs {
4297 t.Errorf("handler ran %d times; want %d", got, reqs)
4298 }
4299 }
4300
4301 func TestServerConnStateNew(t *testing.T) {
4302 sawNew := false
4303 srv := &Server{
4304 ConnState: func(c net.Conn, state ConnState) {
4305 if state == StateNew {
4306 sawNew = true
4307 }
4308 },
4309 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4310 }
4311 srv.Serve(&oneConnListener{
4312 conn: &rwTestConn{
4313 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4314 Writer: io.Discard,
4315 },
4316 })
4317 if !sawNew {
4318 t.Error("StateNew not seen")
4319 }
4320 }
4321
4322 type closeWriteTestConn struct {
4323 rwTestConn
4324 didCloseWrite bool
4325 }
4326
4327 func (c *closeWriteTestConn) CloseWrite() error {
4328 c.didCloseWrite = true
4329 return nil
4330 }
4331
4332 func TestCloseWrite(t *testing.T) {
4333 setParallel(t)
4334 var srv Server
4335 var testConn closeWriteTestConn
4336 c := ExportServerNewConn(&srv, &testConn)
4337 ExportCloseWriteAndWait(c)
4338 if !testConn.didCloseWrite {
4339 t.Error("didn't see CloseWrite call")
4340 }
4341 }
4342
4343
4344
4345
4346
4347
4348
4349
4350 func TestServerFlushAndHijack(t *testing.T) {
4351 defer afterTest(t)
4352 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4353 io.WriteString(w, "Hello, ")
4354 w.(Flusher).Flush()
4355 conn, buf, _ := w.(Hijacker).Hijack()
4356 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4357 if err := buf.Flush(); err != nil {
4358 t.Error(err)
4359 }
4360 if err := conn.Close(); err != nil {
4361 t.Error(err)
4362 }
4363 }))
4364 defer ts.Close()
4365 res, err := Get(ts.URL)
4366 if err != nil {
4367 t.Fatal(err)
4368 }
4369 defer res.Body.Close()
4370 all, err := io.ReadAll(res.Body)
4371 if err != nil {
4372 t.Fatal(err)
4373 }
4374 if want := "Hello, world!"; string(all) != want {
4375 t.Errorf("Got %q; want %q", all, want)
4376 }
4377 }
4378
4379
4380
4381
4382
4383
4384
4385 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4386 if testing.Short() {
4387 t.Skip("skipping in -short mode")
4388 }
4389 defer afterTest(t)
4390 const numReq = 3
4391 addrc := make(chan string, numReq)
4392 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4393 addrc <- r.RemoteAddr
4394 time.Sleep(500 * time.Millisecond)
4395 w.(Flusher).Flush()
4396 }))
4397 ts.Config.WriteTimeout = 250 * time.Millisecond
4398 ts.Start()
4399 defer ts.Close()
4400
4401 errc := make(chan error, numReq)
4402 go func() {
4403 defer close(errc)
4404 for i := 0; i < numReq; i++ {
4405 res, err := Get(ts.URL)
4406 if res != nil {
4407 res.Body.Close()
4408 }
4409 errc <- err
4410 }
4411 }()
4412
4413 timeout := time.NewTimer(numReq * 2 * time.Second)
4414 defer timeout.Stop()
4415 addrSeen := map[string]bool{}
4416 numOkay := 0
4417 for {
4418 select {
4419 case v := <-addrc:
4420 addrSeen[v] = true
4421 case err, ok := <-errc:
4422 if !ok {
4423 if len(addrSeen) != numReq {
4424 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4425 }
4426 if numOkay != 0 {
4427 t.Errorf("got %d successful client requests; want 0", numOkay)
4428 }
4429 return
4430 }
4431 if err == nil {
4432 numOkay++
4433 }
4434 case <-timeout.C:
4435 t.Fatal("timeout waiting for requests to complete")
4436 }
4437 }
4438 }
4439
4440
4441
4442 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4443 defer afterTest(t)
4444 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4445 w.Header().Set("Transfer-Encoding", "foo")
4446 io.WriteString(w, "<html>")
4447 }))
4448 defer ts.Close()
4449 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4450 if err != nil {
4451 t.Fatalf("Dial: %v", err)
4452 }
4453 defer c.Close()
4454 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4455 t.Fatal(err)
4456 }
4457 bs := bufio.NewScanner(c)
4458 var got bytes.Buffer
4459 for bs.Scan() {
4460 if strings.TrimSpace(bs.Text()) == "" {
4461 break
4462 }
4463 got.WriteString(bs.Text())
4464 got.WriteByte('\n')
4465 }
4466 if err := bs.Err(); err != nil {
4467 t.Fatal(err)
4468 }
4469 if strings.Contains(got.String(), "Content-Length") {
4470 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4471 }
4472 if strings.Contains(got.String(), "Content-Type") {
4473 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4474 }
4475 }
4476
4477
4478
4479 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4480 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4481 "\r\n\r\n" +
4482 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4483 var buf bytes.Buffer
4484 conn := &rwTestConn{
4485 Reader: bytes.NewReader(req),
4486 Writer: &buf,
4487 closec: make(chan bool, 1),
4488 }
4489 ln := &oneConnListener{conn: conn}
4490 numReq := 0
4491 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4492 numReq++
4493 }))
4494 <-conn.closec
4495 if numReq != 2 {
4496 t.Errorf("num requests = %d; want 2", numReq)
4497 t.Logf("Res: %s", buf.Bytes())
4498 }
4499 }
4500
4501 func TestIssue13893_Expect100(t *testing.T) {
4502
4503 req := reqBytes(`PUT /readbody HTTP/1.1
4504 User-Agent: PycURL/7.22.0
4505 Host: 127.0.0.1:9000
4506 Accept: */*
4507 Expect: 100-continue
4508 Content-Length: 10
4509
4510 HelloWorld
4511
4512 `)
4513 var buf bytes.Buffer
4514 conn := &rwTestConn{
4515 Reader: bytes.NewReader(req),
4516 Writer: &buf,
4517 closec: make(chan bool, 1),
4518 }
4519 ln := &oneConnListener{conn: conn}
4520 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4521 if _, ok := r.Header["Expect"]; !ok {
4522 t.Error("Expect header should not be filtered out")
4523 }
4524 }))
4525 <-conn.closec
4526 }
4527
4528 func TestIssue11549_Expect100(t *testing.T) {
4529 req := reqBytes(`PUT /readbody HTTP/1.1
4530 User-Agent: PycURL/7.22.0
4531 Host: 127.0.0.1:9000
4532 Accept: */*
4533 Expect: 100-continue
4534 Content-Length: 10
4535
4536 HelloWorldPUT /noreadbody HTTP/1.1
4537 User-Agent: PycURL/7.22.0
4538 Host: 127.0.0.1:9000
4539 Accept: */*
4540 Expect: 100-continue
4541 Content-Length: 10
4542
4543 GET /should-be-ignored HTTP/1.1
4544 Host: foo
4545
4546 `)
4547 var buf bytes.Buffer
4548 conn := &rwTestConn{
4549 Reader: bytes.NewReader(req),
4550 Writer: &buf,
4551 closec: make(chan bool, 1),
4552 }
4553 ln := &oneConnListener{conn: conn}
4554 numReq := 0
4555 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4556 numReq++
4557 if r.URL.Path == "/readbody" {
4558 io.ReadAll(r.Body)
4559 }
4560 io.WriteString(w, "Hello world!")
4561 }))
4562 <-conn.closec
4563 if numReq != 2 {
4564 t.Errorf("num requests = %d; want 2", numReq)
4565 }
4566 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4567 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4568 }
4569 }
4570
4571
4572
4573 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4574 setParallel(t)
4575 conn := &testConn{closec: make(chan bool)}
4576 conn.readBuf.Write([]byte(fmt.Sprintf(
4577 "POST / HTTP/1.1\r\n" +
4578 "Host: test\r\n" +
4579 "Content-Length: 9999999999\r\n" +
4580 "\r\n" + strings.Repeat("a", 1<<20))))
4581
4582 ls := &oneConnListener{conn}
4583 var inHandlerLen int
4584 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4585 inHandlerLen = conn.readBuf.Len()
4586 rw.WriteHeader(404)
4587 }))
4588 <-conn.closec
4589 afterHandlerLen := conn.readBuf.Len()
4590
4591 if afterHandlerLen != inHandlerLen {
4592 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4593 }
4594 }
4595
4596 func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) }
4597 func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) }
4598 func testHandlerSetsBodyNil(t *testing.T, h2 bool) {
4599 defer afterTest(t)
4600 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4601 r.Body = nil
4602 fmt.Fprintf(w, "%v", r.RemoteAddr)
4603 }))
4604 defer cst.close()
4605 get := func() string {
4606 res, err := cst.c.Get(cst.ts.URL)
4607 if err != nil {
4608 t.Fatal(err)
4609 }
4610 defer res.Body.Close()
4611 slurp, err := io.ReadAll(res.Body)
4612 if err != nil {
4613 t.Fatal(err)
4614 }
4615 return string(slurp)
4616 }
4617 a, b := get(), get()
4618 if a != b {
4619 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4620 }
4621 }
4622
4623
4624
4625 func TestServerValidatesHostHeader(t *testing.T) {
4626 tests := []struct {
4627 proto string
4628 host string
4629 want int
4630 }{
4631 {"HTTP/0.9", "", 505},
4632
4633 {"HTTP/1.1", "", 400},
4634 {"HTTP/1.1", "Host: \r\n", 200},
4635 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4636 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4637 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4638 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4639 {"HTTP/1.1", "Host: ::1\r\n", 200},
4640 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4641 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4642 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4643 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4644 {"HTTP/1.1", "Host: \x06\r\n", 400},
4645 {"HTTP/1.1", "Host: \xff\r\n", 400},
4646 {"HTTP/1.1", "Host: {\r\n", 400},
4647 {"HTTP/1.1", "Host: }\r\n", 400},
4648 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4649
4650
4651
4652 {"HTTP/1.0", "", 200},
4653 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4654 {"HTTP/1.0", "Host: \xff\r\n", 400},
4655
4656
4657 {"PRI * HTTP/2.0", "", 200},
4658
4659
4660 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4661
4662
4663 {"PRI / HTTP/2.0", "", 505},
4664 {"GET / HTTP/2.0", "", 505},
4665 {"GET / HTTP/3.0", "", 505},
4666 }
4667 for _, tt := range tests {
4668 conn := &testConn{closec: make(chan bool, 1)}
4669 methodTarget := "GET / "
4670 if !strings.HasPrefix(tt.proto, "HTTP/") {
4671 methodTarget = ""
4672 }
4673 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4674
4675 ln := &oneConnListener{conn}
4676 srv := Server{
4677 ErrorLog: quietLog,
4678 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4679 }
4680 go srv.Serve(ln)
4681 <-conn.closec
4682 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4683 if err != nil {
4684 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4685 continue
4686 }
4687 if res.StatusCode != tt.want {
4688 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4689 }
4690 }
4691 }
4692
4693 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4694 const upgradeResponse = "upgrade here"
4695 defer afterTest(t)
4696 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4697 conn, br, err := w.(Hijacker).Hijack()
4698 if err != nil {
4699 t.Error(err)
4700 return
4701 }
4702 defer conn.Close()
4703 if r.Method != "PRI" || r.RequestURI != "*" {
4704 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4705 return
4706 }
4707 if !r.Close {
4708 t.Errorf("Request.Close = true; want false")
4709 }
4710 const want = "SM\r\n\r\n"
4711 buf := make([]byte, len(want))
4712 n, err := io.ReadFull(br, buf)
4713 if err != nil || string(buf[:n]) != want {
4714 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4715 return
4716 }
4717 io.WriteString(conn, upgradeResponse)
4718 }))
4719 defer ts.Close()
4720
4721 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4722 if err != nil {
4723 t.Fatalf("Dial: %v", err)
4724 }
4725 defer c.Close()
4726 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4727 slurp, err := io.ReadAll(c)
4728 if err != nil {
4729 t.Fatal(err)
4730 }
4731 if string(slurp) != upgradeResponse {
4732 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
4733 }
4734 }
4735
4736
4737
4738 func TestServerValidatesHeaders(t *testing.T) {
4739 setParallel(t)
4740 tests := []struct {
4741 header string
4742 want int
4743 }{
4744 {"", 200},
4745 {"Foo: bar\r\n", 200},
4746 {"X-Foo: bar\r\n", 200},
4747 {"Foo: a space\r\n", 200},
4748
4749 {"A space: foo\r\n", 400},
4750 {"foo\xffbar: foo\r\n", 400},
4751 {"foo\x00bar: foo\r\n", 400},
4752 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
4753
4754
4755 {"Foo : bar\r\n", 400},
4756 {"Foo\t: bar\r\n", 400},
4757
4758 {"foo: foo foo\r\n", 200},
4759 {"foo: foo\tfoo\r\n", 200},
4760 {"foo: foo\x00foo\r\n", 400},
4761 {"foo: foo\x7ffoo\r\n", 400},
4762 {"foo: foo\xfffoo\r\n", 200},
4763 }
4764 for _, tt := range tests {
4765 conn := &testConn{closec: make(chan bool, 1)}
4766 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
4767
4768 ln := &oneConnListener{conn}
4769 srv := Server{
4770 ErrorLog: quietLog,
4771 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4772 }
4773 go srv.Serve(ln)
4774 <-conn.closec
4775 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4776 if err != nil {
4777 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
4778 continue
4779 }
4780 if res.StatusCode != tt.want {
4781 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
4782 }
4783 }
4784 }
4785
4786 func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) {
4787 testServerRequestContextCancel_ServeHTTPDone(t, h1Mode)
4788 }
4789 func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) {
4790 testServerRequestContextCancel_ServeHTTPDone(t, h2Mode)
4791 }
4792 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) {
4793 setParallel(t)
4794 defer afterTest(t)
4795 ctxc := make(chan context.Context, 1)
4796 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4797 ctx := r.Context()
4798 select {
4799 case <-ctx.Done():
4800 t.Error("should not be Done in ServeHTTP")
4801 default:
4802 }
4803 ctxc <- ctx
4804 }))
4805 defer cst.close()
4806 res, err := cst.c.Get(cst.ts.URL)
4807 if err != nil {
4808 t.Fatal(err)
4809 }
4810 res.Body.Close()
4811 ctx := <-ctxc
4812 select {
4813 case <-ctx.Done():
4814 default:
4815 t.Error("context should be done after ServeHTTP completes")
4816 }
4817 }
4818
4819
4820
4821
4822
4823 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
4824 setParallel(t)
4825 defer afterTest(t)
4826 inHandler := make(chan struct{})
4827 handlerDone := make(chan struct{})
4828 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
4829 close(inHandler)
4830 select {
4831 case <-r.Context().Done():
4832 case <-time.After(3 * time.Second):
4833 t.Errorf("timeout waiting for context to be done")
4834 }
4835 close(handlerDone)
4836 }))
4837 defer ts.Close()
4838 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4839 if err != nil {
4840 t.Fatal(err)
4841 }
4842 defer c.Close()
4843 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
4844 select {
4845 case <-inHandler:
4846 case <-time.After(3 * time.Second):
4847 t.Fatalf("timeout waiting to see ServeHTTP get called")
4848 }
4849 c.Close()
4850
4851 select {
4852 case <-handlerDone:
4853 case <-time.After(4 * time.Second):
4854 t.Fatalf("timeout waiting to see ServeHTTP exit")
4855 }
4856 }
4857
4858 func TestServerContext_ServerContextKey_h1(t *testing.T) {
4859 testServerContext_ServerContextKey(t, h1Mode)
4860 }
4861 func TestServerContext_ServerContextKey_h2(t *testing.T) {
4862 testServerContext_ServerContextKey(t, h2Mode)
4863 }
4864 func testServerContext_ServerContextKey(t *testing.T, h2 bool) {
4865 setParallel(t)
4866 defer afterTest(t)
4867 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4868 ctx := r.Context()
4869 got := ctx.Value(ServerContextKey)
4870 if _, ok := got.(*Server); !ok {
4871 t.Errorf("context value = %T; want *http.Server", got)
4872 }
4873 }))
4874 defer cst.close()
4875 res, err := cst.c.Get(cst.ts.URL)
4876 if err != nil {
4877 t.Fatal(err)
4878 }
4879 res.Body.Close()
4880 }
4881
4882 func TestServerContext_LocalAddrContextKey_h1(t *testing.T) {
4883 testServerContext_LocalAddrContextKey(t, h1Mode)
4884 }
4885 func TestServerContext_LocalAddrContextKey_h2(t *testing.T) {
4886 testServerContext_LocalAddrContextKey(t, h2Mode)
4887 }
4888 func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) {
4889 setParallel(t)
4890 defer afterTest(t)
4891 ch := make(chan interface{}, 1)
4892 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
4893 ch <- r.Context().Value(LocalAddrContextKey)
4894 }))
4895 defer cst.close()
4896 if _, err := cst.c.Head(cst.ts.URL); err != nil {
4897 t.Fatal(err)
4898 }
4899
4900 host := cst.ts.Listener.Addr().String()
4901 select {
4902 case got := <-ch:
4903 if addr, ok := got.(net.Addr); !ok {
4904 t.Errorf("local addr value = %T; want net.Addr", got)
4905 } else if fmt.Sprint(addr) != host {
4906 t.Errorf("local addr = %v; want %v", addr, host)
4907 }
4908 case <-time.After(5 * time.Second):
4909 t.Error("timed out")
4910 }
4911 }
4912
4913
4914 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
4915 setParallel(t)
4916 defer afterTest(t)
4917 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4918 w.Header().Set("Transfer-Encoding", "chunked")
4919 w.Write([]byte("hello"))
4920 }))
4921 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4922 const hdr = "Transfer-Encoding: chunked"
4923 if n := strings.Count(resp, hdr); n != 1 {
4924 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
4925 }
4926 }
4927
4928
4929 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
4930 setParallel(t)
4931 defer afterTest(t)
4932 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4933 w.Header().Set("Transfer-Encoding", "gzip")
4934 gz := gzip.NewWriter(w)
4935 gz.Write([]byte("hello"))
4936 gz.Close()
4937 }))
4938 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4939 for _, v := range []string{"gzip", "chunked"} {
4940 hdr := "Transfer-Encoding: " + v
4941 if n := strings.Count(resp, hdr); n != 1 {
4942 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
4943 }
4944 }
4945 }
4946
4947 func BenchmarkClientServer(b *testing.B) {
4948 b.ReportAllocs()
4949 b.StopTimer()
4950 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
4951 fmt.Fprintf(rw, "Hello world.\n")
4952 }))
4953 defer ts.Close()
4954 b.StartTimer()
4955
4956 for i := 0; i < b.N; i++ {
4957 res, err := Get(ts.URL)
4958 if err != nil {
4959 b.Fatal("Get:", err)
4960 }
4961 all, err := io.ReadAll(res.Body)
4962 res.Body.Close()
4963 if err != nil {
4964 b.Fatal("ReadAll:", err)
4965 }
4966 body := string(all)
4967 if body != "Hello world.\n" {
4968 b.Fatal("Got body:", body)
4969 }
4970 }
4971
4972 b.StopTimer()
4973 }
4974
4975 func BenchmarkClientServerParallel4(b *testing.B) {
4976 benchmarkClientServerParallel(b, 4, false)
4977 }
4978
4979 func BenchmarkClientServerParallel64(b *testing.B) {
4980 benchmarkClientServerParallel(b, 64, false)
4981 }
4982
4983 func BenchmarkClientServerParallelTLS4(b *testing.B) {
4984 benchmarkClientServerParallel(b, 4, true)
4985 }
4986
4987 func BenchmarkClientServerParallelTLS64(b *testing.B) {
4988 benchmarkClientServerParallel(b, 64, true)
4989 }
4990
4991 func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) {
4992 b.ReportAllocs()
4993 ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
4994 fmt.Fprintf(rw, "Hello world.\n")
4995 }))
4996 if useTLS {
4997 ts.StartTLS()
4998 } else {
4999 ts.Start()
5000 }
5001 defer ts.Close()
5002 b.ResetTimer()
5003 b.SetParallelism(parallelism)
5004 b.RunParallel(func(pb *testing.PB) {
5005 c := ts.Client()
5006 for pb.Next() {
5007 res, err := c.Get(ts.URL)
5008 if err != nil {
5009 b.Logf("Get: %v", err)
5010 continue
5011 }
5012 all, err := io.ReadAll(res.Body)
5013 res.Body.Close()
5014 if err != nil {
5015 b.Logf("ReadAll: %v", err)
5016 continue
5017 }
5018 body := string(all)
5019 if body != "Hello world.\n" {
5020 panic("Got body: " + body)
5021 }
5022 }
5023 })
5024 }
5025
5026
5027
5028
5029
5030
5031
5032
5033
5034 func BenchmarkServer(b *testing.B) {
5035 b.ReportAllocs()
5036
5037 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5038 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5039 if err != nil {
5040 panic(err)
5041 }
5042 for i := 0; i < n; i++ {
5043 res, err := Get(url)
5044 if err != nil {
5045 log.Panicf("Get: %v", err)
5046 }
5047 all, err := io.ReadAll(res.Body)
5048 res.Body.Close()
5049 if err != nil {
5050 log.Panicf("ReadAll: %v", err)
5051 }
5052 body := string(all)
5053 if body != "Hello world.\n" {
5054 log.Panicf("Got body: %q", body)
5055 }
5056 }
5057 os.Exit(0)
5058 return
5059 }
5060
5061 var res = []byte("Hello world.\n")
5062 b.StopTimer()
5063 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5064 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5065 rw.Write(res)
5066 }))
5067 defer ts.Close()
5068 b.StartTimer()
5069
5070 cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer$")
5071 cmd.Env = append([]string{
5072 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5073 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5074 }, os.Environ()...)
5075 out, err := cmd.CombinedOutput()
5076 if err != nil {
5077 b.Errorf("Test failure: %v, with output: %s", err, out)
5078 }
5079 }
5080
5081
5082 func getNoBody(urlStr string) (*Response, error) {
5083 res, err := Get(urlStr)
5084 if err != nil {
5085 return nil, err
5086 }
5087 res.Body.Close()
5088 return res, nil
5089 }
5090
5091
5092
5093 func BenchmarkClient(b *testing.B) {
5094 b.ReportAllocs()
5095 b.StopTimer()
5096 defer afterTest(b)
5097
5098 var data = []byte("Hello world.\n")
5099 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5100
5101 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5102 if port == "" {
5103 port = "0"
5104 }
5105 ln, err := net.Listen("tcp", "localhost:"+port)
5106 if err != nil {
5107 fmt.Fprintln(os.Stderr, err.Error())
5108 os.Exit(1)
5109 }
5110 fmt.Println(ln.Addr().String())
5111 HandleFunc("/", func(w ResponseWriter, r *Request) {
5112 r.ParseForm()
5113 if r.Form.Get("stop") != "" {
5114 os.Exit(0)
5115 }
5116 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5117 w.Write(data)
5118 })
5119 var srv Server
5120 log.Fatal(srv.Serve(ln))
5121 }
5122
5123
5124 cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$")
5125 cmd.Env = append(os.Environ(), "TEST_BENCH_SERVER=yes")
5126 cmd.Stderr = os.Stderr
5127 stdout, err := cmd.StdoutPipe()
5128 if err != nil {
5129 b.Fatal(err)
5130 }
5131 if err := cmd.Start(); err != nil {
5132 b.Fatalf("subprocess failed to start: %v", err)
5133 }
5134 defer cmd.Process.Kill()
5135
5136
5137
5138 timer := time.AfterFunc(10*time.Second, func() {
5139 cmd.Process.Kill()
5140 })
5141 defer timer.Stop()
5142 bs := bufio.NewScanner(stdout)
5143 if !bs.Scan() {
5144 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5145 }
5146 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5147 timer.Stop()
5148 if _, err := getNoBody(url); err != nil {
5149 b.Fatalf("initial probe of child process failed: %v", err)
5150 }
5151
5152 done := make(chan error)
5153 stop := make(chan struct{})
5154 defer close(stop)
5155 go func() {
5156 select {
5157 case <-stop:
5158 return
5159 case done <- cmd.Wait():
5160 }
5161 }()
5162
5163
5164 b.StartTimer()
5165 for i := 0; i < b.N; i++ {
5166 res, err := Get(url)
5167 if err != nil {
5168 b.Fatalf("Get: %v", err)
5169 }
5170 body, err := io.ReadAll(res.Body)
5171 res.Body.Close()
5172 if err != nil {
5173 b.Fatalf("ReadAll: %v", err)
5174 }
5175 if !bytes.Equal(body, data) {
5176 b.Fatalf("Got body: %q", body)
5177 }
5178 }
5179 b.StopTimer()
5180
5181
5182 getNoBody(url + "?stop=yes")
5183 select {
5184 case err := <-done:
5185 if err != nil {
5186 b.Fatalf("subprocess failed: %v", err)
5187 }
5188 case <-time.After(5 * time.Second):
5189 b.Fatalf("subprocess did not stop")
5190 }
5191 }
5192
5193 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5194 b.ReportAllocs()
5195 req := reqBytes(`GET / HTTP/1.0
5196 Host: golang.org
5197 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5198 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5199 Accept-Encoding: gzip,deflate,sdch
5200 Accept-Language: en-US,en;q=0.8
5201 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5202 `)
5203 res := []byte("Hello world!\n")
5204
5205 conn := &testConn{
5206
5207
5208 closec: make(chan bool, 1),
5209 }
5210 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5211 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5212 rw.Write(res)
5213 })
5214 ln := new(oneConnListener)
5215 for i := 0; i < b.N; i++ {
5216 conn.readBuf.Reset()
5217 conn.writeBuf.Reset()
5218 conn.readBuf.Write(req)
5219 ln.conn = conn
5220 Serve(ln, handler)
5221 <-conn.closec
5222 }
5223 }
5224
5225
5226 type repeatReader struct {
5227 content []byte
5228 count int
5229 off int
5230 }
5231
5232 func (r *repeatReader) Read(p []byte) (n int, err error) {
5233 if r.count <= 0 {
5234 return 0, io.EOF
5235 }
5236 n = copy(p, r.content[r.off:])
5237 r.off += n
5238 if r.off == len(r.content) {
5239 r.count--
5240 r.off = 0
5241 }
5242 return
5243 }
5244
5245 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5246 b.ReportAllocs()
5247
5248 req := reqBytes(`GET / HTTP/1.1
5249 Host: golang.org
5250 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5251 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5252 Accept-Encoding: gzip,deflate,sdch
5253 Accept-Language: en-US,en;q=0.8
5254 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5255 `)
5256 res := []byte("Hello world!\n")
5257
5258 conn := &rwTestConn{
5259 Reader: &repeatReader{content: req, count: b.N},
5260 Writer: io.Discard,
5261 closec: make(chan bool, 1),
5262 }
5263 handled := 0
5264 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5265 handled++
5266 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5267 rw.Write(res)
5268 })
5269 ln := &oneConnListener{conn: conn}
5270 go Serve(ln, handler)
5271 <-conn.closec
5272 if b.N != handled {
5273 b.Errorf("b.N=%d but handled %d", b.N, handled)
5274 }
5275 }
5276
5277
5278
5279 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5280 b.ReportAllocs()
5281
5282 req := reqBytes(`GET / HTTP/1.1
5283 Host: golang.org
5284 `)
5285 res := []byte("Hello world!\n")
5286
5287 conn := &rwTestConn{
5288 Reader: &repeatReader{content: req, count: b.N},
5289 Writer: io.Discard,
5290 closec: make(chan bool, 1),
5291 }
5292 handled := 0
5293 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5294 handled++
5295 rw.Write(res)
5296 })
5297 ln := &oneConnListener{conn: conn}
5298 go Serve(ln, handler)
5299 <-conn.closec
5300 if b.N != handled {
5301 b.Errorf("b.N=%d but handled %d", b.N, handled)
5302 }
5303 }
5304
5305 const someResponse = "<html>some response</html>"
5306
5307
5308 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5309
5310
5311 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5312 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5313 w.Header().Set("Content-Type", "text/html")
5314 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5315 w.Write(response)
5316 }))
5317 }
5318
5319
5320 func BenchmarkServerHandlerNoLen(b *testing.B) {
5321 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5322 w.Header().Set("Content-Type", "text/html")
5323 w.Write(response)
5324 }))
5325 }
5326
5327
5328 func BenchmarkServerHandlerNoType(b *testing.B) {
5329 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5330 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5331 w.Write(response)
5332 }))
5333 }
5334
5335
5336 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5337 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5338 w.Write(response)
5339 }))
5340 }
5341
5342 func benchmarkHandler(b *testing.B, h Handler) {
5343 b.ReportAllocs()
5344 req := reqBytes(`GET / HTTP/1.1
5345 Host: golang.org
5346 `)
5347 conn := &rwTestConn{
5348 Reader: &repeatReader{content: req, count: b.N},
5349 Writer: io.Discard,
5350 closec: make(chan bool, 1),
5351 }
5352 handled := 0
5353 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5354 handled++
5355 h.ServeHTTP(rw, r)
5356 })
5357 ln := &oneConnListener{conn: conn}
5358 go Serve(ln, handler)
5359 <-conn.closec
5360 if b.N != handled {
5361 b.Errorf("b.N=%d but handled %d", b.N, handled)
5362 }
5363 }
5364
5365 func BenchmarkServerHijack(b *testing.B) {
5366 b.ReportAllocs()
5367 req := reqBytes(`GET / HTTP/1.1
5368 Host: golang.org
5369 `)
5370 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5371 conn, _, err := w.(Hijacker).Hijack()
5372 if err != nil {
5373 panic(err)
5374 }
5375 conn.Close()
5376 })
5377 conn := &rwTestConn{
5378 Writer: io.Discard,
5379 closec: make(chan bool, 1),
5380 }
5381 ln := &oneConnListener{conn: conn}
5382 for i := 0; i < b.N; i++ {
5383 conn.Reader = bytes.NewReader(req)
5384 ln.conn = conn
5385 Serve(ln, h)
5386 <-conn.closec
5387 }
5388 }
5389
5390 func BenchmarkCloseNotifier(b *testing.B) {
5391 b.ReportAllocs()
5392 b.StopTimer()
5393 sawClose := make(chan bool)
5394 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
5395 <-rw.(CloseNotifier).CloseNotify()
5396 sawClose <- true
5397 }))
5398 defer ts.Close()
5399 tot := time.NewTimer(5 * time.Second)
5400 defer tot.Stop()
5401 b.StartTimer()
5402 for i := 0; i < b.N; i++ {
5403 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5404 if err != nil {
5405 b.Fatalf("error dialing: %v", err)
5406 }
5407 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5408 if err != nil {
5409 b.Fatal(err)
5410 }
5411 conn.Close()
5412 tot.Reset(5 * time.Second)
5413 select {
5414 case <-sawClose:
5415 case <-tot.C:
5416 b.Fatal("timeout")
5417 }
5418 }
5419 b.StopTimer()
5420 }
5421
5422
5423 func TestConcurrentServerServe(t *testing.T) {
5424 setParallel(t)
5425 for i := 0; i < 100; i++ {
5426 ln1 := &oneConnListener{conn: nil}
5427 ln2 := &oneConnListener{conn: nil}
5428 srv := Server{}
5429 go func() { srv.Serve(ln1) }()
5430 go func() { srv.Serve(ln2) }()
5431 }
5432 }
5433
5434 func TestServerIdleTimeout(t *testing.T) {
5435 if testing.Short() {
5436 t.Skip("skipping in short mode")
5437 }
5438 setParallel(t)
5439 defer afterTest(t)
5440 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5441 io.Copy(io.Discard, r.Body)
5442 io.WriteString(w, r.RemoteAddr)
5443 }))
5444 ts.Config.ReadHeaderTimeout = 1 * time.Second
5445 ts.Config.IdleTimeout = 2 * time.Second
5446 ts.Start()
5447 defer ts.Close()
5448 c := ts.Client()
5449
5450 get := func() string {
5451 res, err := c.Get(ts.URL)
5452 if err != nil {
5453 t.Fatal(err)
5454 }
5455 defer res.Body.Close()
5456 slurp, err := io.ReadAll(res.Body)
5457 if err != nil {
5458 t.Fatal(err)
5459 }
5460 return string(slurp)
5461 }
5462
5463 a1, a2 := get(), get()
5464 if a1 != a2 {
5465 t.Fatalf("did requests on different connections")
5466 }
5467 time.Sleep(3 * time.Second)
5468 a3 := get()
5469 if a2 == a3 {
5470 t.Fatal("request three unexpectedly on same connection")
5471 }
5472
5473
5474 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5475 if err != nil {
5476 t.Fatal(err)
5477 }
5478 defer conn.Close()
5479 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5480 time.Sleep(2 * time.Second)
5481 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5482 t.Fatal("copy byte succeeded; want err")
5483 }
5484 }
5485
5486 func get(t *testing.T, c *Client, url string) string {
5487 res, err := c.Get(url)
5488 if err != nil {
5489 t.Fatal(err)
5490 }
5491 defer res.Body.Close()
5492 slurp, err := io.ReadAll(res.Body)
5493 if err != nil {
5494 t.Fatal(err)
5495 }
5496 return string(slurp)
5497 }
5498
5499
5500
5501 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5502 setParallel(t)
5503 defer afterTest(t)
5504 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5505 io.WriteString(w, r.RemoteAddr)
5506 }))
5507 defer ts.Close()
5508
5509 c := ts.Client()
5510 tr := c.Transport.(*Transport)
5511
5512 get := func() string { return get(t, c, ts.URL) }
5513
5514 a1, a2 := get(), get()
5515 if a1 != a2 {
5516 t.Fatal("expected first two requests on same connection")
5517 }
5518 addr := strings.TrimPrefix(ts.URL, "http://")
5519
5520
5521
5522
5523
5524 n := tr.IdleConnCountForTesting("http", addr)
5525 if n != 1 {
5526 t.Fatalf("idle count for %q after 2 gets = %d, want 1", addr, n)
5527 }
5528
5529
5530 ts.Config.SetKeepAlivesEnabled(false)
5531
5532 var idle1 int
5533 if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
5534 idle1 = tr.IdleConnCountForTesting("http", addr)
5535 return idle1 == 0
5536 }) {
5537 t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1)
5538 }
5539
5540 a3 := get()
5541 if a3 == a2 {
5542 t.Fatal("expected third request on new connection")
5543 }
5544 }
5545
5546 func TestServerShutdown_h1(t *testing.T) {
5547 testServerShutdown(t, h1Mode)
5548 }
5549 func TestServerShutdown_h2(t *testing.T) {
5550 testServerShutdown(t, h2Mode)
5551 }
5552
5553 func testServerShutdown(t *testing.T, h2 bool) {
5554 setParallel(t)
5555 defer afterTest(t)
5556 var doShutdown func()
5557 var doStateCount func()
5558 var shutdownRes = make(chan error, 1)
5559 var statesRes = make(chan map[ConnState]int, 1)
5560 var gotOnShutdown = make(chan struct{}, 1)
5561 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5562 doStateCount()
5563 go doShutdown()
5564
5565
5566
5567
5568 time.Sleep(20 * time.Millisecond)
5569 io.WriteString(w, r.RemoteAddr)
5570 })
5571 cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) {
5572 srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} })
5573 })
5574 defer cst.close()
5575
5576 doShutdown = func() {
5577 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5578 }
5579 doStateCount = func() {
5580 statesRes <- cst.ts.Config.ExportAllConnsByState()
5581 }
5582 get(t, cst.c, cst.ts.URL)
5583
5584 if err := <-shutdownRes; err != nil {
5585 t.Fatalf("Shutdown: %v", err)
5586 }
5587 select {
5588 case <-gotOnShutdown:
5589 case <-time.After(5 * time.Second):
5590 t.Errorf("onShutdown callback not called, RegisterOnShutdown broken?")
5591 }
5592
5593 if states := <-statesRes; states[StateActive] != 1 {
5594 t.Errorf("connection in wrong state, %v", states)
5595 }
5596
5597 res, err := cst.c.Get(cst.ts.URL)
5598 if err == nil {
5599 res.Body.Close()
5600 t.Fatal("second request should fail. server should be shut down")
5601 }
5602 }
5603
5604 func TestServerShutdownStateNew(t *testing.T) {
5605 if testing.Short() {
5606 t.Skip("test takes 5-6 seconds; skipping in short mode")
5607 }
5608 setParallel(t)
5609 defer afterTest(t)
5610
5611 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5612
5613 }))
5614 var connAccepted sync.WaitGroup
5615 ts.Config.ConnState = func(conn net.Conn, state ConnState) {
5616 if state == StateNew {
5617 connAccepted.Done()
5618 }
5619 }
5620 ts.Start()
5621 defer ts.Close()
5622
5623
5624 connAccepted.Add(1)
5625 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5626 if err != nil {
5627 t.Fatal(err)
5628 }
5629 defer c.Close()
5630
5631
5632
5633
5634
5635 connAccepted.Wait()
5636
5637 shutdownRes := make(chan error, 1)
5638 go func() {
5639 shutdownRes <- ts.Config.Shutdown(context.Background())
5640 }()
5641 readRes := make(chan error, 1)
5642 go func() {
5643 _, err := c.Read([]byte{0})
5644 readRes <- err
5645 }()
5646
5647 const expectTimeout = 5 * time.Second
5648 t0 := time.Now()
5649 select {
5650 case got := <-shutdownRes:
5651 d := time.Since(t0)
5652 if got != nil {
5653 t.Fatalf("shutdown error after %v: %v", d, err)
5654 }
5655 if d < expectTimeout/2 {
5656 t.Errorf("shutdown too soon after %v", d)
5657 }
5658 case <-time.After(expectTimeout * 3 / 2):
5659 t.Fatalf("timeout waiting for shutdown")
5660 }
5661
5662
5663
5664 select {
5665 case err := <-readRes:
5666 if err == nil {
5667 t.Error("expected error from Read")
5668 }
5669 case <-time.After(2 * time.Second):
5670 t.Errorf("timeout waiting for Read to unblock")
5671 }
5672 }
5673
5674
5675 func TestServerCloseDeadlock(t *testing.T) {
5676 var s Server
5677 s.Close()
5678 s.Close()
5679 }
5680
5681
5682
5683 func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) }
5684 func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) }
5685 func testServerKeepAlivesEnabled(t *testing.T, h2 bool) {
5686 if h2 {
5687 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5688 defer restore()
5689 }
5690
5691 defer afterTest(t)
5692 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
5693 fmt.Fprintf(w, "%v", r.RemoteAddr)
5694 }))
5695 defer cst.close()
5696 srv := cst.ts.Config
5697 srv.SetKeepAlivesEnabled(false)
5698 a := cst.getURL(cst.ts.URL)
5699 if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
5700 t.Fatalf("test server has active conns")
5701 }
5702 b := cst.getURL(cst.ts.URL)
5703 if a == b {
5704 t.Errorf("got same connection between first and second requests")
5705 }
5706 if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
5707 t.Fatalf("test server has active conns")
5708 }
5709 }
5710
5711
5712
5713
5714 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) {
5715 setParallel(t)
5716 defer afterTest(t)
5717 runTimeSensitiveTest(t, []time.Duration{
5718 10 * time.Millisecond,
5719 50 * time.Millisecond,
5720 250 * time.Millisecond,
5721 time.Second,
5722 2 * time.Second,
5723 }, func(t *testing.T, timeout time.Duration) error {
5724 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5725 select {
5726 case <-time.After(2 * timeout):
5727 fmt.Fprint(w, "ok")
5728 case <-r.Context().Done():
5729 fmt.Fprint(w, r.Context().Err())
5730 }
5731 }))
5732 ts.Config.ReadTimeout = timeout
5733 ts.Start()
5734 defer ts.Close()
5735
5736 c := ts.Client()
5737
5738 res, err := c.Get(ts.URL)
5739 if err != nil {
5740 return fmt.Errorf("Get: %v", err)
5741 }
5742 slurp, err := io.ReadAll(res.Body)
5743 res.Body.Close()
5744 if err != nil {
5745 return fmt.Errorf("Body ReadAll: %v", err)
5746 }
5747 if string(slurp) != "ok" {
5748 return fmt.Errorf("got: %q, want ok", slurp)
5749 }
5750 return nil
5751 })
5752 }
5753
5754
5755
5756 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
5757 for i, d := range durations {
5758 err := test(t, d)
5759 if err == nil {
5760 return
5761 }
5762 if i == len(durations)-1 {
5763 t.Fatalf("failed with duration %v: %v", d, err)
5764 }
5765 }
5766 }
5767
5768
5769
5770 func TestServerDuplicateBackgroundRead(t *testing.T) {
5771 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
5772 testenv.SkipFlaky(t, 24826)
5773 }
5774
5775 setParallel(t)
5776 defer afterTest(t)
5777
5778 goroutines := 5
5779 requests := 2000
5780 if testing.Short() {
5781 goroutines = 3
5782 requests = 100
5783 }
5784
5785 hts := httptest.NewServer(HandlerFunc(NotFound))
5786 defer hts.Close()
5787
5788 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
5789
5790 var wg sync.WaitGroup
5791 for i := 0; i < goroutines; i++ {
5792 wg.Add(1)
5793 go func() {
5794 defer wg.Done()
5795 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
5796 if err != nil {
5797 t.Error(err)
5798 return
5799 }
5800 defer cn.Close()
5801
5802 wg.Add(1)
5803 go func() {
5804 defer wg.Done()
5805 io.Copy(io.Discard, cn)
5806 }()
5807
5808 for j := 0; j < requests; j++ {
5809 if t.Failed() {
5810 return
5811 }
5812 _, err := cn.Write(reqBytes)
5813 if err != nil {
5814 t.Error(err)
5815 return
5816 }
5817 }
5818 }()
5819 }
5820 wg.Wait()
5821 }
5822
5823
5824
5825
5826
5827
5828 func TestServerHijackGetsBackgroundByte(t *testing.T) {
5829 if runtime.GOOS == "plan9" {
5830 t.Skip("skipping test; see https://golang.org/issue/18657")
5831 }
5832 setParallel(t)
5833 defer afterTest(t)
5834 done := make(chan struct{})
5835 inHandler := make(chan bool, 1)
5836 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5837 defer close(done)
5838
5839
5840 inHandler <- true
5841
5842 conn, buf, err := w.(Hijacker).Hijack()
5843 if err != nil {
5844 t.Error(err)
5845 return
5846 }
5847 defer conn.Close()
5848
5849 peek, err := buf.Reader.Peek(3)
5850 if string(peek) != "foo" || err != nil {
5851 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
5852 }
5853
5854 select {
5855 case <-r.Context().Done():
5856 t.Error("context unexpectedly canceled")
5857 default:
5858 }
5859 }))
5860 defer ts.Close()
5861
5862 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
5863 if err != nil {
5864 t.Fatal(err)
5865 }
5866 defer cn.Close()
5867 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5868 t.Fatal(err)
5869 }
5870 <-inHandler
5871 if _, err := cn.Write([]byte("foo")); err != nil {
5872 t.Fatal(err)
5873 }
5874
5875 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
5876 t.Fatal(err)
5877 }
5878 select {
5879 case <-done:
5880 case <-time.After(2 * time.Second):
5881 t.Error("timeout")
5882 }
5883 }
5884
5885
5886
5887
5888 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
5889 if runtime.GOOS == "plan9" {
5890 t.Skip("skipping test; see https://golang.org/issue/18657")
5891 }
5892 setParallel(t)
5893 defer afterTest(t)
5894 done := make(chan struct{})
5895 const size = 8 << 10
5896 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
5897 defer close(done)
5898
5899 conn, buf, err := w.(Hijacker).Hijack()
5900 if err != nil {
5901 t.Error(err)
5902 return
5903 }
5904 defer conn.Close()
5905 slurp, err := io.ReadAll(buf.Reader)
5906 if err != nil {
5907 t.Errorf("Copy: %v", err)
5908 }
5909 allX := true
5910 for _, v := range slurp {
5911 if v != 'x' {
5912 allX = false
5913 }
5914 }
5915 if len(slurp) != size {
5916 t.Errorf("read %d; want %d", len(slurp), size)
5917 } else if !allX {
5918 t.Errorf("read %q; want %d 'x'", slurp, size)
5919 }
5920 }))
5921 defer ts.Close()
5922
5923 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
5924 if err != nil {
5925 t.Fatal(err)
5926 }
5927 defer cn.Close()
5928 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
5929 strings.Repeat("x", size)); err != nil {
5930 t.Fatal(err)
5931 }
5932 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
5933 t.Fatal(err)
5934 }
5935
5936 select {
5937 case <-done:
5938 case <-time.After(2 * time.Second):
5939 t.Error("timeout")
5940 }
5941 }
5942
5943
5944 func TestServerValidatesMethod(t *testing.T) {
5945 tests := []struct {
5946 method string
5947 want int
5948 }{
5949 {"GET", 200},
5950 {"GE(T", 400},
5951 }
5952 for _, tt := range tests {
5953 conn := &testConn{closec: make(chan bool, 1)}
5954 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
5955
5956 ln := &oneConnListener{conn}
5957 go Serve(ln, serve(200))
5958 <-conn.closec
5959 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5960 if err != nil {
5961 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
5962 continue
5963 }
5964 if res.StatusCode != tt.want {
5965 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
5966 }
5967 }
5968 }
5969
5970
5971 type eofListenerNotComparable []int
5972
5973 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
5974 func (eofListenerNotComparable) Addr() net.Addr { return nil }
5975 func (eofListenerNotComparable) Close() error { return nil }
5976
5977
5978 func TestServerListenNotComparableListener(t *testing.T) {
5979 var s Server
5980 s.Serve(make(eofListenerNotComparable, 1))
5981 }
5982
5983
5984 type countCloseListener struct {
5985 net.Listener
5986 closes int32
5987 }
5988
5989 func (p *countCloseListener) Close() error {
5990 var err error
5991 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
5992 err = p.Listener.Close()
5993 }
5994 return err
5995 }
5996
5997
5998 func TestServerCloseListenerOnce(t *testing.T) {
5999 setParallel(t)
6000 defer afterTest(t)
6001
6002 ln := newLocalListener(t)
6003 defer ln.Close()
6004
6005 cl := &countCloseListener{Listener: ln}
6006 server := &Server{}
6007 sdone := make(chan bool, 1)
6008
6009 go func() {
6010 server.Serve(cl)
6011 sdone <- true
6012 }()
6013 time.Sleep(10 * time.Millisecond)
6014 server.Shutdown(context.Background())
6015 ln.Close()
6016 <-sdone
6017
6018 nclose := atomic.LoadInt32(&cl.closes)
6019 if nclose != 1 {
6020 t.Errorf("Close calls = %v; want 1", nclose)
6021 }
6022 }
6023
6024
6025 func TestServerShutdownThenServe(t *testing.T) {
6026 var srv Server
6027 cl := &countCloseListener{Listener: nil}
6028 srv.Shutdown(context.Background())
6029 got := srv.Serve(cl)
6030 if got != ErrServerClosed {
6031 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6032 }
6033 nclose := atomic.LoadInt32(&cl.closes)
6034 if nclose != 1 {
6035 t.Errorf("Close calls = %v; want 1", nclose)
6036 }
6037 }
6038
6039
6040 func TestStripPortFromHost(t *testing.T) {
6041 mux := NewServeMux()
6042
6043 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6044 fmt.Fprintf(w, "OK")
6045 })
6046 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6047 fmt.Fprintf(w, "uh-oh!")
6048 })
6049
6050 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6051 rw := httptest.NewRecorder()
6052
6053 mux.ServeHTTP(rw, req)
6054
6055 response := rw.Body.String()
6056 if response != "OK" {
6057 t.Errorf("Response gotten was %q", response)
6058 }
6059 }
6060
6061 func TestServerContexts(t *testing.T) {
6062 setParallel(t)
6063 defer afterTest(t)
6064 type baseKey struct{}
6065 type connKey struct{}
6066 ch := make(chan context.Context, 1)
6067 ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
6068 ch <- r.Context()
6069 }))
6070 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6071 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6072 t.Errorf("unexpected onceClose listener type %T", ln)
6073 }
6074 return context.WithValue(context.Background(), baseKey{}, "base")
6075 }
6076 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6077 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6078 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6079 }
6080 return context.WithValue(ctx, connKey{}, "conn")
6081 }
6082 ts.Start()
6083 defer ts.Close()
6084 res, err := ts.Client().Get(ts.URL)
6085 if err != nil {
6086 t.Fatal(err)
6087 }
6088 res.Body.Close()
6089 ctx := <-ch
6090 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6091 t.Errorf("base context key = %#v; want %q", got, want)
6092 }
6093 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6094 t.Errorf("conn context key = %#v; want %q", got, want)
6095 }
6096 }
6097
6098 func TestServerContextsHTTP2(t *testing.T) {
6099 setParallel(t)
6100 defer afterTest(t)
6101 type baseKey struct{}
6102 type connKey struct{}
6103 ch := make(chan context.Context, 1)
6104 ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
6105 if r.ProtoMajor != 2 {
6106 t.Errorf("unexpected HTTP/1.x request")
6107 }
6108 ch <- r.Context()
6109 }))
6110 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6111 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6112 t.Errorf("unexpected onceClose listener type %T", ln)
6113 }
6114 return context.WithValue(context.Background(), baseKey{}, "base")
6115 }
6116 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6117 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6118 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6119 }
6120 return context.WithValue(ctx, connKey{}, "conn")
6121 }
6122 ts.TLS = &tls.Config{
6123 NextProtos: []string{"h2", "http/1.1"},
6124 }
6125 ts.StartTLS()
6126 defer ts.Close()
6127 ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true
6128 res, err := ts.Client().Get(ts.URL)
6129 if err != nil {
6130 t.Fatal(err)
6131 }
6132 res.Body.Close()
6133 ctx := <-ch
6134 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6135 t.Errorf("base context key = %#v; want %q", got, want)
6136 }
6137 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6138 t.Errorf("conn context key = %#v; want %q", got, want)
6139 }
6140 }
6141
6142
6143 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6144 setParallel(t)
6145 defer afterTest(t)
6146 type connKey struct{}
6147 ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
6148 rw.Header().Set("Connection", "close")
6149 }))
6150 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6151 if got := ctx.Value(connKey{}); got != nil {
6152 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6153 }
6154 return context.WithValue(ctx, connKey{}, "conn")
6155 }
6156 ts.Start()
6157 defer ts.Close()
6158
6159 var res *Response
6160 var err error
6161
6162 res, err = ts.Client().Get(ts.URL)
6163 if err != nil {
6164 t.Fatal(err)
6165 }
6166 res.Body.Close()
6167
6168 res, err = ts.Client().Get(ts.URL)
6169 if err != nil {
6170 t.Fatal(err)
6171 }
6172 res.Body.Close()
6173 }
6174
6175
6176
6177 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6178 cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
6179 w.Write([]byte("Hello, World!"))
6180 }))
6181 defer cst.Close()
6182
6183 serverURL, err := url.Parse(cst.URL)
6184 if err != nil {
6185 t.Fatalf("Failed to parse server URL: %v", err)
6186 }
6187
6188 unsupportedTEs := []string{
6189 "fugazi",
6190 "foo-bar",
6191 "unknown",
6192 "\rchunked",
6193 }
6194
6195 for _, badTE := range unsupportedTEs {
6196 http1ReqBody := fmt.Sprintf(""+
6197 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6198 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6199
6200 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6201 if err != nil {
6202 t.Errorf("%q. unexpected error: %v", badTE, err)
6203 continue
6204 }
6205
6206 wantBody := fmt.Sprintf("" +
6207 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6208 "Connection: close\r\n\r\nUnsupported transfer encoding")
6209
6210 if string(gotBody) != wantBody {
6211 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6212 }
6213 }
6214 }
6215
6216 func TestContentEncodingNoSniffing_h1(t *testing.T) {
6217 testContentEncodingNoSniffing(t, h1Mode)
6218 }
6219
6220 func TestContentEncodingNoSniffing_h2(t *testing.T) {
6221 testContentEncodingNoSniffing(t, h2Mode)
6222 }
6223
6224
6225 func testContentEncodingNoSniffing(t *testing.T, h2 bool) {
6226 setParallel(t)
6227 defer afterTest(t)
6228
6229 type setting struct {
6230 name string
6231 body []byte
6232
6233
6234
6235
6236 contentEncoding interface{}
6237 wantContentType string
6238 }
6239
6240 settings := []*setting{
6241 {
6242 name: "gzip content-encoding, gzipped",
6243 contentEncoding: "application/gzip",
6244 wantContentType: "",
6245 body: func() []byte {
6246 buf := new(bytes.Buffer)
6247 gzw := gzip.NewWriter(buf)
6248 gzw.Write([]byte("doctype html><p>Hello</p>"))
6249 gzw.Close()
6250 return buf.Bytes()
6251 }(),
6252 },
6253 {
6254 name: "zlib content-encoding, zlibbed",
6255 contentEncoding: "application/zlib",
6256 wantContentType: "",
6257 body: func() []byte {
6258 buf := new(bytes.Buffer)
6259 zw := zlib.NewWriter(buf)
6260 zw.Write([]byte("doctype html><p>Hello</p>"))
6261 zw.Close()
6262 return buf.Bytes()
6263 }(),
6264 },
6265 {
6266 name: "no content-encoding",
6267 wantContentType: "application/x-gzip",
6268 body: func() []byte {
6269 buf := new(bytes.Buffer)
6270 gzw := gzip.NewWriter(buf)
6271 gzw.Write([]byte("doctype html><p>Hello</p>"))
6272 gzw.Close()
6273 return buf.Bytes()
6274 }(),
6275 },
6276 {
6277 name: "phony content-encoding",
6278 contentEncoding: "foo/bar",
6279 body: []byte("doctype html><p>Hello</p>"),
6280 },
6281 {
6282 name: "empty but set content-encoding",
6283 contentEncoding: "",
6284 wantContentType: "audio/mpeg",
6285 body: []byte("ID3"),
6286 },
6287 }
6288
6289 for _, tt := range settings {
6290 t.Run(tt.name, func(t *testing.T) {
6291 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) {
6292 if tt.contentEncoding != nil {
6293 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6294 }
6295 rw.Write(tt.body)
6296 }))
6297 defer cst.close()
6298
6299 res, err := cst.c.Get(cst.ts.URL)
6300 if err != nil {
6301 t.Fatalf("Failed to fetch URL: %v", err)
6302 }
6303 defer res.Body.Close()
6304
6305 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6306 if w != nil {
6307 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6308 } else if g != "" {
6309 t.Errorf("Unexpected Content-Encoding %q", g)
6310 }
6311 }
6312
6313 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6314 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6315 }
6316 })
6317 }
6318 }
6319
6320
6321
6322 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6323 if testing.Short() {
6324 t.Skip("skipping in short mode")
6325 }
6326
6327 setParallel(t)
6328 defer afterTest(t)
6329
6330 pc, curFile, _, _ := runtime.Caller(0)
6331 curFileBaseName := filepath.Base(curFile)
6332 testFuncName := runtime.FuncForPC(pc).Name()
6333
6334 timeoutMsg := "timed out here!"
6335
6336 tests := []struct {
6337 name string
6338 mustTimeout bool
6339 wantResp string
6340 }{
6341 {
6342 name: "return before timeout",
6343 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6344 },
6345 {
6346 name: "return after timeout",
6347 mustTimeout: true,
6348 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6349 len(timeoutMsg), timeoutMsg),
6350 },
6351 }
6352
6353 for _, tt := range tests {
6354 tt := tt
6355 t.Run(tt.name, func(t *testing.T) {
6356 exitHandler := make(chan bool, 1)
6357 defer close(exitHandler)
6358 lastLine := make(chan int, 1)
6359
6360 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6361 w.WriteHeader(404)
6362 w.WriteHeader(404)
6363 w.WriteHeader(404)
6364 w.WriteHeader(404)
6365 _, _, line, _ := runtime.Caller(0)
6366 lastLine <- line
6367 <-exitHandler
6368 })
6369
6370 if !tt.mustTimeout {
6371 exitHandler <- true
6372 }
6373
6374 logBuf := new(bytes.Buffer)
6375 srvLog := log.New(logBuf, "", 0)
6376
6377 dur := 20 * time.Millisecond
6378 if !tt.mustTimeout {
6379
6380 dur = 10 * time.Second
6381 }
6382 th := TimeoutHandler(sh, dur, timeoutMsg)
6383 cst := newClientServerTest(t, h1Mode , th, optWithServerLog(srvLog))
6384 defer cst.close()
6385
6386 res, err := cst.c.Get(cst.ts.URL)
6387 if err != nil {
6388 t.Fatalf("Unexpected error: %v", err)
6389 }
6390
6391
6392
6393 res.Header.Del("Date")
6394 res.Header.Del("Content-Type")
6395
6396
6397 blob, _ := httputil.DumpResponse(res, true)
6398 if g, w := string(blob), tt.wantResp; g != w {
6399 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6400 }
6401
6402
6403
6404 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6405 if g, w := len(logEntries), 3; g != w {
6406 blob, _ := json.MarshalIndent(logEntries, "", " ")
6407 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6408 }
6409
6410 lastSpuriousLine := <-lastLine
6411 firstSpuriousLine := lastSpuriousLine - 3
6412
6413
6414 for i, logEntry := range logEntries {
6415 wantLine := firstSpuriousLine + i
6416 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6417 testFuncName, curFileBaseName, wantLine)
6418 re := regexp.MustCompile(pat)
6419 if !re.MatchString(logEntry) {
6420 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6421 }
6422 }
6423 })
6424 }
6425 }
6426
6427
6428
6429
6430 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6431 conn, err := net.Dial("tcp", host)
6432 if err != nil {
6433 return nil, err
6434 }
6435 defer conn.Close()
6436
6437 if _, err := conn.Write(http1ReqBody); err != nil {
6438 return nil, err
6439 }
6440 return io.ReadAll(conn)
6441 }
6442
6443 func BenchmarkResponseStatusLine(b *testing.B) {
6444 b.ReportAllocs()
6445 b.RunParallel(func(pb *testing.PB) {
6446 bw := bufio.NewWriter(io.Discard)
6447 var buf3 [3]byte
6448 for pb.Next() {
6449 Export_writeStatusLine(bw, true, 200, buf3[:])
6450 }
6451 })
6452 }
6453 func TestDisableKeepAliveUpgrade(t *testing.T) {
6454 if testing.Short() {
6455 t.Skip("skipping in short mode")
6456 }
6457
6458 setParallel(t)
6459 defer afterTest(t)
6460
6461 s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
6462 w.Header().Set("Connection", "Upgrade")
6463 w.Header().Set("Upgrade", "someProto")
6464 w.WriteHeader(StatusSwitchingProtocols)
6465 c, buf, err := w.(Hijacker).Hijack()
6466 if err != nil {
6467 return
6468 }
6469 defer c.Close()
6470
6471
6472
6473 io.Copy(c, buf)
6474 }))
6475 s.Config.SetKeepAlivesEnabled(false)
6476 s.Start()
6477 defer s.Close()
6478
6479 cl := s.Client()
6480 cl.Transport.(*Transport).DisableKeepAlives = true
6481
6482 resp, err := cl.Get(s.URL)
6483 if err != nil {
6484 t.Fatalf("failed to perform request: %v", err)
6485 }
6486 defer resp.Body.Close()
6487
6488 if resp.StatusCode != StatusSwitchingProtocols {
6489 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6490 }
6491
6492 rwc, ok := resp.Body.(io.ReadWriteCloser)
6493 if !ok {
6494 t.Fatalf("Response.Body is not a io.ReadWriteCloser: %T", resp.Body)
6495 }
6496
6497 _, err = rwc.Write([]byte("hello"))
6498 if err != nil {
6499 t.Fatalf("failed to write to body: %v", err)
6500 }
6501
6502 b := make([]byte, 5)
6503 _, err = io.ReadFull(rwc, b)
6504 if err != nil {
6505 t.Fatalf("failed to read from body: %v", err)
6506 }
6507
6508 if string(b) != "hello" {
6509 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6510 }
6511 }
6512
6513 func TestMuxRedirectRelative(t *testing.T) {
6514 setParallel(t)
6515 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6516 if err != nil {
6517 t.Errorf("%s", err)
6518 }
6519 mux := NewServeMux()
6520 resp := httptest.NewRecorder()
6521 mux.ServeHTTP(resp, req)
6522 if got, want := resp.Header().Get("Location"), "/"; got != want {
6523 t.Errorf("Location header expected %q; got %q", want, got)
6524 }
6525 if got, want := resp.Code, StatusMovedPermanently; got != want {
6526 t.Errorf("Expected response code %d; got %d", want, got)
6527 }
6528 }
6529
6530
6531 func TestQuerySemicolon(t *testing.T) {
6532 t.Cleanup(func() { afterTest(t) })
6533
6534 tests := []struct {
6535 query string
6536 xNoSemicolons string
6537 xWithSemicolons string
6538 warning bool
6539 }{
6540 {"?a=1;x=bad&x=good", "good", "bad", true},
6541 {"?a=1;b=bad&x=good", "good", "good", true},
6542 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6543 {"?a=1;x=good;x=bad", "", "good", true},
6544 }
6545
6546 for _, tt := range tests {
6547 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6548 allowSemicolons := false
6549 testQuerySemicolon(t, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning)
6550 })
6551 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6552 allowSemicolons, expectWarning := true, false
6553 testQuerySemicolon(t, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning)
6554 })
6555 }
6556 }
6557
6558 func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolons, expectWarning bool) {
6559 setParallel(t)
6560
6561 writeBackX := func(w ResponseWriter, r *Request) {
6562 x := r.URL.Query().Get("x")
6563 if expectWarning {
6564 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6565 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6566 }
6567 } else {
6568 if err := r.ParseForm(); err != nil {
6569 t.Errorf("expected no error from ParseForm, got %v", err)
6570 }
6571 }
6572 if got := r.FormValue("x"); x != got {
6573 t.Errorf("got %q from FormValue, want %q", got, x)
6574 }
6575 fmt.Fprintf(w, "%s", x)
6576 }
6577
6578 h := Handler(HandlerFunc(writeBackX))
6579 if allowSemicolons {
6580 h = AllowQuerySemicolons(h)
6581 }
6582
6583 ts := httptest.NewUnstartedServer(h)
6584 logBuf := &bytes.Buffer{}
6585 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6586 ts.Start()
6587 defer ts.Close()
6588
6589 req, _ := NewRequest("GET", ts.URL+query, nil)
6590 res, err := ts.Client().Do(req)
6591 if err != nil {
6592 t.Fatal(err)
6593 }
6594 slurp, _ := io.ReadAll(res.Body)
6595 res.Body.Close()
6596 if got, want := res.StatusCode, 200; got != want {
6597 t.Errorf("Status = %d; want = %d", got, want)
6598 }
6599 if got, want := string(slurp), wantX; got != want {
6600 t.Errorf("Body = %q; want = %q", got, want)
6601 }
6602
6603 if expectWarning {
6604 if !strings.Contains(logBuf.String(), "semicolon") {
6605 t.Errorf("got %q from ErrorLog, expected a mention of semicolons", logBuf.String())
6606 }
6607 } else {
6608 if strings.Contains(logBuf.String(), "semicolon") {
6609 t.Errorf("got %q from ErrorLog, expected no mention of semicolons", logBuf.String())
6610 }
6611 }
6612 }
6613
View as plain text