...

Source file src/net/http/transport_internal_test.go

Documentation: net/http

		 1  // Copyright 2016 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  // White-box tests for transport.go (in package http instead of http_test).
		 6  
		 7  package http
		 8  
		 9  import (
		10  	"bytes"
		11  	"crypto/tls"
		12  	"errors"
		13  	"io"
		14  	"net"
		15  	"net/http/internal/testcert"
		16  	"strings"
		17  	"testing"
		18  )
		19  
		20  // Issue 15446: incorrect wrapping of errors when server closes an idle connection.
		21  func TestTransportPersistConnReadLoopEOF(t *testing.T) {
		22  	ln := newLocalListener(t)
		23  	defer ln.Close()
		24  
		25  	connc := make(chan net.Conn, 1)
		26  	go func() {
		27  		defer close(connc)
		28  		c, err := ln.Accept()
		29  		if err != nil {
		30  			t.Error(err)
		31  			return
		32  		}
		33  		connc <- c
		34  	}()
		35  
		36  	tr := new(Transport)
		37  	req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
		38  	req = req.WithT(t)
		39  	treq := &transportRequest{Request: req}
		40  	cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
		41  	pc, err := tr.getConn(treq, cm)
		42  	if err != nil {
		43  		t.Fatal(err)
		44  	}
		45  	defer pc.close(errors.New("test over"))
		46  
		47  	conn := <-connc
		48  	if conn == nil {
		49  		// Already called t.Error in the accept goroutine.
		50  		return
		51  	}
		52  	conn.Close() // simulate the server hanging up on the client
		53  
		54  	_, err = pc.roundTrip(treq)
		55  	if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
		56  		t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
		57  	}
		58  
		59  	<-pc.closech
		60  	err = pc.closed
		61  	if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
		62  		t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
		63  	}
		64  }
		65  
		66  func isTransportReadFromServerError(err error) bool {
		67  	_, ok := err.(transportReadFromServerError)
		68  	return ok
		69  }
		70  
		71  func newLocalListener(t *testing.T) net.Listener {
		72  	ln, err := net.Listen("tcp", "127.0.0.1:0")
		73  	if err != nil {
		74  		ln, err = net.Listen("tcp6", "[::1]:0")
		75  	}
		76  	if err != nil {
		77  		t.Fatal(err)
		78  	}
		79  	return ln
		80  }
		81  
		82  func dummyRequest(method string) *Request {
		83  	req, err := NewRequest(method, "http://fake.tld/", nil)
		84  	if err != nil {
		85  		panic(err)
		86  	}
		87  	return req
		88  }
		89  func dummyRequestWithBody(method string) *Request {
		90  	req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo"))
		91  	if err != nil {
		92  		panic(err)
		93  	}
		94  	return req
		95  }
		96  
		97  func dummyRequestWithBodyNoGetBody(method string) *Request {
		98  	req := dummyRequestWithBody(method)
		99  	req.GetBody = nil
	 100  	return req
	 101  }
	 102  
	 103  // issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn.
	 104  type issue22091Error struct{}
	 105  
	 106  func (issue22091Error) IsHTTP2NoCachedConnError() {}
	 107  func (issue22091Error) Error() string						 { return "issue22091Error" }
	 108  
	 109  func TestTransportShouldRetryRequest(t *testing.T) {
	 110  	tests := []struct {
	 111  		pc	*persistConn
	 112  		req *Request
	 113  
	 114  		err	error
	 115  		want bool
	 116  	}{
	 117  		0: {
	 118  			pc:	 &persistConn{reused: false},
	 119  			req:	dummyRequest("POST"),
	 120  			err:	nothingWrittenError{},
	 121  			want: false,
	 122  		},
	 123  		1: {
	 124  			pc:	 &persistConn{reused: true},
	 125  			req:	dummyRequest("POST"),
	 126  			err:	nothingWrittenError{},
	 127  			want: true,
	 128  		},
	 129  		2: {
	 130  			pc:	 &persistConn{reused: true},
	 131  			req:	dummyRequest("POST"),
	 132  			err:	http2ErrNoCachedConn,
	 133  			want: true,
	 134  		},
	 135  		3: {
	 136  			pc:	 nil,
	 137  			req:	nil,
	 138  			err:	issue22091Error{}, // like an external http2ErrNoCachedConn
	 139  			want: true,
	 140  		},
	 141  		4: {
	 142  			pc:	 &persistConn{reused: true},
	 143  			req:	dummyRequest("POST"),
	 144  			err:	errMissingHost,
	 145  			want: false,
	 146  		},
	 147  		5: {
	 148  			pc:	 &persistConn{reused: true},
	 149  			req:	dummyRequest("POST"),
	 150  			err:	transportReadFromServerError{},
	 151  			want: false,
	 152  		},
	 153  		6: {
	 154  			pc:	 &persistConn{reused: true},
	 155  			req:	dummyRequest("GET"),
	 156  			err:	transportReadFromServerError{},
	 157  			want: true,
	 158  		},
	 159  		7: {
	 160  			pc:	 &persistConn{reused: true},
	 161  			req:	dummyRequest("GET"),
	 162  			err:	errServerClosedIdle,
	 163  			want: true,
	 164  		},
	 165  		8: {
	 166  			pc:	 &persistConn{reused: true},
	 167  			req:	dummyRequestWithBody("POST"),
	 168  			err:	nothingWrittenError{},
	 169  			want: true,
	 170  		},
	 171  		9: {
	 172  			pc:	 &persistConn{reused: true},
	 173  			req:	dummyRequestWithBodyNoGetBody("POST"),
	 174  			err:	nothingWrittenError{},
	 175  			want: false,
	 176  		},
	 177  	}
	 178  	for i, tt := range tests {
	 179  		got := tt.pc.shouldRetryRequest(tt.req, tt.err)
	 180  		if got != tt.want {
	 181  			t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
	 182  		}
	 183  	}
	 184  }
	 185  
	 186  type roundTripFunc func(r *Request) (*Response, error)
	 187  
	 188  func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
	 189  	return f(r)
	 190  }
	 191  
	 192  // Issue 25009
	 193  func TestTransportBodyAltRewind(t *testing.T) {
	 194  	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
	 195  	if err != nil {
	 196  		t.Fatal(err)
	 197  	}
	 198  	ln := newLocalListener(t)
	 199  	defer ln.Close()
	 200  
	 201  	go func() {
	 202  		tln := tls.NewListener(ln, &tls.Config{
	 203  			NextProtos:	 []string{"foo"},
	 204  			Certificates: []tls.Certificate{cert},
	 205  		})
	 206  		for i := 0; i < 2; i++ {
	 207  			sc, err := tln.Accept()
	 208  			if err != nil {
	 209  				t.Error(err)
	 210  				return
	 211  			}
	 212  			if err := sc.(*tls.Conn).Handshake(); err != nil {
	 213  				t.Error(err)
	 214  				return
	 215  			}
	 216  			sc.Close()
	 217  		}
	 218  	}()
	 219  
	 220  	addr := ln.Addr().String()
	 221  	req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
	 222  	roundTripped := false
	 223  	tr := &Transport{
	 224  		DisableKeepAlives: true,
	 225  		TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
	 226  			"foo": func(authority string, c *tls.Conn) RoundTripper {
	 227  				return roundTripFunc(func(r *Request) (*Response, error) {
	 228  					n, _ := io.Copy(io.Discard, r.Body)
	 229  					if n == 0 {
	 230  						t.Error("body length is zero")
	 231  					}
	 232  					if roundTripped {
	 233  						return &Response{
	 234  							Body:			 NoBody,
	 235  							StatusCode: 200,
	 236  						}, nil
	 237  					}
	 238  					roundTripped = true
	 239  					return nil, http2noCachedConnError{}
	 240  				})
	 241  			},
	 242  		},
	 243  		DialTLS: func(_, _ string) (net.Conn, error) {
	 244  			tc, err := tls.Dial("tcp", addr, &tls.Config{
	 245  				InsecureSkipVerify: true,
	 246  				NextProtos:				 []string{"foo"},
	 247  			})
	 248  			if err != nil {
	 249  				return nil, err
	 250  			}
	 251  			if err := tc.Handshake(); err != nil {
	 252  				return nil, err
	 253  			}
	 254  			return tc, nil
	 255  		},
	 256  	}
	 257  	c := &Client{Transport: tr}
	 258  	_, err = c.Do(req)
	 259  	if err != nil {
	 260  		t.Error(err)
	 261  	}
	 262  }
	 263  

View as plain text