1// Copyright 2021 The BoringSSL Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package subprocess
16
17import (
18	"bytes"
19	"encoding/hex"
20	"encoding/json"
21	"fmt"
22)
23
24// The following structures reflect the JSON of ACVP KAS KDF tests. See
25// https://pages.nist.gov/ACVP/draft-hammett-acvp-kas-kdf-hkdf.html
26// https://pages.nist.gov/ACVP/draft-hammett-acvp-kas-kdf-onestepnocounter.html
27
28type multiModeKda struct {
29	modes map[string]primitive
30}
31
32func (k multiModeKda) Process(vectorSet []byte, m Transactable) (any, error) {
33	var vector struct {
34		Mode string `json:"mode"`
35	}
36	if err := json.Unmarshal(vectorSet, &vector); err != nil {
37		return nil, fmt.Errorf("invalid KDA test vector: %w", err)
38	}
39	mode, ok := k.modes[vector.Mode]
40	if !ok {
41		return nil, fmt.Errorf("unsupported KDA mode %q", vector.Mode)
42	}
43	return mode.Process(vectorSet, m)
44}
45
46type kdaPartyInfo struct {
47	IDHex    string `json:"partyId"`
48	ExtraHex string `json:"ephemeralData"`
49}
50
51func (p *kdaPartyInfo) data() ([]byte, error) {
52	ret, err := hex.DecodeString(p.IDHex)
53	if err != nil {
54		return nil, err
55	}
56	if len(p.ExtraHex) > 0 {
57		extra, err := hex.DecodeString(p.ExtraHex)
58		if err != nil {
59			return nil, err
60		}
61		ret = append(ret, extra...)
62	}
63	return ret, nil
64}
65
66type hkdfTestVectorSet struct {
67	Mode   string          `json:"mode"`
68	Groups []hkdfTestGroup `json:"testGroups"`
69}
70
71type hkdfTestGroup struct {
72	ID     uint64            `json:"tgId"`
73	Type   string            `json:"testType"` // AFT or VAL
74	Config hkdfConfiguration `json:"kdfConfiguration"`
75	Tests  []hkdfTest        `json:"tests"`
76}
77
78type hkdfTest struct {
79	ID          uint64         `json:"tcId"`
80	Params      hkdfParameters `json:"kdfParameter"`
81	PartyU      kdaPartyInfo   `json:"fixedInfoPartyU"`
82	PartyV      kdaPartyInfo   `json:"fixedInfoPartyV"`
83	ExpectedHex string         `json:"dkm"`
84}
85
86type hkdfConfiguration struct {
87	Type               string `json:"kdfType"`
88	OutputBits         uint32 `json:"l"`
89	HashName           string `json:"hmacAlg"`
90	FixedInfoPattern   string `json:"fixedInfoPattern"`
91	FixedInputEncoding string `json:"fixedInfoEncoding"`
92}
93
94func (c *hkdfConfiguration) extract() (outBytes uint32, hashName string, err error) {
95	if c.Type != "hkdf" ||
96		c.FixedInfoPattern != "uPartyInfo||vPartyInfo" ||
97		c.FixedInputEncoding != "concatenation" ||
98		c.OutputBits%8 != 0 {
99		return 0, "", fmt.Errorf("KDA not configured for HKDF: %#v", c)
100	}
101
102	return c.OutputBits / 8, c.HashName, nil
103}
104
105type hkdfParameters struct {
106	SaltHex string `json:"salt"`
107	KeyHex  string `json:"z"`
108}
109
110func (p *hkdfParameters) extract() (key, salt []byte, err error) {
111	salt, err = hex.DecodeString(p.SaltHex)
112	if err != nil {
113		return nil, nil, err
114	}
115
116	key, err = hex.DecodeString(p.KeyHex)
117	if err != nil {
118		return nil, nil, err
119	}
120
121	return key, salt, nil
122}
123
124type hkdfTestGroupResponse struct {
125	ID    uint64             `json:"tgId"`
126	Tests []hkdfTestResponse `json:"tests"`
127}
128
129type hkdfTestResponse struct {
130	ID     uint64 `json:"tcId"`
131	KeyOut string `json:"dkm,omitempty"`
132	Passed *bool  `json:"testPassed,omitempty"`
133}
134
135type hkdf struct{}
136
137func (k *hkdf) Process(vectorSet []byte, m Transactable) (any, error) {
138	var parsed hkdfTestVectorSet
139	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
140		return nil, err
141	}
142
143	if parsed.Mode != "HKDF" {
144		return nil, fmt.Errorf("unexpected KDA mode %q", parsed.Mode)
145	}
146
147	var respGroups []hkdfTestGroupResponse
148	for _, group := range parsed.Groups {
149		group := group
150		groupResp := hkdfTestGroupResponse{ID: group.ID}
151
152		var isValidationTest bool
153		switch group.Type {
154		case "VAL":
155			isValidationTest = true
156		case "AFT":
157			isValidationTest = false
158		default:
159			return nil, fmt.Errorf("unknown test type %q", group.Type)
160		}
161
162		outBytes, hashName, err := group.Config.extract()
163		if err != nil {
164			return nil, err
165		}
166
167		for _, test := range group.Tests {
168			test := test
169			testResp := hkdfTestResponse{ID: test.ID}
170
171			key, salt, err := test.Params.extract()
172			if err != nil {
173				return nil, err
174			}
175			uData, err := test.PartyU.data()
176			if err != nil {
177				return nil, err
178			}
179			vData, err := test.PartyV.data()
180			if err != nil {
181				return nil, err
182			}
183
184			var expected []byte
185			if isValidationTest {
186				expected, err = hex.DecodeString(test.ExpectedHex)
187				if err != nil {
188					return nil, err
189				}
190			}
191
192			info := make([]byte, 0, len(uData)+len(vData))
193			info = append(info, uData...)
194			info = append(info, vData...)
195
196			m.TransactAsync("HKDF/"+hashName, 1, [][]byte{key, salt, info, uint32le(outBytes)}, func(result [][]byte) error {
197				if len(result[0]) != int(outBytes) {
198					return fmt.Errorf("HKDF operation resulted in %d bytes but wanted %d", len(result[0]), outBytes)
199				}
200				if isValidationTest {
201					passed := bytes.Equal(expected, result[0])
202					testResp.Passed = &passed
203				} else {
204					testResp.KeyOut = hex.EncodeToString(result[0])
205				}
206
207				groupResp.Tests = append(groupResp.Tests, testResp)
208				return nil
209			})
210		}
211
212		m.Barrier(func() {
213			respGroups = append(respGroups, groupResp)
214		})
215	}
216
217	if err := m.Flush(); err != nil {
218		return nil, err
219	}
220
221	return respGroups, nil
222}
223
224type oneStepTestVectorSet struct {
225	Mode   string             `json:"mode"`
226	Groups []oneStepTestGroup `json:"testGroups"`
227}
228
229type oneStepTestGroup struct {
230	ID     uint64               `json:"tgId"`
231	Type   string               `json:"testType"` // AFT or VAL
232	Config oneStepConfiguration `json:"kdfConfiguration"`
233	Tests  []oneStepTest        `json:"tests"`
234}
235
236type oneStepConfiguration struct {
237	Type               string `json:"kdfType"`
238	SaltMethod         string `json:"saltMethod"`
239	FixedInfoPattern   string `json:"fixedInfoPattern"`
240	FixedInputEncoding string `json:"fixedInfoEncoding"`
241	AuxFunction        string `json:"auxFunction"`
242	OutputBits         uint32 `json:"l"`
243}
244
245func (c *oneStepConfiguration) extract() (outBytes uint32, auxFunction string, err error) {
246	if c.Type != "oneStepNoCounter" ||
247		c.FixedInfoPattern != "uPartyInfo||vPartyInfo" ||
248		c.FixedInputEncoding != "concatenation" ||
249		c.OutputBits%8 != 0 {
250		return 0, "", fmt.Errorf("KDA not configured for OneStepNoCounter: %#v", c)
251	}
252	return c.OutputBits / 8, c.AuxFunction, nil
253}
254
255type oneStepTest struct {
256	ID              uint64                `json:"tcId"`
257	Params          oneStepTestParameters `json:"kdfParameter"`
258	FixedInfoPartyU kdaPartyInfo          `json:"fixedInfoPartyU"`
259	FixedInfoPartyV kdaPartyInfo          `json:"fixedInfoPartyV"`
260	DerivedKeyHex   string                `json:"dkm,omitempty"` // For VAL tests only.
261}
262
263type oneStepTestParameters struct {
264	KdfType    string `json:"kdfType"`
265	SaltHex    string `json:"salt"`
266	ZHex       string `json:"z"`
267	OutputBits uint32 `json:"l"`
268}
269
270func (p oneStepTestParameters) extract() (key []byte, salt []byte, outLen uint32, err error) {
271	if p.KdfType != "oneStepNoCounter" ||
272		p.OutputBits%8 != 0 {
273		return nil, nil, 0, fmt.Errorf("KDA not configured for OneStepNoCounter: %#v", p)
274	}
275	outLen = p.OutputBits / 8
276	salt, err = hex.DecodeString(p.SaltHex)
277	if err != nil {
278		return
279	}
280	key, err = hex.DecodeString(p.ZHex)
281	if err != nil {
282		return
283	}
284	return
285}
286
287type oneStepTestGroupResponse struct {
288	ID    uint64                `json:"tgId"`
289	Tests []oneStepTestResponse `json:"tests"`
290}
291
292type oneStepTestResponse struct {
293	ID     uint64 `json:"tcId"`
294	KeyOut string `json:"dkm,omitempty"`        // For AFT
295	Passed *bool  `json:"testPassed,omitempty"` // For VAL
296}
297
298type oneStepNoCounter struct{}
299
300func (k oneStepNoCounter) Process(vectorSet []byte, m Transactable) (any, error) {
301	var parsed oneStepTestVectorSet
302	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
303		return nil, err
304	}
305
306	if parsed.Mode != "OneStepNoCounter" {
307		return nil, fmt.Errorf("unexpected KDA mode %q", parsed.Mode)
308	}
309
310	var respGroups []oneStepTestGroupResponse
311	for _, group := range parsed.Groups {
312		group := group
313
314		groupResp := oneStepTestGroupResponse{ID: group.ID}
315		outBytes, hashName, err := group.Config.extract()
316		if err != nil {
317			return nil, err
318		}
319
320		var isValidationTest bool
321		switch group.Type {
322		case "VAL":
323			isValidationTest = true
324		case "AFT":
325			isValidationTest = false
326		default:
327			return nil, fmt.Errorf("unknown test type %q", group.Type)
328		}
329
330		for _, test := range group.Tests {
331			test := test
332			testResp := oneStepTestResponse{ID: test.ID}
333
334			key, salt, paramsOutBytes, err := test.Params.extract()
335			if err != nil {
336				return nil, err
337			}
338			if paramsOutBytes != outBytes {
339				return nil, fmt.Errorf("test %d in group %d: output length mismatch: %d != %d", test.ID, group.ID, paramsOutBytes, outBytes)
340			}
341
342			uData, err := test.FixedInfoPartyU.data()
343			if err != nil {
344				return nil, err
345			}
346			vData, err := test.FixedInfoPartyV.data()
347			if err != nil {
348				return nil, err
349			}
350
351			info := make([]byte, 0, len(uData)+len(vData))
352			info = append(info, uData...)
353			info = append(info, vData...)
354			var expected []byte
355			if isValidationTest {
356				expected, err = hex.DecodeString(test.DerivedKeyHex)
357				if err != nil {
358					return nil, fmt.Errorf("test %d in group %d: invalid DerivedKeyHex: %w", test.ID, group.ID, err)
359				}
360			}
361
362			cmd := "OneStepNoCounter/" + hashName
363			m.TransactAsync(cmd, 1, [][]byte{key, info, salt, uint32le(outBytes)}, func(result [][]byte) error {
364				if len(result[0]) != int(outBytes) {
365					return fmt.Errorf("OneStepNoCounter operation resulted in %d bytes but wanted %d", len(result[0]), outBytes)
366				}
367
368				if isValidationTest {
369					passed := bytes.Equal(expected, result[0])
370					testResp.Passed = &passed
371				} else {
372					testResp.KeyOut = hex.EncodeToString(result[0])
373				}
374
375				groupResp.Tests = append(groupResp.Tests, testResp)
376				return nil
377			})
378		}
379
380		m.Barrier(func() {
381			respGroups = append(respGroups, groupResp)
382		})
383	}
384
385	if err := m.Flush(); err != nil {
386		return nil, err
387	}
388
389	return respGroups, nil
390}
391