...
Source file
src/net/http/alpn_test.go
1
2
3
4
5 package http_test
6
7 import (
8 "bufio"
9 "bytes"
10 "crypto/tls"
11 "crypto/x509"
12 "fmt"
13 "io"
14 . "net/http"
15 "net/http/httptest"
16 "strings"
17 "testing"
18 )
19
20 func TestNextProtoUpgrade(t *testing.T) {
21 setParallel(t)
22 defer afterTest(t)
23 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
24 fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
25 if r.TLS != nil {
26 w.Write([]byte(r.TLS.NegotiatedProtocol))
27 }
28 if r.RemoteAddr == "" {
29 t.Error("request with no RemoteAddr")
30 }
31 if r.Body == nil {
32 t.Errorf("request with nil Body")
33 }
34 }))
35 ts.TLS = &tls.Config{
36 NextProtos: []string{"unhandled-proto", "tls-0.9"},
37 }
38 ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
39 "tls-0.9": handleTLSProtocol09,
40 }
41 ts.StartTLS()
42 defer ts.Close()
43
44
45 {
46 c := ts.Client()
47 res, err := c.Get(ts.URL)
48 if err != nil {
49 t.Fatal(err)
50 }
51 body, err := io.ReadAll(res.Body)
52 if err != nil {
53 t.Fatal(err)
54 }
55 if want := "path=/,proto="; string(body) != want {
56 t.Errorf("plain request = %q; want %q", body, want)
57 }
58 }
59
60
61
62 {
63 certPool := x509.NewCertPool()
64 certPool.AddCert(ts.Certificate())
65 tr := &Transport{
66 TLSClientConfig: &tls.Config{
67 RootCAs: certPool,
68 NextProtos: []string{"unhandled-proto"},
69 },
70 }
71 defer tr.CloseIdleConnections()
72 c := &Client{
73 Transport: tr,
74 }
75 res, err := c.Get(ts.URL)
76 if err == nil {
77 defer res.Body.Close()
78 var buf bytes.Buffer
79 res.Write(&buf)
80 t.Errorf("expected error on unhandled-proto request; got: %s", buf.Bytes())
81 }
82 }
83
84
85
86 {
87 c := ts.Client()
88 tlsConfig := c.Transport.(*Transport).TLSClientConfig
89 tlsConfig.NextProtos = []string{"tls-0.9"}
90 conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
91 if err != nil {
92 t.Fatal(err)
93 }
94 conn.Write([]byte("GET /foo\n"))
95 body, err := io.ReadAll(conn)
96 if err != nil {
97 t.Fatal(err)
98 }
99 if want := "path=/foo,proto=tls-0.9"; string(body) != want {
100 t.Errorf("plain request = %q; want %q", body, want)
101 }
102 }
103 }
104
105
106
107 func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
108 br := bufio.NewReader(conn)
109 line, err := br.ReadString('\n')
110 if err != nil {
111 return
112 }
113 line = strings.TrimSpace(line)
114 path := strings.TrimPrefix(line, "GET ")
115 if path == line {
116 return
117 }
118 req, _ := NewRequest("GET", path, nil)
119 req.Proto = "HTTP/0.9"
120 req.ProtoMajor = 0
121 req.ProtoMinor = 9
122 rw := &http09Writer{conn, make(Header)}
123 h.ServeHTTP(rw, req)
124 }
125
126 type http09Writer struct {
127 io.Writer
128 h Header
129 }
130
131 func (w http09Writer) Header() Header { return w.h }
132 func (w http09Writer) WriteHeader(int) {}
133
View as plain text