Source file
src/crypto/tls/handshake_client_test.go
1
2
3
4
5 package tls
6
7 import (
8 "bytes"
9 "context"
10 "crypto/rsa"
11 "crypto/x509"
12 "encoding/base64"
13 "encoding/binary"
14 "encoding/pem"
15 "errors"
16 "fmt"
17 "io"
18 "math/big"
19 "net"
20 "os"
21 "os/exec"
22 "path/filepath"
23 "reflect"
24 "runtime"
25 "strconv"
26 "strings"
27 "testing"
28 "time"
29 )
30
31
32
33
34
35
36 type opensslInputEvent int
37
38 const (
39
40
41 opensslRenegotiate opensslInputEvent = iota
42
43
44
45 opensslSendSentinel
46
47
48
49 opensslKeyUpdate
50 )
51
52 const opensslSentinel = "SENTINEL\n"
53
54 type opensslInput chan opensslInputEvent
55
56 func (i opensslInput) Read(buf []byte) (n int, err error) {
57 for event := range i {
58 switch event {
59 case opensslRenegotiate:
60 return copy(buf, []byte("R\n")), nil
61 case opensslKeyUpdate:
62 return copy(buf, []byte("K\n")), nil
63 case opensslSendSentinel:
64 return copy(buf, []byte(opensslSentinel)), nil
65 default:
66 panic("unknown event")
67 }
68 }
69
70 return 0, io.EOF
71 }
72
73
74
75
76 type opensslOutputSink struct {
77 handshakeComplete chan struct{}
78 readKeyUpdate chan struct{}
79 all []byte
80 line []byte
81 }
82
83 func newOpensslOutputSink() *opensslOutputSink {
84 return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil}
85 }
86
87
88
89 const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
90
91
92
93 const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update"
94
95 func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
96 o.line = append(o.line, data...)
97 o.all = append(o.all, data...)
98
99 for {
100 i := bytes.IndexByte(o.line, '\n')
101 if i < 0 {
102 break
103 }
104
105 if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
106 o.handshakeComplete <- struct{}{}
107 }
108 if bytes.Equal([]byte(opensslReadKeyUpdate), o.line[:i]) {
109 o.readKeyUpdate <- struct{}{}
110 }
111 o.line = o.line[i+1:]
112 }
113
114 return len(data), nil
115 }
116
117 func (o *opensslOutputSink) String() string {
118 return string(o.all)
119 }
120
121
122
123 type clientTest struct {
124
125
126 name string
127
128
129 args []string
130
131 config *Config
132
133
134 cert []byte
135
136
137 key interface{}
138
139
140
141 extensions [][]byte
142
143
144
145 validate func(ConnectionState) error
146
147
148 numRenegotiations int
149
150
151 renegotiationExpectedToFail int
152
153
154
155 checkRenegotiationError func(renegotiationNum int, err error) error
156
157 sendKeyUpdate bool
158 }
159
160 var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"}
161
162
163
164
165
166 func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
167 cert := testRSACertificate
168 if len(test.cert) > 0 {
169 cert = test.cert
170 }
171 certPath := tempFile(string(cert))
172 defer os.Remove(certPath)
173
174 var key interface{} = testRSAPrivateKey
175 if test.key != nil {
176 key = test.key
177 }
178 derBytes, err := x509.MarshalPKCS8PrivateKey(key)
179 if err != nil {
180 panic(err)
181 }
182
183 var pemOut bytes.Buffer
184 pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
185
186 keyPath := tempFile(pemOut.String())
187 defer os.Remove(keyPath)
188
189 var command []string
190 command = append(command, serverCommand...)
191 command = append(command, test.args...)
192 command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
193
194
195
196
197
198 const serverPort = 24323
199 command = append(command, "-accept", strconv.Itoa(serverPort))
200
201 if len(test.extensions) > 0 {
202 var serverInfo bytes.Buffer
203 for _, ext := range test.extensions {
204 pem.Encode(&serverInfo, &pem.Block{
205 Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
206 Bytes: ext,
207 })
208 }
209 serverInfoPath := tempFile(serverInfo.String())
210 defer os.Remove(serverInfoPath)
211 command = append(command, "-serverinfo", serverInfoPath)
212 }
213
214 if test.numRenegotiations > 0 || test.sendKeyUpdate {
215 found := false
216 for _, flag := range command[1:] {
217 if flag == "-state" {
218 found = true
219 break
220 }
221 }
222
223 if !found {
224 panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate")
225 }
226 }
227
228 cmd := exec.Command(command[0], command[1:]...)
229 stdin = opensslInput(make(chan opensslInputEvent))
230 cmd.Stdin = stdin
231 out := newOpensslOutputSink()
232 cmd.Stdout = out
233 cmd.Stderr = out
234 if err := cmd.Start(); err != nil {
235 return nil, nil, nil, nil, err
236 }
237
238
239
240
241
242 var tcpConn net.Conn
243 for i := uint(0); i < 5; i++ {
244 tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
245 IP: net.IPv4(127, 0, 0, 1),
246 Port: serverPort,
247 })
248 if err == nil {
249 break
250 }
251 time.Sleep((1 << i) * 5 * time.Millisecond)
252 }
253 if err != nil {
254 close(stdin)
255 cmd.Process.Kill()
256 err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out)
257 return nil, nil, nil, nil, err
258 }
259
260 record := &recordingConn{
261 Conn: tcpConn,
262 }
263
264 return record, cmd, stdin, out, nil
265 }
266
267 func (test *clientTest) dataPath() string {
268 return filepath.Join("testdata", "Client-"+test.name)
269 }
270
271 func (test *clientTest) loadData() (flows [][]byte, err error) {
272 in, err := os.Open(test.dataPath())
273 if err != nil {
274 return nil, err
275 }
276 defer in.Close()
277 return parseTestData(in)
278 }
279
280 func (test *clientTest) run(t *testing.T, write bool) {
281 var clientConn, serverConn net.Conn
282 var recordingConn *recordingConn
283 var childProcess *exec.Cmd
284 var stdin opensslInput
285 var stdout *opensslOutputSink
286
287 if write {
288 var err error
289 recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
290 if err != nil {
291 t.Fatalf("Failed to start subcommand: %s", err)
292 }
293 clientConn = recordingConn
294 defer func() {
295 if t.Failed() {
296 t.Logf("OpenSSL output:\n\n%s", stdout.all)
297 }
298 }()
299 } else {
300 clientConn, serverConn = localPipe(t)
301 }
302
303 doneChan := make(chan bool)
304 defer func() {
305 clientConn.Close()
306 <-doneChan
307 }()
308 go func() {
309 defer close(doneChan)
310
311 config := test.config
312 if config == nil {
313 config = testConfig
314 }
315 client := Client(clientConn, config)
316 defer client.Close()
317
318 if _, err := client.Write([]byte("hello\n")); err != nil {
319 t.Errorf("Client.Write failed: %s", err)
320 return
321 }
322
323 for i := 1; i <= test.numRenegotiations; i++ {
324
325
326 if i == 1 && write {
327 <-stdout.handshakeComplete
328 }
329
330
331
332
333
334
335
336
337
338 if write {
339 stdin <- opensslRenegotiate
340 }
341
342 signalChan := make(chan struct{})
343
344 go func() {
345 defer close(signalChan)
346
347 buf := make([]byte, 256)
348 n, err := client.Read(buf)
349
350 if test.checkRenegotiationError != nil {
351 newErr := test.checkRenegotiationError(i, err)
352 if err != nil && newErr == nil {
353 return
354 }
355 err = newErr
356 }
357
358 if err != nil {
359 t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
360 return
361 }
362
363 buf = buf[:n]
364 if !bytes.Equal([]byte(opensslSentinel), buf) {
365 t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
366 }
367
368 if expected := i + 1; client.handshakes != expected {
369 t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
370 }
371 }()
372
373 if write && test.renegotiationExpectedToFail != i {
374 <-stdout.handshakeComplete
375 stdin <- opensslSendSentinel
376 }
377 <-signalChan
378 }
379
380 if test.sendKeyUpdate {
381 if write {
382 <-stdout.handshakeComplete
383 stdin <- opensslKeyUpdate
384 }
385
386 doneRead := make(chan struct{})
387
388 go func() {
389 defer close(doneRead)
390
391 buf := make([]byte, 256)
392 n, err := client.Read(buf)
393
394 if err != nil {
395 t.Errorf("Client.Read failed after KeyUpdate: %s", err)
396 return
397 }
398
399 buf = buf[:n]
400 if !bytes.Equal([]byte(opensslSentinel), buf) {
401 t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
402 }
403 }()
404
405 if write {
406
407
408
409 <-stdout.readKeyUpdate
410 stdin <- opensslSendSentinel
411 }
412 <-doneRead
413
414 if _, err := client.Write([]byte("hello again\n")); err != nil {
415 t.Errorf("Client.Write failed: %s", err)
416 return
417 }
418 }
419
420 if test.validate != nil {
421 if err := test.validate(client.ConnectionState()); err != nil {
422 t.Errorf("validate callback returned error: %s", err)
423 }
424 }
425
426
427
428 if write && test.renegotiationExpectedToFail == 0 {
429 if err := peekError(client); err != nil {
430 t.Errorf("final Read returned an error: %s", err)
431 }
432 }
433 }()
434
435 if !write {
436 flows, err := test.loadData()
437 if err != nil {
438 t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
439 }
440 for i, b := range flows {
441 if i%2 == 1 {
442 if *fast {
443 serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
444 } else {
445 serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
446 }
447 serverConn.Write(b)
448 continue
449 }
450 bb := make([]byte, len(b))
451 if *fast {
452 serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
453 } else {
454 serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
455 }
456 _, err := io.ReadFull(serverConn, bb)
457 if err != nil {
458 t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
459 }
460 if !bytes.Equal(b, bb) {
461 t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
462 }
463 }
464 }
465
466 <-doneChan
467 if !write {
468 serverConn.Close()
469 }
470
471 if write {
472 path := test.dataPath()
473 out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
474 if err != nil {
475 t.Fatalf("Failed to create output file: %s", err)
476 }
477 defer out.Close()
478 recordingConn.Close()
479 close(stdin)
480 childProcess.Process.Kill()
481 childProcess.Wait()
482 if len(recordingConn.flows) < 3 {
483 t.Fatalf("Client connection didn't work")
484 }
485 recordingConn.WriteTo(out)
486 t.Logf("Wrote %s\n", path)
487 }
488 }
489
490
491
492 func peekError(conn net.Conn) error {
493 conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
494 if n, err := conn.Read(make([]byte, 1)); n != 0 {
495 return errors.New("unexpectedly read data")
496 } else if err != nil {
497 if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
498 return err
499 }
500 }
501 return nil
502 }
503
504 func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) {
505
506 test := *template
507 if template.config != nil {
508 test.config = template.config.Clone()
509 }
510 test.name = version + "-" + test.name
511 test.args = append([]string{option}, test.args...)
512
513 runTestAndUpdateIfNeeded(t, version, test.run, false)
514 }
515
516 func runClientTestTLS10(t *testing.T, template *clientTest) {
517 runClientTestForVersion(t, template, "TLSv10", "-tls1")
518 }
519
520 func runClientTestTLS11(t *testing.T, template *clientTest) {
521 runClientTestForVersion(t, template, "TLSv11", "-tls1_1")
522 }
523
524 func runClientTestTLS12(t *testing.T, template *clientTest) {
525 runClientTestForVersion(t, template, "TLSv12", "-tls1_2")
526 }
527
528 func runClientTestTLS13(t *testing.T, template *clientTest) {
529 runClientTestForVersion(t, template, "TLSv13", "-tls1_3")
530 }
531
532 func TestHandshakeClientRSARC4(t *testing.T) {
533 test := &clientTest{
534 name: "RSA-RC4",
535 args: []string{"-cipher", "RC4-SHA"},
536 }
537 runClientTestTLS10(t, test)
538 runClientTestTLS11(t, test)
539 runClientTestTLS12(t, test)
540 }
541
542 func TestHandshakeClientRSAAES128GCM(t *testing.T) {
543 test := &clientTest{
544 name: "AES128-GCM-SHA256",
545 args: []string{"-cipher", "AES128-GCM-SHA256"},
546 }
547 runClientTestTLS12(t, test)
548 }
549
550 func TestHandshakeClientRSAAES256GCM(t *testing.T) {
551 test := &clientTest{
552 name: "AES256-GCM-SHA384",
553 args: []string{"-cipher", "AES256-GCM-SHA384"},
554 }
555 runClientTestTLS12(t, test)
556 }
557
558 func TestHandshakeClientECDHERSAAES(t *testing.T) {
559 test := &clientTest{
560 name: "ECDHE-RSA-AES",
561 args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"},
562 }
563 runClientTestTLS10(t, test)
564 runClientTestTLS11(t, test)
565 runClientTestTLS12(t, test)
566 }
567
568 func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
569 test := &clientTest{
570 name: "ECDHE-ECDSA-AES",
571 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"},
572 cert: testECDSACertificate,
573 key: testECDSAPrivateKey,
574 }
575 runClientTestTLS10(t, test)
576 runClientTestTLS11(t, test)
577 runClientTestTLS12(t, test)
578 }
579
580 func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
581 test := &clientTest{
582 name: "ECDHE-ECDSA-AES-GCM",
583 args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
584 cert: testECDSACertificate,
585 key: testECDSAPrivateKey,
586 }
587 runClientTestTLS12(t, test)
588 }
589
590 func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
591 test := &clientTest{
592 name: "ECDHE-ECDSA-AES256-GCM-SHA384",
593 args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
594 cert: testECDSACertificate,
595 key: testECDSAPrivateKey,
596 }
597 runClientTestTLS12(t, test)
598 }
599
600 func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
601 test := &clientTest{
602 name: "AES128-SHA256",
603 args: []string{"-cipher", "AES128-SHA256"},
604 }
605 runClientTestTLS12(t, test)
606 }
607
608 func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
609 test := &clientTest{
610 name: "ECDHE-RSA-AES128-SHA256",
611 args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"},
612 }
613 runClientTestTLS12(t, test)
614 }
615
616 func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
617 test := &clientTest{
618 name: "ECDHE-ECDSA-AES128-SHA256",
619 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"},
620 cert: testECDSACertificate,
621 key: testECDSAPrivateKey,
622 }
623 runClientTestTLS12(t, test)
624 }
625
626 func TestHandshakeClientX25519(t *testing.T) {
627 config := testConfig.Clone()
628 config.CurvePreferences = []CurveID{X25519}
629
630 test := &clientTest{
631 name: "X25519-ECDHE",
632 args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"},
633 config: config,
634 }
635
636 runClientTestTLS12(t, test)
637 runClientTestTLS13(t, test)
638 }
639
640 func TestHandshakeClientP256(t *testing.T) {
641 config := testConfig.Clone()
642 config.CurvePreferences = []CurveID{CurveP256}
643
644 test := &clientTest{
645 name: "P256-ECDHE",
646 args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
647 config: config,
648 }
649
650 runClientTestTLS12(t, test)
651 runClientTestTLS13(t, test)
652 }
653
654 func TestHandshakeClientHelloRetryRequest(t *testing.T) {
655 config := testConfig.Clone()
656 config.CurvePreferences = []CurveID{X25519, CurveP256}
657
658 test := &clientTest{
659 name: "HelloRetryRequest",
660 args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
661 config: config,
662 }
663
664 runClientTestTLS13(t, test)
665 }
666
667 func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
668 config := testConfig.Clone()
669 config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
670
671 test := &clientTest{
672 name: "ECDHE-RSA-CHACHA20-POLY1305",
673 args: []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
674 config: config,
675 }
676
677 runClientTestTLS12(t, test)
678 }
679
680 func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
681 config := testConfig.Clone()
682 config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
683
684 test := &clientTest{
685 name: "ECDHE-ECDSA-CHACHA20-POLY1305",
686 args: []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
687 config: config,
688 cert: testECDSACertificate,
689 key: testECDSAPrivateKey,
690 }
691
692 runClientTestTLS12(t, test)
693 }
694
695 func TestHandshakeClientAES128SHA256(t *testing.T) {
696 test := &clientTest{
697 name: "AES128-SHA256",
698 args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"},
699 }
700 runClientTestTLS13(t, test)
701 }
702 func TestHandshakeClientAES256SHA384(t *testing.T) {
703 test := &clientTest{
704 name: "AES256-SHA384",
705 args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"},
706 }
707 runClientTestTLS13(t, test)
708 }
709 func TestHandshakeClientCHACHA20SHA256(t *testing.T) {
710 test := &clientTest{
711 name: "CHACHA20-SHA256",
712 args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"},
713 }
714 runClientTestTLS13(t, test)
715 }
716
717 func TestHandshakeClientECDSATLS13(t *testing.T) {
718 test := &clientTest{
719 name: "ECDSA",
720 cert: testECDSACertificate,
721 key: testECDSAPrivateKey,
722 }
723 runClientTestTLS13(t, test)
724 }
725
726 func TestHandshakeClientEd25519(t *testing.T) {
727 test := &clientTest{
728 name: "Ed25519",
729 cert: testEd25519Certificate,
730 key: testEd25519PrivateKey,
731 }
732 runClientTestTLS12(t, test)
733 runClientTestTLS13(t, test)
734
735 config := testConfig.Clone()
736 cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM))
737 config.Certificates = []Certificate{cert}
738
739 test = &clientTest{
740 name: "ClientCert-Ed25519",
741 args: []string{"-Verify", "1"},
742 config: config,
743 }
744
745 runClientTestTLS12(t, test)
746 runClientTestTLS13(t, test)
747 }
748
749 func TestHandshakeClientCertRSA(t *testing.T) {
750 config := testConfig.Clone()
751 cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
752 config.Certificates = []Certificate{cert}
753
754 test := &clientTest{
755 name: "ClientCert-RSA-RSA",
756 args: []string{"-cipher", "AES128", "-Verify", "1"},
757 config: config,
758 }
759
760 runClientTestTLS10(t, test)
761 runClientTestTLS12(t, test)
762
763 test = &clientTest{
764 name: "ClientCert-RSA-ECDSA",
765 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
766 config: config,
767 cert: testECDSACertificate,
768 key: testECDSAPrivateKey,
769 }
770
771 runClientTestTLS10(t, test)
772 runClientTestTLS12(t, test)
773 runClientTestTLS13(t, test)
774
775 test = &clientTest{
776 name: "ClientCert-RSA-AES256-GCM-SHA384",
777 args: []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"},
778 config: config,
779 cert: testRSACertificate,
780 key: testRSAPrivateKey,
781 }
782
783 runClientTestTLS12(t, test)
784 }
785
786 func TestHandshakeClientCertECDSA(t *testing.T) {
787 config := testConfig.Clone()
788 cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
789 config.Certificates = []Certificate{cert}
790
791 test := &clientTest{
792 name: "ClientCert-ECDSA-RSA",
793 args: []string{"-cipher", "AES128", "-Verify", "1"},
794 config: config,
795 }
796
797 runClientTestTLS10(t, test)
798 runClientTestTLS12(t, test)
799 runClientTestTLS13(t, test)
800
801 test = &clientTest{
802 name: "ClientCert-ECDSA-ECDSA",
803 args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
804 config: config,
805 cert: testECDSACertificate,
806 key: testECDSAPrivateKey,
807 }
808
809 runClientTestTLS10(t, test)
810 runClientTestTLS12(t, test)
811 }
812
813
814
815
816
817 func TestHandshakeClientCertRSAPSS(t *testing.T) {
818 cert, err := x509.ParseCertificate(testRSAPSSCertificate)
819 if err != nil {
820 panic(err)
821 }
822 rootCAs := x509.NewCertPool()
823 rootCAs.AddCert(cert)
824
825 config := testConfig.Clone()
826
827 config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) {
828 return &Certificate{
829 Certificate: [][]byte{testRSAPSSCertificate},
830 PrivateKey: testRSAPrivateKey,
831 }, nil
832 }
833 config.RootCAs = rootCAs
834
835 test := &clientTest{
836 name: "ClientCert-RSA-RSAPSS",
837 args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
838 "rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"},
839 config: config,
840 cert: testRSAPSSCertificate,
841 key: testRSAPrivateKey,
842 }
843 runClientTestTLS12(t, test)
844 runClientTestTLS13(t, test)
845 }
846
847 func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) {
848 config := testConfig.Clone()
849 cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
850 config.Certificates = []Certificate{cert}
851
852 test := &clientTest{
853 name: "ClientCert-RSA-RSAPKCS1v15",
854 args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
855 "rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"},
856 config: config,
857 }
858
859 runClientTestTLS12(t, test)
860 }
861
862 func TestClientKeyUpdate(t *testing.T) {
863 test := &clientTest{
864 name: "KeyUpdate",
865 args: []string{"-state"},
866 sendKeyUpdate: true,
867 }
868 runClientTestTLS13(t, test)
869 }
870
871 func TestResumption(t *testing.T) {
872 t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
873 t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
874 }
875
876 func testResumption(t *testing.T, version uint16) {
877 if testing.Short() {
878 t.Skip("skipping in -short mode")
879 }
880 serverConfig := &Config{
881 MaxVersion: version,
882 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
883 Certificates: testConfig.Certificates,
884 }
885
886 issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
887 if err != nil {
888 panic(err)
889 }
890
891 rootCAs := x509.NewCertPool()
892 rootCAs.AddCert(issuer)
893
894 clientConfig := &Config{
895 MaxVersion: version,
896 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
897 ClientSessionCache: NewLRUClientSessionCache(32),
898 RootCAs: rootCAs,
899 ServerName: "example.golang",
900 }
901
902 testResumeState := func(test string, didResume bool) {
903 _, hs, err := testHandshake(t, clientConfig, serverConfig)
904 if err != nil {
905 t.Fatalf("%s: handshake failed: %s", test, err)
906 }
907 if hs.DidResume != didResume {
908 t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
909 }
910 if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
911 t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
912 }
913 if got, want := hs.ServerName, clientConfig.ServerName; got != want {
914 t.Errorf("%s: server name %s, want %s", test, got, want)
915 }
916 }
917
918 getTicket := func() []byte {
919 return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
920 }
921 deleteTicket := func() {
922 ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
923 clientConfig.ClientSessionCache.Put(ticketKey, nil)
924 }
925 corruptTicket := func() {
926 clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff
927 }
928 randomKey := func() [32]byte {
929 var k [32]byte
930 if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
931 t.Fatalf("Failed to read new SessionTicketKey: %s", err)
932 }
933 return k
934 }
935
936 testResumeState("Handshake", false)
937 ticket := getTicket()
938 testResumeState("Resume", true)
939 if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 {
940 t.Fatal("first ticket doesn't match ticket after resumption")
941 }
942 if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 {
943 t.Fatal("ticket didn't change after resumption")
944 }
945
946
947 serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
948 testResumeState("ResumeWithOldTicket", true)
949 if bytes.Equal(ticket[:ticketKeyNameLen], getTicket()[:ticketKeyNameLen]) {
950 t.Fatal("old first ticket matches the fresh one")
951 }
952
953
954 serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
955 testResumeState("ResumeWithExpiredTicket", false)
956 if bytes.Equal(ticket, getTicket()) {
957 t.Fatal("expired first ticket matches the fresh one")
958 }
959
960 serverConfig.Time = func() time.Time { return time.Now() }
961 key1 := randomKey()
962 serverConfig.SetSessionTicketKeys([][32]byte{key1})
963
964 testResumeState("InvalidSessionTicketKey", false)
965 testResumeState("ResumeAfterInvalidSessionTicketKey", true)
966
967 key2 := randomKey()
968 serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
969 ticket = getTicket()
970 testResumeState("KeyChange", true)
971 if bytes.Equal(ticket, getTicket()) {
972 t.Fatal("new ticket wasn't included while resuming")
973 }
974 testResumeState("KeyChangeFinish", true)
975
976
977 serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
978 testResumeState("OldSessionTicket", true)
979 ticket = getTicket()
980
981 serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
982 testResumeState("ExpiredSessionTicket", false)
983 if bytes.Equal(ticket, getTicket()) {
984 t.Fatal("new ticket wasn't provided after old ticket expired")
985 }
986
987
988 d := 0 * time.Hour
989 for i := 0; i < 13; i++ {
990 d += 12 * time.Hour
991 serverConfig.Time = func() time.Time { return time.Now().Add(d) }
992 testResumeState("OldSessionTicket", true)
993 }
994
995
996
997
998 d += 12 * time.Hour
999 serverConfig.Time = func() time.Time { return time.Now().Add(d) }
1000 if version == VersionTLS13 {
1001 testResumeState("ExpiredSessionTicket", true)
1002 } else {
1003 testResumeState("ExpiredSessionTicket", false)
1004 }
1005 if bytes.Equal(ticket, getTicket()) {
1006 t.Fatal("new ticket wasn't provided after old ticket expired")
1007 }
1008
1009
1010
1011 serverConfig = &Config{
1012 MaxVersion: version,
1013 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
1014 Certificates: testConfig.Certificates,
1015 }
1016 serverConfig.SetSessionTicketKeys([][32]byte{key2})
1017
1018 testResumeState("FreshConfig", true)
1019
1020
1021
1022 if version != VersionTLS13 {
1023 clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
1024 testResumeState("DifferentCipherSuite", false)
1025 testResumeState("DifferentCipherSuiteRecovers", true)
1026 }
1027
1028 deleteTicket()
1029 testResumeState("WithoutSessionTicket", false)
1030
1031
1032 deleteTicket()
1033 serverConfig.ClientCAs = rootCAs
1034 serverConfig.ClientAuth = RequireAndVerifyClientCert
1035 clientConfig.Certificates = serverConfig.Certificates
1036 testResumeState("InitialHandshake", false)
1037 testResumeState("WithClientCertificates", true)
1038 serverConfig.ClientAuth = NoClientCert
1039
1040
1041
1042 testResumeState("FetchTicketToCorrupt", false)
1043 corruptTicket()
1044 _, _, err = testHandshake(t, clientConfig, serverConfig)
1045 if err == nil {
1046 t.Fatalf("handshake did not fail with a corrupted client secret")
1047 }
1048 testResumeState("AfterHandshakeFailure", false)
1049
1050 clientConfig.ClientSessionCache = nil
1051 testResumeState("WithoutSessionCache", false)
1052 }
1053
1054 func TestLRUClientSessionCache(t *testing.T) {
1055
1056 cache := NewLRUClientSessionCache(4)
1057 cs := make([]ClientSessionState, 6)
1058 keys := []string{"0", "1", "2", "3", "4", "5", "6"}
1059
1060
1061 for i := 0; i < 4; i++ {
1062 cache.Put(keys[i], &cs[i])
1063 }
1064 for i := 0; i < 4; i++ {
1065 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1066 t.Fatalf("session cache failed lookup for added key: %s", keys[i])
1067 }
1068 }
1069
1070
1071 for i := 4; i < 6; i++ {
1072 cache.Put(keys[i], &cs[i])
1073 }
1074 for i := 0; i < 2; i++ {
1075 if s, ok := cache.Get(keys[i]); ok || s != nil {
1076 t.Fatalf("session cache should have evicted key: %s", keys[i])
1077 }
1078 }
1079
1080
1081 cache.Get(keys[2])
1082 cache.Put(keys[0], &cs[0])
1083 if s, ok := cache.Get(keys[3]); ok || s != nil {
1084 t.Fatalf("session cache should have evicted key 3")
1085 }
1086
1087
1088 cache.Put(keys[0], &cs[3])
1089 if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
1090 t.Fatalf("session cache failed update for key 0")
1091 }
1092
1093
1094 cache.Put(keys[0], nil)
1095 if _, ok := cache.Get(keys[0]); ok {
1096 t.Fatalf("session cache failed to delete key 0")
1097 }
1098
1099
1100 cache.Put(keys[2], nil)
1101 if _, ok := cache.Get(keys[2]); ok {
1102 t.Fatalf("session cache failed to delete key 4")
1103 }
1104 for i := 4; i < 6; i++ {
1105 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1106 t.Fatalf("session cache should not have deleted key: %s", keys[i])
1107 }
1108 }
1109 }
1110
1111 func TestKeyLogTLS12(t *testing.T) {
1112 var serverBuf, clientBuf bytes.Buffer
1113
1114 clientConfig := testConfig.Clone()
1115 clientConfig.KeyLogWriter = &clientBuf
1116 clientConfig.MaxVersion = VersionTLS12
1117
1118 serverConfig := testConfig.Clone()
1119 serverConfig.KeyLogWriter = &serverBuf
1120 serverConfig.MaxVersion = VersionTLS12
1121
1122 c, s := localPipe(t)
1123 done := make(chan bool)
1124
1125 go func() {
1126 defer close(done)
1127
1128 if err := Server(s, serverConfig).Handshake(); err != nil {
1129 t.Errorf("server: %s", err)
1130 return
1131 }
1132 s.Close()
1133 }()
1134
1135 if err := Client(c, clientConfig).Handshake(); err != nil {
1136 t.Fatalf("client: %s", err)
1137 }
1138
1139 c.Close()
1140 <-done
1141
1142 checkKeylogLine := func(side, loggedLine string) {
1143 if len(loggedLine) == 0 {
1144 t.Fatalf("%s: no keylog line was produced", side)
1145 }
1146 const expectedLen = 13 +
1147 1 +
1148 32*2 +
1149 1 +
1150 48*2 +
1151 1
1152 if len(loggedLine) != expectedLen {
1153 t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
1154 }
1155 if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
1156 t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
1157 }
1158 }
1159
1160 checkKeylogLine("client", clientBuf.String())
1161 checkKeylogLine("server", serverBuf.String())
1162 }
1163
1164 func TestKeyLogTLS13(t *testing.T) {
1165 var serverBuf, clientBuf bytes.Buffer
1166
1167 clientConfig := testConfig.Clone()
1168 clientConfig.KeyLogWriter = &clientBuf
1169
1170 serverConfig := testConfig.Clone()
1171 serverConfig.KeyLogWriter = &serverBuf
1172
1173 c, s := localPipe(t)
1174 done := make(chan bool)
1175
1176 go func() {
1177 defer close(done)
1178
1179 if err := Server(s, serverConfig).Handshake(); err != nil {
1180 t.Errorf("server: %s", err)
1181 return
1182 }
1183 s.Close()
1184 }()
1185
1186 if err := Client(c, clientConfig).Handshake(); err != nil {
1187 t.Fatalf("client: %s", err)
1188 }
1189
1190 c.Close()
1191 <-done
1192
1193 checkKeylogLines := func(side, loggedLines string) {
1194 loggedLines = strings.TrimSpace(loggedLines)
1195 lines := strings.Split(loggedLines, "\n")
1196 if len(lines) != 4 {
1197 t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines))
1198 }
1199 }
1200
1201 checkKeylogLines("client", clientBuf.String())
1202 checkKeylogLines("server", serverBuf.String())
1203 }
1204
1205 func TestHandshakeClientALPNMatch(t *testing.T) {
1206 config := testConfig.Clone()
1207 config.NextProtos = []string{"proto2", "proto1"}
1208
1209 test := &clientTest{
1210 name: "ALPN",
1211
1212
1213 args: []string{"-alpn", "proto1,proto2"},
1214 config: config,
1215 validate: func(state ConnectionState) error {
1216
1217 if state.NegotiatedProtocol != "proto1" {
1218 return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
1219 }
1220 return nil
1221 },
1222 }
1223 runClientTestTLS12(t, test)
1224 runClientTestTLS13(t, test)
1225 }
1226
1227 func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) {
1228
1229
1230
1231 c, s := localPipe(t)
1232 errChan := make(chan error, 1)
1233
1234 go func() {
1235 client := Client(c, &Config{
1236 ServerName: "foo",
1237 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1238 NextProtos: []string{"http", "something-else"},
1239 })
1240 errChan <- client.Handshake()
1241 }()
1242
1243 var header [5]byte
1244 if _, err := io.ReadFull(s, header[:]); err != nil {
1245 t.Fatal(err)
1246 }
1247 recordLen := int(header[3])<<8 | int(header[4])
1248
1249 record := make([]byte, recordLen)
1250 if _, err := io.ReadFull(s, record); err != nil {
1251 t.Fatal(err)
1252 }
1253
1254 serverHello := &serverHelloMsg{
1255 vers: VersionTLS12,
1256 random: make([]byte, 32),
1257 cipherSuite: TLS_RSA_WITH_AES_128_GCM_SHA256,
1258 alpnProtocol: "how-about-this",
1259 }
1260 serverHelloBytes := serverHello.marshal()
1261
1262 s.Write([]byte{
1263 byte(recordTypeHandshake),
1264 byte(VersionTLS12 >> 8),
1265 byte(VersionTLS12 & 0xff),
1266 byte(len(serverHelloBytes) >> 8),
1267 byte(len(serverHelloBytes)),
1268 })
1269 s.Write(serverHelloBytes)
1270 s.Close()
1271
1272 if err := <-errChan; !strings.Contains(err.Error(), "server selected unadvertised ALPN protocol") {
1273 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1274 }
1275 }
1276
1277
1278 const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
1279
1280 func TestHandshakClientSCTs(t *testing.T) {
1281 config := testConfig.Clone()
1282
1283 scts, err := base64.StdEncoding.DecodeString(sctsBase64)
1284 if err != nil {
1285 t.Fatal(err)
1286 }
1287
1288
1289
1290 test := &clientTest{
1291 name: "SCT",
1292 config: config,
1293 extensions: [][]byte{scts},
1294 validate: func(state ConnectionState) error {
1295 expectedSCTs := [][]byte{
1296 scts[8:125],
1297 scts[127:245],
1298 scts[247:],
1299 }
1300 if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
1301 return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
1302 }
1303 for i, expected := range expectedSCTs {
1304 if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
1305 return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
1306 }
1307 }
1308 return nil
1309 },
1310 }
1311 runClientTestTLS12(t, test)
1312
1313
1314
1315 }
1316
1317 func TestRenegotiationRejected(t *testing.T) {
1318 config := testConfig.Clone()
1319 test := &clientTest{
1320 name: "RenegotiationRejected",
1321 args: []string{"-state"},
1322 config: config,
1323 numRenegotiations: 1,
1324 renegotiationExpectedToFail: 1,
1325 checkRenegotiationError: func(renegotiationNum int, err error) error {
1326 if err == nil {
1327 return errors.New("expected error from renegotiation but got nil")
1328 }
1329 if !strings.Contains(err.Error(), "no renegotiation") {
1330 return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1331 }
1332 return nil
1333 },
1334 }
1335 runClientTestTLS12(t, test)
1336 }
1337
1338 func TestRenegotiateOnce(t *testing.T) {
1339 config := testConfig.Clone()
1340 config.Renegotiation = RenegotiateOnceAsClient
1341
1342 test := &clientTest{
1343 name: "RenegotiateOnce",
1344 args: []string{"-state"},
1345 config: config,
1346 numRenegotiations: 1,
1347 }
1348
1349 runClientTestTLS12(t, test)
1350 }
1351
1352 func TestRenegotiateTwice(t *testing.T) {
1353 config := testConfig.Clone()
1354 config.Renegotiation = RenegotiateFreelyAsClient
1355
1356 test := &clientTest{
1357 name: "RenegotiateTwice",
1358 args: []string{"-state"},
1359 config: config,
1360 numRenegotiations: 2,
1361 }
1362
1363 runClientTestTLS12(t, test)
1364 }
1365
1366 func TestRenegotiateTwiceRejected(t *testing.T) {
1367 config := testConfig.Clone()
1368 config.Renegotiation = RenegotiateOnceAsClient
1369
1370 test := &clientTest{
1371 name: "RenegotiateTwiceRejected",
1372 args: []string{"-state"},
1373 config: config,
1374 numRenegotiations: 2,
1375 renegotiationExpectedToFail: 2,
1376 checkRenegotiationError: func(renegotiationNum int, err error) error {
1377 if renegotiationNum == 1 {
1378 return err
1379 }
1380
1381 if err == nil {
1382 return errors.New("expected error from renegotiation but got nil")
1383 }
1384 if !strings.Contains(err.Error(), "no renegotiation") {
1385 return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1386 }
1387 return nil
1388 },
1389 }
1390
1391 runClientTestTLS12(t, test)
1392 }
1393
1394 func TestHandshakeClientExportKeyingMaterial(t *testing.T) {
1395 test := &clientTest{
1396 name: "ExportKeyingMaterial",
1397 config: testConfig.Clone(),
1398 validate: func(state ConnectionState) error {
1399 if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
1400 return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
1401 } else if len(km) != 42 {
1402 return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
1403 }
1404 return nil
1405 },
1406 }
1407 runClientTestTLS10(t, test)
1408 runClientTestTLS12(t, test)
1409 runClientTestTLS13(t, test)
1410 }
1411
1412 var hostnameInSNITests = []struct {
1413 in, out string
1414 }{
1415
1416 {"", ""},
1417 {"localhost", "localhost"},
1418 {"foo, bar, baz and qux", "foo, bar, baz and qux"},
1419
1420
1421 {"golang.org", "golang.org"},
1422 {"golang.org.", "golang.org"},
1423
1424
1425 {"1.2.3.4", ""},
1426
1427
1428 {"::1", ""},
1429 {"::1%lo0", ""},
1430 {"[::1]", ""},
1431 {"[::1%lo0]", ""},
1432 }
1433
1434 func TestHostnameInSNI(t *testing.T) {
1435 for _, tt := range hostnameInSNITests {
1436 c, s := localPipe(t)
1437
1438 go func(host string) {
1439 Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
1440 }(tt.in)
1441
1442 var header [5]byte
1443 if _, err := io.ReadFull(s, header[:]); err != nil {
1444 t.Fatal(err)
1445 }
1446 recordLen := int(header[3])<<8 | int(header[4])
1447
1448 record := make([]byte, recordLen)
1449 if _, err := io.ReadFull(s, record[:]); err != nil {
1450 t.Fatal(err)
1451 }
1452
1453 c.Close()
1454 s.Close()
1455
1456 var m clientHelloMsg
1457 if !m.unmarshal(record) {
1458 t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
1459 continue
1460 }
1461 if tt.in != tt.out && m.serverName == tt.in {
1462 t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
1463 }
1464 if m.serverName != tt.out {
1465 t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
1466 }
1467 }
1468 }
1469
1470 func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
1471
1472
1473
1474 c, s := localPipe(t)
1475 errChan := make(chan error, 1)
1476
1477 go func() {
1478 client := Client(c, &Config{
1479 ServerName: "foo",
1480 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1481 })
1482 errChan <- client.Handshake()
1483 }()
1484
1485 var header [5]byte
1486 if _, err := io.ReadFull(s, header[:]); err != nil {
1487 t.Fatal(err)
1488 }
1489 recordLen := int(header[3])<<8 | int(header[4])
1490
1491 record := make([]byte, recordLen)
1492 if _, err := io.ReadFull(s, record); err != nil {
1493 t.Fatal(err)
1494 }
1495
1496
1497
1498 serverHello := &serverHelloMsg{
1499 vers: VersionTLS12,
1500 random: make([]byte, 32),
1501 cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
1502 }
1503 serverHelloBytes := serverHello.marshal()
1504
1505 s.Write([]byte{
1506 byte(recordTypeHandshake),
1507 byte(VersionTLS12 >> 8),
1508 byte(VersionTLS12 & 0xff),
1509 byte(len(serverHelloBytes) >> 8),
1510 byte(len(serverHelloBytes)),
1511 })
1512 s.Write(serverHelloBytes)
1513 s.Close()
1514
1515 if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
1516 t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1517 }
1518 }
1519
1520 func TestVerifyConnection(t *testing.T) {
1521 t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
1522 t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
1523 }
1524
1525 func testVerifyConnection(t *testing.T, version uint16) {
1526 checkFields := func(c ConnectionState, called *int, errorType string) error {
1527 if c.Version != version {
1528 return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
1529 }
1530 if c.HandshakeComplete {
1531 return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
1532 }
1533 if c.ServerName != "example.golang" {
1534 return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
1535 }
1536 if c.NegotiatedProtocol != "protocol1" {
1537 return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
1538 }
1539 if c.CipherSuite == 0 {
1540 return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
1541 }
1542 wantDidResume := false
1543 if *called == 2 {
1544 wantDidResume = true
1545 }
1546 if c.DidResume != wantDidResume {
1547 return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
1548 }
1549 return nil
1550 }
1551
1552 tests := []struct {
1553 name string
1554 configureServer func(*Config, *int)
1555 configureClient func(*Config, *int)
1556 }{
1557 {
1558 name: "RequireAndVerifyClientCert",
1559 configureServer: func(config *Config, called *int) {
1560 config.ClientAuth = RequireAndVerifyClientCert
1561 config.VerifyConnection = func(c ConnectionState) error {
1562 *called++
1563 if l := len(c.PeerCertificates); l != 1 {
1564 return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1565 }
1566 if len(c.VerifiedChains) == 0 {
1567 return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
1568 }
1569 return checkFields(c, called, "server")
1570 }
1571 },
1572 configureClient: func(config *Config, called *int) {
1573 config.VerifyConnection = func(c ConnectionState) error {
1574 *called++
1575 if l := len(c.PeerCertificates); l != 1 {
1576 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1577 }
1578 if len(c.VerifiedChains) == 0 {
1579 return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1580 }
1581 if c.DidResume {
1582 return nil
1583
1584
1585 }
1586 if len(c.OCSPResponse) == 0 {
1587 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1588 }
1589 if len(c.SignedCertificateTimestamps) == 0 {
1590 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1591 }
1592 return checkFields(c, called, "client")
1593 }
1594 },
1595 },
1596 {
1597 name: "InsecureSkipVerify",
1598 configureServer: func(config *Config, called *int) {
1599 config.ClientAuth = RequireAnyClientCert
1600 config.InsecureSkipVerify = true
1601 config.VerifyConnection = func(c ConnectionState) error {
1602 *called++
1603 if l := len(c.PeerCertificates); l != 1 {
1604 return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1605 }
1606 if c.VerifiedChains != nil {
1607 return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1608 }
1609 return checkFields(c, called, "server")
1610 }
1611 },
1612 configureClient: func(config *Config, called *int) {
1613 config.InsecureSkipVerify = true
1614 config.VerifyConnection = func(c ConnectionState) error {
1615 *called++
1616 if l := len(c.PeerCertificates); l != 1 {
1617 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1618 }
1619 if c.VerifiedChains != nil {
1620 return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1621 }
1622 if c.DidResume {
1623 return nil
1624
1625
1626 }
1627 if len(c.OCSPResponse) == 0 {
1628 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1629 }
1630 if len(c.SignedCertificateTimestamps) == 0 {
1631 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1632 }
1633 return checkFields(c, called, "client")
1634 }
1635 },
1636 },
1637 {
1638 name: "NoClientCert",
1639 configureServer: func(config *Config, called *int) {
1640 config.ClientAuth = NoClientCert
1641 config.VerifyConnection = func(c ConnectionState) error {
1642 *called++
1643 return checkFields(c, called, "server")
1644 }
1645 },
1646 configureClient: func(config *Config, called *int) {
1647 config.VerifyConnection = func(c ConnectionState) error {
1648 *called++
1649 return checkFields(c, called, "client")
1650 }
1651 },
1652 },
1653 {
1654 name: "RequestClientCert",
1655 configureServer: func(config *Config, called *int) {
1656 config.ClientAuth = RequestClientCert
1657 config.VerifyConnection = func(c ConnectionState) error {
1658 *called++
1659 return checkFields(c, called, "server")
1660 }
1661 },
1662 configureClient: func(config *Config, called *int) {
1663 config.Certificates = nil
1664 config.VerifyConnection = func(c ConnectionState) error {
1665 *called++
1666 if l := len(c.PeerCertificates); l != 1 {
1667 return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1668 }
1669 if len(c.VerifiedChains) == 0 {
1670 return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1671 }
1672 if c.DidResume {
1673 return nil
1674
1675
1676 }
1677 if len(c.OCSPResponse) == 0 {
1678 return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1679 }
1680 if len(c.SignedCertificateTimestamps) == 0 {
1681 return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1682 }
1683 return checkFields(c, called, "client")
1684 }
1685 },
1686 },
1687 }
1688 for _, test := range tests {
1689 issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1690 if err != nil {
1691 panic(err)
1692 }
1693 rootCAs := x509.NewCertPool()
1694 rootCAs.AddCert(issuer)
1695
1696 var serverCalled, clientCalled int
1697
1698 serverConfig := &Config{
1699 MaxVersion: version,
1700 Certificates: []Certificate{testConfig.Certificates[0]},
1701 ClientCAs: rootCAs,
1702 NextProtos: []string{"protocol1"},
1703 }
1704 serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1705 serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
1706 test.configureServer(serverConfig, &serverCalled)
1707
1708 clientConfig := &Config{
1709 MaxVersion: version,
1710 ClientSessionCache: NewLRUClientSessionCache(32),
1711 RootCAs: rootCAs,
1712 ServerName: "example.golang",
1713 Certificates: []Certificate{testConfig.Certificates[0]},
1714 NextProtos: []string{"protocol1"},
1715 }
1716 test.configureClient(clientConfig, &clientCalled)
1717
1718 testHandshakeState := func(name string, didResume bool) {
1719 _, hs, err := testHandshake(t, clientConfig, serverConfig)
1720 if err != nil {
1721 t.Fatalf("%s: handshake failed: %s", name, err)
1722 }
1723 if hs.DidResume != didResume {
1724 t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
1725 }
1726 wantCalled := 1
1727 if didResume {
1728 wantCalled = 2
1729 }
1730 if clientCalled != wantCalled {
1731 t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
1732 }
1733 if serverCalled != wantCalled {
1734 t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
1735 }
1736 }
1737 testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
1738 testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
1739 }
1740 }
1741
1742 func TestVerifyPeerCertificate(t *testing.T) {
1743 t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
1744 t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
1745 }
1746
1747 func testVerifyPeerCertificate(t *testing.T, version uint16) {
1748 issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1749 if err != nil {
1750 panic(err)
1751 }
1752
1753 rootCAs := x509.NewCertPool()
1754 rootCAs.AddCert(issuer)
1755
1756 now := func() time.Time { return time.Unix(1476984729, 0) }
1757
1758 sentinelErr := errors.New("TestVerifyPeerCertificate")
1759
1760 verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1761 if l := len(rawCerts); l != 1 {
1762 return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1763 }
1764 if len(validatedChains) == 0 {
1765 return errors.New("got len(validatedChains) = 0, wanted non-zero")
1766 }
1767 *called = true
1768 return nil
1769 }
1770 verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
1771 if l := len(c.PeerCertificates); l != 1 {
1772 return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
1773 }
1774 if len(c.VerifiedChains) == 0 {
1775 return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
1776 }
1777 if isClient && len(c.OCSPResponse) == 0 {
1778 return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
1779 }
1780 *called = true
1781 return nil
1782 }
1783
1784 tests := []struct {
1785 configureServer func(*Config, *bool)
1786 configureClient func(*Config, *bool)
1787 validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
1788 }{
1789 {
1790 configureServer: func(config *Config, called *bool) {
1791 config.InsecureSkipVerify = false
1792 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1793 return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1794 }
1795 },
1796 configureClient: func(config *Config, called *bool) {
1797 config.InsecureSkipVerify = false
1798 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1799 return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1800 }
1801 },
1802 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1803 if clientErr != nil {
1804 t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1805 }
1806 if serverErr != nil {
1807 t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1808 }
1809 if !clientCalled {
1810 t.Errorf("test[%d]: client did not call callback", testNo)
1811 }
1812 if !serverCalled {
1813 t.Errorf("test[%d]: server did not call callback", testNo)
1814 }
1815 },
1816 },
1817 {
1818 configureServer: func(config *Config, called *bool) {
1819 config.InsecureSkipVerify = false
1820 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1821 return sentinelErr
1822 }
1823 },
1824 configureClient: func(config *Config, called *bool) {
1825 config.VerifyPeerCertificate = nil
1826 },
1827 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1828 if serverErr != sentinelErr {
1829 t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1830 }
1831 },
1832 },
1833 {
1834 configureServer: func(config *Config, called *bool) {
1835 config.InsecureSkipVerify = false
1836 },
1837 configureClient: func(config *Config, called *bool) {
1838 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1839 return sentinelErr
1840 }
1841 },
1842 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1843 if clientErr != sentinelErr {
1844 t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1845 }
1846 },
1847 },
1848 {
1849 configureServer: func(config *Config, called *bool) {
1850 config.InsecureSkipVerify = false
1851 },
1852 configureClient: func(config *Config, called *bool) {
1853 config.InsecureSkipVerify = true
1854 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1855 if l := len(rawCerts); l != 1 {
1856 return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1857 }
1858
1859
1860
1861 if l := len(validatedChains); l != 0 {
1862 return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
1863 }
1864 *called = true
1865 return nil
1866 }
1867 },
1868 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1869 if clientErr != nil {
1870 t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1871 }
1872 if serverErr != nil {
1873 t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1874 }
1875 if !clientCalled {
1876 t.Errorf("test[%d]: client did not call callback", testNo)
1877 }
1878 },
1879 },
1880 {
1881 configureServer: func(config *Config, called *bool) {
1882 config.InsecureSkipVerify = false
1883 config.VerifyConnection = func(c ConnectionState) error {
1884 return verifyConnectionCallback(called, false, c)
1885 }
1886 },
1887 configureClient: func(config *Config, called *bool) {
1888 config.InsecureSkipVerify = false
1889 config.VerifyConnection = func(c ConnectionState) error {
1890 return verifyConnectionCallback(called, true, c)
1891 }
1892 },
1893 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1894 if clientErr != nil {
1895 t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1896 }
1897 if serverErr != nil {
1898 t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1899 }
1900 if !clientCalled {
1901 t.Errorf("test[%d]: client did not call callback", testNo)
1902 }
1903 if !serverCalled {
1904 t.Errorf("test[%d]: server did not call callback", testNo)
1905 }
1906 },
1907 },
1908 {
1909 configureServer: func(config *Config, called *bool) {
1910 config.InsecureSkipVerify = false
1911 config.VerifyConnection = func(c ConnectionState) error {
1912 return sentinelErr
1913 }
1914 },
1915 configureClient: func(config *Config, called *bool) {
1916 config.InsecureSkipVerify = false
1917 config.VerifyConnection = nil
1918 },
1919 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1920 if serverErr != sentinelErr {
1921 t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1922 }
1923 },
1924 },
1925 {
1926 configureServer: func(config *Config, called *bool) {
1927 config.InsecureSkipVerify = false
1928 config.VerifyConnection = nil
1929 },
1930 configureClient: func(config *Config, called *bool) {
1931 config.InsecureSkipVerify = false
1932 config.VerifyConnection = func(c ConnectionState) error {
1933 return sentinelErr
1934 }
1935 },
1936 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1937 if clientErr != sentinelErr {
1938 t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1939 }
1940 },
1941 },
1942 {
1943 configureServer: func(config *Config, called *bool) {
1944 config.InsecureSkipVerify = false
1945 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1946 return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1947 }
1948 config.VerifyConnection = func(c ConnectionState) error {
1949 return sentinelErr
1950 }
1951 },
1952 configureClient: func(config *Config, called *bool) {
1953 config.InsecureSkipVerify = false
1954 config.VerifyPeerCertificate = nil
1955 config.VerifyConnection = nil
1956 },
1957 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1958 if serverErr != sentinelErr {
1959 t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1960 }
1961 if !serverCalled {
1962 t.Errorf("test[%d]: server did not call callback", testNo)
1963 }
1964 },
1965 },
1966 {
1967 configureServer: func(config *Config, called *bool) {
1968 config.InsecureSkipVerify = false
1969 config.VerifyPeerCertificate = nil
1970 config.VerifyConnection = nil
1971 },
1972 configureClient: func(config *Config, called *bool) {
1973 config.InsecureSkipVerify = false
1974 config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1975 return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1976 }
1977 config.VerifyConnection = func(c ConnectionState) error {
1978 return sentinelErr
1979 }
1980 },
1981 validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1982 if clientErr != sentinelErr {
1983 t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1984 }
1985 if !clientCalled {
1986 t.Errorf("test[%d]: client did not call callback", testNo)
1987 }
1988 },
1989 },
1990 }
1991
1992 for i, test := range tests {
1993 c, s := localPipe(t)
1994 done := make(chan error)
1995
1996 var clientCalled, serverCalled bool
1997
1998 go func() {
1999 config := testConfig.Clone()
2000 config.ServerName = "example.golang"
2001 config.ClientAuth = RequireAndVerifyClientCert
2002 config.ClientCAs = rootCAs
2003 config.Time = now
2004 config.MaxVersion = version
2005 config.Certificates = make([]Certificate, 1)
2006 config.Certificates[0].Certificate = [][]byte{testRSACertificate}
2007 config.Certificates[0].PrivateKey = testRSAPrivateKey
2008 config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
2009 config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
2010 test.configureServer(config, &serverCalled)
2011
2012 err = Server(s, config).Handshake()
2013 s.Close()
2014 done <- err
2015 }()
2016
2017 config := testConfig.Clone()
2018 config.ServerName = "example.golang"
2019 config.RootCAs = rootCAs
2020 config.Time = now
2021 config.MaxVersion = version
2022 test.configureClient(config, &clientCalled)
2023 clientErr := Client(c, config).Handshake()
2024 c.Close()
2025 serverErr := <-done
2026
2027 test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
2028 }
2029 }
2030
2031
2032
2033 type brokenConn struct {
2034 net.Conn
2035
2036
2037
2038 breakAfter int
2039
2040
2041 numWrites int
2042 }
2043
2044
2045 var brokenConnErr = errors.New("too many writes to brokenConn")
2046
2047 func (b *brokenConn) Write(data []byte) (int, error) {
2048 if b.numWrites >= b.breakAfter {
2049 return 0, brokenConnErr
2050 }
2051
2052 b.numWrites++
2053 return b.Conn.Write(data)
2054 }
2055
2056 func TestFailedWrite(t *testing.T) {
2057
2058 for _, breakAfter := range []int{0, 1} {
2059 c, s := localPipe(t)
2060 done := make(chan bool)
2061
2062 go func() {
2063 Server(s, testConfig).Handshake()
2064 s.Close()
2065 done <- true
2066 }()
2067
2068 brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
2069 err := Client(brokenC, testConfig).Handshake()
2070 if err != brokenConnErr {
2071 t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
2072 }
2073 brokenC.Close()
2074
2075 <-done
2076 }
2077 }
2078
2079
2080 type writeCountingConn struct {
2081 net.Conn
2082
2083
2084 numWrites int
2085 }
2086
2087 func (wcc *writeCountingConn) Write(data []byte) (int, error) {
2088 wcc.numWrites++
2089 return wcc.Conn.Write(data)
2090 }
2091
2092 func TestBuffering(t *testing.T) {
2093 t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) })
2094 t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) })
2095 }
2096
2097 func testBuffering(t *testing.T, version uint16) {
2098 c, s := localPipe(t)
2099 done := make(chan bool)
2100
2101 clientWCC := &writeCountingConn{Conn: c}
2102 serverWCC := &writeCountingConn{Conn: s}
2103
2104 go func() {
2105 config := testConfig.Clone()
2106 config.MaxVersion = version
2107 Server(serverWCC, config).Handshake()
2108 serverWCC.Close()
2109 done <- true
2110 }()
2111
2112 err := Client(clientWCC, testConfig).Handshake()
2113 if err != nil {
2114 t.Fatal(err)
2115 }
2116 clientWCC.Close()
2117 <-done
2118
2119 var expectedClient, expectedServer int
2120 if version == VersionTLS13 {
2121 expectedClient = 2
2122 expectedServer = 1
2123 } else {
2124 expectedClient = 2
2125 expectedServer = 2
2126 }
2127
2128 if n := clientWCC.numWrites; n != expectedClient {
2129 t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n)
2130 }
2131
2132 if n := serverWCC.numWrites; n != expectedServer {
2133 t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n)
2134 }
2135 }
2136
2137 func TestAlertFlushing(t *testing.T) {
2138 c, s := localPipe(t)
2139 done := make(chan bool)
2140
2141 clientWCC := &writeCountingConn{Conn: c}
2142 serverWCC := &writeCountingConn{Conn: s}
2143
2144 serverConfig := testConfig.Clone()
2145
2146
2147 brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
2148 brokenKey.D = big.NewInt(42)
2149 serverConfig.Certificates = []Certificate{{
2150 Certificate: [][]byte{testRSACertificate},
2151 PrivateKey: &brokenKey,
2152 }}
2153
2154 go func() {
2155 Server(serverWCC, serverConfig).Handshake()
2156 serverWCC.Close()
2157 done <- true
2158 }()
2159
2160 err := Client(clientWCC, testConfig).Handshake()
2161 if err == nil {
2162 t.Fatal("client unexpectedly returned no error")
2163 }
2164
2165 const expectedError = "remote error: tls: internal error"
2166 if e := err.Error(); !strings.Contains(e, expectedError) {
2167 t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
2168 }
2169 clientWCC.Close()
2170 <-done
2171
2172 if n := serverWCC.numWrites; n != 1 {
2173 t.Errorf("expected server handshake to complete with one write, but saw %d", n)
2174 }
2175 }
2176
2177 func TestHandshakeRace(t *testing.T) {
2178 if testing.Short() {
2179 t.Skip("skipping in -short mode")
2180 }
2181 t.Parallel()
2182
2183
2184
2185 for i := 0; i < 32; i++ {
2186 c, s := localPipe(t)
2187
2188 go func() {
2189 server := Server(s, testConfig)
2190 if err := server.Handshake(); err != nil {
2191 panic(err)
2192 }
2193
2194 var request [1]byte
2195 if n, err := server.Read(request[:]); err != nil || n != 1 {
2196 panic(err)
2197 }
2198
2199 server.Write(request[:])
2200 server.Close()
2201 }()
2202
2203 startWrite := make(chan struct{})
2204 startRead := make(chan struct{})
2205 readDone := make(chan struct{}, 1)
2206
2207 client := Client(c, testConfig)
2208 go func() {
2209 <-startWrite
2210 var request [1]byte
2211 client.Write(request[:])
2212 }()
2213
2214 go func() {
2215 <-startRead
2216 var reply [1]byte
2217 if _, err := io.ReadFull(client, reply[:]); err != nil {
2218 panic(err)
2219 }
2220 c.Close()
2221 readDone <- struct{}{}
2222 }()
2223
2224 if i&1 == 1 {
2225 startWrite <- struct{}{}
2226 startRead <- struct{}{}
2227 } else {
2228 startRead <- struct{}{}
2229 startWrite <- struct{}{}
2230 }
2231 <-readDone
2232 }
2233 }
2234
2235 var getClientCertificateTests = []struct {
2236 setup func(*Config, *Config)
2237 expectedClientError string
2238 verify func(*testing.T, int, *ConnectionState)
2239 }{
2240 {
2241 func(clientConfig, serverConfig *Config) {
2242
2243
2244
2245 serverConfig.ClientCAs = nil
2246 clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2247 if len(cri.SignatureSchemes) == 0 {
2248 panic("empty SignatureSchemes")
2249 }
2250 if len(cri.AcceptableCAs) != 0 {
2251 panic("AcceptableCAs should have been empty")
2252 }
2253 return new(Certificate), nil
2254 }
2255 },
2256 "",
2257 func(t *testing.T, testNum int, cs *ConnectionState) {
2258 if l := len(cs.PeerCertificates); l != 0 {
2259 t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2260 }
2261 },
2262 },
2263 {
2264 func(clientConfig, serverConfig *Config) {
2265
2266
2267 clientConfig.MaxVersion = VersionTLS11
2268 clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2269 if len(cri.SignatureSchemes) == 0 {
2270 panic("empty SignatureSchemes")
2271 }
2272 return new(Certificate), nil
2273 }
2274 },
2275 "",
2276 func(t *testing.T, testNum int, cs *ConnectionState) {
2277 if l := len(cs.PeerCertificates); l != 0 {
2278 t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2279 }
2280 },
2281 },
2282 {
2283 func(clientConfig, serverConfig *Config) {
2284
2285
2286 clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2287 return nil, errors.New("GetClientCertificate")
2288 }
2289 },
2290 "GetClientCertificate",
2291 func(t *testing.T, testNum int, cs *ConnectionState) {
2292 },
2293 },
2294 {
2295 func(clientConfig, serverConfig *Config) {
2296 clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2297 if len(cri.AcceptableCAs) == 0 {
2298 panic("empty AcceptableCAs")
2299 }
2300 cert := &Certificate{
2301 Certificate: [][]byte{testRSACertificate},
2302 PrivateKey: testRSAPrivateKey,
2303 }
2304 return cert, nil
2305 }
2306 },
2307 "",
2308 func(t *testing.T, testNum int, cs *ConnectionState) {
2309 if len(cs.VerifiedChains) == 0 {
2310 t.Errorf("#%d: expected some verified chains, but found none", testNum)
2311 }
2312 },
2313 },
2314 }
2315
2316 func TestGetClientCertificate(t *testing.T) {
2317 t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) })
2318 t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) })
2319 }
2320
2321 func testGetClientCertificate(t *testing.T, version uint16) {
2322 issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2323 if err != nil {
2324 panic(err)
2325 }
2326
2327 for i, test := range getClientCertificateTests {
2328 serverConfig := testConfig.Clone()
2329 serverConfig.ClientAuth = VerifyClientCertIfGiven
2330 serverConfig.RootCAs = x509.NewCertPool()
2331 serverConfig.RootCAs.AddCert(issuer)
2332 serverConfig.ClientCAs = serverConfig.RootCAs
2333 serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
2334 serverConfig.MaxVersion = version
2335
2336 clientConfig := testConfig.Clone()
2337 clientConfig.MaxVersion = version
2338
2339 test.setup(clientConfig, serverConfig)
2340
2341 type serverResult struct {
2342 cs ConnectionState
2343 err error
2344 }
2345
2346 c, s := localPipe(t)
2347 done := make(chan serverResult)
2348
2349 go func() {
2350 defer s.Close()
2351 server := Server(s, serverConfig)
2352 err := server.Handshake()
2353
2354 var cs ConnectionState
2355 if err == nil {
2356 cs = server.ConnectionState()
2357 }
2358 done <- serverResult{cs, err}
2359 }()
2360
2361 clientErr := Client(c, clientConfig).Handshake()
2362 c.Close()
2363
2364 result := <-done
2365
2366 if clientErr != nil {
2367 if len(test.expectedClientError) == 0 {
2368 t.Errorf("#%d: client error: %v", i, clientErr)
2369 } else if got := clientErr.Error(); got != test.expectedClientError {
2370 t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
2371 } else {
2372 test.verify(t, i, &result.cs)
2373 }
2374 } else if len(test.expectedClientError) > 0 {
2375 t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
2376 } else if err := result.err; err != nil {
2377 t.Errorf("#%d: server error: %v", i, err)
2378 } else {
2379 test.verify(t, i, &result.cs)
2380 }
2381 }
2382 }
2383
2384 func TestRSAPSSKeyError(t *testing.T) {
2385
2386
2387
2388
2389 b, _ := pem.Decode([]byte(`
2390 -----BEGIN CERTIFICATE-----
2391 MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
2392 MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
2393 AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
2394 MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
2395 ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
2396 /a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
2397 b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
2398 QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
2399 czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
2400 JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
2401 AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
2402 OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
2403 AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
2404 sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
2405 H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
2406 KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
2407 bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
2408 HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
2409 RwBA9Xk1KBNF
2410 -----END CERTIFICATE-----`))
2411 if b == nil {
2412 t.Fatal("Failed to decode certificate")
2413 }
2414 cert, err := x509.ParseCertificate(b.Bytes)
2415 if err != nil {
2416 return
2417 }
2418 if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
2419 t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms")
2420 }
2421 }
2422
2423 func TestCloseClientConnectionOnIdleServer(t *testing.T) {
2424 clientConn, serverConn := localPipe(t)
2425 client := Client(clientConn, testConfig.Clone())
2426 go func() {
2427 var b [1]byte
2428 serverConn.Read(b[:])
2429 client.Close()
2430 }()
2431 client.SetWriteDeadline(time.Now().Add(time.Minute))
2432 err := client.Handshake()
2433 if err != nil {
2434 if err, ok := err.(net.Error); ok && err.Timeout() {
2435 t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
2436 }
2437 } else {
2438 t.Errorf("Error expected, but no error returned")
2439 }
2440 }
2441
2442 func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error {
2443 defer func() { testingOnlyForceDowngradeCanary = false }()
2444 testingOnlyForceDowngradeCanary = true
2445
2446 clientConfig := testConfig.Clone()
2447 clientConfig.MaxVersion = clientVersion
2448 serverConfig := testConfig.Clone()
2449 serverConfig.MaxVersion = serverVersion
2450 _, _, err := testHandshake(t, clientConfig, serverConfig)
2451 return err
2452 }
2453
2454 func TestDowngradeCanary(t *testing.T) {
2455 if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil {
2456 t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected")
2457 }
2458 if testing.Short() {
2459 t.Skip("skipping the rest of the checks in short mode")
2460 }
2461 if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil {
2462 t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected")
2463 }
2464 if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil {
2465 t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected")
2466 }
2467 if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil {
2468 t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected")
2469 }
2470 if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil {
2471 t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected")
2472 }
2473 if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil {
2474 t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3")
2475 }
2476 if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil {
2477 t.Errorf("client didn't ignore expected TLS 1.2 canary")
2478 }
2479 if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil {
2480 t.Errorf("client unexpectedly reacted to a canary in TLS 1.1")
2481 }
2482 if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil {
2483 t.Errorf("client unexpectedly reacted to a canary in TLS 1.0")
2484 }
2485 }
2486
2487 func TestResumptionKeepsOCSPAndSCT(t *testing.T) {
2488 t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) })
2489 t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) })
2490 }
2491
2492 func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
2493 issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2494 if err != nil {
2495 t.Fatalf("failed to parse test issuer")
2496 }
2497 roots := x509.NewCertPool()
2498 roots.AddCert(issuer)
2499 clientConfig := &Config{
2500 MaxVersion: ver,
2501 ClientSessionCache: NewLRUClientSessionCache(32),
2502 ServerName: "example.golang",
2503 RootCAs: roots,
2504 }
2505 serverConfig := testConfig.Clone()
2506 serverConfig.MaxVersion = ver
2507 serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3}
2508 serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}}
2509
2510 _, ccs, err := testHandshake(t, clientConfig, serverConfig)
2511 if err != nil {
2512 t.Fatalf("handshake failed: %s", err)
2513 }
2514
2515
2516 if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2517 t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v",
2518 serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2519 }
2520 if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2521 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v",
2522 serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2523 }
2524
2525
2526 oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps
2527 serverConfig.Certificates[0].SignedCertificateTimestamps = nil
2528 _, ccs, err = testHandshake(t, clientConfig, serverConfig)
2529 if err != nil {
2530 t.Fatalf("handshake failed: %s", err)
2531 }
2532 if !ccs.DidResume {
2533 t.Fatalf("expected session to be resumed")
2534 }
2535
2536
2537 if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2538 t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v",
2539 serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2540 }
2541 if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) {
2542 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2543 oldSCTs, ccs.SignedCertificateTimestamps)
2544 }
2545
2546
2547
2548 if ver == VersionTLS13 {
2549 return
2550 }
2551
2552
2553 serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}}
2554 _, ccs, err = testHandshake(t, clientConfig, serverConfig)
2555 if err != nil {
2556 t.Fatalf("handshake failed: %s", err)
2557 }
2558 if !ccs.DidResume {
2559 t.Fatalf("expected session to be resumed")
2560 }
2561 if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2562 t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2563 serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2564 }
2565 }
2566
2567
2568
2569
2570 func TestClientHandshakeContextCancellation(t *testing.T) {
2571 c, s := localPipe(t)
2572 ctx, cancel := context.WithCancel(context.Background())
2573 unblockServer := make(chan struct{})
2574 defer close(unblockServer)
2575 go func() {
2576 cancel()
2577 <-unblockServer
2578 _ = s.Close()
2579 }()
2580 cli := Client(c, testConfig)
2581
2582
2583 err := cli.HandshakeContext(ctx)
2584 if err == nil {
2585 t.Fatal("Client handshake did not error when the context was canceled")
2586 }
2587 if err != context.Canceled {
2588 t.Errorf("Unexpected client handshake error: %v", err)
2589 }
2590 if runtime.GOARCH == "wasm" {
2591 t.Skip("conn.Close does not error as expected when called multiple times on WASM")
2592 }
2593 err = cli.Close()
2594 if err == nil {
2595 t.Error("Client connection was not closed when the context was canceled")
2596 }
2597 }
2598
View as plain text