1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package runner
6
7import (
8	"crypto"
9	"crypto/hkdf"
10	"crypto/hmac"
11	"crypto/md5"
12	"crypto/sha1"
13	"crypto/sha256"
14	"encoding"
15	"hash"
16
17	"golang.org/x/crypto/cryptobyte"
18)
19
20// copyHash returns a copy of |h|, which must be an instance of |hashType|.
21func copyHash(h hash.Hash, hash crypto.Hash) hash.Hash {
22	// While hash.Hash is not copyable, the documentation says all standard
23	// library hash.Hash implementations implement BinaryMarshaler and
24	// BinaryUnmarshaler interfaces.
25	m, ok := h.(encoding.BinaryMarshaler)
26	if !ok {
27		panic("hash did not implement encoding.BinaryMarshaler")
28	}
29	data, err := m.MarshalBinary()
30	if err != nil {
31		panic(err)
32	}
33	ret := hash.New()
34	u, ok := ret.(encoding.BinaryUnmarshaler)
35	if !ok {
36		panic("hash did not implement BinaryUnmarshaler")
37	}
38	if err := u.UnmarshalBinary(data); err != nil {
39		panic(err)
40	}
41	return ret
42}
43
44// Split a premaster secret in two as specified in RFC 4346, section 5.
45func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
46	s1 = secret[0 : (len(secret)+1)/2]
47	s2 = secret[len(secret)/2:]
48	return
49}
50
51// pHash implements the P_hash function, as defined in RFC 4346, section 5.
52func pHash(result, secret, seed []byte, hash func() hash.Hash) {
53	h := hmac.New(hash, secret)
54	h.Write(seed)
55	a := h.Sum(nil)
56
57	j := 0
58	for j < len(result) {
59		h.Reset()
60		h.Write(a)
61		h.Write(seed)
62		b := h.Sum(nil)
63		todo := len(b)
64		if j+todo > len(result) {
65			todo = len(result) - j
66		}
67		copy(result[j:j+todo], b)
68		j += todo
69
70		h.Reset()
71		h.Write(a)
72		a = h.Sum(nil)
73	}
74}
75
76// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, section 5.
77func prf10(result, secret, label, seed []byte) {
78	hashSHA1 := sha1.New
79	hashMD5 := md5.New
80
81	labelAndSeed := make([]byte, len(label)+len(seed))
82	copy(labelAndSeed, label)
83	copy(labelAndSeed[len(label):], seed)
84
85	s1, s2 := splitPreMasterSecret(secret)
86	pHash(result, s1, labelAndSeed, hashMD5)
87	result2 := make([]byte, len(result))
88	pHash(result2, s2, labelAndSeed, hashSHA1)
89
90	for i, b := range result2 {
91		result[i] ^= b
92	}
93}
94
95// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, section 5.
96func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
97	return func(result, secret, label, seed []byte) {
98		labelAndSeed := make([]byte, len(label)+len(seed))
99		copy(labelAndSeed, label)
100		copy(labelAndSeed[len(label):], seed)
101
102		pHash(result, secret, labelAndSeed, hashFunc)
103	}
104}
105
106const (
107	tlsRandomLength      = 32 // Length of a random nonce in TLS 1.1.
108	masterSecretLength   = 48 // Length of a master secret in TLS 1.1.
109	finishedVerifyLength = 12 // Length of verify_data in a Finished message.
110)
111
112var masterSecretLabel = []byte("master secret")
113var extendedMasterSecretLabel = []byte("extended master secret")
114var keyExpansionLabel = []byte("key expansion")
115var clientFinishedLabel = []byte("client finished")
116var serverFinishedLabel = []byte("server finished")
117var finishedLabel = []byte("finished")
118var channelIDLabel = []byte("TLS Channel ID signature\x00")
119var channelIDResumeLabel = []byte("Resumption\x00")
120
121func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
122	switch version {
123	case VersionTLS10, VersionTLS11:
124		return prf10
125	case VersionTLS12:
126		return prf12(suite.hash().New)
127	}
128	panic("unknown version")
129}
130
131// masterFromPreMasterSecret generates the master secret from the pre-master
132// secret. See http://tools.ietf.org/html/rfc5246#section-8.1
133func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
134	var seed [tlsRandomLength * 2]byte
135	copy(seed[0:len(clientRandom)], clientRandom)
136	copy(seed[len(clientRandom):], serverRandom)
137	masterSecret := make([]byte, masterSecretLength)
138	prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed[0:])
139	return masterSecret
140}
141
142// extendedMasterFromPreMasterSecret generates the master secret from the
143// pre-master secret when the Triple Handshake fix is in effect. See
144// https://tools.ietf.org/html/rfc7627
145func extendedMasterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret []byte, h finishedHash) []byte {
146	masterSecret := make([]byte, masterSecretLength)
147	prfForVersion(version, suite)(masterSecret, preMasterSecret, extendedMasterSecretLabel, h.Sum())
148	return masterSecret
149}
150
151// keysFromMasterSecret generates the connection keys from the master
152// secret, given the lengths of the MAC key, cipher key and IV, as defined in
153// RFC 2246, section 6.3.
154func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
155	var seed [tlsRandomLength * 2]byte
156	copy(seed[0:len(clientRandom)], serverRandom)
157	copy(seed[len(serverRandom):], clientRandom)
158
159	n := 2*macLen + 2*keyLen + 2*ivLen
160	keyMaterial := make([]byte, n)
161	prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed[0:])
162	clientMAC = keyMaterial[:macLen]
163	keyMaterial = keyMaterial[macLen:]
164	serverMAC = keyMaterial[:macLen]
165	keyMaterial = keyMaterial[macLen:]
166	clientKey = keyMaterial[:keyLen]
167	keyMaterial = keyMaterial[keyLen:]
168	serverKey = keyMaterial[:keyLen]
169	keyMaterial = keyMaterial[keyLen:]
170	clientIV = keyMaterial[:ivLen]
171	keyMaterial = keyMaterial[ivLen:]
172	serverIV = keyMaterial[:ivLen]
173	return
174}
175
176func newFinishedHash(wireVersion uint16, isDTLS bool, cipherSuite *cipherSuite) finishedHash {
177	version, ok := wireToVersion(wireVersion, isDTLS)
178	if !ok {
179		panic("unknown version")
180	}
181
182	var ret finishedHash
183	if version >= VersionTLS12 {
184		ret.hash = cipherSuite.hash().New()
185
186		if version == VersionTLS12 {
187			ret.prf = prf12(cipherSuite.hash().New)
188		} else {
189			ret.secret = make([]byte, ret.hash.Size())
190		}
191	} else {
192		ret.hash = sha1.New()
193		ret.md5 = md5.New()
194
195		ret.prf = prf10
196	}
197
198	ret.suite = cipherSuite
199	ret.buffer = []byte{}
200	ret.version = version
201	ret.wireVersion = wireVersion
202	ret.isDTLS = isDTLS
203	return ret
204}
205
206// A finishedHash calculates the hash of a set of handshake messages suitable
207// for including in a Finished message.
208type finishedHash struct {
209	suite *cipherSuite
210
211	// hash maintains a running hash of handshake messages. In TLS 1.2 and up,
212	// the hash is determined from suite.hash(). In TLS 1.0 and 1.1, this is the
213	// SHA-1 half of the MD5/SHA-1 concatenation.
214	hash hash.Hash
215
216	// md5 is the MD5 half of the TLS 1.0 and 1.1 MD5/SHA1 concatenation.
217	md5 hash.Hash
218
219	// In TLS 1.2, a full buffer is required.
220	buffer []byte
221
222	version     uint16
223	wireVersion uint16
224	isDTLS      bool
225	prf         func(result, secret, label, seed []byte)
226
227	// secret, in TLS 1.3, is the running input secret.
228	secret []byte
229}
230
231func (h *finishedHash) UpdateForHelloRetryRequest() {
232	data := cryptobyte.NewBuilder(nil)
233	data.AddUint8(typeMessageHash)
234	data.AddUint24(uint32(h.hash.Size()))
235	data.AddBytes(h.Sum())
236	h.hash = h.suite.hash().New()
237	if h.buffer != nil {
238		h.buffer = []byte{}
239	}
240	h.Write(data.BytesOrPanic())
241}
242
243func (h *finishedHash) Write(msg []byte) (n int, err error) {
244	h.hash.Write(msg)
245
246	if h.version < VersionTLS12 {
247		h.md5.Write(msg)
248	}
249
250	if h.buffer != nil {
251		h.buffer = append(h.buffer, msg...)
252	}
253
254	return len(msg), nil
255}
256
257// WriteHandshake appends |msg| to the hash, which must be a serialized
258// handshake message with a TLS header. In DTLS, the header is rewritten to a
259// DTLS header with |seqno| as the sequence number.
260func (h *finishedHash) WriteHandshake(msg []byte, seqno uint16) {
261	if h.isDTLS && h.version <= VersionTLS12 {
262		// This is somewhat hacky. DTLS <= 1.2 hashes a slightly different format. (DTLS 1.3 uses the same format as TLS.)
263		// First, the TLS header.
264		h.Write(msg[:4])
265		// Then the sequence number and reassembled fragment offset (always 0).
266		h.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0})
267		// Then the reassembled fragment (always equal to the message length).
268		h.Write(msg[1:4])
269		// And then the message body.
270		h.Write(msg[4:])
271	} else {
272		h.Write(msg)
273	}
274}
275
276func (h finishedHash) Sum() []byte {
277	if h.version >= VersionTLS12 {
278		return h.hash.Sum(nil)
279	}
280
281	out := make([]byte, 0, md5.Size+sha1.Size)
282	out = h.md5.Sum(out)
283	return h.hash.Sum(out)
284}
285
286// clientSum returns the contents of the verify_data member of a client's
287// Finished message.
288func (h finishedHash) clientSum(baseKey []byte) []byte {
289	if h.version < VersionTLS13 {
290		out := make([]byte, finishedVerifyLength)
291		h.prf(out, baseKey, clientFinishedLabel, h.Sum())
292		return out
293	}
294
295	clientFinishedKey := hkdfExpandLabel(h.suite.hash(), baseKey, finishedLabel, nil, h.hash.Size(), h.isDTLS)
296	finishedHMAC := hmac.New(h.suite.hash().New, clientFinishedKey)
297	finishedHMAC.Write(h.appendContextHashes(nil))
298	return finishedHMAC.Sum(nil)
299}
300
301// serverSum returns the contents of the verify_data member of a server's
302// Finished message.
303func (h finishedHash) serverSum(baseKey []byte) []byte {
304	if h.version < VersionTLS13 {
305		out := make([]byte, finishedVerifyLength)
306		h.prf(out, baseKey, serverFinishedLabel, h.Sum())
307		return out
308	}
309
310	serverFinishedKey := hkdfExpandLabel(h.suite.hash(), baseKey, finishedLabel, nil, h.hash.Size(), h.isDTLS)
311	finishedHMAC := hmac.New(h.suite.hash().New, serverFinishedKey)
312	finishedHMAC.Write(h.appendContextHashes(nil))
313	return finishedHMAC.Sum(nil)
314}
315
316// hashForChannelID returns the hash to be signed for TLS Channel
317// ID. If a resumption, resumeHash has the previous handshake
318// hash. Otherwise, it is nil.
319func (h finishedHash) hashForChannelID(resumeHash []byte) []byte {
320	hash := sha256.New()
321	hash.Write(channelIDLabel)
322	if resumeHash != nil {
323		hash.Write(channelIDResumeLabel)
324		hash.Write(resumeHash)
325	}
326	hash.Write(h.Sum())
327	return hash.Sum(nil)
328}
329
330// discardHandshakeBuffer is called when there is no more need to
331// buffer the entirety of the handshake messages.
332func (h *finishedHash) discardHandshakeBuffer() {
333	h.buffer = nil
334}
335
336// zeroSecretTLS13 returns the default all zeros secret for TLS 1.3, used when a
337// given secret is not available in the handshake. See RFC 8446, section 7.1.
338func (h *finishedHash) zeroSecret() []byte {
339	return make([]byte, h.hash.Size())
340}
341
342// addEntropy incorporates ikm into the running TLS 1.3 secret with HKDF-Expand.
343func (h *finishedHash) addEntropy(ikm []byte) {
344	var err error
345	h.secret, err = hkdf.Extract(h.suite.hash().New, ikm, h.secret)
346	if err != nil {
347		panic(err)
348	}
349}
350
351func (h *finishedHash) nextSecret() {
352	h.secret = hkdfExpandLabel(h.suite.hash(), h.secret, []byte("derived"), h.suite.hash().New().Sum(nil), h.hash.Size(), h.isDTLS)
353}
354
355// hkdfExpandLabel implements TLS 1.3's HKDF-Expand-Label function, as defined
356// in section 7.1 of RFC 8446.
357func hkdfExpandLabel(hash crypto.Hash, secret, label, hashValue []byte, length int, isDTLS bool) []byte {
358	if len(label) > 255 || len(hashValue) > 255 {
359		panic("hkdfExpandLabel: label or hashValue too long")
360	}
361
362	versionLabel := []byte("tls13 ")
363	if isDTLS {
364		versionLabel = []byte("dtls13")
365	}
366	hkdfLabel := make([]byte, 3+len(versionLabel)+len(label)+1+len(hashValue))
367	x := hkdfLabel
368	x[0] = byte(length >> 8)
369	x[1] = byte(length)
370	x[2] = byte(len(versionLabel) + len(label))
371	x = x[3:]
372	copy(x, versionLabel)
373	x = x[len(versionLabel):]
374	copy(x, label)
375	x = x[len(label):]
376	x[0] = byte(len(hashValue))
377	copy(x[1:], hashValue)
378	ret, err := hkdf.Expand(hash.New, secret, string(hkdfLabel), length)
379	if err != nil {
380		panic(err)
381	}
382	return ret
383}
384
385// appendContextHashes returns the concatenation of the handshake hash and the
386// resumption context hash, as used in TLS 1.3.
387func (h *finishedHash) appendContextHashes(b []byte) []byte {
388	b = h.hash.Sum(b)
389	return b
390}
391
392var (
393	externalPSKBinderLabel        = []byte("ext binder")
394	resumptionPSKBinderLabel      = []byte("res binder")
395	earlyTrafficLabel             = []byte("c e traffic")
396	clientHandshakeTrafficLabel   = []byte("c hs traffic")
397	serverHandshakeTrafficLabel   = []byte("s hs traffic")
398	clientApplicationTrafficLabel = []byte("c ap traffic")
399	serverApplicationTrafficLabel = []byte("s ap traffic")
400	applicationTrafficLabel       = []byte("traffic upd")
401	earlyExporterLabel            = []byte("e exp master")
402	exporterLabel                 = []byte("exp master")
403	resumptionLabel               = []byte("res master")
404
405	resumptionPSKLabel = []byte("resumption")
406
407	echAcceptConfirmationLabel    = []byte("ech accept confirmation")
408	echAcceptConfirmationHRRLabel = []byte("hrr ech accept confirmation")
409)
410
411// deriveSecret implements TLS 1.3's Derive-Secret function, as defined in
412// section 7.1 of RFC8446.
413func (h *finishedHash) deriveSecret(label []byte) []byte {
414	return hkdfExpandLabel(h.suite.hash(), h.secret, label, h.appendContextHashes(nil), h.hash.Size(), h.isDTLS)
415}
416
417// echAcceptConfirmation computes the ECH accept confirmation signal, as defined
418// in sections 7.2 and 7.2.1 of draft-ietf-tls-esni-13. The transcript hash is
419// computed by concatenating |h| with |extraMessages|.
420func (h *finishedHash) echAcceptConfirmation(clientRandom, label, extraMessages []byte) []byte {
421	secret, err := hkdf.Extract(h.suite.hash().New, clientRandom, h.zeroSecret())
422	if err != nil {
423		panic(err)
424	}
425	hashCopy := copyHash(h.hash, h.suite.hash())
426	hashCopy.Write(extraMessages)
427	return hkdfExpandLabel(h.suite.hash(), secret, label, hashCopy.Sum(nil), echAcceptConfirmationLength, h.isDTLS)
428}
429
430// The following are context strings for CertificateVerify in TLS 1.3.
431var (
432	clientCertificateVerifyContextTLS13 = []byte("TLS 1.3, client CertificateVerify")
433	serverCertificateVerifyContextTLS13 = []byte("TLS 1.3, server CertificateVerify")
434	channelIDContextTLS13               = []byte("TLS 1.3, Channel ID")
435)
436
437// certificateVerifyMessage returns the input to be signed for CertificateVerify
438// in TLS 1.3.
439func (h *finishedHash) certificateVerifyInput(context []byte) []byte {
440	const paddingLen = 64
441	b := make([]byte, paddingLen, paddingLen+len(context)+1+2*h.hash.Size())
442	for i := 0; i < paddingLen; i++ {
443		b[i] = 32
444	}
445	b = append(b, context...)
446	b = append(b, 0)
447	b = h.appendContextHashes(b)
448	return b
449}
450
451type trafficDirection int
452
453const (
454	clientWrite trafficDirection = iota
455	serverWrite
456)
457
458var (
459	keyTLS13 = []byte("key")
460	ivTLS13  = []byte("iv")
461)
462
463// deriveTrafficAEAD derives traffic keys and constructs an AEAD given a traffic
464// secret.
465func deriveTrafficAEAD(version uint16, suite *cipherSuite, secret []byte, side trafficDirection, isDTLS bool) any {
466	key := hkdfExpandLabel(suite.hash(), secret, keyTLS13, nil, suite.keyLen, isDTLS)
467	iv := hkdfExpandLabel(suite.hash(), secret, ivTLS13, nil, suite.ivLen(version), isDTLS)
468
469	return suite.aead(version, key, iv)
470}
471
472func updateTrafficSecret(hash crypto.Hash, version uint16, secret []byte, isDTLS bool) []byte {
473	return hkdfExpandLabel(hash, secret, applicationTrafficLabel, nil, hash.Size(), isDTLS)
474}
475
476func computePSKBinder(psk []byte, version uint16, isDTLS bool, label []byte, cipherSuite *cipherSuite, clientHello, helloRetryRequest, truncatedHello []byte) []byte {
477	finishedHash := newFinishedHash(version, isDTLS, cipherSuite)
478	finishedHash.addEntropy(psk)
479	binderKey := finishedHash.deriveSecret(label)
480	finishedHash.Write(clientHello)
481	if len(helloRetryRequest) != 0 {
482		finishedHash.UpdateForHelloRetryRequest()
483	}
484	finishedHash.Write(helloRetryRequest)
485	finishedHash.Write(truncatedHello)
486	return finishedHash.clientSum(binderKey)
487}
488
489func deriveSessionPSK(suite *cipherSuite, version uint16, masterSecret []byte, nonce []byte, isDTLS bool) []byte {
490	hash := suite.hash()
491	return hkdfExpandLabel(hash, masterSecret, resumptionPSKLabel, nonce, hash.Size(), isDTLS)
492}
493