1// Copyright 2023 The BoringSSL Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package kyber
16
17// This code is ported from kyber.c.
18
19import (
20	"crypto/sha3"
21	"crypto/subtle"
22	"io"
23)
24
25const (
26	CiphertextSize = 1088
27	PublicKeySize  = 1184
28	PrivateKeySize = 2400
29)
30
31const (
32	degree               = 256
33	rank                 = 3
34	prime                = 3329
35	log2Prime            = 12
36	halfPrime            = (prime - 1) / 2
37	du                   = 10
38	dv                   = 4
39	inverseDegree        = 3303
40	encodedVectorSize    = log2Prime * degree / 8 * rank
41	compressedVectorSize = du * rank * degree / 8
42	barrettMultiplier    = 5039
43	barrettShift         = 24
44)
45
46func reduceOnce(x uint16) uint16 {
47	if x >= 2*prime {
48		panic("reduce_once: value out of range")
49	}
50	subtracted := x - prime
51	mask := 0 - (subtracted >> 15)
52	return (mask & x) | (^mask & subtracted)
53}
54
55func reduce(x uint32) uint16 {
56	if x >= prime+2*prime*prime {
57		panic("reduce: value out of range")
58	}
59	product := uint64(x) * barrettMultiplier
60	quotient := uint32(product >> barrettShift)
61	remainder := uint32(x) - quotient*prime
62	return reduceOnce(uint16(remainder))
63}
64
65// lt returns 0xff..f if a < b and 0 otherwise
66func lt(a, b uint32) uint32 {
67	return uint32(0 - int32(a^((a^b)|((a-b)^a)))>>31)
68}
69
70// Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
71// numbers close to each other together. The formula used is
72// round(2^|bits|/prime*x) mod 2^|bits|.
73// Uses Barrett reduction to achieve constant time. Since we need both the
74// remainder (for rounding) and the quotient (as the result), we cannot use
75// |reduce| here, but need to do the Barrett reduction directly.
76func compress(x uint16, bits int) uint16 {
77	product := uint32(x) << bits
78	quotient := uint32((uint64(product) * barrettMultiplier) >> barrettShift)
79	remainder := product - quotient*prime
80
81	// Adjust the quotient to round correctly:
82	//   0 <= remainder <= halfPrime round to 0
83	//   halfPrime < remainder <= prime + halfPrime round to 1
84	//   prime + halfPrime < remainder < 2 * prime round to 2
85	quotient += 1 & lt(halfPrime, remainder)
86	quotient += 1 & lt(prime+halfPrime, remainder)
87	return uint16(quotient) & ((1 << bits) - 1)
88}
89
90func decompress(x uint16, bits int) uint16 {
91	product := uint32(x) * prime
92	power := uint32(1) << bits
93	// This is |product| % power, since |power| is a power of 2.
94	remainder := product & (power - 1)
95	// This is |product| / power, since |power| is a power of 2.
96	lower := product >> bits
97	// The rounding logic works since the first half of numbers mod |power| have a
98	// 0 as first bit, and the second half has a 1 as first bit, since |power| is
99	// a power of 2. As a 12 bit number, |remainder| is always positive, so we
100	// will shift in 0s for a right shift.
101	return uint16(lower + (remainder >> (bits - 1)))
102}
103
104type scalar [degree]uint16
105
106func (s *scalar) zero() {
107	clear(s[:])
108}
109
110// This bit of Python will be referenced in some of the following comments:
111//
112// p = 3329
113//
114// def bitreverse(i):
115//     ret = 0
116//     for n in range(7):
117//         bit = i & 1
118//         ret <<= 1
119//         ret |= bit
120//         i >>= 1
121//     return ret
122
123// kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
124var nttRoots = [128]uint16{
125	1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
126	2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
127	1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
128	1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
129	2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
130	2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
131	1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
132	1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
133	1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
134	2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
135	1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
136}
137
138func (s *scalar) ntt() {
139	offset := degree
140	for step := 1; step < degree/2; step <<= 1 {
141		offset >>= 1
142		k := 0
143		for i := 0; i < step; i++ {
144			stepRoot := uint32(nttRoots[i+step])
145			for j := k; j < k+offset; j++ {
146				odd := reduce(stepRoot * uint32(s[j+offset]))
147				even := s[j]
148				s[j] = reduceOnce(odd + even)
149				s[j+offset] = reduceOnce(even - odd + prime)
150			}
151			k += 2 * offset
152		}
153	}
154}
155
156// kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
157var inverseNTTRoots = [128]uint16{
158	1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
159	2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
160	1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
161	2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
162	1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
163	1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
164	2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
165	2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
166	2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
167	1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
168	2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
169}
170
171func (s *scalar) inverseNTT() {
172	step := degree / 2
173	for offset := 2; offset < degree; offset <<= 1 {
174		step >>= 1
175		k := 0
176		for i := 0; i < step; i++ {
177			stepRoot := uint32(inverseNTTRoots[i+step])
178			for j := k; j < k+offset; j++ {
179				odd := s[j+offset]
180				even := s[j]
181				s[j] = reduceOnce(odd + even)
182				s[j+offset] = reduce(stepRoot * uint32(even-odd+prime))
183			}
184			k += 2 * offset
185		}
186	}
187	for i := range s {
188		s[i] = reduce(uint32(s[i]) * inverseDegree)
189	}
190}
191
192func (s *scalar) add(b *scalar) {
193	for i := range s {
194		s[i] = reduceOnce(s[i] + b[i])
195	}
196}
197
198func (s *scalar) sub(b *scalar) {
199	for i := range s {
200		s[i] = reduceOnce(s[i] - b[i] + prime)
201	}
202}
203
204// kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
205var modRoots = [128]uint16{
206	17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
207	2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
208	756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
209	2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
210	939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
211	268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
212	375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
213	2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
214	2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
215	2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
216	2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
217}
218
219func (s *scalar) mult(a, b *scalar) {
220	for i := 0; i < degree/2; i++ {
221		realReal := uint32(a[2*i]) * uint32(b[2*i])
222		imgImg := uint32(a[2*i+1]) * uint32(b[2*i+1])
223		realImg := uint32(a[2*i]) * uint32(b[2*i+1])
224		imgReal := uint32(a[2*i+1]) * uint32(b[2*i])
225		s[2*i] = reduce(realReal + uint32(reduce(imgImg))*uint32(modRoots[i]))
226		s[2*i+1] = reduce(imgReal + realImg)
227	}
228}
229
230func (s *scalar) innerProduct(left, right *vector) {
231	s.zero()
232	var product scalar
233	for i := range left {
234		product.mult(&left[i], &right[i])
235		s.add(&product)
236	}
237}
238
239func (s *scalar) fromKeccakVartime(keccak io.Reader) {
240	var buf [3]byte
241	for i := 0; i < len(s); {
242		keccak.Read(buf[:])
243		d1 := uint16(buf[0]) + 256*uint16(buf[1]%16)
244		d2 := uint16(buf[1])/16 + 16*uint16(buf[2])
245		if d1 < prime {
246			s[i] = d1
247			i++
248		}
249		if d2 < prime && i < len(s) {
250			s[i] = d2
251			i++
252		}
253	}
254}
255
256func (s *scalar) centeredBinomialEta2(input *[33]byte) {
257	entropy := sha3.SumSHAKE256(input[:], 128)
258
259	for i := 0; i < len(s); i += 2 {
260		b := uint16(entropy[i/2])
261
262		value := uint16(prime)
263		value += (b & 1) + ((b >> 1) & 1)
264		value -= ((b >> 2) & 1) + ((b >> 3) & 1)
265		s[i] = reduceOnce(value)
266
267		b >>= 4
268		value = prime
269		value += (b & 1) + ((b >> 1) & 1)
270		value -= ((b >> 2) & 1) + ((b >> 3) & 1)
271		s[i+1] = reduceOnce(value)
272	}
273}
274
275var masks = [8]uint16{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff}
276
277func (s *scalar) encode(out []byte, bits int) []byte {
278	var outByte byte
279	outByteBits := 0
280
281	for i := range s {
282		element := s[i]
283		elementBitsDone := 0
284
285		for elementBitsDone < bits {
286			chunkBits := bits - elementBitsDone
287			outBitsRemaining := 8 - outByteBits
288			if chunkBits >= outBitsRemaining {
289				chunkBits = outBitsRemaining
290				outByte |= byte(element&masks[chunkBits-1]) << outByteBits
291				out[0] = outByte
292				out = out[1:]
293				outByteBits = 0
294				outByte = 0
295			} else {
296				outByte |= byte(element&masks[chunkBits-1]) << outByteBits
297				outByteBits += chunkBits
298			}
299
300			elementBitsDone += chunkBits
301			element >>= chunkBits
302		}
303	}
304
305	if outByteBits > 0 {
306		out[0] = outByte
307		out = out[1:]
308	}
309
310	return out
311}
312
313func (s *scalar) decode(in []byte, bits int) ([]byte, bool) {
314	var inByte byte
315	inByteBitsLeft := 0
316
317	for i := range s {
318		var element uint16
319		elementBitsDone := 0
320
321		for elementBitsDone < bits {
322			if inByteBitsLeft == 0 {
323				inByte = in[0]
324				in = in[1:]
325				inByteBitsLeft = 8
326			}
327
328			chunkBits := bits - elementBitsDone
329			if chunkBits > inByteBitsLeft {
330				chunkBits = inByteBitsLeft
331			}
332
333			element |= (uint16(inByte) & masks[chunkBits-1]) << elementBitsDone
334			inByteBitsLeft -= chunkBits
335			inByte >>= chunkBits
336
337			elementBitsDone += chunkBits
338		}
339
340		if element >= prime {
341			return nil, false
342		}
343		s[i] = element
344	}
345
346	return in, true
347}
348
349func (s *scalar) compress(bits int) {
350	for i := range s {
351		s[i] = compress(s[i], bits)
352	}
353}
354
355func (s *scalar) decompress(bits int) {
356	for i := range s {
357		s[i] = decompress(s[i], bits)
358	}
359}
360
361type vector [rank]scalar
362
363func (v *vector) zero() {
364	for i := range v {
365		v[i].zero()
366	}
367}
368
369func (v *vector) ntt() {
370	for i := range v {
371		v[i].ntt()
372	}
373}
374
375func (v *vector) inverseNTT() {
376	for i := range v {
377		v[i].inverseNTT()
378	}
379}
380
381func (v *vector) add(b *vector) {
382	for i := range v {
383		v[i].add(&b[i])
384	}
385}
386
387func (out *vector) mult(m *matrix, v *vector) {
388	out.zero()
389	var product scalar
390	for i := 0; i < rank; i++ {
391		for j := 0; j < rank; j++ {
392			product.mult(&m[i][j], &v[j])
393			out[i].add(&product)
394		}
395	}
396}
397
398func (out *vector) multTranspose(m *matrix, v *vector) {
399	out.zero()
400	var product scalar
401	for i := 0; i < rank; i++ {
402		for j := 0; j < rank; j++ {
403			product.mult(&m[j][i], &v[j])
404			out[i].add(&product)
405		}
406	}
407}
408
409func (v *vector) generateSecretEta2(counter *byte, seed *[32]byte) {
410	var input [33]byte
411	copy(input[:], seed[:])
412	for i := range v {
413		input[32] = *counter
414		*counter++
415		v[i].centeredBinomialEta2(&input)
416	}
417}
418
419func (v *vector) encode(out []byte, bits int) []byte {
420	for i := range v {
421		out = v[i].encode(out, bits)
422	}
423	return out
424}
425
426func (v *vector) decode(out []byte, bits int) ([]byte, bool) {
427	var ok bool
428	for i := range v {
429		out, ok = v[i].decode(out, bits)
430		if !ok {
431			return nil, false
432		}
433	}
434
435	return out, true
436}
437
438func (v *vector) compress(bits int) {
439	for i := range v {
440		v[i].compress(bits)
441	}
442}
443
444func (v *vector) decompress(bits int) {
445	for i := range v {
446		v[i].decompress(bits)
447	}
448}
449
450type matrix [rank][rank]scalar
451
452func (m *matrix) expand(rho *[32]byte) {
453	shake := sha3.NewSHAKE128()
454
455	var input [34]byte
456	copy(input[:], rho[:])
457
458	for i := 0; i < rank; i++ {
459		for j := 0; j < rank; j++ {
460			input[32] = byte(i)
461			input[33] = byte(j)
462
463			shake.Reset()
464			shake.Write(input[:])
465			m[i][j].fromKeccakVartime(shake)
466		}
467	}
468}
469
470type PublicKey struct {
471	t             vector
472	rho           [32]byte
473	publicKeyHash [32]byte
474	m             matrix
475}
476
477func UnmarshalPublicKey(data *[PublicKeySize]byte) (*PublicKey, bool) {
478	var ret PublicKey
479	ret.publicKeyHash = sha3.Sum256(data[:])
480	in, ok := ret.t.decode(data[:], log2Prime)
481	if !ok {
482		return nil, false
483	}
484	copy(ret.rho[:], in)
485	ret.m.expand(&ret.rho)
486	return &ret, true
487}
488
489func (pub *PublicKey) Marshal() *[PublicKeySize]byte {
490	var ret [PublicKeySize]byte
491	out := pub.t.encode(ret[:], log2Prime)
492	copy(out, pub.rho[:])
493	return &ret
494}
495
496func (pub *PublicKey) encryptCPA(message, entropy *[32]byte) *[CiphertextSize]byte {
497	var counter uint8
498	var secret, error vector
499	secret.generateSecretEta2(&counter, entropy)
500	error.generateSecretEta2(&counter, entropy)
501	secret.ntt()
502
503	var input [33]byte
504	copy(input[:], entropy[:])
505	input[32] = counter
506	var scalarError scalar
507	scalarError.centeredBinomialEta2(&input)
508
509	var u vector
510	u.mult(&pub.m, &secret)
511	u.inverseNTT()
512	u.add(&error)
513
514	var v scalar
515	v.innerProduct(&pub.t, &secret)
516	v.inverseNTT()
517	v.add(&scalarError)
518
519	out := make([]byte, CiphertextSize)
520	var expandedMessage scalar
521	expandedMessage.decode(message[:], 1)
522	expandedMessage.decompress(1)
523	v.add(&expandedMessage)
524	u.compress(du)
525	it := u.encode(out, du)
526	v.compress(dv)
527	v.encode(it, dv)
528	return (*[CiphertextSize]byte)(out)
529}
530
531func (pub *PublicKey) Encap(outSharedSecret []byte, entropy *[32]byte) *[CiphertextSize]byte {
532	var input [64]byte
533	copy(input[:], entropy[:])
534	copy(input[32:], pub.publicKeyHash[:])
535	prekeyAndRandomness := sha3.Sum512(input[:])
536	ciphertext := pub.encryptCPA(entropy, (*[32]byte)(prekeyAndRandomness[32:]))
537	ciphertextHash := sha3.Sum256(ciphertext[:])
538	copy(prekeyAndRandomness[32:], ciphertextHash[:])
539	copy(outSharedSecret, sha3.SumSHAKE256(prekeyAndRandomness[:], len(outSharedSecret)))
540	return ciphertext
541}
542
543type PrivateKey struct {
544	PublicKey
545	s               vector
546	foFailureSecret [32]byte
547}
548
549func NewPrivateKey(entropy *[64]byte) (*PrivateKey, *[PublicKeySize]byte) {
550	hashed := sha3.Sum512(entropy[:32])
551	rho := (*[32]byte)(hashed[:32])
552	sigma := (*[32]byte)(hashed[32:])
553	ret := new(PrivateKey)
554	copy(ret.foFailureSecret[:], entropy[32:])
555	copy(ret.rho[:], rho[:])
556	ret.m.expand(rho)
557	counter := uint8(0)
558	ret.s.generateSecretEta2(&counter, sigma)
559	ret.s.ntt()
560	var error vector
561	error.generateSecretEta2(&counter, sigma)
562	error.ntt()
563	ret.t.multTranspose(&ret.m, &ret.s)
564	ret.t.add(&error)
565
566	marshalledPublicKey := ret.PublicKey.Marshal()
567	ret.publicKeyHash = sha3.Sum256(marshalledPublicKey[:])
568
569	return ret, marshalledPublicKey
570}
571
572func (priv *PrivateKey) decryptCPA(ciphertext *[CiphertextSize]byte) [32]byte {
573	var u vector
574	u.decode(ciphertext[:], du)
575	u.decompress(du)
576	u.ntt()
577
578	var v scalar
579	v.decode(ciphertext[compressedVectorSize:], dv)
580	v.decompress(dv)
581
582	var mask scalar
583	mask.innerProduct(&priv.s, &u)
584	mask.inverseNTT()
585	v.sub(&mask)
586	v.compress(1)
587	var out [32]byte
588	v.encode(out[:], 1)
589	return out
590}
591
592func (priv *PrivateKey) Decap(outSharedSecret []byte, ciphertext *[CiphertextSize]byte) {
593	decrypted := priv.decryptCPA(ciphertext)
594	h := sha3.New512()
595	h.Write(decrypted[:])
596	h.Write(priv.publicKeyHash[:])
597	prekeyAndRandomness := h.Sum(nil)
598	expectedCiphertext := priv.encryptCPA(&decrypted, (*[32]byte)(prekeyAndRandomness[32:]))
599	equal := subtle.ConstantTimeCompare(ciphertext[:], expectedCiphertext[:])
600	var secret [32]byte
601	for i := range secret {
602		secret[i] = byte(subtle.ConstantTimeSelect(equal, int(prekeyAndRandomness[i]), int(priv.foFailureSecret[i])))
603	}
604	ciphertextHash := sha3.Sum256(ciphertext[:])
605
606	shake := sha3.NewSHAKE256()
607	shake.Write(secret[:])
608	shake.Write(ciphertextHash[:])
609	shake.Read(outSharedSecret)
610}
611
612func (priv *PrivateKey) Marshal() *[PrivateKeySize]byte {
613	var ret [PrivateKeySize]byte
614	out := priv.s.encode(ret[:], log2Prime)
615	publicKey := priv.PublicKey.Marshal()
616	n := copy(out, publicKey[:])
617	out = out[n:]
618	n = copy(out, priv.publicKeyHash[:])
619	out = out[n:]
620	copy(out, priv.foFailureSecret[:])
621	return &ret
622}
623