1// Copyright (c) 2025, Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package subprocess
16
17import (
18	"encoding/binary"
19	"encoding/hex"
20	"encoding/json"
21	"fmt"
22)
23
24// The following structures reflect the JSON of ACVP XOF cSHAKE tests. See
25// https://pages.nist.gov/ACVP/draft-celi-acvp-xof.html#name-test-vectors
26
27type cShakeTestVectorSet struct {
28	Groups []cShakeTestGroup `json:"testGroups"`
29}
30
31type cShakeTestGroup struct {
32	ID                  uint64 `json:"tgId"`
33	Type                string `json:"testType"`
34	HexCustomization    bool   `json:"hexCustomization"`
35	MaxOutLenBits       uint32 `json:"maxOutLen"`
36	MinOutLenBits       uint32 `json:"minOutLen"`
37	OutLenIncrementBits uint32 `json:"outLenIncrement"`
38	Tests               []struct {
39		ID               uint64 `json:"tcId"`
40		MsgHex           string `json:"msg"`
41		BitLength        uint64 `json:"len"`
42		FunctionName     string `json:"functionName"`
43		Customization    string `json:"customization"`
44		CustomizationHex string `json:"customizationHex"`
45		OutLenBits       uint32 `json:"outLen"`
46	} `json:"tests"`
47}
48
49type cShakeTestGroupResponse struct {
50	ID    uint64               `json:"tgId"`
51	Tests []cShakeTestResponse `json:"tests"`
52}
53
54type cShakeTestResponse struct {
55	ID         uint64            `json:"tcId"`
56	DigestHex  string            `json:"md,omitempty"`
57	OutLenBits uint32            `json:"outLen,omitempty"`
58	MCTResults []cShakeMCTResult `json:"resultsArray,omitempty"`
59}
60
61type cShakeMCTResult struct {
62	DigestHex  string `json:"md"`
63	OutLenBits uint32 `json:"outLen,omitempty"`
64}
65
66type cShake struct {
67	algo string
68}
69
70func (h *cShake) Process(vectorSet []byte, m Transactable) (any, error) {
71	var parsed cShakeTestVectorSet
72	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
73		return nil, err
74	}
75
76	// See
77	// https://pages.nist.gov/ACVP/draft-celi-acvp-xof.html#name-test-types
78	// for details about the tests.
79	var ret []cShakeTestGroupResponse
80	for _, group := range parsed.Groups {
81		group := group
82		response := cShakeTestGroupResponse{
83			ID: group.ID,
84		}
85
86		if group.HexCustomization {
87			return nil, fmt.Errorf("test group %d has unsupported hex customization", group.ID)
88		}
89
90		for _, test := range group.Tests {
91			test := test
92
93			if test.CustomizationHex != "" {
94				return nil, fmt.Errorf("test case %d/%d has unsupported hex customization", group.ID, test.ID)
95			}
96
97			if uint64(len(test.MsgHex))*4 != test.BitLength {
98				return nil, fmt.Errorf("test case %d/%d contains hex message of length %d but specifies a bit length of %d", group.ID, test.ID, len(test.MsgHex), test.BitLength)
99			}
100			msg, err := hex.DecodeString(test.MsgHex)
101			if err != nil {
102				return nil, fmt.Errorf("failed to decode hex in test case %d/%d: %s", group.ID, test.ID, err)
103			}
104
105			if test.OutLenBits%8 != 0 {
106				return nil, fmt.Errorf("test case %d/%d has bit length %d - fractional bytes not supported", group.ID, test.ID, test.OutLenBits)
107			}
108
109			switch group.Type {
110			case "AFT":
111				args := [][]byte{msg, uint32le(test.OutLenBits / 8), []byte(test.FunctionName), []byte(test.Customization)}
112				m.TransactAsync(h.algo, 1, args, func(result [][]byte) error {
113					response.Tests = append(response.Tests, cShakeTestResponse{
114						ID:         test.ID,
115						DigestHex:  hex.EncodeToString(result[0]),
116						OutLenBits: test.OutLenBits,
117					})
118					return nil
119				})
120			case "MCT":
121				testResponse := cShakeTestResponse{ID: test.ID}
122
123				if group.MinOutLenBits%8 != 0 {
124					return nil, fmt.Errorf("MCT test group %d has min output length %d - fractional bytes not supported", group.ID, group.MinOutLenBits)
125				}
126				if group.MaxOutLenBits%8 != 0 {
127					return nil, fmt.Errorf("MCT test group %d has max output length %d - fractional bytes not supported", group.ID, group.MaxOutLenBits)
128				}
129				if group.OutLenIncrementBits%8 != 0 {
130					return nil, fmt.Errorf("MCT test group %d has output length increment %d - fractional bytes not supported", group.ID, group.OutLenIncrementBits)
131				}
132
133				minOutLenBytes := uint32le(group.MinOutLenBits / 8)
134				maxOutLenBytes := uint32le(group.MaxOutLenBits / 8)
135				outputLenBytes := uint32le(group.MaxOutLenBits / 8)
136				incrementBytes := uint32le(group.OutLenIncrementBits / 8)
137				var mctCustomization []byte
138
139				for i := 0; i < 100; i++ {
140					args := [][]byte{msg, minOutLenBytes, maxOutLenBytes, outputLenBytes, incrementBytes, mctCustomization}
141					result, err := m.Transact(h.algo+"/MCT", 3, args...)
142					if err != nil {
143						panic(h.algo + " mct operation failed: " + err.Error())
144					}
145
146					msg = result[0]
147					outputLenBytes = uint32le(binary.LittleEndian.Uint32(result[1]))
148					mctCustomization = result[2]
149
150					mctResult := cShakeMCTResult{
151						DigestHex:  hex.EncodeToString(msg),
152						OutLenBits: uint32(len(msg) * 8),
153					}
154					testResponse.MCTResults = append(testResponse.MCTResults, mctResult)
155				}
156
157				response.Tests = append(response.Tests, testResponse)
158			default:
159				return nil, fmt.Errorf("test group %d has unknown type %q", group.ID, group.Type)
160			}
161		}
162
163		m.Barrier(func() {
164			ret = append(ret, response)
165		})
166	}
167
168	if err := m.Flush(); err != nil {
169		return nil, err
170	}
171
172	return ret, nil
173}
174