...

Source file src/crypto/tls/handshake_messages_test.go

Documentation: crypto/tls

		 1  // Copyright 2009 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  package tls
		 6  
		 7  import (
		 8  	"bytes"
		 9  	"math/rand"
		10  	"reflect"
		11  	"strings"
		12  	"testing"
		13  	"testing/quick"
		14  	"time"
		15  )
		16  
		17  var tests = []interface{}{
		18  	&clientHelloMsg{},
		19  	&serverHelloMsg{},
		20  	&finishedMsg{},
		21  
		22  	&certificateMsg{},
		23  	&certificateRequestMsg{},
		24  	&certificateVerifyMsg{
		25  		hasSignatureAlgorithm: true,
		26  	},
		27  	&certificateStatusMsg{},
		28  	&clientKeyExchangeMsg{},
		29  	&newSessionTicketMsg{},
		30  	&sessionState{},
		31  	&sessionStateTLS13{},
		32  	&encryptedExtensionsMsg{},
		33  	&endOfEarlyDataMsg{},
		34  	&keyUpdateMsg{},
		35  	&newSessionTicketMsgTLS13{},
		36  	&certificateRequestMsgTLS13{},
		37  	&certificateMsgTLS13{},
		38  }
		39  
		40  func TestMarshalUnmarshal(t *testing.T) {
		41  	rand := rand.New(rand.NewSource(time.Now().UnixNano()))
		42  
		43  	for i, iface := range tests {
		44  		ty := reflect.ValueOf(iface).Type()
		45  
		46  		n := 100
		47  		if testing.Short() {
		48  			n = 5
		49  		}
		50  		for j := 0; j < n; j++ {
		51  			v, ok := quick.Value(ty, rand)
		52  			if !ok {
		53  				t.Errorf("#%d: failed to create value", i)
		54  				break
		55  			}
		56  
		57  			m1 := v.Interface().(handshakeMessage)
		58  			marshaled := m1.marshal()
		59  			m2 := iface.(handshakeMessage)
		60  			if !m2.unmarshal(marshaled) {
		61  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
		62  				break
		63  			}
		64  			m2.marshal() // to fill any marshal cache in the message
		65  
		66  			if !reflect.DeepEqual(m1, m2) {
		67  				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
		68  				break
		69  			}
		70  
		71  			if i >= 3 {
		72  				// The first three message types (ClientHello,
		73  				// ServerHello and Finished) are allowed to
		74  				// have parsable prefixes because the extension
		75  				// data is optional and the length of the
		76  				// Finished varies across versions.
		77  				for j := 0; j < len(marshaled); j++ {
		78  					if m2.unmarshal(marshaled[0:j]) {
		79  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
		80  						break
		81  					}
		82  				}
		83  			}
		84  		}
		85  	}
		86  }
		87  
		88  func TestFuzz(t *testing.T) {
		89  	rand := rand.New(rand.NewSource(0))
		90  	for _, iface := range tests {
		91  		m := iface.(handshakeMessage)
		92  
		93  		for j := 0; j < 1000; j++ {
		94  			len := rand.Intn(100)
		95  			bytes := randomBytes(len, rand)
		96  			// This just looks for crashes due to bounds errors etc.
		97  			m.unmarshal(bytes)
		98  		}
		99  	}
	 100  }
	 101  
	 102  func randomBytes(n int, rand *rand.Rand) []byte {
	 103  	r := make([]byte, n)
	 104  	if _, err := rand.Read(r); err != nil {
	 105  		panic("rand.Read failed: " + err.Error())
	 106  	}
	 107  	return r
	 108  }
	 109  
	 110  func randomString(n int, rand *rand.Rand) string {
	 111  	b := randomBytes(n, rand)
	 112  	return string(b)
	 113  }
	 114  
	 115  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 116  	m := &clientHelloMsg{}
	 117  	m.vers = uint16(rand.Intn(65536))
	 118  	m.random = randomBytes(32, rand)
	 119  	m.sessionId = randomBytes(rand.Intn(32), rand)
	 120  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
	 121  	for i := 0; i < len(m.cipherSuites); i++ {
	 122  		cs := uint16(rand.Int31())
	 123  		if cs == scsvRenegotiation {
	 124  			cs += 1
	 125  		}
	 126  		m.cipherSuites[i] = cs
	 127  	}
	 128  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
	 129  	if rand.Intn(10) > 5 {
	 130  		m.serverName = randomString(rand.Intn(255), rand)
	 131  		for strings.HasSuffix(m.serverName, ".") {
	 132  			m.serverName = m.serverName[:len(m.serverName)-1]
	 133  		}
	 134  	}
	 135  	m.ocspStapling = rand.Intn(10) > 5
	 136  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
	 137  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
	 138  	for i := range m.supportedCurves {
	 139  		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
	 140  	}
	 141  	if rand.Intn(10) > 5 {
	 142  		m.ticketSupported = true
	 143  		if rand.Intn(10) > 5 {
	 144  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
	 145  		} else {
	 146  			m.sessionTicket = make([]byte, 0)
	 147  		}
	 148  	}
	 149  	if rand.Intn(10) > 5 {
	 150  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
	 151  	}
	 152  	if rand.Intn(10) > 5 {
	 153  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
	 154  	}
	 155  	for i := 0; i < rand.Intn(5); i++ {
	 156  		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
	 157  	}
	 158  	if rand.Intn(10) > 5 {
	 159  		m.scts = true
	 160  	}
	 161  	if rand.Intn(10) > 5 {
	 162  		m.secureRenegotiationSupported = true
	 163  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
	 164  	}
	 165  	for i := 0; i < rand.Intn(5); i++ {
	 166  		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
	 167  	}
	 168  	if rand.Intn(10) > 5 {
	 169  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
	 170  	}
	 171  	for i := 0; i < rand.Intn(5); i++ {
	 172  		var ks keyShare
	 173  		ks.group = CurveID(rand.Intn(30000) + 1)
	 174  		ks.data = randomBytes(rand.Intn(200)+1, rand)
	 175  		m.keyShares = append(m.keyShares, ks)
	 176  	}
	 177  	switch rand.Intn(3) {
	 178  	case 1:
	 179  		m.pskModes = []uint8{pskModeDHE}
	 180  	case 2:
	 181  		m.pskModes = []uint8{pskModeDHE, pskModePlain}
	 182  	}
	 183  	for i := 0; i < rand.Intn(5); i++ {
	 184  		var psk pskIdentity
	 185  		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
	 186  		psk.label = randomBytes(rand.Intn(500)+1, rand)
	 187  		m.pskIdentities = append(m.pskIdentities, psk)
	 188  		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
	 189  	}
	 190  	if rand.Intn(10) > 5 {
	 191  		m.earlyData = true
	 192  	}
	 193  
	 194  	return reflect.ValueOf(m)
	 195  }
	 196  
	 197  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 198  	m := &serverHelloMsg{}
	 199  	m.vers = uint16(rand.Intn(65536))
	 200  	m.random = randomBytes(32, rand)
	 201  	m.sessionId = randomBytes(rand.Intn(32), rand)
	 202  	m.cipherSuite = uint16(rand.Int31())
	 203  	m.compressionMethod = uint8(rand.Intn(256))
	 204  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
	 205  
	 206  	if rand.Intn(10) > 5 {
	 207  		m.ocspStapling = true
	 208  	}
	 209  	if rand.Intn(10) > 5 {
	 210  		m.ticketSupported = true
	 211  	}
	 212  	if rand.Intn(10) > 5 {
	 213  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
	 214  	}
	 215  
	 216  	for i := 0; i < rand.Intn(4); i++ {
	 217  		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
	 218  	}
	 219  
	 220  	if rand.Intn(10) > 5 {
	 221  		m.secureRenegotiationSupported = true
	 222  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
	 223  	}
	 224  	if rand.Intn(10) > 5 {
	 225  		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
	 226  	}
	 227  	if rand.Intn(10) > 5 {
	 228  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
	 229  	}
	 230  	if rand.Intn(10) > 5 {
	 231  		for i := 0; i < rand.Intn(5); i++ {
	 232  			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
	 233  			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
	 234  		}
	 235  	} else if rand.Intn(10) > 5 {
	 236  		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
	 237  	}
	 238  	if rand.Intn(10) > 5 {
	 239  		m.selectedIdentityPresent = true
	 240  		m.selectedIdentity = uint16(rand.Intn(0xffff))
	 241  	}
	 242  
	 243  	return reflect.ValueOf(m)
	 244  }
	 245  
	 246  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 247  	m := &encryptedExtensionsMsg{}
	 248  
	 249  	if rand.Intn(10) > 5 {
	 250  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
	 251  	}
	 252  
	 253  	return reflect.ValueOf(m)
	 254  }
	 255  
	 256  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 257  	m := &certificateMsg{}
	 258  	numCerts := rand.Intn(20)
	 259  	m.certificates = make([][]byte, numCerts)
	 260  	for i := 0; i < numCerts; i++ {
	 261  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
	 262  	}
	 263  	return reflect.ValueOf(m)
	 264  }
	 265  
	 266  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 267  	m := &certificateRequestMsg{}
	 268  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
	 269  	for i := 0; i < rand.Intn(100); i++ {
	 270  		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
	 271  	}
	 272  	return reflect.ValueOf(m)
	 273  }
	 274  
	 275  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 276  	m := &certificateVerifyMsg{}
	 277  	m.hasSignatureAlgorithm = true
	 278  	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
	 279  	m.signature = randomBytes(rand.Intn(15)+1, rand)
	 280  	return reflect.ValueOf(m)
	 281  }
	 282  
	 283  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 284  	m := &certificateStatusMsg{}
	 285  	m.response = randomBytes(rand.Intn(10)+1, rand)
	 286  	return reflect.ValueOf(m)
	 287  }
	 288  
	 289  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 290  	m := &clientKeyExchangeMsg{}
	 291  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
	 292  	return reflect.ValueOf(m)
	 293  }
	 294  
	 295  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 296  	m := &finishedMsg{}
	 297  	m.verifyData = randomBytes(12, rand)
	 298  	return reflect.ValueOf(m)
	 299  }
	 300  
	 301  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 302  	m := &newSessionTicketMsg{}
	 303  	m.ticket = randomBytes(rand.Intn(4), rand)
	 304  	return reflect.ValueOf(m)
	 305  }
	 306  
	 307  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
	 308  	s := &sessionState{}
	 309  	s.vers = uint16(rand.Intn(10000))
	 310  	s.cipherSuite = uint16(rand.Intn(10000))
	 311  	s.masterSecret = randomBytes(rand.Intn(100)+1, rand)
	 312  	s.createdAt = uint64(rand.Int63())
	 313  	for i := 0; i < rand.Intn(20); i++ {
	 314  		s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand))
	 315  	}
	 316  	return reflect.ValueOf(s)
	 317  }
	 318  
	 319  func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
	 320  	s := &sessionStateTLS13{}
	 321  	s.cipherSuite = uint16(rand.Intn(10000))
	 322  	s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
	 323  	s.createdAt = uint64(rand.Int63())
	 324  	for i := 0; i < rand.Intn(2)+1; i++ {
	 325  		s.certificate.Certificate = append(
	 326  			s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
	 327  	}
	 328  	if rand.Intn(10) > 5 {
	 329  		s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
	 330  	}
	 331  	if rand.Intn(10) > 5 {
	 332  		for i := 0; i < rand.Intn(2)+1; i++ {
	 333  			s.certificate.SignedCertificateTimestamps = append(
	 334  				s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
	 335  		}
	 336  	}
	 337  	return reflect.ValueOf(s)
	 338  }
	 339  
	 340  func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 341  	m := &endOfEarlyDataMsg{}
	 342  	return reflect.ValueOf(m)
	 343  }
	 344  
	 345  func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	 346  	m := &keyUpdateMsg{}
	 347  	m.updateRequested = rand.Intn(10) > 5
	 348  	return reflect.ValueOf(m)
	 349  }
	 350  
	 351  func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
	 352  	m := &newSessionTicketMsgTLS13{}
	 353  	m.lifetime = uint32(rand.Intn(500000))
	 354  	m.ageAdd = uint32(rand.Intn(500000))
	 355  	m.nonce = randomBytes(rand.Intn(100), rand)
	 356  	m.label = randomBytes(rand.Intn(1000), rand)
	 357  	if rand.Intn(10) > 5 {
	 358  		m.maxEarlyData = uint32(rand.Intn(500000))
	 359  	}
	 360  	return reflect.ValueOf(m)
	 361  }
	 362  
	 363  func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
	 364  	m := &certificateRequestMsgTLS13{}
	 365  	if rand.Intn(10) > 5 {
	 366  		m.ocspStapling = true
	 367  	}
	 368  	if rand.Intn(10) > 5 {
	 369  		m.scts = true
	 370  	}
	 371  	if rand.Intn(10) > 5 {
	 372  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
	 373  	}
	 374  	if rand.Intn(10) > 5 {
	 375  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
	 376  	}
	 377  	if rand.Intn(10) > 5 {
	 378  		m.certificateAuthorities = make([][]byte, 3)
	 379  		for i := 0; i < 3; i++ {
	 380  			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
	 381  		}
	 382  	}
	 383  	return reflect.ValueOf(m)
	 384  }
	 385  
	 386  func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
	 387  	m := &certificateMsgTLS13{}
	 388  	for i := 0; i < rand.Intn(2)+1; i++ {
	 389  		m.certificate.Certificate = append(
	 390  			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
	 391  	}
	 392  	if rand.Intn(10) > 5 {
	 393  		m.ocspStapling = true
	 394  		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
	 395  	}
	 396  	if rand.Intn(10) > 5 {
	 397  		m.scts = true
	 398  		for i := 0; i < rand.Intn(2)+1; i++ {
	 399  			m.certificate.SignedCertificateTimestamps = append(
	 400  				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
	 401  		}
	 402  	}
	 403  	return reflect.ValueOf(m)
	 404  }
	 405  
	 406  func TestRejectEmptySCTList(t *testing.T) {
	 407  	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
	 408  
	 409  	var random [32]byte
	 410  	sct := []byte{0x42, 0x42, 0x42, 0x42}
	 411  	serverHello := serverHelloMsg{
	 412  		vers:	 VersionTLS12,
	 413  		random: random[:],
	 414  		scts:	 [][]byte{sct},
	 415  	}
	 416  	serverHelloBytes := serverHello.marshal()
	 417  
	 418  	var serverHelloCopy serverHelloMsg
	 419  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
	 420  		t.Fatal("Failed to unmarshal initial message")
	 421  	}
	 422  
	 423  	// Change serverHelloBytes so that the SCT list is empty
	 424  	i := bytes.Index(serverHelloBytes, sct)
	 425  	if i < 0 {
	 426  		t.Fatal("Cannot find SCT in ServerHello")
	 427  	}
	 428  
	 429  	var serverHelloEmptySCT []byte
	 430  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
	 431  	// Append the extension length and SCT list length for an empty list.
	 432  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
	 433  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
	 434  
	 435  	// Update the handshake message length.
	 436  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
	 437  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
	 438  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
	 439  
	 440  	// Update the extensions length
	 441  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
	 442  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
	 443  
	 444  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
	 445  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
	 446  	}
	 447  }
	 448  
	 449  func TestRejectEmptySCT(t *testing.T) {
	 450  	// Not only must the SCT list be non-empty, but the SCT elements must
	 451  	// not be zero length.
	 452  
	 453  	var random [32]byte
	 454  	serverHello := serverHelloMsg{
	 455  		vers:	 VersionTLS12,
	 456  		random: random[:],
	 457  		scts:	 [][]byte{nil},
	 458  	}
	 459  	serverHelloBytes := serverHello.marshal()
	 460  
	 461  	var serverHelloCopy serverHelloMsg
	 462  	if serverHelloCopy.unmarshal(serverHelloBytes) {
	 463  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
	 464  	}
	 465  }
	 466  

View as plain text