Source file
src/net/http/transport_internal_test.go
1
2
3
4
5
6
7 package http
8
9 import (
10 "bytes"
11 "crypto/tls"
12 "errors"
13 "io"
14 "net"
15 "net/http/internal/testcert"
16 "strings"
17 "testing"
18 )
19
20
21 func TestTransportPersistConnReadLoopEOF(t *testing.T) {
22 ln := newLocalListener(t)
23 defer ln.Close()
24
25 connc := make(chan net.Conn, 1)
26 go func() {
27 defer close(connc)
28 c, err := ln.Accept()
29 if err != nil {
30 t.Error(err)
31 return
32 }
33 connc <- c
34 }()
35
36 tr := new(Transport)
37 req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
38 req = req.WithT(t)
39 treq := &transportRequest{Request: req}
40 cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
41 pc, err := tr.getConn(treq, cm)
42 if err != nil {
43 t.Fatal(err)
44 }
45 defer pc.close(errors.New("test over"))
46
47 conn := <-connc
48 if conn == nil {
49
50 return
51 }
52 conn.Close()
53
54 _, err = pc.roundTrip(treq)
55 if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
56 t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
57 }
58
59 <-pc.closech
60 err = pc.closed
61 if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
62 t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
63 }
64 }
65
66 func isTransportReadFromServerError(err error) bool {
67 _, ok := err.(transportReadFromServerError)
68 return ok
69 }
70
71 func newLocalListener(t *testing.T) net.Listener {
72 ln, err := net.Listen("tcp", "127.0.0.1:0")
73 if err != nil {
74 ln, err = net.Listen("tcp6", "[::1]:0")
75 }
76 if err != nil {
77 t.Fatal(err)
78 }
79 return ln
80 }
81
82 func dummyRequest(method string) *Request {
83 req, err := NewRequest(method, "http://fake.tld/", nil)
84 if err != nil {
85 panic(err)
86 }
87 return req
88 }
89 func dummyRequestWithBody(method string) *Request {
90 req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo"))
91 if err != nil {
92 panic(err)
93 }
94 return req
95 }
96
97 func dummyRequestWithBodyNoGetBody(method string) *Request {
98 req := dummyRequestWithBody(method)
99 req.GetBody = nil
100 return req
101 }
102
103
104 type issue22091Error struct{}
105
106 func (issue22091Error) IsHTTP2NoCachedConnError() {}
107 func (issue22091Error) Error() string { return "issue22091Error" }
108
109 func TestTransportShouldRetryRequest(t *testing.T) {
110 tests := []struct {
111 pc *persistConn
112 req *Request
113
114 err error
115 want bool
116 }{
117 0: {
118 pc: &persistConn{reused: false},
119 req: dummyRequest("POST"),
120 err: nothingWrittenError{},
121 want: false,
122 },
123 1: {
124 pc: &persistConn{reused: true},
125 req: dummyRequest("POST"),
126 err: nothingWrittenError{},
127 want: true,
128 },
129 2: {
130 pc: &persistConn{reused: true},
131 req: dummyRequest("POST"),
132 err: http2ErrNoCachedConn,
133 want: true,
134 },
135 3: {
136 pc: nil,
137 req: nil,
138 err: issue22091Error{},
139 want: true,
140 },
141 4: {
142 pc: &persistConn{reused: true},
143 req: dummyRequest("POST"),
144 err: errMissingHost,
145 want: false,
146 },
147 5: {
148 pc: &persistConn{reused: true},
149 req: dummyRequest("POST"),
150 err: transportReadFromServerError{},
151 want: false,
152 },
153 6: {
154 pc: &persistConn{reused: true},
155 req: dummyRequest("GET"),
156 err: transportReadFromServerError{},
157 want: true,
158 },
159 7: {
160 pc: &persistConn{reused: true},
161 req: dummyRequest("GET"),
162 err: errServerClosedIdle,
163 want: true,
164 },
165 8: {
166 pc: &persistConn{reused: true},
167 req: dummyRequestWithBody("POST"),
168 err: nothingWrittenError{},
169 want: true,
170 },
171 9: {
172 pc: &persistConn{reused: true},
173 req: dummyRequestWithBodyNoGetBody("POST"),
174 err: nothingWrittenError{},
175 want: false,
176 },
177 }
178 for i, tt := range tests {
179 got := tt.pc.shouldRetryRequest(tt.req, tt.err)
180 if got != tt.want {
181 t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
182 }
183 }
184 }
185
186 type roundTripFunc func(r *Request) (*Response, error)
187
188 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
189 return f(r)
190 }
191
192
193 func TestTransportBodyAltRewind(t *testing.T) {
194 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
195 if err != nil {
196 t.Fatal(err)
197 }
198 ln := newLocalListener(t)
199 defer ln.Close()
200
201 go func() {
202 tln := tls.NewListener(ln, &tls.Config{
203 NextProtos: []string{"foo"},
204 Certificates: []tls.Certificate{cert},
205 })
206 for i := 0; i < 2; i++ {
207 sc, err := tln.Accept()
208 if err != nil {
209 t.Error(err)
210 return
211 }
212 if err := sc.(*tls.Conn).Handshake(); err != nil {
213 t.Error(err)
214 return
215 }
216 sc.Close()
217 }
218 }()
219
220 addr := ln.Addr().String()
221 req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
222 roundTripped := false
223 tr := &Transport{
224 DisableKeepAlives: true,
225 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
226 "foo": func(authority string, c *tls.Conn) RoundTripper {
227 return roundTripFunc(func(r *Request) (*Response, error) {
228 n, _ := io.Copy(io.Discard, r.Body)
229 if n == 0 {
230 t.Error("body length is zero")
231 }
232 if roundTripped {
233 return &Response{
234 Body: NoBody,
235 StatusCode: 200,
236 }, nil
237 }
238 roundTripped = true
239 return nil, http2noCachedConnError{}
240 })
241 },
242 },
243 DialTLS: func(_, _ string) (net.Conn, error) {
244 tc, err := tls.Dial("tcp", addr, &tls.Config{
245 InsecureSkipVerify: true,
246 NextProtos: []string{"foo"},
247 })
248 if err != nil {
249 return nil, err
250 }
251 if err := tc.Handshake(); err != nil {
252 return nil, err
253 }
254 return tc, nil
255 },
256 }
257 c := &Client{Transport: tr}
258 _, err = c.Do(req)
259 if err != nil {
260 t.Error(err)
261 }
262 }
263
View as plain text