1 // Copyright 2024 The BoringSSL Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <openssl/base.h>
16 
17 #include <string>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21 
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <string.h>
25 
26 #include <openssl/bn.h>
27 #include <openssl/span.h>
28 
29 #include <gtest/gtest.h>
30 
31 #include "../internal.h"
32 #include "../test/test_util.h"
33 #include "./internal.h"
34 
35 
36 BSSL_NAMESPACE_BEGIN
37 
38 namespace {
39 
40 using namespace spake2plus;
41 
HexToBytes(const char * str)42 std::vector<uint8_t> HexToBytes(const char *str) {
43   std::vector<uint8_t> ret;
44   if (!DecodeHex(&ret, str)) {
45     abort();
46   }
47   return ret;
48 }
49 
50 class RegistrationCache {
51  public:
52   struct Result {
53     std::vector<uint8_t> w0, w1, record;
54   };
55 
Get(const std::pair<std::string,std::string> & names,const std::string & pw)56   const Result &Get(const std::pair<std::string, std::string> &names,
57                     const std::string &pw) {
58     CacheKey key{names.first, names.second, pw};
59 
60     auto it = cache.find(key);
61     if (it != cache.end()) {
62       return it->second;
63     }
64 
65     Result output;
66     output.w0.resize(kVerifierSize);
67     output.w1.resize(kVerifierSize);
68     output.record.resize(kRegistrationRecordSize);
69 
70     if (!Register(Span(output.w0), Span(output.w1), Span(output.record),
71                   StringAsBytes(pw), StringAsBytes(names.first),
72                   StringAsBytes(names.second))) {
73       abort();
74     }
75 
76     return cache.emplace(std::move(key), std::move(output)).first->second;
77   }
78 
79  private:
80   struct CacheKey {
81     std::string id_prover, id_verifier, password;
82 
operator ==__anone741afaf0111::RegistrationCache::CacheKey83     bool operator==(const CacheKey &other) const {
84       return std::tie(id_prover, id_verifier, password) ==
85              std::tie(other.id_prover, other.id_verifier, other.password);
86     }
87   };
88 
89   struct KeyHash {
operator ()__anone741afaf0111::RegistrationCache::KeyHash90     std::size_t operator()(const CacheKey &k) const {
91       return std::hash<std::string>()(k.id_prover) ^
92              std::hash<std::string>()(k.id_verifier) ^
93              std::hash<std::string>()(k.password);
94     }
95   };
96 
97   std::unordered_map<CacheKey, Result, KeyHash> cache;
98 };
99 
GlobalRegistrationCache()100 RegistrationCache &GlobalRegistrationCache() {
101   static RegistrationCache cache;
102   return cache;
103 }
104 
105 struct SPAKEPLUSRun {
Run__anone741afaf0111::SPAKEPLUSRun106   bool Run() {
107     const RegistrationCache::Result &registration =
108         GlobalRegistrationCache().Get(prover_names, pw);
109 
110     Prover prover;
111     if (!prover.Init(StringAsBytes(context), StringAsBytes(prover_names.first),
112                      StringAsBytes(prover_names.second), registration.w0,
113                      registration.w1)) {
114       return false;
115     }
116 
117     std::vector<uint8_t> verifier_registration_record = registration.record;
118     if (verifier_corrupt_record) {
119       verifier_registration_record[verifier_registration_record.size() - 1] ^=
120           0xFF;
121     }
122 
123     Verifier verifier;
124     if (!verifier.Init(StringAsBytes(context),
125                        StringAsBytes(verifier_names.first),
126                        StringAsBytes(verifier_names.second), registration.w0,
127                        verifier_registration_record)) {
128       return false;
129     }
130 
131     uint8_t prover_share[kShareSize];
132     if (!prover.GenerateShare(prover_share)) {
133       return false;
134     }
135 
136     if (repeat_invocations && prover.GenerateShare(prover_share)) {
137       return false;
138     }
139 
140     if (prover_corrupt_msg_bit &&
141         *prover_corrupt_msg_bit < 8 * sizeof(prover_share)) {
142       prover_share[*prover_corrupt_msg_bit / 8] ^=
143           1 << (*prover_corrupt_msg_bit & 7);
144     }
145 
146     uint8_t verifier_share[kShareSize];
147     uint8_t verifier_confirm[kConfirmSize];
148     uint8_t verifier_secret[kSecretSize];
149     if (!verifier.ProcessProverShare(verifier_share, verifier_confirm,
150                                      verifier_secret, prover_share)) {
151       return false;
152     }
153 
154     if (repeat_invocations &&
155         verifier.ProcessProverShare(verifier_share, verifier_confirm,
156                                     verifier_secret, prover_share)) {
157       return false;
158     }
159 
160     uint8_t prover_confirm[kConfirmSize];
161     uint8_t prover_secret[kSecretSize];
162     if (!prover.ComputeConfirmation(prover_confirm, prover_secret,
163                                     verifier_share, verifier_confirm)) {
164       return false;
165     }
166 
167     if (repeat_invocations &&  //
168         prover.ComputeConfirmation(prover_confirm, prover_secret,
169                                    verifier_share, verifier_confirm)) {
170       return false;
171     }
172 
173     if (prover_corrupt_confirm_bit &&
174         *prover_corrupt_confirm_bit < 8 * sizeof(prover_confirm)) {
175       prover_confirm[*prover_corrupt_confirm_bit / 8] ^=
176           1 << (*prover_corrupt_confirm_bit & 7);
177     }
178 
179     if (!verifier.VerifyProverConfirmation(prover_confirm)) {
180       return false;
181     }
182 
183     if (repeat_invocations &&
184         verifier.VerifyProverConfirmation(prover_confirm)) {
185       return false;
186     }
187 
188     key_matches_ = Span(prover_secret) == Span(verifier_secret);
189     return true;
190   }
191 
key_matches__anone741afaf0111::SPAKEPLUSRun192   bool key_matches() const { return key_matches_; }
193 
194   std::string context =
195       "SPAKE2+-P256-SHA256-HKDF-SHA256-HMAC-SHA256 Test Vectors";
196   std::string pw = "password";
197   std::pair<std::string, std::string> prover_names = {"client", "server"};
198   std::pair<std::string, std::string> verifier_names = {"client", "server"};
199   bool verifier_corrupt_record = false;
200   bool repeat_invocations = false;
201   std::optional<size_t> prover_corrupt_msg_bit;
202   std::optional<size_t> prover_corrupt_confirm_bit;
203 
204  private:
205   bool key_matches_ = false;
206 };
207 
TEST(SPAKEPLUSTest,TestVectors)208 TEST(SPAKEPLUSTest, TestVectors) {
209   // https://datatracker.ietf.org/doc/html/rfc9383#appendix-C
210   // SPAKE2+-P256-SHA256-HKDF-SHA256-HMAC-SHA256 Test Vectors
211   const char w0_str[] =
212       "bb8e1bbcf3c48f62c08db243652ae55d3e5586053fca77102994f23ad95491b3";
213   const char w1_str[] =
214       "7e945f34d78785b8a3ef44d0df5a1a97d6b3b460409a345ca7830387a74b1dba";
215   const char L_str[] =
216       "04eb7c9db3d9a9eb1f8adab81b5794c1f13ae3e225efbe91ea487425854c7fc00f00bfed"
217       "cbd09b2400142d40a14f2064ef31dfaa903b91d1faea7093d835966efd";
218   const char x_str[] =
219       "d1232c8e8693d02368976c174e2088851b8365d0d79a9eee709c6a05a2fad539";
220   const char share_p_str[] =
221       "04ef3bd051bf78a2234ec0df197f7828060fe9856503579bb1733009042c15c0c1de1277"
222       "27f418b5966afadfdd95a6e4591d171056b333dab97a79c7193e341727";
223   const char y_str[] =
224       "717a72348a182085109c8d3917d6c43d59b224dc6a7fc4f0483232fa6516d8b3";
225   const char share_v_str[] =
226       "04c0f65da0d11927bdf5d560c69e1d7d939a05b0e88291887d679fcadea75810fb5cc1ca"
227       "7494db39e82ff2f50665255d76173e09986ab46742c798a9a68437b048";
228   const char confirm_p_str[] =
229       "926cc713504b9b4d76c9162ded04b5493e89109f6d89462cd33adc46fda27527";
230   const char confirm_v_str[] =
231       "9747bcc4f8fe9f63defee53ac9b07876d907d55047e6ff2def2e7529089d3e68";
232   const char secret_str[] =
233       "0c5f8ccd1413423a54f6c1fb26ff01534a87f893779c6e68666d772bfd91f3e7";
234   const std::string context =
235       "SPAKE2+-P256-SHA256-HKDF-SHA256-HMAC-SHA256 Test Vectors";
236   const std::pair<std::string, std::string> prover_names = {"client", "server"};
237   const std::pair<std::string, std::string> verifier_names = {"client",
238                                                               "server"};
239 
240   std::vector<uint8_t> w0 = HexToBytes(w0_str);
241   std::vector<uint8_t> w1 = HexToBytes(w1_str);
242   std::vector<uint8_t> registration_record = HexToBytes(L_str);
243   std::vector<uint8_t> x = HexToBytes(x_str);
244   std::vector<uint8_t> y = HexToBytes(y_str);
245 
246   Prover prover;
247   ASSERT_TRUE(prover.Init(StringAsBytes(context),
248                           StringAsBytes(prover_names.first),
249                           StringAsBytes(prover_names.second), MakeConstSpan(w0),
250                           MakeConstSpan(w1), x));
251 
252   Verifier verifier;
253   ASSERT_TRUE(
254       verifier.Init(StringAsBytes(context), StringAsBytes(prover_names.first),
255                     StringAsBytes(prover_names.second), MakeConstSpan(w0),
256                     MakeConstSpan(registration_record), y));
257 
258   uint8_t prover_share[kShareSize];
259   ASSERT_TRUE(prover.GenerateShare(prover_share));
260 
261   std::vector<uint8_t> share_p = HexToBytes(share_p_str);
262   ASSERT_TRUE(
263       OPENSSL_memcmp(share_p.data(), prover_share, sizeof(prover_share)) == 0);
264 
265   uint8_t verifier_share[kShareSize];
266   uint8_t verifier_confirm[kConfirmSize];
267   uint8_t verifier_secret[kSecretSize];
268   ASSERT_TRUE(verifier.ProcessProverShare(verifier_share, verifier_confirm,
269                                           verifier_secret, prover_share));
270 
271   std::vector<uint8_t> share_v = HexToBytes(share_v_str);
272   ASSERT_TRUE(OPENSSL_memcmp(share_v.data(), verifier_share,
273                              sizeof(verifier_share)) == 0);
274   std::vector<uint8_t> confirm_v = HexToBytes(confirm_v_str);
275   ASSERT_TRUE(OPENSSL_memcmp(confirm_v.data(), verifier_confirm,
276                              sizeof(verifier_confirm)) == 0);
277 
278   uint8_t prover_confirm[kConfirmSize];
279   uint8_t prover_secret[kSecretSize];
280   ASSERT_TRUE(prover.ComputeConfirmation(prover_confirm, prover_secret,
281                                          verifier_share, verifier_confirm));
282 
283   std::vector<uint8_t> confirm_p = HexToBytes(confirm_p_str);
284   ASSERT_TRUE(OPENSSL_memcmp(confirm_p.data(), prover_confirm,
285                              sizeof(prover_confirm)) == 0);
286 
287   ASSERT_TRUE(verifier.VerifyProverConfirmation(prover_confirm));
288 
289   std::vector<uint8_t> expected_secret = HexToBytes(secret_str);
290   static_assert(sizeof(verifier_secret) == sizeof(prover_secret));
291   ASSERT_TRUE(OPENSSL_memcmp(prover_secret, verifier_secret,
292                              sizeof(prover_secret)) == 0);
293   ASSERT_TRUE(OPENSSL_memcmp(expected_secret.data(), verifier_secret,
294                              sizeof(verifier_secret)) == 0);
295 }
296 
TEST(SPAKEPLUSTest,SPAKEPLUS)297 TEST(SPAKEPLUSTest, SPAKEPLUS) {
298   for (unsigned i = 0; i < 20; i++) {
299     SPAKEPLUSRun spake2;
300     ASSERT_TRUE(spake2.Run());
301     EXPECT_TRUE(spake2.key_matches());
302   }
303 }
304 
TEST(SPAKEPLUSTest,WrongPassword)305 TEST(SPAKEPLUSTest, WrongPassword) {
306   SPAKEPLUSRun spake2;
307   spake2.verifier_corrupt_record = true;
308   ASSERT_FALSE(spake2.Run());
309 }
310 
TEST(SPAKEPLUSTest,WrongNames)311 TEST(SPAKEPLUSTest, WrongNames) {
312   SPAKEPLUSRun spake2;
313   spake2.prover_names.second = "alice";
314   spake2.verifier_names.second = "bob";
315   ASSERT_FALSE(spake2.Run());
316 }
317 
TEST(SPAKEPLUSTest,CorruptMessages)318 TEST(SPAKEPLUSTest, CorruptMessages) {
319   for (size_t i = 0; i < 8 * kShareSize; i++) {
320     SPAKEPLUSRun spake2;
321     spake2.prover_corrupt_msg_bit = i;
322     EXPECT_FALSE(spake2.Run())
323         << "Passed after corrupting Prover's key share message, bit " << i;
324   }
325 
326   for (size_t i = 0; i < 8 * kConfirmSize; i++) {
327     SPAKEPLUSRun spake2;
328     spake2.prover_corrupt_confirm_bit = i;
329     EXPECT_FALSE(spake2.Run())
330         << "Passed after corrupting Verifier's confirmation message, bit " << i;
331   }
332 }
333 
TEST(SPAKEPLUSTest,StateMachine)334 TEST(SPAKEPLUSTest, StateMachine) {
335   SPAKEPLUSRun spake2;
336   spake2.repeat_invocations = true;
337   ASSERT_TRUE(spake2.Run());
338 }
339 
340 }  // namespace
341 
342 BSSL_NAMESPACE_END
343