1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include <assert.h>
11 #include "ml_dsa_poly.h"
12 
13 struct vector_st {
14     POLY *poly;
15     size_t num_poly;
16 };
17 
18 /**
19  * @brief Initialize a Vector object.
20  *
21  * @param v The vector to initialize.
22  * @param polys Preallocated storage for an array of Polynomials blocks. |v|
23  *              does not own/free this.
24  * @param num_polys The number of |polys| blocks (k or l)
25  */
26 static ossl_inline ossl_unused
vector_init(VECTOR * v,POLY * polys,size_t num_polys)27 void vector_init(VECTOR *v, POLY *polys, size_t num_polys)
28 {
29     v->poly = polys;
30     v->num_poly = num_polys;
31 }
32 
33 static ossl_inline ossl_unused
vector_alloc(VECTOR * v,size_t num_polys)34 int vector_alloc(VECTOR *v, size_t num_polys)
35 {
36     v->poly = OPENSSL_malloc_array(num_polys, sizeof(POLY));
37     if (v->poly == NULL)
38         return 0;
39     v->num_poly = num_polys;
40     return 1;
41 }
42 
43 static ossl_inline ossl_unused
vector_secure_alloc(VECTOR * v,size_t num_polys)44 int vector_secure_alloc(VECTOR *v, size_t num_polys)
45 {
46     v->poly = OPENSSL_secure_malloc_array(num_polys, sizeof(POLY));
47     if (v->poly == NULL)
48         return 0;
49     v->num_poly = num_polys;
50     return 1;
51 }
52 
53 static ossl_inline ossl_unused
vector_free(VECTOR * v)54 void vector_free(VECTOR *v)
55 {
56     OPENSSL_free(v->poly);
57     v->poly = NULL;
58     v->num_poly = 0;
59 }
60 
61 static ossl_inline ossl_unused
vector_secure_free(VECTOR * v,size_t rank)62 void vector_secure_free(VECTOR *v, size_t rank)
63 {
64     OPENSSL_secure_clear_free(v->poly, rank * sizeof(POLY));
65     v->poly = NULL;
66     v->num_poly = 0;
67 }
68 
69 /* @brief zeroize a vectors polynomial coefficients */
70 static ossl_inline ossl_unused
vector_zero(VECTOR * va)71 void vector_zero(VECTOR *va)
72 {
73     if (va->poly != NULL)
74         memset(va->poly, 0, va->num_poly * sizeof(va->poly[0]));
75 }
76 
77 /*
78  * @brief copy a vector
79  * The assumption is that |dst| has already been initialized
80  */
81 static ossl_inline ossl_unused void
vector_copy(VECTOR * dst,const VECTOR * src)82 vector_copy(VECTOR *dst, const VECTOR *src)
83 {
84     assert(dst->num_poly == src->num_poly);
85     memcpy(dst->poly, src->poly, src->num_poly * sizeof(src->poly[0]));
86 }
87 
88 /* @brief return 1 if 2 vectors are equal, or 0 otherwise */
89 static ossl_inline ossl_unused int
vector_equal(const VECTOR * a,const VECTOR * b)90 vector_equal(const VECTOR *a, const VECTOR *b)
91 {
92     size_t i;
93 
94     if (a->num_poly != b->num_poly)
95         return 0;
96     for (i = 0; i < a->num_poly; ++i) {
97         if (!poly_equal(a->poly + i, b->poly + i))
98             return 0;
99     }
100     return 1;
101 }
102 
103 /* @brief add 2 vectors */
104 static ossl_inline ossl_unused void
vector_add(const VECTOR * lhs,const VECTOR * rhs,VECTOR * out)105 vector_add(const VECTOR *lhs, const VECTOR *rhs, VECTOR *out)
106 {
107     size_t i;
108 
109     for (i = 0; i < lhs->num_poly; i++)
110         poly_add(lhs->poly + i, rhs->poly + i, out->poly + i);
111 }
112 
113 /* @brief subtract 2 vectors */
114 static ossl_inline ossl_unused void
vector_sub(const VECTOR * lhs,const VECTOR * rhs,VECTOR * out)115 vector_sub(const VECTOR *lhs, const VECTOR *rhs, VECTOR *out)
116 {
117     size_t i;
118 
119     for (i = 0; i < lhs->num_poly; i++)
120         poly_sub(lhs->poly + i, rhs->poly + i, out->poly + i);
121 }
122 
123 /* @brief convert a vector in place into NTT form */
124 static ossl_inline ossl_unused void
vector_ntt(VECTOR * va)125 vector_ntt(VECTOR *va)
126 {
127     size_t i;
128 
129     for (i = 0; i < va->num_poly; i++)
130         ossl_ml_dsa_poly_ntt(va->poly + i);
131 }
132 
133 /* @brief convert a vector in place into inverse NTT form */
134 static ossl_inline ossl_unused void
vector_ntt_inverse(VECTOR * va)135 vector_ntt_inverse(VECTOR *va)
136 {
137     size_t i;
138 
139     for (i = 0; i < va->num_poly; i++)
140         ossl_ml_dsa_poly_ntt_inverse(va->poly + i);
141 }
142 
143 /* @brief multiply a vector by a SCALAR polynomial */
144 static ossl_inline ossl_unused void
vector_mult_scalar(const VECTOR * lhs,const POLY * rhs,VECTOR * out)145 vector_mult_scalar(const VECTOR *lhs, const POLY *rhs, VECTOR *out)
146 {
147     size_t i;
148 
149     for (i = 0; i < lhs->num_poly; i++)
150         ossl_ml_dsa_poly_ntt_mult(lhs->poly + i, rhs, out->poly + i);
151 }
152 
153 static ossl_inline ossl_unused int
vector_expand_S(EVP_MD_CTX * h_ctx,const EVP_MD * md,int eta,const uint8_t * seed,VECTOR * s1,VECTOR * s2)154 vector_expand_S(EVP_MD_CTX *h_ctx, const EVP_MD *md, int eta,
155                 const uint8_t *seed, VECTOR *s1, VECTOR *s2)
156 {
157     return ossl_ml_dsa_vector_expand_S(h_ctx, md, eta, seed, s1, s2);
158 }
159 
160 static ossl_inline ossl_unused void
vector_expand_mask(VECTOR * out,const uint8_t * rho_prime,size_t rho_prime_len,uint32_t kappa,uint32_t gamma1,EVP_MD_CTX * h_ctx,const EVP_MD * md)161 vector_expand_mask(VECTOR *out, const uint8_t *rho_prime, size_t rho_prime_len,
162                    uint32_t kappa, uint32_t gamma1,
163                    EVP_MD_CTX *h_ctx, const EVP_MD *md)
164 {
165     size_t i;
166     uint8_t derived_seed[ML_DSA_RHO_PRIME_BYTES + 2];
167 
168     memcpy(derived_seed, rho_prime, ML_DSA_RHO_PRIME_BYTES);
169 
170     for (i = 0; i < out->num_poly; i++) {
171         size_t index = kappa + i;
172 
173         derived_seed[ML_DSA_RHO_PRIME_BYTES] = index & 0xFF;
174         derived_seed[ML_DSA_RHO_PRIME_BYTES + 1] = (index >> 8) & 0xFF;
175         poly_expand_mask(out->poly + i, derived_seed, sizeof(derived_seed),
176                          gamma1, h_ctx, md);
177     }
178 }
179 
180 /* Scale back previously rounded value */
181 static ossl_inline ossl_unused void
vector_scale_power2_round_ntt(const VECTOR * in,VECTOR * out)182 vector_scale_power2_round_ntt(const VECTOR *in, VECTOR *out)
183 {
184     size_t i;
185 
186     for (i = 0; i < in->num_poly; i++)
187         poly_scale_power2_round(in->poly + i, out->poly + i);
188     vector_ntt(out);
189 }
190 
191 /*
192  * @brief Decompose all polynomial coefficients of a vector into (t1, t0) such
193  * that coeff[i] == t1[i] * 2^13 + t0[i] mod q.
194  * See FIPS 204, Algorithm 35, Power2Round()
195  */
196 static ossl_inline ossl_unused void
vector_power2_round(const VECTOR * t,VECTOR * t1,VECTOR * t0)197 vector_power2_round(const VECTOR *t, VECTOR *t1, VECTOR *t0)
198 {
199     size_t i;
200 
201     for (i = 0; i < t->num_poly; i++)
202         poly_power2_round(t->poly + i, t1->poly + i, t0->poly + i);
203 }
204 
205 static ossl_inline ossl_unused void
vector_high_bits(const VECTOR * in,uint32_t gamma2,VECTOR * out)206 vector_high_bits(const VECTOR *in, uint32_t gamma2, VECTOR *out)
207 {
208     size_t i;
209 
210     for (i = 0; i < out->num_poly; i++)
211         poly_high_bits(in->poly + i, gamma2, out->poly + i);
212 }
213 
214 static ossl_inline ossl_unused void
vector_low_bits(const VECTOR * in,uint32_t gamma2,VECTOR * out)215 vector_low_bits(const VECTOR *in, uint32_t gamma2, VECTOR *out)
216 {
217     size_t i;
218 
219     for (i = 0; i < out->num_poly; i++)
220         poly_low_bits(in->poly + i, gamma2, out->poly + i);
221 }
222 
223 static ossl_inline ossl_unused uint32_t
vector_max(const VECTOR * v)224 vector_max(const VECTOR *v)
225 {
226     size_t i;
227     uint32_t mx = 0;
228 
229     for (i = 0; i < v->num_poly; i++)
230         poly_max(v->poly + i, &mx);
231     return mx;
232 }
233 
234 static ossl_inline ossl_unused uint32_t
vector_max_signed(const VECTOR * v)235 vector_max_signed(const VECTOR *v)
236 {
237     size_t i;
238     uint32_t mx = 0;
239 
240     for (i = 0; i < v->num_poly; i++)
241         poly_max_signed(v->poly + i, &mx);
242     return mx;
243 }
244 
245 static ossl_inline ossl_unused size_t
vector_count_ones(const VECTOR * v)246 vector_count_ones(const VECTOR *v)
247 {
248     int j;
249     size_t i, count = 0;
250 
251     for (i = 0; i < v->num_poly; i++)
252         for (j = 0; j < ML_DSA_NUM_POLY_COEFFICIENTS; j++)
253             count += v->poly[i].coeff[j];
254     return count;
255 }
256 
257 static ossl_inline ossl_unused void
vector_make_hint(const VECTOR * ct0,const VECTOR * cs2,const VECTOR * w,uint32_t gamma2,VECTOR * out)258 vector_make_hint(const VECTOR *ct0, const VECTOR *cs2, const VECTOR *w,
259                  uint32_t gamma2, VECTOR *out)
260 {
261     size_t i;
262 
263     for (i = 0; i < out->num_poly; i++)
264         poly_make_hint(ct0->poly + i, cs2->poly + i, w->poly + i, gamma2,
265                        out->poly + i);
266 }
267 
268 static ossl_inline ossl_unused void
vector_use_hint(const VECTOR * h,const VECTOR * r,uint32_t gamma2,VECTOR * out)269 vector_use_hint(const VECTOR *h, const VECTOR *r, uint32_t gamma2, VECTOR *out)
270 {
271     size_t i;
272 
273     for (i = 0; i < out->num_poly; i++)
274         poly_use_hint(h->poly + i, r->poly + i, gamma2, out->poly + i);
275 }
276