1package subprocess
2
3import (
4	"encoding/hex"
5	"encoding/json"
6	"fmt"
7	"strings"
8)
9
10const MLDSARandomizerLength = 32
11
12// Common top-level structure to parse mode
13type mldsaTestVectorSet struct {
14	Algorithm string `json:"algorithm"`
15	Mode      string `json:"mode"`
16	Revision  string `json:"revision"`
17}
18
19// Key generation specific structures
20type mldsaKeyGenTestVectorSet struct {
21	Algorithm string                 `json:"algorithm"`
22	Mode      string                 `json:"mode"`
23	Revision  string                 `json:"revision"`
24	Groups    []mldsaKeyGenTestGroup `json:"testGroups"`
25}
26
27type mldsaKeyGenTestGroup struct {
28	ID           uint64            `json:"tgId"`
29	TestType     string            `json:"testType"`
30	ParameterSet string            `json:"parameterSet"`
31	Tests        []mldsaKeyGenTest `json:"tests"`
32}
33
34type mldsaKeyGenTest struct {
35	ID   uint64 `json:"tcId"`
36	Seed string `json:"seed"`
37}
38
39type mldsaKeyGenTestGroupResponse struct {
40	ID    uint64                    `json:"tgId"`
41	Tests []mldsaKeyGenTestResponse `json:"tests"`
42}
43
44type mldsaKeyGenTestResponse struct {
45	ID         uint64 `json:"tcId"`
46	PublicKey  string `json:"pk"`
47	PrivateKey string `json:"sk"`
48}
49
50// Signature generation specific structures
51type mldsaSigGenTestVectorSet struct {
52	Algorithm string                 `json:"algorithm"`
53	Mode      string                 `json:"mode"`
54	Revision  string                 `json:"revision"`
55	Groups    []mldsaSigGenTestGroup `json:"testGroups"`
56}
57
58type mldsaSigGenTestGroup struct {
59	ID            uint64            `json:"tgId"`
60	TestType      string            `json:"testType"`
61	ParameterSet  string            `json:"parameterSet"`
62	Deterministic bool              `json:"deterministic"`
63	Tests         []mldsaSigGenTest `json:"tests"`
64}
65
66type mldsaSigGenTest struct {
67	ID         uint64 `json:"tcId"`
68	Message    string `json:"message"`
69	PrivateKey string `json:"sk"`
70	Randomizer string `json:"rnd"`
71}
72
73type mldsaSigGenTestGroupResponse struct {
74	ID    uint64                    `json:"tgId"`
75	Tests []mldsaSigGenTestResponse `json:"tests"`
76}
77
78type mldsaSigGenTestResponse struct {
79	ID        uint64 `json:"tcId"`
80	Signature string `json:"signature"`
81}
82
83// Signature verification specific structures
84type mldsaSigVerTestVectorSet struct {
85	Algorithm string                 `json:"algorithm"`
86	Mode      string                 `json:"mode"`
87	Revision  string                 `json:"revision"`
88	Groups    []mldsaSigVerTestGroup `json:"testGroups"`
89}
90
91type mldsaSigVerTestGroup struct {
92	ID           uint64            `json:"tgId"`
93	TestType     string            `json:"testType"`
94	ParameterSet string            `json:"parameterSet"`
95	Tests        []mldsaSigVerTest `json:"tests"`
96}
97
98type mldsaSigVerTest struct {
99	ID        uint64 `json:"tcId"`
100	PublicKey string `json:"pk"`
101	Message   string `json:"message"`
102	Signature string `json:"signature"`
103}
104
105type mldsaSigVerTestGroupResponse struct {
106	ID    uint64                    `json:"tgId"`
107	Tests []mldsaSigVerTestResponse `json:"tests"`
108}
109
110type mldsaSigVerTestResponse struct {
111	ID         uint64 `json:"tcId"`
112	TestPassed bool   `json:"testPassed"`
113}
114
115type mldsa struct{}
116
117func (m *mldsa) Process(vectorSet []byte, t Transactable) (any, error) {
118	// First parse just the common fields to get the mode
119	var common mldsaTestVectorSet
120	if err := json.Unmarshal(vectorSet, &common); err != nil {
121		return nil, fmt.Errorf("failed to unmarshal vector set: %v", err)
122	}
123
124	switch common.Mode {
125	case "keyGen":
126		return m.processKeyGen(vectorSet, t)
127	case "sigGen":
128		return m.processSigGen(vectorSet, t)
129	case "sigVer":
130		return m.processSigVer(vectorSet, t)
131	default:
132		return nil, fmt.Errorf("unsupported ML-DSA mode: %s", common.Mode)
133	}
134}
135
136func (m *mldsa) processKeyGen(vectorSet []byte, t Transactable) (any, error) {
137	var parsed mldsaKeyGenTestVectorSet
138	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
139		return nil, fmt.Errorf("failed to unmarshal keyGen vector set: %v", err)
140	}
141
142	var ret []mldsaKeyGenTestGroupResponse
143
144	for _, group := range parsed.Groups {
145		response := mldsaKeyGenTestGroupResponse{
146			ID: group.ID,
147		}
148
149		if !strings.HasPrefix(group.ParameterSet, "ML-DSA-") {
150			return nil, fmt.Errorf("invalid parameter set: %s", group.ParameterSet)
151		}
152		cmdName := group.ParameterSet + "/keyGen"
153
154		for _, test := range group.Tests {
155			seed, err := hex.DecodeString(test.Seed)
156			if err != nil {
157				return nil, fmt.Errorf("failed to decode seed in test case %d/%d: %s",
158					group.ID, test.ID, err)
159			}
160
161			result, err := t.Transact(cmdName, 2, seed)
162			if err != nil {
163				return nil, fmt.Errorf("key generation failed for test case %d/%d: %s",
164					group.ID, test.ID, err)
165			}
166
167			response.Tests = append(response.Tests, mldsaKeyGenTestResponse{
168				ID:         test.ID,
169				PublicKey:  hex.EncodeToString(result[0]),
170				PrivateKey: hex.EncodeToString(result[1]),
171			})
172		}
173
174		ret = append(ret, response)
175	}
176
177	return ret, nil
178}
179
180func (m *mldsa) processSigGen(vectorSet []byte, t Transactable) (any, error) {
181	var parsed mldsaSigGenTestVectorSet
182	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
183		return nil, fmt.Errorf("failed to unmarshal sigGen vector set: %v", err)
184	}
185
186	var ret []mldsaSigGenTestGroupResponse
187
188	for _, group := range parsed.Groups {
189		response := mldsaSigGenTestGroupResponse{
190			ID: group.ID,
191		}
192
193		if !strings.HasPrefix(group.ParameterSet, "ML-DSA-") {
194			return nil, fmt.Errorf("invalid parameter set: %s", group.ParameterSet)
195		}
196		cmdName := group.ParameterSet + "/sigGen"
197
198		for _, test := range group.Tests {
199			sk, err := hex.DecodeString(test.PrivateKey)
200			if err != nil {
201				return nil, fmt.Errorf("failed to decode private key in test case %d/%d: %s",
202					group.ID, test.ID, err)
203			}
204
205			msg, err := hex.DecodeString(test.Message)
206			if err != nil {
207				return nil, fmt.Errorf("failed to decode message in test case %d/%d: %s",
208					group.ID, test.ID, err)
209			}
210
211			var randomizer []byte
212			if group.Deterministic {
213				randomizer = make([]byte, MLDSARandomizerLength)
214			} else {
215				randomizer, err = hex.DecodeString(test.Randomizer)
216				if err != nil || len(randomizer) != MLDSARandomizerLength {
217					return nil, fmt.Errorf("failed to parse randomizer in test case %d/%d: %s", group.ID, test.ID, err)
218				}
219			}
220
221			result, err := t.Transact(cmdName, 1, sk, msg, randomizer)
222			if err != nil {
223				return nil, fmt.Errorf("signature generation failed for test case %d/%d: %s",
224					group.ID, test.ID, err)
225			}
226
227			response.Tests = append(response.Tests, mldsaSigGenTestResponse{
228				ID:        test.ID,
229				Signature: hex.EncodeToString(result[0]),
230			})
231		}
232
233		ret = append(ret, response)
234	}
235
236	return ret, nil
237}
238
239func (m *mldsa) processSigVer(vectorSet []byte, t Transactable) (any, error) {
240	var parsed mldsaSigVerTestVectorSet
241	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
242		return nil, fmt.Errorf("failed to unmarshal sigVer vector set: %v", err)
243	}
244
245	var ret []mldsaSigVerTestGroupResponse
246
247	for _, group := range parsed.Groups {
248		response := mldsaSigVerTestGroupResponse{
249			ID: group.ID,
250		}
251
252		if !strings.HasPrefix(group.ParameterSet, "ML-DSA-") {
253			return nil, fmt.Errorf("invalid parameter set: %s", group.ParameterSet)
254		}
255		cmdName := group.ParameterSet + "/sigVer"
256
257		for _, test := range group.Tests {
258			pk, err := hex.DecodeString(test.PublicKey)
259			if err != nil || len(pk) == 0 {
260				return nil, fmt.Errorf("failed to decode public key in test case %d/%d: %s",
261					group.ID, test.ID, err)
262			}
263
264			msg, err := hex.DecodeString(test.Message)
265			if err != nil {
266				return nil, fmt.Errorf("failed to decode message in test case %d/%d: %s",
267					group.ID, test.ID, err)
268			}
269
270			sig, err := hex.DecodeString(test.Signature)
271			if err != nil {
272				return nil, fmt.Errorf("failed to decode signature in test case %d/%d: %s",
273					group.ID, test.ID, err)
274			}
275
276			result, err := t.Transact(cmdName, 1, pk, msg, sig)
277			if err != nil {
278				return nil, fmt.Errorf("signature verification failed for test case %d/%d: %s",
279					group.ID, test.ID, err)
280			}
281
282			// Result is a single byte: 0 for false, non-zero for true
283			testPassed := result[0][0] != 0
284			response.Tests = append(response.Tests, mldsaSigVerTestResponse{
285				ID:         test.ID,
286				TestPassed: testPassed,
287			})
288		}
289
290		ret = append(ret, response)
291	}
292
293	return ret, nil
294}
295