1package subprocess
2
3import (
4	"encoding/hex"
5	"encoding/json"
6	"fmt"
7	"strings"
8)
9
10// Common top-level structure to parse mode
11type mlkemTestVectorSet struct {
12	Algorithm string `json:"algorithm"`
13	Mode      string `json:"mode"`
14	Revision  string `json:"revision"`
15}
16
17// Key generation specific structures
18type mlkemKeyGenTestVectorSet struct {
19	Algorithm string                 `json:"algorithm"`
20	Mode      string                 `json:"mode"`
21	Revision  string                 `json:"revision"`
22	Groups    []mlkemKeyGenTestGroup `json:"testGroups"`
23}
24
25type mlkemKeyGenTestGroup struct {
26	ID           uint64            `json:"tgId"`
27	TestType     string            `json:"testType"`
28	ParameterSet string            `json:"parameterSet"`
29	Tests        []mlkemKeyGenTest `json:"tests"`
30}
31
32type mlkemKeyGenTest struct {
33	ID uint64 `json:"tcId"`
34	Z  string `json:"z"`
35	D  string `json:"d"`
36}
37
38type mlkemKeyGenTestGroupResponse struct {
39	ID    uint64                    `json:"tgId"`
40	Tests []mlkemKeyGenTestResponse `json:"tests"`
41}
42
43type mlkemKeyGenTestResponse struct {
44	ID uint64 `json:"tcId"`
45	EK string `json:"ek"`
46	DK string `json:"dk"`
47}
48
49type mlkemEncapDecapTestVectorSet struct {
50	Algorithm string                     `json:"algorithm"`
51	Mode      string                     `json:"mode"`
52	Revision  string                     `json:"revision"`
53	Groups    []mlkemEncapDecapTestGroup `json:"testGroups"`
54}
55
56type mlkemEncapDecapTestGroup struct {
57	ID           uint64                `json:"tgId"`
58	TestType     string                `json:"testType"`
59	ParameterSet string                `json:"parameterSet"`
60	Function     string                `json:"function"`
61	DK           string                `json:"dk,omitempty"`
62	Tests        []mlkemEncapDecapTest `json:"tests"`
63}
64
65type mlkemEncapDecapTest struct {
66	ID uint64 `json:"tcId"`
67	EK string `json:"ek,omitempty"`
68	M  string `json:"m,omitempty"`
69	C  string `json:"c,omitempty"`
70}
71
72type mlkemEncapDecapTestGroupResponse struct {
73	ID    uint64                        `json:"tgId"`
74	Tests []mlkemEncapDecapTestResponse `json:"tests"`
75}
76
77type mlkemEncapDecapTestResponse struct {
78	ID uint64 `json:"tcId"`
79	C  string `json:"c,omitempty"`
80	K  string `json:"k,omitempty"`
81}
82
83type mlkem struct{}
84
85func (m *mlkem) Process(vectorSet []byte, t Transactable) (any, error) {
86	var common mlkemTestVectorSet
87	if err := json.Unmarshal(vectorSet, &common); err != nil {
88		return nil, fmt.Errorf("failed to unmarshal vector set: %v", err)
89	}
90
91	switch common.Mode {
92	case "keyGen":
93		return m.processKeyGen(vectorSet, t)
94	case "encapDecap":
95		return m.processEncapDecap(vectorSet, t)
96	default:
97		return nil, fmt.Errorf("unsupported ML-KEM mode: %q", common.Mode)
98	}
99}
100
101func (m *mlkem) processKeyGen(vectorSet []byte, t Transactable) (any, error) {
102	var parsed mlkemKeyGenTestVectorSet
103	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
104		return nil, fmt.Errorf("failed to unmarshal keyGen vector set: %v", err)
105	}
106
107	var ret []mlkemKeyGenTestGroupResponse
108
109	for _, group := range parsed.Groups {
110		response := mlkemKeyGenTestGroupResponse{
111			ID: group.ID,
112		}
113
114		if !strings.HasPrefix(group.ParameterSet, "ML-KEM-") {
115			return nil, fmt.Errorf("invalid parameter set: %s", group.ParameterSet)
116		}
117		cmdName := group.ParameterSet + "/keyGen"
118
119		for _, test := range group.Tests {
120			// Concatenate d and z to form the seed
121			dBytes, err := hex.DecodeString(test.D)
122			if err != nil {
123				return nil, fmt.Errorf("failed to decode d in test case %d/%d: %s",
124					group.ID, test.ID, err)
125			}
126			zBytes, err := hex.DecodeString(test.Z)
127			if err != nil {
128				return nil, fmt.Errorf("failed to decode z in test case %d/%d: %s",
129					group.ID, test.ID, err)
130			}
131
132			seed := make([]byte, len(dBytes)+len(zBytes))
133			copy(seed, dBytes)
134			copy(seed[len(dBytes):], zBytes)
135
136			result, err := t.Transact(cmdName, 2, seed)
137			if err != nil {
138				return nil, fmt.Errorf("key generation failed for test case %d/%d: %s",
139					group.ID, test.ID, err)
140			}
141
142			response.Tests = append(response.Tests, mlkemKeyGenTestResponse{
143				ID: test.ID,
144				EK: hex.EncodeToString(result[0]),
145				DK: hex.EncodeToString(result[1]),
146			})
147		}
148
149		ret = append(ret, response)
150	}
151
152	return ret, nil
153}
154
155func (m *mlkem) processEncapDecap(vectorSet []byte, t Transactable) (any, error) {
156	var parsed mlkemEncapDecapTestVectorSet
157	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
158		return nil, fmt.Errorf("failed to unmarshal encapDecap vector set: %v", err)
159	}
160
161	var ret []mlkemEncapDecapTestGroupResponse
162
163	for _, group := range parsed.Groups {
164		response := mlkemEncapDecapTestGroupResponse{
165			ID: group.ID,
166		}
167
168		if !strings.HasPrefix(group.ParameterSet, "ML-KEM-") {
169			return nil, fmt.Errorf("invalid parameter set: %s", group.ParameterSet)
170		}
171
172		switch group.Function {
173		case "encapsulation":
174			cmdName := group.ParameterSet + "/encap"
175			for _, test := range group.Tests {
176				ek, err := hex.DecodeString(test.EK)
177				if err != nil {
178					return nil, fmt.Errorf("failed to decode ek in test case %d/%d: %s",
179						group.ID, test.ID, err)
180				}
181
182				m, err := hex.DecodeString(test.M)
183				if err != nil {
184					return nil, fmt.Errorf("failed to decode m in test case %d/%d: %s",
185						group.ID, test.ID, err)
186				}
187
188				result, err := t.Transact(cmdName, 2, ek, m)
189				if err != nil {
190					return nil, fmt.Errorf("encapsulation failed for test case %d/%d: %s",
191						group.ID, test.ID, err)
192				}
193
194				response.Tests = append(response.Tests, mlkemEncapDecapTestResponse{
195					ID: test.ID,
196					C:  hex.EncodeToString(result[0]),
197					K:  hex.EncodeToString(result[1]),
198				})
199			}
200
201		case "decapsulation":
202			cmdName := group.ParameterSet + "/decap"
203			dk, err := hex.DecodeString(group.DK)
204			if err != nil {
205				return nil, fmt.Errorf("failed to decode dk in group %d: %s",
206					group.ID, err)
207			}
208
209			for _, test := range group.Tests {
210				c, err := hex.DecodeString(test.C)
211				if err != nil {
212					return nil, fmt.Errorf("failed to decode c in test case %d/%d: %s",
213						group.ID, test.ID, err)
214				}
215
216				result, err := t.Transact(cmdName, 1, dk, c)
217				if err != nil {
218					return nil, fmt.Errorf("decapsulation failed for test case %d/%d: %s",
219						group.ID, test.ID, err)
220				}
221
222				response.Tests = append(response.Tests, mlkemEncapDecapTestResponse{
223					ID: test.ID,
224					K:  hex.EncodeToString(result[0]),
225				})
226			}
227
228		default:
229			return nil, fmt.Errorf("unsupported function: %s", group.Function)
230		}
231
232		ret = append(ret, response)
233	}
234
235	return ret, nil
236}
237