1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "io"
19 "net"
20 "sync"
21 "sync/atomic"
22 "time"
23 )
24
25
26
27 type Conn struct {
28
29 conn net.Conn
30 isClient bool
31 handshakeFn func(context.Context) error
32
33
34
35
36
37 handshakeStatus uint32
38
39 handshakeMutex sync.Mutex
40 handshakeErr error
41 vers uint16
42 haveVers bool
43 config *Config
44
45
46
47 handshakes int
48 didResume bool
49 cipherSuite uint16
50 ocspResponse []byte
51 scts [][]byte
52 peerCertificates []*x509.Certificate
53
54
55 verifiedChains [][]*x509.Certificate
56
57 serverName string
58
59
60
61 secureRenegotiation bool
62
63 ekm func(label string, context []byte, length int) ([]byte, error)
64
65
66 resumptionSecret []byte
67
68
69
70
71 ticketKeys []ticketKey
72
73
74
75
76
77 clientFinishedIsFirst bool
78
79
80 closeNotifyErr error
81
82
83 closeNotifySent bool
84
85
86
87
88
89 clientFinished [12]byte
90 serverFinished [12]byte
91
92
93 clientProtocol string
94
95
96 in, out halfConn
97 rawInput bytes.Buffer
98 input bytes.Reader
99 hand bytes.Buffer
100 buffering bool
101 sendBuf []byte
102
103
104
105 bytesSent int64
106 packetsSent int64
107
108
109
110
111 retryCount int
112
113
114
115
116 activeCall int32
117
118 tmp [16]byte
119 }
120
121
122
123
124
125
126 func (c *Conn) LocalAddr() net.Addr {
127 return c.conn.LocalAddr()
128 }
129
130
131 func (c *Conn) RemoteAddr() net.Addr {
132 return c.conn.RemoteAddr()
133 }
134
135
136
137
138 func (c *Conn) SetDeadline(t time.Time) error {
139 return c.conn.SetDeadline(t)
140 }
141
142
143
144 func (c *Conn) SetReadDeadline(t time.Time) error {
145 return c.conn.SetReadDeadline(t)
146 }
147
148
149
150
151 func (c *Conn) SetWriteDeadline(t time.Time) error {
152 return c.conn.SetWriteDeadline(t)
153 }
154
155
156
157 type halfConn struct {
158 sync.Mutex
159
160 err error
161 version uint16
162 cipher interface{}
163 mac hash.Hash
164 seq [8]byte
165
166 scratchBuf [13]byte
167
168 nextCipher interface{}
169 nextMac hash.Hash
170
171 trafficSecret []byte
172 }
173
174 type permanentError struct {
175 err net.Error
176 }
177
178 func (e *permanentError) Error() string { return e.err.Error() }
179 func (e *permanentError) Unwrap() error { return e.err }
180 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
181 func (e *permanentError) Temporary() bool { return false }
182
183 func (hc *halfConn) setErrorLocked(err error) error {
184 if e, ok := err.(net.Error); ok {
185 hc.err = &permanentError{err: e}
186 } else {
187 hc.err = err
188 }
189 return hc.err
190 }
191
192
193
194 func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) {
195 hc.version = version
196 hc.nextCipher = cipher
197 hc.nextMac = mac
198 }
199
200
201
202 func (hc *halfConn) changeCipherSpec() error {
203 if hc.nextCipher == nil || hc.version == VersionTLS13 {
204 return alertInternalError
205 }
206 hc.cipher = hc.nextCipher
207 hc.mac = hc.nextMac
208 hc.nextCipher = nil
209 hc.nextMac = nil
210 for i := range hc.seq {
211 hc.seq[i] = 0
212 }
213 return nil
214 }
215
216 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
217 hc.trafficSecret = secret
218 key, iv := suite.trafficKey(secret)
219 hc.cipher = suite.aead(key, iv)
220 for i := range hc.seq {
221 hc.seq[i] = 0
222 }
223 }
224
225
226 func (hc *halfConn) incSeq() {
227 for i := 7; i >= 0; i-- {
228 hc.seq[i]++
229 if hc.seq[i] != 0 {
230 return
231 }
232 }
233
234
235
236
237 panic("TLS: sequence number wraparound")
238 }
239
240
241
242
243 func (hc *halfConn) explicitNonceLen() int {
244 if hc.cipher == nil {
245 return 0
246 }
247
248 switch c := hc.cipher.(type) {
249 case cipher.Stream:
250 return 0
251 case aead:
252 return c.explicitNonceLen()
253 case cbcMode:
254
255 if hc.version >= VersionTLS11 {
256 return c.BlockSize()
257 }
258 return 0
259 default:
260 panic("unknown cipher type")
261 }
262 }
263
264
265
266
267 func extractPadding(payload []byte) (toRemove int, good byte) {
268 if len(payload) < 1 {
269 return 0, 0
270 }
271
272 paddingLen := payload[len(payload)-1]
273 t := uint(len(payload)-1) - uint(paddingLen)
274
275 good = byte(int32(^t) >> 31)
276
277
278 toCheck := 256
279
280 if toCheck > len(payload) {
281 toCheck = len(payload)
282 }
283
284 for i := 0; i < toCheck; i++ {
285 t := uint(paddingLen) - uint(i)
286
287 mask := byte(int32(^t) >> 31)
288 b := payload[len(payload)-1-i]
289 good &^= mask&paddingLen ^ mask&b
290 }
291
292
293
294 good &= good << 4
295 good &= good << 2
296 good &= good << 1
297 good = uint8(int8(good) >> 7)
298
299
300
301
302
303
304
305
306
307
308 paddingLen &= good
309
310 toRemove = int(paddingLen) + 1
311 return
312 }
313
314 func roundUp(a, b int) int {
315 return a + (b-a%b)%b
316 }
317
318
319 type cbcMode interface {
320 cipher.BlockMode
321 SetIV([]byte)
322 }
323
324
325
326 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
327 var plaintext []byte
328 typ := recordType(record[0])
329 payload := record[recordHeaderLen:]
330
331
332
333 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
334 return payload, typ, nil
335 }
336
337 paddingGood := byte(255)
338 paddingLen := 0
339
340 explicitNonceLen := hc.explicitNonceLen()
341
342 if hc.cipher != nil {
343 switch c := hc.cipher.(type) {
344 case cipher.Stream:
345 c.XORKeyStream(payload, payload)
346 case aead:
347 if len(payload) < explicitNonceLen {
348 return nil, 0, alertBadRecordMAC
349 }
350 nonce := payload[:explicitNonceLen]
351 if len(nonce) == 0 {
352 nonce = hc.seq[:]
353 }
354 payload = payload[explicitNonceLen:]
355
356 var additionalData []byte
357 if hc.version == VersionTLS13 {
358 additionalData = record[:recordHeaderLen]
359 } else {
360 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
361 additionalData = append(additionalData, record[:3]...)
362 n := len(payload) - c.Overhead()
363 additionalData = append(additionalData, byte(n>>8), byte(n))
364 }
365
366 var err error
367 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
368 if err != nil {
369 return nil, 0, alertBadRecordMAC
370 }
371 case cbcMode:
372 blockSize := c.BlockSize()
373 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
374 if len(payload)%blockSize != 0 || len(payload) < minPayload {
375 return nil, 0, alertBadRecordMAC
376 }
377
378 if explicitNonceLen > 0 {
379 c.SetIV(payload[:explicitNonceLen])
380 payload = payload[explicitNonceLen:]
381 }
382 c.CryptBlocks(payload, payload)
383
384
385
386
387
388
389
390 paddingLen, paddingGood = extractPadding(payload)
391 default:
392 panic("unknown cipher type")
393 }
394
395 if hc.version == VersionTLS13 {
396 if typ != recordTypeApplicationData {
397 return nil, 0, alertUnexpectedMessage
398 }
399 if len(plaintext) > maxPlaintext+1 {
400 return nil, 0, alertRecordOverflow
401 }
402
403 for i := len(plaintext) - 1; i >= 0; i-- {
404 if plaintext[i] != 0 {
405 typ = recordType(plaintext[i])
406 plaintext = plaintext[:i]
407 break
408 }
409 if i == 0 {
410 return nil, 0, alertUnexpectedMessage
411 }
412 }
413 }
414 } else {
415 plaintext = payload
416 }
417
418 if hc.mac != nil {
419 macSize := hc.mac.Size()
420 if len(payload) < macSize {
421 return nil, 0, alertBadRecordMAC
422 }
423
424 n := len(payload) - macSize - paddingLen
425 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
426 record[3] = byte(n >> 8)
427 record[4] = byte(n)
428 remoteMAC := payload[n : n+macSize]
429 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
430
431
432
433
434
435
436
437
438 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
439 if macAndPaddingGood != 1 {
440 return nil, 0, alertBadRecordMAC
441 }
442
443 plaintext = payload[:n]
444 }
445
446 hc.incSeq()
447 return plaintext, typ, nil
448 }
449
450
451
452
453 func sliceForAppend(in []byte, n int) (head, tail []byte) {
454 if total := len(in) + n; cap(in) >= total {
455 head = in[:total]
456 } else {
457 head = make([]byte, total)
458 copy(head, in)
459 }
460 tail = head[len(in):]
461 return
462 }
463
464
465
466 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
467 if hc.cipher == nil {
468 return append(record, payload...), nil
469 }
470
471 var explicitNonce []byte
472 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
473 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
474 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
475
476
477
478
479
480
481
482
483
484 copy(explicitNonce, hc.seq[:])
485 } else {
486 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
487 return nil, err
488 }
489 }
490 }
491
492 var dst []byte
493 switch c := hc.cipher.(type) {
494 case cipher.Stream:
495 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
496 record, dst = sliceForAppend(record, len(payload)+len(mac))
497 c.XORKeyStream(dst[:len(payload)], payload)
498 c.XORKeyStream(dst[len(payload):], mac)
499 case aead:
500 nonce := explicitNonce
501 if len(nonce) == 0 {
502 nonce = hc.seq[:]
503 }
504
505 if hc.version == VersionTLS13 {
506 record = append(record, payload...)
507
508
509 record = append(record, record[0])
510 record[0] = byte(recordTypeApplicationData)
511
512 n := len(payload) + 1 + c.Overhead()
513 record[3] = byte(n >> 8)
514 record[4] = byte(n)
515
516 record = c.Seal(record[:recordHeaderLen],
517 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
518 } else {
519 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
520 additionalData = append(additionalData, record[:recordHeaderLen]...)
521 record = c.Seal(record, nonce, payload, additionalData)
522 }
523 case cbcMode:
524 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
525 blockSize := c.BlockSize()
526 plaintextLen := len(payload) + len(mac)
527 paddingLen := blockSize - plaintextLen%blockSize
528 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
529 copy(dst, payload)
530 copy(dst[len(payload):], mac)
531 for i := plaintextLen; i < len(dst); i++ {
532 dst[i] = byte(paddingLen - 1)
533 }
534 if len(explicitNonce) > 0 {
535 c.SetIV(explicitNonce)
536 }
537 c.CryptBlocks(dst, dst)
538 default:
539 panic("unknown cipher type")
540 }
541
542
543 n := len(record) - recordHeaderLen
544 record[3] = byte(n >> 8)
545 record[4] = byte(n)
546 hc.incSeq()
547
548 return record, nil
549 }
550
551
552 type RecordHeaderError struct {
553
554 Msg string
555
556
557 RecordHeader [5]byte
558
559
560
561
562 Conn net.Conn
563 }
564
565 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
566
567 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
568 err.Msg = msg
569 err.Conn = conn
570 copy(err.RecordHeader[:], c.rawInput.Bytes())
571 return err
572 }
573
574 func (c *Conn) readRecord() error {
575 return c.readRecordOrCCS(false)
576 }
577
578 func (c *Conn) readChangeCipherSpec() error {
579 return c.readRecordOrCCS(true)
580 }
581
582
583
584
585
586
587
588
589
590
591
592
593
594 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
595 if c.in.err != nil {
596 return c.in.err
597 }
598 handshakeComplete := c.handshakeComplete()
599
600
601 if c.input.Len() != 0 {
602 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
603 }
604 c.input.Reset(nil)
605
606
607 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
608
609
610
611 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
612 err = io.EOF
613 }
614 if e, ok := err.(net.Error); !ok || !e.Temporary() {
615 c.in.setErrorLocked(err)
616 }
617 return err
618 }
619 hdr := c.rawInput.Bytes()[:recordHeaderLen]
620 typ := recordType(hdr[0])
621
622
623
624
625
626 if !handshakeComplete && typ == 0x80 {
627 c.sendAlert(alertProtocolVersion)
628 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
629 }
630
631 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
632 n := int(hdr[3])<<8 | int(hdr[4])
633 if c.haveVers && c.vers != VersionTLS13 && vers != c.vers {
634 c.sendAlert(alertProtocolVersion)
635 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
636 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
637 }
638 if !c.haveVers {
639
640
641
642
643 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
644 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
645 }
646 }
647 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
648 c.sendAlert(alertRecordOverflow)
649 msg := fmt.Sprintf("oversized record received with length %d", n)
650 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
651 }
652 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
653 if e, ok := err.(net.Error); !ok || !e.Temporary() {
654 c.in.setErrorLocked(err)
655 }
656 return err
657 }
658
659
660 record := c.rawInput.Next(recordHeaderLen + n)
661 data, typ, err := c.in.decrypt(record)
662 if err != nil {
663 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
664 }
665 if len(data) > maxPlaintext {
666 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
667 }
668
669
670 if c.in.cipher == nil && typ == recordTypeApplicationData {
671 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
672 }
673
674 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
675
676 c.retryCount = 0
677 }
678
679
680 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
681 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
682 }
683
684 switch typ {
685 default:
686 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
687
688 case recordTypeAlert:
689 if len(data) != 2 {
690 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
691 }
692 if alert(data[1]) == alertCloseNotify {
693 return c.in.setErrorLocked(io.EOF)
694 }
695 if c.vers == VersionTLS13 {
696 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
697 }
698 switch data[0] {
699 case alertLevelWarning:
700
701 return c.retryReadRecord(expectChangeCipherSpec)
702 case alertLevelError:
703 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
704 default:
705 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
706 }
707
708 case recordTypeChangeCipherSpec:
709 if len(data) != 1 || data[0] != 1 {
710 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
711 }
712
713 if c.hand.Len() > 0 {
714 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
715 }
716
717
718
719
720
721 if c.vers == VersionTLS13 {
722 return c.retryReadRecord(expectChangeCipherSpec)
723 }
724 if !expectChangeCipherSpec {
725 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
726 }
727 if err := c.in.changeCipherSpec(); err != nil {
728 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
729 }
730
731 case recordTypeApplicationData:
732 if !handshakeComplete || expectChangeCipherSpec {
733 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
734 }
735
736
737 if len(data) == 0 {
738 return c.retryReadRecord(expectChangeCipherSpec)
739 }
740
741
742
743 c.input.Reset(data)
744
745 case recordTypeHandshake:
746 if len(data) == 0 || expectChangeCipherSpec {
747 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
748 }
749 c.hand.Write(data)
750 }
751
752 return nil
753 }
754
755
756
757 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
758 c.retryCount++
759 if c.retryCount > maxUselessRecords {
760 c.sendAlert(alertUnexpectedMessage)
761 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
762 }
763 return c.readRecordOrCCS(expectChangeCipherSpec)
764 }
765
766
767
768
769 type atLeastReader struct {
770 R io.Reader
771 N int64
772 }
773
774 func (r *atLeastReader) Read(p []byte) (int, error) {
775 if r.N <= 0 {
776 return 0, io.EOF
777 }
778 n, err := r.R.Read(p)
779 r.N -= int64(n)
780 if r.N > 0 && err == io.EOF {
781 return n, io.ErrUnexpectedEOF
782 }
783 if r.N <= 0 && err == nil {
784 return n, io.EOF
785 }
786 return n, err
787 }
788
789
790
791 func (c *Conn) readFromUntil(r io.Reader, n int) error {
792 if c.rawInput.Len() >= n {
793 return nil
794 }
795 needs := n - c.rawInput.Len()
796
797
798
799 c.rawInput.Grow(needs + bytes.MinRead)
800 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
801 return err
802 }
803
804
805 func (c *Conn) sendAlertLocked(err alert) error {
806 switch err {
807 case alertNoRenegotiation, alertCloseNotify:
808 c.tmp[0] = alertLevelWarning
809 default:
810 c.tmp[0] = alertLevelError
811 }
812 c.tmp[1] = byte(err)
813
814 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
815 if err == alertCloseNotify {
816
817 return writeErr
818 }
819
820 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
821 }
822
823
824 func (c *Conn) sendAlert(err alert) error {
825 c.out.Lock()
826 defer c.out.Unlock()
827 return c.sendAlertLocked(err)
828 }
829
830 const (
831
832
833
834
835
836 tcpMSSEstimate = 1208
837
838
839
840
841 recordSizeBoostThreshold = 128 * 1024
842 )
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
861 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
862 return maxPlaintext
863 }
864
865 if c.bytesSent >= recordSizeBoostThreshold {
866 return maxPlaintext
867 }
868
869
870 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
871 if c.out.cipher != nil {
872 switch ciph := c.out.cipher.(type) {
873 case cipher.Stream:
874 payloadBytes -= c.out.mac.Size()
875 case cipher.AEAD:
876 payloadBytes -= ciph.Overhead()
877 case cbcMode:
878 blockSize := ciph.BlockSize()
879
880
881 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
882
883
884 payloadBytes -= c.out.mac.Size()
885 default:
886 panic("unknown cipher type")
887 }
888 }
889 if c.vers == VersionTLS13 {
890 payloadBytes--
891 }
892
893
894 pkt := c.packetsSent
895 c.packetsSent++
896 if pkt > 1000 {
897 return maxPlaintext
898 }
899
900 n := payloadBytes * int(pkt+1)
901 if n > maxPlaintext {
902 n = maxPlaintext
903 }
904 return n
905 }
906
907 func (c *Conn) write(data []byte) (int, error) {
908 if c.buffering {
909 c.sendBuf = append(c.sendBuf, data...)
910 return len(data), nil
911 }
912
913 n, err := c.conn.Write(data)
914 c.bytesSent += int64(n)
915 return n, err
916 }
917
918 func (c *Conn) flush() (int, error) {
919 if len(c.sendBuf) == 0 {
920 return 0, nil
921 }
922
923 n, err := c.conn.Write(c.sendBuf)
924 c.bytesSent += int64(n)
925 c.sendBuf = nil
926 c.buffering = false
927 return n, err
928 }
929
930
931 var outBufPool = sync.Pool{
932 New: func() interface{} {
933 return new([]byte)
934 },
935 }
936
937
938
939 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
940 outBufPtr := outBufPool.Get().(*[]byte)
941 outBuf := *outBufPtr
942 defer func() {
943
944
945
946
947
948 *outBufPtr = outBuf
949 outBufPool.Put(outBufPtr)
950 }()
951
952 var n int
953 for len(data) > 0 {
954 m := len(data)
955 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
956 m = maxPayload
957 }
958
959 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
960 outBuf[0] = byte(typ)
961 vers := c.vers
962 if vers == 0 {
963
964
965 vers = VersionTLS10
966 } else if vers == VersionTLS13 {
967
968
969 vers = VersionTLS12
970 }
971 outBuf[1] = byte(vers >> 8)
972 outBuf[2] = byte(vers)
973 outBuf[3] = byte(m >> 8)
974 outBuf[4] = byte(m)
975
976 var err error
977 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
978 if err != nil {
979 return n, err
980 }
981 if _, err := c.write(outBuf); err != nil {
982 return n, err
983 }
984 n += m
985 data = data[m:]
986 }
987
988 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
989 if err := c.out.changeCipherSpec(); err != nil {
990 return n, c.sendAlertLocked(err.(alert))
991 }
992 }
993
994 return n, nil
995 }
996
997
998
999 func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
1000 c.out.Lock()
1001 defer c.out.Unlock()
1002
1003 return c.writeRecordLocked(typ, data)
1004 }
1005
1006
1007
1008 func (c *Conn) readHandshake() (interface{}, error) {
1009 for c.hand.Len() < 4 {
1010 if err := c.readRecord(); err != nil {
1011 return nil, err
1012 }
1013 }
1014
1015 data := c.hand.Bytes()
1016 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1017 if n > maxHandshake {
1018 c.sendAlertLocked(alertInternalError)
1019 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1020 }
1021 for c.hand.Len() < 4+n {
1022 if err := c.readRecord(); err != nil {
1023 return nil, err
1024 }
1025 }
1026 data = c.hand.Next(4 + n)
1027 var m handshakeMessage
1028 switch data[0] {
1029 case typeHelloRequest:
1030 m = new(helloRequestMsg)
1031 case typeClientHello:
1032 m = new(clientHelloMsg)
1033 case typeServerHello:
1034 m = new(serverHelloMsg)
1035 case typeNewSessionTicket:
1036 if c.vers == VersionTLS13 {
1037 m = new(newSessionTicketMsgTLS13)
1038 } else {
1039 m = new(newSessionTicketMsg)
1040 }
1041 case typeCertificate:
1042 if c.vers == VersionTLS13 {
1043 m = new(certificateMsgTLS13)
1044 } else {
1045 m = new(certificateMsg)
1046 }
1047 case typeCertificateRequest:
1048 if c.vers == VersionTLS13 {
1049 m = new(certificateRequestMsgTLS13)
1050 } else {
1051 m = &certificateRequestMsg{
1052 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1053 }
1054 }
1055 case typeCertificateStatus:
1056 m = new(certificateStatusMsg)
1057 case typeServerKeyExchange:
1058 m = new(serverKeyExchangeMsg)
1059 case typeServerHelloDone:
1060 m = new(serverHelloDoneMsg)
1061 case typeClientKeyExchange:
1062 m = new(clientKeyExchangeMsg)
1063 case typeCertificateVerify:
1064 m = &certificateVerifyMsg{
1065 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1066 }
1067 case typeFinished:
1068 m = new(finishedMsg)
1069 case typeEncryptedExtensions:
1070 m = new(encryptedExtensionsMsg)
1071 case typeEndOfEarlyData:
1072 m = new(endOfEarlyDataMsg)
1073 case typeKeyUpdate:
1074 m = new(keyUpdateMsg)
1075 default:
1076 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1077 }
1078
1079
1080
1081
1082 data = append([]byte(nil), data...)
1083
1084 if !m.unmarshal(data) {
1085 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1086 }
1087 return m, nil
1088 }
1089
1090 var (
1091 errShutdown = errors.New("tls: protocol is shutdown")
1092 )
1093
1094
1095
1096
1097
1098
1099
1100 func (c *Conn) Write(b []byte) (int, error) {
1101
1102 for {
1103 x := atomic.LoadInt32(&c.activeCall)
1104 if x&1 != 0 {
1105 return 0, net.ErrClosed
1106 }
1107 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1108 break
1109 }
1110 }
1111 defer atomic.AddInt32(&c.activeCall, -2)
1112
1113 if err := c.Handshake(); err != nil {
1114 return 0, err
1115 }
1116
1117 c.out.Lock()
1118 defer c.out.Unlock()
1119
1120 if err := c.out.err; err != nil {
1121 return 0, err
1122 }
1123
1124 if !c.handshakeComplete() {
1125 return 0, alertInternalError
1126 }
1127
1128 if c.closeNotifySent {
1129 return 0, errShutdown
1130 }
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141 var m int
1142 if len(b) > 1 && c.vers == VersionTLS10 {
1143 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1144 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1145 if err != nil {
1146 return n, c.out.setErrorLocked(err)
1147 }
1148 m, b = 1, b[1:]
1149 }
1150 }
1151
1152 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1153 return n + m, c.out.setErrorLocked(err)
1154 }
1155
1156
1157 func (c *Conn) handleRenegotiation() error {
1158 if c.vers == VersionTLS13 {
1159 return errors.New("tls: internal error: unexpected renegotiation")
1160 }
1161
1162 msg, err := c.readHandshake()
1163 if err != nil {
1164 return err
1165 }
1166
1167 helloReq, ok := msg.(*helloRequestMsg)
1168 if !ok {
1169 c.sendAlert(alertUnexpectedMessage)
1170 return unexpectedMessageError(helloReq, msg)
1171 }
1172
1173 if !c.isClient {
1174 return c.sendAlert(alertNoRenegotiation)
1175 }
1176
1177 switch c.config.Renegotiation {
1178 case RenegotiateNever:
1179 return c.sendAlert(alertNoRenegotiation)
1180 case RenegotiateOnceAsClient:
1181 if c.handshakes > 1 {
1182 return c.sendAlert(alertNoRenegotiation)
1183 }
1184 case RenegotiateFreelyAsClient:
1185
1186 default:
1187 c.sendAlert(alertInternalError)
1188 return errors.New("tls: unknown Renegotiation value")
1189 }
1190
1191 c.handshakeMutex.Lock()
1192 defer c.handshakeMutex.Unlock()
1193
1194 atomic.StoreUint32(&c.handshakeStatus, 0)
1195 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1196 c.handshakes++
1197 }
1198 return c.handshakeErr
1199 }
1200
1201
1202
1203 func (c *Conn) handlePostHandshakeMessage() error {
1204 if c.vers != VersionTLS13 {
1205 return c.handleRenegotiation()
1206 }
1207
1208 msg, err := c.readHandshake()
1209 if err != nil {
1210 return err
1211 }
1212
1213 c.retryCount++
1214 if c.retryCount > maxUselessRecords {
1215 c.sendAlert(alertUnexpectedMessage)
1216 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1217 }
1218
1219 switch msg := msg.(type) {
1220 case *newSessionTicketMsgTLS13:
1221 return c.handleNewSessionTicket(msg)
1222 case *keyUpdateMsg:
1223 return c.handleKeyUpdate(msg)
1224 default:
1225 c.sendAlert(alertUnexpectedMessage)
1226 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1227 }
1228 }
1229
1230 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1231 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1232 if cipherSuite == nil {
1233 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1234 }
1235
1236 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1237 c.in.setTrafficSecret(cipherSuite, newSecret)
1238
1239 if keyUpdate.updateRequested {
1240 c.out.Lock()
1241 defer c.out.Unlock()
1242
1243 msg := &keyUpdateMsg{}
1244 _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
1245 if err != nil {
1246
1247 c.out.setErrorLocked(err)
1248 return nil
1249 }
1250
1251 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1252 c.out.setTrafficSecret(cipherSuite, newSecret)
1253 }
1254
1255 return nil
1256 }
1257
1258
1259
1260
1261
1262
1263
1264 func (c *Conn) Read(b []byte) (int, error) {
1265 if err := c.Handshake(); err != nil {
1266 return 0, err
1267 }
1268 if len(b) == 0 {
1269
1270
1271 return 0, nil
1272 }
1273
1274 c.in.Lock()
1275 defer c.in.Unlock()
1276
1277 for c.input.Len() == 0 {
1278 if err := c.readRecord(); err != nil {
1279 return 0, err
1280 }
1281 for c.hand.Len() > 0 {
1282 if err := c.handlePostHandshakeMessage(); err != nil {
1283 return 0, err
1284 }
1285 }
1286 }
1287
1288 n, _ := c.input.Read(b)
1289
1290
1291
1292
1293
1294
1295
1296
1297 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1298 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1299 if err := c.readRecord(); err != nil {
1300 return n, err
1301 }
1302 }
1303
1304 return n, nil
1305 }
1306
1307
1308 func (c *Conn) Close() error {
1309
1310 var x int32
1311 for {
1312 x = atomic.LoadInt32(&c.activeCall)
1313 if x&1 != 0 {
1314 return net.ErrClosed
1315 }
1316 if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1317 break
1318 }
1319 }
1320 if x != 0 {
1321
1322
1323
1324
1325
1326
1327 return c.conn.Close()
1328 }
1329
1330 var alertErr error
1331 if c.handshakeComplete() {
1332 if err := c.closeNotify(); err != nil {
1333 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1334 }
1335 }
1336
1337 if err := c.conn.Close(); err != nil {
1338 return err
1339 }
1340 return alertErr
1341 }
1342
1343 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1344
1345
1346
1347
1348 func (c *Conn) CloseWrite() error {
1349 if !c.handshakeComplete() {
1350 return errEarlyCloseWrite
1351 }
1352
1353 return c.closeNotify()
1354 }
1355
1356 func (c *Conn) closeNotify() error {
1357 c.out.Lock()
1358 defer c.out.Unlock()
1359
1360 if !c.closeNotifySent {
1361
1362 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1363 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1364 c.closeNotifySent = true
1365
1366 c.SetWriteDeadline(time.Now())
1367 }
1368 return c.closeNotifyErr
1369 }
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379 func (c *Conn) Handshake() error {
1380 return c.HandshakeContext(context.Background())
1381 }
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393 func (c *Conn) HandshakeContext(ctx context.Context) error {
1394
1395
1396 return c.handshakeContext(ctx)
1397 }
1398
1399 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1400
1401
1402
1403 if c.handshakeComplete() {
1404 return nil
1405 }
1406
1407 handshakeCtx, cancel := context.WithCancel(ctx)
1408
1409
1410
1411 defer cancel()
1412
1413
1414
1415
1416
1417
1418 if ctx.Done() != nil {
1419 done := make(chan struct{})
1420 interruptRes := make(chan error, 1)
1421 defer func() {
1422 close(done)
1423 if ctxErr := <-interruptRes; ctxErr != nil {
1424
1425 ret = ctxErr
1426 }
1427 }()
1428 go func() {
1429 select {
1430 case <-handshakeCtx.Done():
1431
1432 _ = c.conn.Close()
1433 interruptRes <- handshakeCtx.Err()
1434 case <-done:
1435 interruptRes <- nil
1436 }
1437 }()
1438 }
1439
1440 c.handshakeMutex.Lock()
1441 defer c.handshakeMutex.Unlock()
1442
1443 if err := c.handshakeErr; err != nil {
1444 return err
1445 }
1446 if c.handshakeComplete() {
1447 return nil
1448 }
1449
1450 c.in.Lock()
1451 defer c.in.Unlock()
1452
1453 c.handshakeErr = c.handshakeFn(handshakeCtx)
1454 if c.handshakeErr == nil {
1455 c.handshakes++
1456 } else {
1457
1458
1459 c.flush()
1460 }
1461
1462 if c.handshakeErr == nil && !c.handshakeComplete() {
1463 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1464 }
1465 if c.handshakeErr != nil && c.handshakeComplete() {
1466 panic("tls: internal error: handshake returned an error but is marked successful")
1467 }
1468
1469 return c.handshakeErr
1470 }
1471
1472
1473 func (c *Conn) ConnectionState() ConnectionState {
1474 c.handshakeMutex.Lock()
1475 defer c.handshakeMutex.Unlock()
1476 return c.connectionStateLocked()
1477 }
1478
1479 func (c *Conn) connectionStateLocked() ConnectionState {
1480 var state ConnectionState
1481 state.HandshakeComplete = c.handshakeComplete()
1482 state.Version = c.vers
1483 state.NegotiatedProtocol = c.clientProtocol
1484 state.DidResume = c.didResume
1485 state.NegotiatedProtocolIsMutual = true
1486 state.ServerName = c.serverName
1487 state.CipherSuite = c.cipherSuite
1488 state.PeerCertificates = c.peerCertificates
1489 state.VerifiedChains = c.verifiedChains
1490 state.SignedCertificateTimestamps = c.scts
1491 state.OCSPResponse = c.ocspResponse
1492 if !c.didResume && c.vers != VersionTLS13 {
1493 if c.clientFinishedIsFirst {
1494 state.TLSUnique = c.clientFinished[:]
1495 } else {
1496 state.TLSUnique = c.serverFinished[:]
1497 }
1498 }
1499 if c.config.Renegotiation != RenegotiateNever {
1500 state.ekm = noExportedKeyingMaterial
1501 } else {
1502 state.ekm = c.ekm
1503 }
1504 return state
1505 }
1506
1507
1508
1509 func (c *Conn) OCSPResponse() []byte {
1510 c.handshakeMutex.Lock()
1511 defer c.handshakeMutex.Unlock()
1512
1513 return c.ocspResponse
1514 }
1515
1516
1517
1518
1519 func (c *Conn) VerifyHostname(host string) error {
1520 c.handshakeMutex.Lock()
1521 defer c.handshakeMutex.Unlock()
1522 if !c.isClient {
1523 return errors.New("tls: VerifyHostname called on TLS server connection")
1524 }
1525 if !c.handshakeComplete() {
1526 return errors.New("tls: handshake has not yet been performed")
1527 }
1528 if len(c.verifiedChains) == 0 {
1529 return errors.New("tls: handshake did not verify certificate chain")
1530 }
1531 return c.peerCertificates[0].VerifyHostname(host)
1532 }
1533
1534 func (c *Conn) handshakeComplete() bool {
1535 return atomic.LoadUint32(&c.handshakeStatus) == 1
1536 }
1537
View as plain text