1// Copyright 2023 The BoringSSL Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package kyber 16 17// This code is ported from kyber.c. 18 19import ( 20 "crypto/sha3" 21 "crypto/subtle" 22 "io" 23) 24 25const ( 26 CiphertextSize = 1088 27 PublicKeySize = 1184 28 PrivateKeySize = 2400 29) 30 31const ( 32 degree = 256 33 rank = 3 34 prime = 3329 35 log2Prime = 12 36 halfPrime = (prime - 1) / 2 37 du = 10 38 dv = 4 39 inverseDegree = 3303 40 encodedVectorSize = log2Prime * degree / 8 * rank 41 compressedVectorSize = du * rank * degree / 8 42 barrettMultiplier = 5039 43 barrettShift = 24 44) 45 46func reduceOnce(x uint16) uint16 { 47 if x >= 2*prime { 48 panic("reduce_once: value out of range") 49 } 50 subtracted := x - prime 51 mask := 0 - (subtracted >> 15) 52 return (mask & x) | (^mask & subtracted) 53} 54 55func reduce(x uint32) uint16 { 56 if x >= prime+2*prime*prime { 57 panic("reduce: value out of range") 58 } 59 product := uint64(x) * barrettMultiplier 60 quotient := uint32(product >> barrettShift) 61 remainder := uint32(x) - quotient*prime 62 return reduceOnce(uint16(remainder)) 63} 64 65// lt returns 0xff..f if a < b and 0 otherwise 66func lt(a, b uint32) uint32 { 67 return uint32(0 - int32(a^((a^b)|((a-b)^a)))>>31) 68} 69 70// Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping 71// numbers close to each other together. The formula used is 72// round(2^|bits|/prime*x) mod 2^|bits|. 73// Uses Barrett reduction to achieve constant time. Since we need both the 74// remainder (for rounding) and the quotient (as the result), we cannot use 75// |reduce| here, but need to do the Barrett reduction directly. 76func compress(x uint16, bits int) uint16 { 77 product := uint32(x) << bits 78 quotient := uint32((uint64(product) * barrettMultiplier) >> barrettShift) 79 remainder := product - quotient*prime 80 81 // Adjust the quotient to round correctly: 82 // 0 <= remainder <= halfPrime round to 0 83 // halfPrime < remainder <= prime + halfPrime round to 1 84 // prime + halfPrime < remainder < 2 * prime round to 2 85 quotient += 1 & lt(halfPrime, remainder) 86 quotient += 1 & lt(prime+halfPrime, remainder) 87 return uint16(quotient) & ((1 << bits) - 1) 88} 89 90func decompress(x uint16, bits int) uint16 { 91 product := uint32(x) * prime 92 power := uint32(1) << bits 93 // This is |product| % power, since |power| is a power of 2. 94 remainder := product & (power - 1) 95 // This is |product| / power, since |power| is a power of 2. 96 lower := product >> bits 97 // The rounding logic works since the first half of numbers mod |power| have a 98 // 0 as first bit, and the second half has a 1 as first bit, since |power| is 99 // a power of 2. As a 12 bit number, |remainder| is always positive, so we 100 // will shift in 0s for a right shift. 101 return uint16(lower + (remainder >> (bits - 1))) 102} 103 104type scalar [degree]uint16 105 106func (s *scalar) zero() { 107 clear(s[:]) 108} 109 110// This bit of Python will be referenced in some of the following comments: 111// 112// p = 3329 113// 114// def bitreverse(i): 115// ret = 0 116// for n in range(7): 117// bit = i & 1 118// ret <<= 1 119// ret |= bit 120// i >>= 1 121// return ret 122 123// kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] 124var nttRoots = [128]uint16{ 125 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 126 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 127 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 128 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 129 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 130 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 131 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 132 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 133 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 134 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 135 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154, 136} 137 138func (s *scalar) ntt() { 139 offset := degree 140 for step := 1; step < degree/2; step <<= 1 { 141 offset >>= 1 142 k := 0 143 for i := 0; i < step; i++ { 144 stepRoot := uint32(nttRoots[i+step]) 145 for j := k; j < k+offset; j++ { 146 odd := reduce(stepRoot * uint32(s[j+offset])) 147 even := s[j] 148 s[j] = reduceOnce(odd + even) 149 s[j+offset] = reduceOnce(even - odd + prime) 150 } 151 k += 2 * offset 152 } 153 } 154} 155 156// kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] 157var inverseNTTRoots = [128]uint16{ 158 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, 159 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903, 160 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, 161 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, 162 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132, 163 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, 164 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230, 165 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, 166 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, 167 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920, 168 2229, 1041, 2606, 1692, 680, 2746, 568, 3312, 169} 170 171func (s *scalar) inverseNTT() { 172 step := degree / 2 173 for offset := 2; offset < degree; offset <<= 1 { 174 step >>= 1 175 k := 0 176 for i := 0; i < step; i++ { 177 stepRoot := uint32(inverseNTTRoots[i+step]) 178 for j := k; j < k+offset; j++ { 179 odd := s[j+offset] 180 even := s[j] 181 s[j] = reduceOnce(odd + even) 182 s[j+offset] = reduce(stepRoot * uint32(even-odd+prime)) 183 } 184 k += 2 * offset 185 } 186 } 187 for i := range s { 188 s[i] = reduce(uint32(s[i]) * inverseDegree) 189 } 190} 191 192func (s *scalar) add(b *scalar) { 193 for i := range s { 194 s[i] = reduceOnce(s[i] + b[i]) 195 } 196} 197 198func (s *scalar) sub(b *scalar) { 199 for i := range s { 200 s[i] = reduceOnce(s[i] - b[i] + prime) 201 } 202} 203 204// kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] 205var modRoots = [128]uint16{ 206 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 207 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 208 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 209 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 210 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 211 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 212 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 213 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 214 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 215 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 216 2110, 1219, 2935, 394, 885, 2444, 2154, 1175, 217} 218 219func (s *scalar) mult(a, b *scalar) { 220 for i := 0; i < degree/2; i++ { 221 realReal := uint32(a[2*i]) * uint32(b[2*i]) 222 imgImg := uint32(a[2*i+1]) * uint32(b[2*i+1]) 223 realImg := uint32(a[2*i]) * uint32(b[2*i+1]) 224 imgReal := uint32(a[2*i+1]) * uint32(b[2*i]) 225 s[2*i] = reduce(realReal + uint32(reduce(imgImg))*uint32(modRoots[i])) 226 s[2*i+1] = reduce(imgReal + realImg) 227 } 228} 229 230func (s *scalar) innerProduct(left, right *vector) { 231 s.zero() 232 var product scalar 233 for i := range left { 234 product.mult(&left[i], &right[i]) 235 s.add(&product) 236 } 237} 238 239func (s *scalar) fromKeccakVartime(keccak io.Reader) { 240 var buf [3]byte 241 for i := 0; i < len(s); { 242 keccak.Read(buf[:]) 243 d1 := uint16(buf[0]) + 256*uint16(buf[1]%16) 244 d2 := uint16(buf[1])/16 + 16*uint16(buf[2]) 245 if d1 < prime { 246 s[i] = d1 247 i++ 248 } 249 if d2 < prime && i < len(s) { 250 s[i] = d2 251 i++ 252 } 253 } 254} 255 256func (s *scalar) centeredBinomialEta2(input *[33]byte) { 257 entropy := sha3.SumSHAKE256(input[:], 128) 258 259 for i := 0; i < len(s); i += 2 { 260 b := uint16(entropy[i/2]) 261 262 value := uint16(prime) 263 value += (b & 1) + ((b >> 1) & 1) 264 value -= ((b >> 2) & 1) + ((b >> 3) & 1) 265 s[i] = reduceOnce(value) 266 267 b >>= 4 268 value = prime 269 value += (b & 1) + ((b >> 1) & 1) 270 value -= ((b >> 2) & 1) + ((b >> 3) & 1) 271 s[i+1] = reduceOnce(value) 272 } 273} 274 275var masks = [8]uint16{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff} 276 277func (s *scalar) encode(out []byte, bits int) []byte { 278 var outByte byte 279 outByteBits := 0 280 281 for i := range s { 282 element := s[i] 283 elementBitsDone := 0 284 285 for elementBitsDone < bits { 286 chunkBits := bits - elementBitsDone 287 outBitsRemaining := 8 - outByteBits 288 if chunkBits >= outBitsRemaining { 289 chunkBits = outBitsRemaining 290 outByte |= byte(element&masks[chunkBits-1]) << outByteBits 291 out[0] = outByte 292 out = out[1:] 293 outByteBits = 0 294 outByte = 0 295 } else { 296 outByte |= byte(element&masks[chunkBits-1]) << outByteBits 297 outByteBits += chunkBits 298 } 299 300 elementBitsDone += chunkBits 301 element >>= chunkBits 302 } 303 } 304 305 if outByteBits > 0 { 306 out[0] = outByte 307 out = out[1:] 308 } 309 310 return out 311} 312 313func (s *scalar) decode(in []byte, bits int) ([]byte, bool) { 314 var inByte byte 315 inByteBitsLeft := 0 316 317 for i := range s { 318 var element uint16 319 elementBitsDone := 0 320 321 for elementBitsDone < bits { 322 if inByteBitsLeft == 0 { 323 inByte = in[0] 324 in = in[1:] 325 inByteBitsLeft = 8 326 } 327 328 chunkBits := bits - elementBitsDone 329 if chunkBits > inByteBitsLeft { 330 chunkBits = inByteBitsLeft 331 } 332 333 element |= (uint16(inByte) & masks[chunkBits-1]) << elementBitsDone 334 inByteBitsLeft -= chunkBits 335 inByte >>= chunkBits 336 337 elementBitsDone += chunkBits 338 } 339 340 if element >= prime { 341 return nil, false 342 } 343 s[i] = element 344 } 345 346 return in, true 347} 348 349func (s *scalar) compress(bits int) { 350 for i := range s { 351 s[i] = compress(s[i], bits) 352 } 353} 354 355func (s *scalar) decompress(bits int) { 356 for i := range s { 357 s[i] = decompress(s[i], bits) 358 } 359} 360 361type vector [rank]scalar 362 363func (v *vector) zero() { 364 for i := range v { 365 v[i].zero() 366 } 367} 368 369func (v *vector) ntt() { 370 for i := range v { 371 v[i].ntt() 372 } 373} 374 375func (v *vector) inverseNTT() { 376 for i := range v { 377 v[i].inverseNTT() 378 } 379} 380 381func (v *vector) add(b *vector) { 382 for i := range v { 383 v[i].add(&b[i]) 384 } 385} 386 387func (out *vector) mult(m *matrix, v *vector) { 388 out.zero() 389 var product scalar 390 for i := 0; i < rank; i++ { 391 for j := 0; j < rank; j++ { 392 product.mult(&m[i][j], &v[j]) 393 out[i].add(&product) 394 } 395 } 396} 397 398func (out *vector) multTranspose(m *matrix, v *vector) { 399 out.zero() 400 var product scalar 401 for i := 0; i < rank; i++ { 402 for j := 0; j < rank; j++ { 403 product.mult(&m[j][i], &v[j]) 404 out[i].add(&product) 405 } 406 } 407} 408 409func (v *vector) generateSecretEta2(counter *byte, seed *[32]byte) { 410 var input [33]byte 411 copy(input[:], seed[:]) 412 for i := range v { 413 input[32] = *counter 414 *counter++ 415 v[i].centeredBinomialEta2(&input) 416 } 417} 418 419func (v *vector) encode(out []byte, bits int) []byte { 420 for i := range v { 421 out = v[i].encode(out, bits) 422 } 423 return out 424} 425 426func (v *vector) decode(out []byte, bits int) ([]byte, bool) { 427 var ok bool 428 for i := range v { 429 out, ok = v[i].decode(out, bits) 430 if !ok { 431 return nil, false 432 } 433 } 434 435 return out, true 436} 437 438func (v *vector) compress(bits int) { 439 for i := range v { 440 v[i].compress(bits) 441 } 442} 443 444func (v *vector) decompress(bits int) { 445 for i := range v { 446 v[i].decompress(bits) 447 } 448} 449 450type matrix [rank][rank]scalar 451 452func (m *matrix) expand(rho *[32]byte) { 453 shake := sha3.NewSHAKE128() 454 455 var input [34]byte 456 copy(input[:], rho[:]) 457 458 for i := 0; i < rank; i++ { 459 for j := 0; j < rank; j++ { 460 input[32] = byte(i) 461 input[33] = byte(j) 462 463 shake.Reset() 464 shake.Write(input[:]) 465 m[i][j].fromKeccakVartime(shake) 466 } 467 } 468} 469 470type PublicKey struct { 471 t vector 472 rho [32]byte 473 publicKeyHash [32]byte 474 m matrix 475} 476 477func UnmarshalPublicKey(data *[PublicKeySize]byte) (*PublicKey, bool) { 478 var ret PublicKey 479 ret.publicKeyHash = sha3.Sum256(data[:]) 480 in, ok := ret.t.decode(data[:], log2Prime) 481 if !ok { 482 return nil, false 483 } 484 copy(ret.rho[:], in) 485 ret.m.expand(&ret.rho) 486 return &ret, true 487} 488 489func (pub *PublicKey) Marshal() *[PublicKeySize]byte { 490 var ret [PublicKeySize]byte 491 out := pub.t.encode(ret[:], log2Prime) 492 copy(out, pub.rho[:]) 493 return &ret 494} 495 496func (pub *PublicKey) encryptCPA(message, entropy *[32]byte) *[CiphertextSize]byte { 497 var counter uint8 498 var secret, error vector 499 secret.generateSecretEta2(&counter, entropy) 500 error.generateSecretEta2(&counter, entropy) 501 secret.ntt() 502 503 var input [33]byte 504 copy(input[:], entropy[:]) 505 input[32] = counter 506 var scalarError scalar 507 scalarError.centeredBinomialEta2(&input) 508 509 var u vector 510 u.mult(&pub.m, &secret) 511 u.inverseNTT() 512 u.add(&error) 513 514 var v scalar 515 v.innerProduct(&pub.t, &secret) 516 v.inverseNTT() 517 v.add(&scalarError) 518 519 out := make([]byte, CiphertextSize) 520 var expandedMessage scalar 521 expandedMessage.decode(message[:], 1) 522 expandedMessage.decompress(1) 523 v.add(&expandedMessage) 524 u.compress(du) 525 it := u.encode(out, du) 526 v.compress(dv) 527 v.encode(it, dv) 528 return (*[CiphertextSize]byte)(out) 529} 530 531func (pub *PublicKey) Encap(outSharedSecret []byte, entropy *[32]byte) *[CiphertextSize]byte { 532 var input [64]byte 533 copy(input[:], entropy[:]) 534 copy(input[32:], pub.publicKeyHash[:]) 535 prekeyAndRandomness := sha3.Sum512(input[:]) 536 ciphertext := pub.encryptCPA(entropy, (*[32]byte)(prekeyAndRandomness[32:])) 537 ciphertextHash := sha3.Sum256(ciphertext[:]) 538 copy(prekeyAndRandomness[32:], ciphertextHash[:]) 539 copy(outSharedSecret, sha3.SumSHAKE256(prekeyAndRandomness[:], len(outSharedSecret))) 540 return ciphertext 541} 542 543type PrivateKey struct { 544 PublicKey 545 s vector 546 foFailureSecret [32]byte 547} 548 549func NewPrivateKey(entropy *[64]byte) (*PrivateKey, *[PublicKeySize]byte) { 550 hashed := sha3.Sum512(entropy[:32]) 551 rho := (*[32]byte)(hashed[:32]) 552 sigma := (*[32]byte)(hashed[32:]) 553 ret := new(PrivateKey) 554 copy(ret.foFailureSecret[:], entropy[32:]) 555 copy(ret.rho[:], rho[:]) 556 ret.m.expand(rho) 557 counter := uint8(0) 558 ret.s.generateSecretEta2(&counter, sigma) 559 ret.s.ntt() 560 var error vector 561 error.generateSecretEta2(&counter, sigma) 562 error.ntt() 563 ret.t.multTranspose(&ret.m, &ret.s) 564 ret.t.add(&error) 565 566 marshalledPublicKey := ret.PublicKey.Marshal() 567 ret.publicKeyHash = sha3.Sum256(marshalledPublicKey[:]) 568 569 return ret, marshalledPublicKey 570} 571 572func (priv *PrivateKey) decryptCPA(ciphertext *[CiphertextSize]byte) [32]byte { 573 var u vector 574 u.decode(ciphertext[:], du) 575 u.decompress(du) 576 u.ntt() 577 578 var v scalar 579 v.decode(ciphertext[compressedVectorSize:], dv) 580 v.decompress(dv) 581 582 var mask scalar 583 mask.innerProduct(&priv.s, &u) 584 mask.inverseNTT() 585 v.sub(&mask) 586 v.compress(1) 587 var out [32]byte 588 v.encode(out[:], 1) 589 return out 590} 591 592func (priv *PrivateKey) Decap(outSharedSecret []byte, ciphertext *[CiphertextSize]byte) { 593 decrypted := priv.decryptCPA(ciphertext) 594 h := sha3.New512() 595 h.Write(decrypted[:]) 596 h.Write(priv.publicKeyHash[:]) 597 prekeyAndRandomness := h.Sum(nil) 598 expectedCiphertext := priv.encryptCPA(&decrypted, (*[32]byte)(prekeyAndRandomness[32:])) 599 equal := subtle.ConstantTimeCompare(ciphertext[:], expectedCiphertext[:]) 600 var secret [32]byte 601 for i := range secret { 602 secret[i] = byte(subtle.ConstantTimeSelect(equal, int(prekeyAndRandomness[i]), int(priv.foFailureSecret[i]))) 603 } 604 ciphertextHash := sha3.Sum256(ciphertext[:]) 605 606 shake := sha3.NewSHAKE256() 607 shake.Write(secret[:]) 608 shake.Write(ciphertextHash[:]) 609 shake.Read(outSharedSecret) 610} 611 612func (priv *PrivateKey) Marshal() *[PrivateKeySize]byte { 613 var ret [PrivateKeySize]byte 614 out := priv.s.encode(ret[:], log2Prime) 615 publicKey := priv.PublicKey.Marshal() 616 n := copy(out, publicKey[:]) 617 out = out[n:] 618 n = copy(out, priv.publicKeyHash[:]) 619 out = out[n:] 620 copy(out, priv.foFailureSecret[:]) 621 return &ret 622} 623