1 // Copyright 2024 The BoringSSL Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <openssl/mldsa.h>
16 
17 #include <memory>
18 #include <vector>
19 
20 #include <gtest/gtest.h>
21 
22 #include <openssl/bytestring.h>
23 #include <openssl/mem.h>
24 #include <openssl/span.h>
25 
26 #include "../fipsmodule/bcm_interface.h"
27 #include "../internal.h"
28 #include "../test/file_test.h"
29 #include "../test/test_util.h"
30 
31 
32 namespace {
33 
34 template <typename T>
Marshal(bcm_status (* marshal_func)(CBB *,const T *),const T * t)35 std::vector<uint8_t> Marshal(bcm_status (*marshal_func)(CBB *, const T *),
36                              const T *t) {
37   bssl::ScopedCBB cbb;
38   uint8_t *encoded;
39   size_t encoded_len;
40   if (!CBB_init(cbb.get(), 1) ||                             //
41       marshal_func(cbb.get(), t) != bcm_status::approved ||  //
42       !CBB_finish(cbb.get(), &encoded, &encoded_len)) {
43     abort();
44   }
45 
46   std::vector<uint8_t> ret(encoded, encoded + encoded_len);
47   OPENSSL_free(encoded);
48   return ret;
49 }
50 
51 // This test is very slow, so it is disabled by default.
TEST(MLDSATest,DISABLED_BitFlips)52 TEST(MLDSATest, DISABLED_BitFlips) {
53   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
54   auto priv = std::make_unique<MLDSA65_private_key>();
55   uint8_t seed[MLDSA_SEED_BYTES];
56   EXPECT_TRUE(
57       MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
58 
59   std::vector<uint8_t> encoded_signature(MLDSA65_SIGNATURE_BYTES);
60   static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
61                                      'w', 'o', 'r', 'l', 'd'};
62   EXPECT_TRUE(MLDSA65_sign(encoded_signature.data(), priv.get(), kMessage,
63                            sizeof(kMessage), nullptr, 0));
64 
65   auto pub = std::make_unique<MLDSA65_public_key>();
66   CBS cbs = CBS(encoded_public_key);
67   ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs));
68 
69   EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(),
70                            encoded_signature.size(), kMessage, sizeof(kMessage),
71                            nullptr, 0),
72             1);
73 
74   for (size_t i = 0; i < MLDSA65_SIGNATURE_BYTES; i++) {
75     for (int j = 0; j < 8; j++) {
76       encoded_signature[i] ^= 1 << j;
77       EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(),
78                                encoded_signature.size(), kMessage,
79                                sizeof(kMessage), nullptr, 0),
80                 0)
81           << "Bit flip in signature at byte " << i << " bit " << j
82           << " didn't cause a verification failure";
83       encoded_signature[i] ^= 1 << j;
84     }
85   }
86 }
87 
88 template <
89     typename PrivateKey, typename PublicKey, size_t PublicKeyBytes,
90     size_t SignatureBytes, int (*Generate)(uint8_t *, uint8_t *, PrivateKey *),
91     int (*Sign)(uint8_t *, const PrivateKey *, const uint8_t *, size_t,
92                 const uint8_t *, size_t),
93     int (*ParsePublicKey)(PublicKey *, CBS *),
94     int (*Verify)(const PublicKey *, const uint8_t *, size_t, const uint8_t *,
95                   size_t, const uint8_t *, size_t),
96     int (*PrivateKeyFromSeed)(PrivateKey *, const uint8_t *, size_t),
97     typename BCMPrivateKey, bcm_status (*ParsePrivate)(BCMPrivateKey *, CBS *),
98     bcm_status (*MarshalPrivate)(CBB *, const BCMPrivateKey *)>
MLDSABasicTest()99 static void MLDSABasicTest() {
100   std::vector<uint8_t> encoded_public_key(PublicKeyBytes);
101   auto priv = std::make_unique<PrivateKey>();
102   uint8_t seed[MLDSA_SEED_BYTES];
103   EXPECT_TRUE(Generate(encoded_public_key.data(), seed, priv.get()));
104 
105   const std::vector<uint8_t> encoded_private_key =
106       Marshal(MarshalPrivate, reinterpret_cast<BCMPrivateKey *>(priv.get()));
107   CBS cbs = CBS(encoded_private_key);
108   EXPECT_TRUE(bcm_success(
109       ParsePrivate(reinterpret_cast<BCMPrivateKey *>(priv.get()), &cbs)));
110 
111   std::vector<uint8_t> encoded_signature(SignatureBytes);
112   static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
113                                      'w', 'o', 'r', 'l', 'd'};
114   static const uint8_t kContext[] = {'c', 't', 'x'};
115   EXPECT_TRUE(Sign(encoded_signature.data(), priv.get(), kMessage,
116                    sizeof(kMessage), kContext, sizeof(kContext)));
117 
118   auto pub = std::make_unique<PublicKey>();
119   cbs = CBS(encoded_public_key);
120   ASSERT_TRUE(ParsePublicKey(pub.get(), &cbs));
121 
122   EXPECT_EQ(
123       Verify(pub.get(), encoded_signature.data(), encoded_signature.size(),
124              kMessage, sizeof(kMessage), kContext, sizeof(kContext)),
125       1);
126 
127   auto priv2 = std::make_unique<PrivateKey>();
128   EXPECT_TRUE(PrivateKeyFromSeed(priv2.get(), seed, sizeof(seed)));
129 
130   EXPECT_EQ(
131       Bytes(Declassified(Marshal(
132           MarshalPrivate, reinterpret_cast<BCMPrivateKey *>(priv.get())))),
133       Bytes(Declassified(Marshal(
134           MarshalPrivate, reinterpret_cast<BCMPrivateKey *>(priv2.get())))));
135 }
136 
TEST(MLDSATest,Basic65)137 TEST(MLDSATest, Basic65) {
138   MLDSABasicTest<MLDSA65_private_key, MLDSA65_public_key,
139                  MLDSA65_PUBLIC_KEY_BYTES, MLDSA65_SIGNATURE_BYTES,
140                  MLDSA65_generate_key, MLDSA65_sign, MLDSA65_parse_public_key,
141                  MLDSA65_verify, MLDSA65_private_key_from_seed,
142                  BCM_mldsa65_private_key, BCM_mldsa65_parse_private_key,
143                  BCM_mldsa65_marshal_private_key>();
144 }
145 
TEST(MLDSATest,Basic87)146 TEST(MLDSATest, Basic87) {
147   MLDSABasicTest<MLDSA87_private_key, MLDSA87_public_key,
148                  BCM_MLDSA87_PUBLIC_KEY_BYTES, BCM_MLDSA87_SIGNATURE_BYTES,
149                  MLDSA87_generate_key, MLDSA87_sign, MLDSA87_parse_public_key,
150                  MLDSA87_verify, MLDSA87_private_key_from_seed,
151                  BCM_mldsa87_private_key, BCM_mldsa87_parse_private_key,
152                  BCM_mldsa87_marshal_private_key>();
153 }
154 
TEST(MLDSATest,Basic44)155 TEST(MLDSATest, Basic44) {
156   MLDSABasicTest<MLDSA44_private_key, MLDSA44_public_key,
157                  BCM_MLDSA44_PUBLIC_KEY_BYTES, BCM_MLDSA44_SIGNATURE_BYTES,
158                  MLDSA44_generate_key, MLDSA44_sign, MLDSA44_parse_public_key,
159                  MLDSA44_verify, MLDSA44_private_key_from_seed,
160                  BCM_mldsa44_private_key, BCM_mldsa44_parse_private_key,
161                  BCM_mldsa44_marshal_private_key>();
162 }
163 
TEST(MLDSATest,SignatureIsRandomized)164 TEST(MLDSATest, SignatureIsRandomized) {
165   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
166   auto priv = std::make_unique<MLDSA65_private_key>();
167   uint8_t seed[MLDSA_SEED_BYTES];
168   EXPECT_TRUE(
169       MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
170 
171   auto pub = std::make_unique<MLDSA65_public_key>();
172   CBS cbs = CBS(encoded_public_key);
173   ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs));
174 
175   std::vector<uint8_t> encoded_signature1(MLDSA65_SIGNATURE_BYTES);
176   std::vector<uint8_t> encoded_signature2(MLDSA65_SIGNATURE_BYTES);
177   static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
178                                      'w', 'o', 'r', 'l', 'd'};
179   EXPECT_TRUE(MLDSA65_sign(encoded_signature1.data(), priv.get(), kMessage,
180                            sizeof(kMessage), nullptr, 0));
181   EXPECT_TRUE(MLDSA65_sign(encoded_signature2.data(), priv.get(), kMessage,
182                            sizeof(kMessage), nullptr, 0));
183 
184   EXPECT_NE(Bytes(encoded_signature1), Bytes(encoded_signature2));
185 
186   // Even though the signatures are different, they both verify.
187   EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature1.data(),
188                            encoded_signature1.size(), kMessage,
189                            sizeof(kMessage), nullptr, 0),
190             1);
191   EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature2.data(),
192                            encoded_signature2.size(), kMessage,
193                            sizeof(kMessage), nullptr, 0),
194             1);
195 }
196 
TEST(MLDSATest,PrehashedSignatureVerifies)197 TEST(MLDSATest, PrehashedSignatureVerifies) {
198   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
199   auto priv = std::make_unique<MLDSA65_private_key>();
200   uint8_t seed[MLDSA_SEED_BYTES];
201   EXPECT_TRUE(
202       MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
203 
204   auto pub = std::make_unique<MLDSA65_public_key>();
205   CBS cbs = CBS(encoded_public_key);
206   ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs));
207 
208   std::vector<uint8_t> encoded_signature(MLDSA65_SIGNATURE_BYTES);
209   static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
210                                      'w', 'o', 'r', 'l', 'd'};
211 
212   MLDSA65_prehash prehash_state;
213   EXPECT_TRUE(MLDSA65_prehash_init(&prehash_state, pub.get(), nullptr, 0));
214   MLDSA65_prehash_update(&prehash_state, kMessage, sizeof(kMessage));
215   uint8_t representative[MLDSA_MU_BYTES];
216   MLDSA65_prehash_finalize(representative, &prehash_state);
217   EXPECT_TRUE(MLDSA65_sign_message_representative(encoded_signature.data(),
218                                                   priv.get(), representative));
219 
220   EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(),
221                            encoded_signature.size(), kMessage, sizeof(kMessage),
222                            nullptr, 0),
223             1);
224 
225   // Updating in multiple chunks also works.
226   for (size_t i = 0; i <= sizeof(kMessage); ++i) {
227     for (size_t j = i; j <= sizeof(kMessage); ++j) {
228       EXPECT_TRUE(MLDSA65_prehash_init(&prehash_state, pub.get(), nullptr, 0));
229       MLDSA65_prehash_update(&prehash_state, kMessage, i);
230       MLDSA65_prehash_update(&prehash_state, kMessage + i, j - i);
231       MLDSA65_prehash_update(&prehash_state, kMessage + j,
232                              sizeof(kMessage) - j);
233       MLDSA65_prehash_finalize(representative, &prehash_state);
234       EXPECT_TRUE(MLDSA65_sign_message_representative(
235           encoded_signature.data(), priv.get(), representative));
236 
237       EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(),
238                                encoded_signature.size(), kMessage,
239                                sizeof(kMessage), nullptr, 0),
240                 1);
241     }
242   }
243 }
244 
TEST(MLDSATest,PublicFromPrivateIsConsistent)245 TEST(MLDSATest, PublicFromPrivateIsConsistent) {
246   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
247   auto priv = std::make_unique<MLDSA65_private_key>();
248   uint8_t seed[MLDSA_SEED_BYTES];
249   EXPECT_TRUE(
250       MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
251 
252   auto pub = std::make_unique<MLDSA65_public_key>();
253   EXPECT_TRUE(MLDSA65_public_from_private(pub.get(), priv.get()));
254 
255   std::vector<uint8_t> encoded_public_key2(MLDSA65_PUBLIC_KEY_BYTES);
256 
257   CBB cbb;
258   CBB_init_fixed(&cbb, encoded_public_key2.data(), encoded_public_key2.size());
259   ASSERT_TRUE(MLDSA65_marshal_public_key(&cbb, pub.get()));
260 
261   EXPECT_EQ(Bytes(encoded_public_key2), Bytes(encoded_public_key));
262 }
263 
TEST(MLDSATest,InvalidPublicKeyEncodingLength)264 TEST(MLDSATest, InvalidPublicKeyEncodingLength) {
265   // Encode a public key with a trailing 0 at the end.
266   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES + 1);
267   auto priv = std::make_unique<MLDSA65_private_key>();
268   uint8_t seed[MLDSA_SEED_BYTES];
269   EXPECT_TRUE(
270       MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
271 
272   // Public key is 1 byte too short.
273   CBS cbs =
274       CBS(bssl::Span(encoded_public_key).first(MLDSA65_PUBLIC_KEY_BYTES - 1));
275   auto parsed_pub = std::make_unique<MLDSA65_public_key>();
276   EXPECT_FALSE(MLDSA65_parse_public_key(parsed_pub.get(), &cbs));
277 
278   // Public key has the correct length.
279   cbs = CBS(bssl::Span(encoded_public_key).first(MLDSA65_PUBLIC_KEY_BYTES));
280   EXPECT_TRUE(MLDSA65_parse_public_key(parsed_pub.get(), &cbs));
281 
282   // Public key is 1 byte too long.
283   cbs = CBS(encoded_public_key);
284   EXPECT_FALSE(MLDSA65_parse_public_key(parsed_pub.get(), &cbs));
285 }
286 
TEST(MLDSATest,InvalidPrivateKeyEncodingLength)287 TEST(MLDSATest, InvalidPrivateKeyEncodingLength) {
288   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
289   auto priv = std::make_unique<BCM_mldsa65_private_key>();
290   uint8_t seed[MLDSA_SEED_BYTES];
291   EXPECT_TRUE(bcm_success(
292       BCM_mldsa65_generate_key(encoded_public_key.data(), seed, priv.get())));
293 
294   CBB cbb;
295   std::vector<uint8_t> malformed_private_key(MLDSA65_PRIVATE_KEY_BYTES + 1, 0);
296   CBB_init_fixed(&cbb, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES);
297   ASSERT_TRUE(bcm_success(BCM_mldsa65_marshal_private_key(
298       &cbb, reinterpret_cast<BCM_mldsa65_private_key *>(priv.get()))));
299 
300   CBS cbs;
301   auto parsed_priv = std::make_unique<BCM_mldsa65_private_key>();
302 
303   // Private key is 1 byte too short.
304   CBS_init(&cbs, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES - 1);
305   EXPECT_FALSE(
306       bcm_success(BCM_mldsa65_parse_private_key(parsed_priv.get(), &cbs)));
307 
308   // Private key has the correct length.
309   CBS_init(&cbs, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES);
310   EXPECT_TRUE(
311       bcm_success(BCM_mldsa65_parse_private_key(parsed_priv.get(), &cbs)));
312 
313   // Private key is 1 byte too long.
314   CBS_init(&cbs, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES + 1);
315   EXPECT_FALSE(
316       bcm_success(BCM_mldsa65_parse_private_key(parsed_priv.get(), &cbs)));
317 }
318 
319 template <typename PrivateKey, typename PublicKey, size_t SignatureBytes,
320           bcm_status (*ParsePrivateKey)(PrivateKey *, CBS *),
321           bcm_status (*SignInternal)(uint8_t *, const PrivateKey *,
322                                      const uint8_t *, size_t, const uint8_t *,
323                                      size_t, const uint8_t *, size_t,
324                                      const uint8_t *),
325           bcm_status (*PublicFromPrivate)(PublicKey *, const PrivateKey *),
326           bcm_status (*VerifyInternal)(const PublicKey *, const uint8_t *,
327                                        const uint8_t *, size_t, const uint8_t *,
328                                        size_t, const uint8_t *, size_t)>
MLDSASigGenTest(FileTest * t)329 static void MLDSASigGenTest(FileTest *t) {
330   std::vector<uint8_t> private_key_bytes, msg, expected_signature;
331   ASSERT_TRUE(t->GetBytes(&private_key_bytes, "sk"));
332   ASSERT_TRUE(t->GetBytes(&msg, "message"));
333   ASSERT_TRUE(t->GetBytes(&expected_signature, "signature"));
334 
335   auto priv = std::make_unique<PrivateKey>();
336   CBS cbs;
337   CBS_init(&cbs, private_key_bytes.data(), private_key_bytes.size());
338   EXPECT_TRUE(bcm_success(ParsePrivateKey(priv.get(), &cbs)));
339 
340   const uint8_t zero_randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {0};
341   std::vector<uint8_t> signature(SignatureBytes);
342   EXPECT_TRUE(bcm_success(SignInternal(signature.data(), priv.get(), msg.data(),
343                                        msg.size(), nullptr, 0, nullptr, 0,
344                                        zero_randomizer)));
345 
346   EXPECT_EQ(Bytes(signature), Bytes(expected_signature));
347 
348   auto pub = std::make_unique<PublicKey>();
349   ASSERT_TRUE(bcm_success(PublicFromPrivate(pub.get(), priv.get())));
350   EXPECT_TRUE(
351       bcm_success(VerifyInternal(pub.get(), signature.data(), msg.data(),
352                                  msg.size(), nullptr, 0, nullptr, 0)));
353 }
354 
TEST(MLDSATest,SigGenTests65)355 TEST(MLDSATest, SigGenTests65) {
356   FileTestGTest(
357       "crypto/mldsa/mldsa_nist_siggen_65_tests.txt",
358       MLDSASigGenTest<BCM_mldsa65_private_key, BCM_mldsa65_public_key,
359                       MLDSA65_SIGNATURE_BYTES, BCM_mldsa65_parse_private_key,
360                       BCM_mldsa65_sign_internal,
361                       BCM_mldsa65_public_from_private,
362                       BCM_mldsa65_verify_internal>);
363 }
364 
TEST(MLDSATest,SigGenTests87)365 TEST(MLDSATest, SigGenTests87) {
366   FileTestGTest(
367       "crypto/mldsa/mldsa_nist_siggen_87_tests.txt",
368       MLDSASigGenTest<BCM_mldsa87_private_key, BCM_mldsa87_public_key,
369                       BCM_MLDSA87_SIGNATURE_BYTES,
370                       BCM_mldsa87_parse_private_key, BCM_mldsa87_sign_internal,
371                       BCM_mldsa87_public_from_private,
372                       BCM_mldsa87_verify_internal>);
373 }
374 
TEST(MLDSATest,SigGenTests44)375 TEST(MLDSATest, SigGenTests44) {
376   FileTestGTest(
377       "crypto/mldsa/mldsa_nist_siggen_44_tests.txt",
378       MLDSASigGenTest<BCM_mldsa44_private_key, BCM_mldsa44_public_key,
379                       BCM_MLDSA44_SIGNATURE_BYTES,
380                       BCM_mldsa44_parse_private_key, BCM_mldsa44_sign_internal,
381                       BCM_mldsa44_public_from_private,
382                       BCM_mldsa44_verify_internal>);
383 }
384 
385 template <typename PrivateKey, size_t PublicKeyBytes,
386           bcm_status (*Generate)(uint8_t *, PrivateKey *, const uint8_t *),
387           bcm_status (*MarshalPrivate)(CBB *, const PrivateKey *)>
MLDSAKeyGenTest(FileTest * t)388 static void MLDSAKeyGenTest(FileTest *t) {
389   std::vector<uint8_t> seed, expected_public_key, expected_private_key;
390   ASSERT_TRUE(t->GetBytes(&seed, "seed"));
391   CONSTTIME_SECRET(seed.data(), seed.size());
392   ASSERT_TRUE(t->GetBytes(&expected_public_key, "pub"));
393   ASSERT_TRUE(t->GetBytes(&expected_private_key, "priv"));
394 
395   std::vector<uint8_t> encoded_public_key(PublicKeyBytes);
396   auto priv = std::make_unique<PrivateKey>();
397   ASSERT_TRUE(bcm_success(
398       Generate(encoded_public_key.data(), priv.get(), seed.data())));
399 
400   const std::vector<uint8_t> encoded_private_key =
401       Marshal(MarshalPrivate, priv.get());
402 
403   EXPECT_EQ(Bytes(encoded_public_key), Bytes(expected_public_key));
404   EXPECT_EQ(Bytes(Declassified(encoded_private_key)),
405             Bytes(expected_private_key));
406 }
407 
TEST(MLDSATest,KeyGenTests65)408 TEST(MLDSATest, KeyGenTests65) {
409   FileTestGTest(
410       "crypto/mldsa/mldsa_nist_keygen_65_tests.txt",
411       MLDSAKeyGenTest<BCM_mldsa65_private_key, MLDSA65_PUBLIC_KEY_BYTES,
412                       BCM_mldsa65_generate_key_external_entropy,
413                       BCM_mldsa65_marshal_private_key>);
414 }
415 
TEST(MLDSATest,KeyGenTests87)416 TEST(MLDSATest, KeyGenTests87) {
417   FileTestGTest(
418       "crypto/mldsa/mldsa_nist_keygen_87_tests.txt",
419       MLDSAKeyGenTest<BCM_mldsa87_private_key, BCM_MLDSA87_PUBLIC_KEY_BYTES,
420                       BCM_mldsa87_generate_key_external_entropy,
421                       BCM_mldsa87_marshal_private_key>);
422 }
423 
TEST(MLDSATest,KeyGenTests44)424 TEST(MLDSATest, KeyGenTests44) {
425   FileTestGTest(
426       "crypto/mldsa/mldsa_nist_keygen_44_tests.txt",
427       MLDSAKeyGenTest<BCM_mldsa44_private_key, BCM_MLDSA44_PUBLIC_KEY_BYTES,
428                       BCM_mldsa44_generate_key_external_entropy,
429                       BCM_mldsa44_marshal_private_key>);
430 }
431 
432 template <
433     typename PrivateKey, bcm_status_t (*ParsePrivateKey)(PrivateKey *, CBS *),
434     size_t SignatureBytes,
435     bcm_status_t (*SignInternal)(uint8_t *, const PrivateKey *, const uint8_t *,
436                                  size_t, const uint8_t *, size_t,
437                                  const uint8_t *, size_t, const uint8_t *)>
MLDSAWycheproofSignTest(FileTest * t)438 static void MLDSAWycheproofSignTest(FileTest *t) {
439   std::vector<uint8_t> private_key_bytes, msg, expected_signature, context;
440   ASSERT_TRUE(t->GetInstructionBytes(&private_key_bytes, "privateKey"));
441   ASSERT_TRUE(t->GetBytes(&msg, "msg"));
442   ASSERT_TRUE(t->GetBytes(&expected_signature, "sig"));
443   if (t->HasAttribute("ctx")) {
444     t->GetBytes(&context, "ctx");
445   }
446   std::string result;
447   ASSERT_TRUE(t->GetAttribute(&result, "result"));
448   t->IgnoreAttribute("flags");
449 
450   CBS cbs;
451   CBS_init(&cbs, private_key_bytes.data(), private_key_bytes.size());
452   auto priv = std::make_unique<PrivateKey>();
453   const int priv_ok = bcm_success(ParsePrivateKey(priv.get(), &cbs));
454 
455   if (!priv_ok) {
456     ASSERT_TRUE(result != "valid");
457     return;
458   }
459 
460   // Unfortunately we need to reimplement the context length check here because
461   // we are using the internal function in order to pass in an all-zero
462   // randomizer.
463   if (context.size() > 255) {
464     ASSERT_TRUE(result != "valid");
465     return;
466   }
467 
468   const uint8_t zero_randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {0};
469   std::vector<uint8_t> signature(SignatureBytes);
470   const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context.size())};
471   EXPECT_TRUE(bcm_success(SignInternal(signature.data(), priv.get(), msg.data(),
472                                        msg.size(), context_prefix,
473                                        sizeof(context_prefix), context.data(),
474                                        context.size(), zero_randomizer)));
475 
476   EXPECT_EQ(Bytes(signature), Bytes(expected_signature));
477 }
478 
TEST(MLDSATest,WycheproofSignTests65)479 TEST(MLDSATest, WycheproofSignTests65) {
480   FileTestGTest(
481       "third_party/wycheproof_testvectors/mldsa_65_standard_sign_test.txt",
482       MLDSAWycheproofSignTest<
483           BCM_mldsa65_private_key, BCM_mldsa65_parse_private_key,
484           MLDSA65_SIGNATURE_BYTES, BCM_mldsa65_sign_internal>);
485 }
486 
TEST(MLDSATest,WycheproofSignTests87)487 TEST(MLDSATest, WycheproofSignTests87) {
488   FileTestGTest(
489       "third_party/wycheproof_testvectors/mldsa_87_standard_sign_test.txt",
490       MLDSAWycheproofSignTest<
491           BCM_mldsa87_private_key, BCM_mldsa87_parse_private_key,
492           BCM_MLDSA87_SIGNATURE_BYTES, BCM_mldsa87_sign_internal>);
493 }
494 
TEST(MLDSATest,WycheproofSignTests44)495 TEST(MLDSATest, WycheproofSignTests44) {
496   FileTestGTest(
497       "third_party/wycheproof_testvectors/mldsa_44_standard_sign_test.txt",
498       MLDSAWycheproofSignTest<
499           BCM_mldsa44_private_key, BCM_mldsa44_parse_private_key,
500           BCM_MLDSA44_SIGNATURE_BYTES, BCM_mldsa44_sign_internal>);
501 }
502 
503 template <typename PublicKey, size_t SignatureLength,
504           bcm_status_t (*ParsePublicKey)(PublicKey *, CBS *),
505           bcm_status_t (*Verify)(const PublicKey *, const uint8_t *,
506                                  const uint8_t *, size_t, const uint8_t *,
507                                  size_t)>
MLDSAWycheproofVerifyTest(FileTest * t)508 static void MLDSAWycheproofVerifyTest(FileTest *t) {
509   std::vector<uint8_t> public_key_bytes, msg, signature, context;
510   ASSERT_TRUE(t->GetInstructionBytes(&public_key_bytes, "publicKey"));
511   ASSERT_TRUE(t->GetBytes(&msg, "msg"));
512   ASSERT_TRUE(t->GetBytes(&signature, "sig"));
513   if (t->HasAttribute("ctx")) {
514     t->GetBytes(&context, "ctx");
515   }
516   std::string result, flags;
517   ASSERT_TRUE(t->GetAttribute(&result, "result"));
518   ASSERT_TRUE(t->GetAttribute(&flags, "flags"));
519 
520   CBS cbs;
521   CBS_init(&cbs, public_key_bytes.data(), public_key_bytes.size());
522   auto pub = std::make_unique<PublicKey>();
523   const int pub_ok = bcm_success(ParsePublicKey(pub.get(), &cbs));
524 
525   if (!pub_ok) {
526     EXPECT_EQ(flags, "IncorrectPublicKeyLength");
527     return;
528   }
529 
530   const int sig_ok =
531       signature.size() == SignatureLength && context.size() <= 255 &&
532       bcm_success(Verify(pub.get(), signature.data(), msg.data(), msg.size(),
533                          context.data(), context.size()));
534   if (!sig_ok) {
535     EXPECT_EQ(result, "invalid");
536   } else {
537     EXPECT_EQ(result, "valid");
538   }
539 }
540 
TEST(MLDSATest,WycheproofVerifyTests65)541 TEST(MLDSATest, WycheproofVerifyTests65) {
542   FileTestGTest(
543       "third_party/wycheproof_testvectors/mldsa_65_standard_verify_test.txt",
544       MLDSAWycheproofVerifyTest<
545           BCM_mldsa65_public_key, BCM_MLDSA65_SIGNATURE_BYTES,
546           BCM_mldsa65_parse_public_key, BCM_mldsa65_verify>);
547 }
548 
TEST(MLDSATest,WycheproofVerifyTests87)549 TEST(MLDSATest, WycheproofVerifyTests87) {
550   FileTestGTest(
551       "third_party/wycheproof_testvectors/mldsa_87_standard_verify_test.txt",
552       MLDSAWycheproofVerifyTest<
553           BCM_mldsa87_public_key, BCM_MLDSA87_SIGNATURE_BYTES,
554           BCM_mldsa87_parse_public_key, BCM_mldsa87_verify>);
555 }
556 
TEST(MLDSATest,WycheproofVerifyTests44)557 TEST(MLDSATest, WycheproofVerifyTests44) {
558   FileTestGTest(
559       "third_party/wycheproof_testvectors/mldsa_44_standard_verify_test.txt",
560       MLDSAWycheproofVerifyTest<
561           BCM_mldsa44_public_key, BCM_MLDSA44_SIGNATURE_BYTES,
562           BCM_mldsa44_parse_public_key, BCM_mldsa44_verify>);
563 }
564 
TEST(MLDSATest,Self)565 TEST(MLDSATest, Self) { ASSERT_TRUE(boringssl_self_test_mldsa()); }
566 
TEST(MLDSATest,PWCT)567 TEST(MLDSATest, PWCT) {
568   uint8_t seed[BCM_MLDSA_SEED_BYTES];
569 
570   auto pub65 = std::make_unique<uint8_t[]>(BCM_MLDSA65_PUBLIC_KEY_BYTES);
571   auto priv65 = std::make_unique<BCM_mldsa65_private_key>();
572   ASSERT_EQ(BCM_mldsa65_generate_key_fips(pub65.get(), seed, priv65.get()),
573             bcm_status::approved);
574 
575   auto pub87 = std::make_unique<uint8_t[]>(BCM_MLDSA87_PUBLIC_KEY_BYTES);
576   auto priv87 = std::make_unique<BCM_mldsa87_private_key>();
577   ASSERT_EQ(BCM_mldsa87_generate_key_fips(pub87.get(), seed, priv87.get()),
578             bcm_status::approved);
579 
580   auto pub44 = std::make_unique<uint8_t[]>(BCM_MLDSA44_PUBLIC_KEY_BYTES);
581   auto priv44 = std::make_unique<BCM_mldsa44_private_key>();
582   ASSERT_EQ(BCM_mldsa44_generate_key_fips(pub44.get(), seed, priv44.get()),
583             bcm_status::approved);
584 }
585 
TEST(MLDSATest,NullptrArgumentsToCreate)586 TEST(MLDSATest, NullptrArgumentsToCreate) {
587   // For FIPS reasons, this should fail rather than crash.
588   ASSERT_EQ(BCM_mldsa65_generate_key_fips(nullptr, nullptr, nullptr),
589             bcm_status::failure);
590   ASSERT_EQ(BCM_mldsa87_generate_key_fips(nullptr, nullptr, nullptr),
591             bcm_status::failure);
592   ASSERT_EQ(BCM_mldsa44_generate_key_fips(nullptr, nullptr, nullptr),
593             bcm_status::failure);
594   ASSERT_EQ(
595       BCM_mldsa65_generate_key_external_entropy_fips(nullptr, nullptr, nullptr),
596       bcm_status::failure);
597   ASSERT_EQ(
598       BCM_mldsa87_generate_key_external_entropy_fips(nullptr, nullptr, nullptr),
599       bcm_status::failure);
600   ASSERT_EQ(
601       BCM_mldsa44_generate_key_external_entropy_fips(nullptr, nullptr, nullptr),
602       bcm_status::failure);
603 }
604 
605 }  // namespace
606