1
2
3
4
5
6
7 package httptest
8
9 import (
10 "crypto/tls"
11 "crypto/x509"
12 "flag"
13 "fmt"
14 "log"
15 "net"
16 "net/http"
17 "net/http/internal/testcert"
18 "os"
19 "strings"
20 "sync"
21 "time"
22 )
23
24
25
26 type Server struct {
27 URL string
28 Listener net.Listener
29
30
31
32
33 EnableHTTP2 bool
34
35
36
37
38 TLS *tls.Config
39
40
41
42 Config *http.Server
43
44
45 certificate *x509.Certificate
46
47
48
49 wg sync.WaitGroup
50
51 mu sync.Mutex
52 closed bool
53 conns map[net.Conn]http.ConnState
54
55
56
57 client *http.Client
58 }
59
60 func newLocalListener() net.Listener {
61 if serveFlag != "" {
62 l, err := net.Listen("tcp", serveFlag)
63 if err != nil {
64 panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
65 }
66 return l
67 }
68 l, err := net.Listen("tcp", "127.0.0.1:0")
69 if err != nil {
70 if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
71 panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
72 }
73 }
74 return l
75 }
76
77
78
79
80
81
82
83
84 var serveFlag string
85
86 func init() {
87 if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
88 flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
89 }
90 }
91
92 func strSliceContainsPrefix(v []string, pre string) bool {
93 for _, s := range v {
94 if strings.HasPrefix(s, pre) {
95 return true
96 }
97 }
98 return false
99 }
100
101
102
103 func NewServer(handler http.Handler) *Server {
104 ts := NewUnstartedServer(handler)
105 ts.Start()
106 return ts
107 }
108
109
110
111
112
113
114
115 func NewUnstartedServer(handler http.Handler) *Server {
116 return &Server{
117 Listener: newLocalListener(),
118 Config: &http.Server{Handler: handler},
119 }
120 }
121
122
123 func (s *Server) Start() {
124 if s.URL != "" {
125 panic("Server already started")
126 }
127 if s.client == nil {
128 s.client = &http.Client{Transport: &http.Transport{}}
129 }
130 s.URL = "http://" + s.Listener.Addr().String()
131 s.wrap()
132 s.goServe()
133 if serveFlag != "" {
134 fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
135 select {}
136 }
137 }
138
139
140 func (s *Server) StartTLS() {
141 if s.URL != "" {
142 panic("Server already started")
143 }
144 if s.client == nil {
145 s.client = &http.Client{Transport: &http.Transport{}}
146 }
147 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
148 if err != nil {
149 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
150 }
151
152 existingConfig := s.TLS
153 if existingConfig != nil {
154 s.TLS = existingConfig.Clone()
155 } else {
156 s.TLS = new(tls.Config)
157 }
158 if s.TLS.NextProtos == nil {
159 nextProtos := []string{"http/1.1"}
160 if s.EnableHTTP2 {
161 nextProtos = []string{"h2"}
162 }
163 s.TLS.NextProtos = nextProtos
164 }
165 if len(s.TLS.Certificates) == 0 {
166 s.TLS.Certificates = []tls.Certificate{cert}
167 }
168 s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
169 if err != nil {
170 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
171 }
172 certpool := x509.NewCertPool()
173 certpool.AddCert(s.certificate)
174 s.client.Transport = &http.Transport{
175 TLSClientConfig: &tls.Config{
176 RootCAs: certpool,
177 },
178 ForceAttemptHTTP2: s.EnableHTTP2,
179 }
180 s.Listener = tls.NewListener(s.Listener, s.TLS)
181 s.URL = "https://" + s.Listener.Addr().String()
182 s.wrap()
183 s.goServe()
184 }
185
186
187
188 func NewTLSServer(handler http.Handler) *Server {
189 ts := NewUnstartedServer(handler)
190 ts.StartTLS()
191 return ts
192 }
193
194 type closeIdleTransport interface {
195 CloseIdleConnections()
196 }
197
198
199
200 func (s *Server) Close() {
201 s.mu.Lock()
202 if !s.closed {
203 s.closed = true
204 s.Listener.Close()
205 s.Config.SetKeepAlivesEnabled(false)
206 for c, st := range s.conns {
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225 if st == http.StateIdle || st == http.StateNew {
226 s.closeConn(c)
227 }
228 }
229
230 t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
231 defer t.Stop()
232 }
233 s.mu.Unlock()
234
235
236
237
238 if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
239 t.CloseIdleConnections()
240 }
241
242
243 if s.client != nil {
244 if t, ok := s.client.Transport.(closeIdleTransport); ok {
245 t.CloseIdleConnections()
246 }
247 }
248
249 s.wg.Wait()
250 }
251
252 func (s *Server) logCloseHangDebugInfo() {
253 s.mu.Lock()
254 defer s.mu.Unlock()
255 var buf strings.Builder
256 buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
257 for c, st := range s.conns {
258 fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
259 }
260 log.Print(buf.String())
261 }
262
263
264 func (s *Server) CloseClientConnections() {
265 s.mu.Lock()
266 nconn := len(s.conns)
267 ch := make(chan struct{}, nconn)
268 for c := range s.conns {
269 go s.closeConnChan(c, ch)
270 }
271 s.mu.Unlock()
272
273
274
275
276
277
278
279 timer := time.NewTimer(5 * time.Second)
280 defer timer.Stop()
281 for i := 0; i < nconn; i++ {
282 select {
283 case <-ch:
284 case <-timer.C:
285
286 return
287 }
288 }
289 }
290
291
292
293 func (s *Server) Certificate() *x509.Certificate {
294 return s.certificate
295 }
296
297
298
299
300 func (s *Server) Client() *http.Client {
301 return s.client
302 }
303
304 func (s *Server) goServe() {
305 s.wg.Add(1)
306 go func() {
307 defer s.wg.Done()
308 s.Config.Serve(s.Listener)
309 }()
310 }
311
312
313
314 func (s *Server) wrap() {
315 oldHook := s.Config.ConnState
316 s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
317 s.mu.Lock()
318 defer s.mu.Unlock()
319
320 switch cs {
321 case http.StateNew:
322 if _, exists := s.conns[c]; exists {
323 panic("invalid state transition")
324 }
325 if s.conns == nil {
326 s.conns = make(map[net.Conn]http.ConnState)
327 }
328
329
330 s.wg.Add(1)
331 s.conns[c] = cs
332 if s.closed {
333
334
335
336
337 s.closeConn(c)
338 }
339 case http.StateActive:
340 if oldState, ok := s.conns[c]; ok {
341 if oldState != http.StateNew && oldState != http.StateIdle {
342 panic("invalid state transition")
343 }
344 s.conns[c] = cs
345 }
346 case http.StateIdle:
347 if oldState, ok := s.conns[c]; ok {
348 if oldState != http.StateActive {
349 panic("invalid state transition")
350 }
351 s.conns[c] = cs
352 }
353 if s.closed {
354 s.closeConn(c)
355 }
356 case http.StateHijacked, http.StateClosed:
357
358
359 if _, ok := s.conns[c]; ok {
360 delete(s.conns, c)
361
362
363 defer s.wg.Done()
364 }
365 }
366 if oldHook != nil {
367 oldHook(c, cs)
368 }
369 }
370 }
371
372
373
374 func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
375
376
377
378 func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
379 c.Close()
380 if done != nil {
381 done <- struct{}{}
382 }
383 }
384
View as plain text