1// Copyright 2019 The BoringSSL Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package subprocess contains functionality to talk to a modulewrapper for
16// testing of various algorithm implementations.
17package subprocess
18
19import (
20	"encoding/binary"
21	"encoding/json"
22	"errors"
23	"fmt"
24	"io"
25	"os"
26	"os/exec"
27)
28
29// Transactable provides an interface to allow test injection of transactions
30// that don't call a server.
31type Transactable interface {
32	Transact(cmd string, expectedResults int, args ...[]byte) ([][]byte, error)
33	TransactAsync(cmd string, expectedResults int, args [][]byte, callback func([][]byte) error)
34	Barrier(callback func()) error
35	Flush() error
36}
37
38// Subprocess is a "middle" layer that interacts with a FIPS module via running
39// a command and speaking a simple protocol over stdin/stdout.
40type Subprocess struct {
41	cmd        *exec.Cmd
42	stdin      io.WriteCloser
43	stdout     io.ReadCloser
44	primitives map[string]primitive
45	// supportsFlush is true if the modulewrapper indicated that it wants to receive flush commands.
46	supportsFlush bool
47	// pendingReads is a queue of expected responses. `readerRoutine` reads each response and calls the callback in the matching pendingRead.
48	pendingReads chan pendingRead
49	// readerFinished is a channel that is closed if `readerRoutine` has finished (e.g. because of a read error).
50	readerFinished chan struct{}
51}
52
53// pendingRead represents an expected response from the modulewrapper.
54type pendingRead struct {
55	// barrierCallback is called as soon as this pendingRead is the next in the queue, before any read from the modulewrapper.
56	barrierCallback func()
57
58	// callback is called with the result from the modulewrapper. If this is nil then no read is performed.
59	callback func(result [][]byte) error
60	// cmd is the command that requested this read for logging purposes.
61	cmd                string
62	expectedNumResults int
63}
64
65// New returns a new Subprocess middle layer that runs the given binary.
66func New(path string) (*Subprocess, error) {
67	cmd := exec.Command(path)
68	cmd.Stderr = os.Stderr
69	stdin, err := cmd.StdinPipe()
70	if err != nil {
71		return nil, err
72	}
73	stdout, err := cmd.StdoutPipe()
74	if err != nil {
75		return nil, err
76	}
77
78	if err := cmd.Start(); err != nil {
79		return nil, err
80	}
81
82	return NewWithIO(cmd, stdin, stdout), nil
83}
84
85// maxPending is the maximum number of requests that can be in the pipeline.
86const maxPending = 4096
87
88// NewWithIO returns a new Subprocess middle layer with the given ReadCloser and
89// WriteCloser. The returned Subprocess will call Wait on the Cmd when closed.
90func NewWithIO(cmd *exec.Cmd, in io.WriteCloser, out io.ReadCloser) *Subprocess {
91	m := &Subprocess{
92		cmd:            cmd,
93		stdin:          in,
94		stdout:         out,
95		pendingReads:   make(chan pendingRead, maxPending),
96		readerFinished: make(chan struct{}),
97	}
98
99	m.primitives = map[string]primitive{
100		"SHA-1":             &hashPrimitive{"SHA-1", 20},
101		"SHA2-224":          &hashPrimitive{"SHA2-224", 28},
102		"SHA2-256":          &hashPrimitive{"SHA2-256", 32},
103		"SHA2-384":          &hashPrimitive{"SHA2-384", 48},
104		"SHA2-512":          &hashPrimitive{"SHA2-512", 64},
105		"SHA2-512/224":      &hashPrimitive{"SHA2-512/224", 28},
106		"SHA2-512/256":      &hashPrimitive{"SHA2-512/256", 32},
107		"SHA3-224":          &hashPrimitive{"SHA3-224", 28},
108		"SHA3-256":          &hashPrimitive{"SHA3-256", 32},
109		"SHA3-384":          &hashPrimitive{"SHA3-384", 48},
110		"SHA3-512":          &hashPrimitive{"SHA3-512", 64},
111		"SHAKE-128":         &shake{"SHAKE-128", 16},
112		"SHAKE-256":         &shake{"SHAKE-256", 32},
113		"cSHAKE-128":        &cShake{"cSHAKE-128"},
114		"cSHAKE-256":        &cShake{"cSHAKE-256"},
115		"ACVP-AES-ECB":      &blockCipher{"AES", 16, 2, true, false, iterateAES},
116		"ACVP-AES-CBC":      &blockCipher{"AES-CBC", 16, 2, true, true, iterateAESCBC},
117		"ACVP-AES-CBC-CS3":  &blockCipher{"AES-CBC-CS3", 16, 1, false, true, iterateAESCBC},
118		"ACVP-AES-CTR":      &blockCipher{"AES-CTR", 16, 1, false, true, nil},
119		"ACVP-TDES-ECB":     &blockCipher{"3DES-ECB", 8, 3, true, false, iterate3DES},
120		"ACVP-TDES-CBC":     &blockCipher{"3DES-CBC", 8, 3, true, true, iterate3DESCBC},
121		"ACVP-AES-XTS":      &xts{},
122		"ACVP-AES-GCM":      &aead{"AES-GCM", false},
123		"ACVP-AES-GMAC":     &aead{"AES-GCM", false},
124		"ACVP-AES-CCM":      &aead{"AES-CCM", true},
125		"ACVP-AES-KW":       &aead{"AES-KW", false},
126		"ACVP-AES-KWP":      &aead{"AES-KWP", false},
127		"HMAC-SHA-1":        &hmacPrimitive{"HMAC-SHA-1", 20},
128		"HMAC-SHA2-224":     &hmacPrimitive{"HMAC-SHA2-224", 28},
129		"HMAC-SHA2-256":     &hmacPrimitive{"HMAC-SHA2-256", 32},
130		"HMAC-SHA2-384":     &hmacPrimitive{"HMAC-SHA2-384", 48},
131		"HMAC-SHA2-512":     &hmacPrimitive{"HMAC-SHA2-512", 64},
132		"HMAC-SHA2-512/224": &hmacPrimitive{"HMAC-SHA2-512/224", 28},
133		"HMAC-SHA2-512/256": &hmacPrimitive{"HMAC-SHA2-512/256", 32},
134		"HMAC-SHA3-224":     &hmacPrimitive{"HMAC-SHA3-224", 28},
135		"HMAC-SHA3-256":     &hmacPrimitive{"HMAC-SHA3-256", 32},
136		"HMAC-SHA3-384":     &hmacPrimitive{"HMAC-SHA3-384", 48},
137		"HMAC-SHA3-512":     &hmacPrimitive{"HMAC-SHA3-512", 64},
138		"ctrDRBG":           &drbg{"ctrDRBG", map[string]bool{"AES-128": true, "AES-192": true, "AES-256": true}},
139		"hmacDRBG":          &drbg{"hmacDRBG", map[string]bool{"SHA-1": true, "SHA2-224": true, "SHA2-256": true, "SHA2-384": true, "SHA2-512": true, "SHA2-512/224": true, "SHA2-512/256": true, "SHA3-224": true, "SHA3-256": true, "SHA3-384": true, "SHA3-512": true}},
140		"KDF":               &kdfPrimitive{},
141		"KDA":               &multiModeKda{modes: map[string]primitive{"HKDF": &hkdf{}, "OneStepNoCounter": &oneStepNoCounter{}}},
142		"TLS-v1.2":          &tlsKDF{},
143		"TLS-v1.3":          &tls13{},
144		"CMAC-AES":          &keyedMACPrimitive{"CMAC-AES"},
145		"RSA":               &rsa{},
146		"KAS-ECC-SSC":       &kas{},
147		"KAS-FFC-SSC":       &kasDH{},
148		"PBKDF":             &pbkdf{},
149		"ML-DSA":            &mldsa{},
150		"ML-KEM":            &mlkem{},
151		"SLH-DSA":           &slhdsa{},
152		"kdf-components":    &ssh{},
153		"KTS-IFC":           &kts{map[string]bool{"SHA-1": true, "SHA2-224": true, "SHA2-256": true, "SHA2-384": true, "SHA2-512": true, "SHA2-512/224": true, "SHA2-512/256": true, "SHA3-224": true, "SHA3-256": true, "SHA3-384": true, "SHA3-512": true}},
154	}
155	m.primitives["ECDSA"] = &ecdsa{"ECDSA", map[string]bool{"P-224": true, "P-256": true, "P-384": true, "P-521": true}, m.primitives}
156	m.primitives["DetECDSA"] = &ecdsa{"DetECDSA", map[string]bool{"P-224": true, "P-256": true, "P-384": true, "P-521": true}, m.primitives}
157	m.primitives["EDDSA"] = &eddsa{"EDDSA", map[string]bool{"ED-25519": true}}
158
159	go m.readerRoutine()
160	return m
161}
162
163// Close signals the child process to exit and waits for it to complete.
164func (m *Subprocess) Close() {
165	m.stdout.Close()
166	m.stdin.Close()
167	m.cmd.Wait()
168	close(m.pendingReads)
169	<-m.readerFinished
170}
171
172func (m *Subprocess) flush() error {
173	if !m.supportsFlush {
174		return nil
175	}
176
177	const cmd = "flush"
178	buf := make([]byte, 8, 8+len(cmd))
179	binary.LittleEndian.PutUint32(buf, 1)
180	binary.LittleEndian.PutUint32(buf[4:], uint32(len(cmd)))
181	buf = append(buf, []byte(cmd)...)
182
183	if _, err := m.stdin.Write(buf); err != nil {
184		return err
185	}
186	return nil
187}
188
189func (m *Subprocess) enqueueRead(pending pendingRead) error {
190	select {
191	case <-m.readerFinished:
192		panic("attempted to enqueue request after the reader failed")
193	default:
194	}
195
196	select {
197	case m.pendingReads <- pending:
198		break
199	default:
200		// `pendingReads` is full. Ensure that the modulewrapper will process
201		// some outstanding requests to free up space in the queue.
202		if err := m.flush(); err != nil {
203			return err
204		}
205		m.pendingReads <- pending
206	}
207
208	return nil
209}
210
211// TransactAsync performs a single request--response pair with the subprocess.
212// The callback will run at some future point, in a separate goroutine. All
213// callbacks will, however, be run in the order that TransactAsync was called.
214// Use Flush to wait for all outstanding callbacks.
215func (m *Subprocess) TransactAsync(cmd string, expectedNumResults int, args [][]byte, callback func(result [][]byte) error) {
216	if err := m.enqueueRead(pendingRead{nil, callback, cmd, expectedNumResults}); err != nil {
217		panic(err)
218	}
219
220	argLength := len(cmd)
221	for _, arg := range args {
222		argLength += len(arg)
223	}
224
225	buf := make([]byte, 4*(2+len(args)), 4*(2+len(args))+argLength)
226	binary.LittleEndian.PutUint32(buf, uint32(1+len(args)))
227	binary.LittleEndian.PutUint32(buf[4:], uint32(len(cmd)))
228	for i, arg := range args {
229		binary.LittleEndian.PutUint32(buf[4*(i+2):], uint32(len(arg)))
230	}
231	buf = append(buf, []byte(cmd)...)
232	for _, arg := range args {
233		buf = append(buf, arg...)
234	}
235
236	if _, err := m.stdin.Write(buf); err != nil {
237		panic(err)
238	}
239}
240
241// Flush tells the subprocess to complete all outstanding requests and waits
242// for all outstanding TransactAsync callbacks to complete.
243func (m *Subprocess) Flush() error {
244	if m.supportsFlush {
245		m.flush()
246	}
247
248	done := make(chan struct{})
249	if err := m.enqueueRead(pendingRead{barrierCallback: func() {
250		close(done)
251	}}); err != nil {
252		return err
253	}
254
255	<-done
256	return nil
257}
258
259// Barrier runs callback after all outstanding TransactAsync callbacks have
260// been run.
261func (m *Subprocess) Barrier(callback func()) error {
262	return m.enqueueRead(pendingRead{barrierCallback: callback})
263}
264
265func (m *Subprocess) Transact(cmd string, expectedNumResults int, args ...[]byte) ([][]byte, error) {
266	done := make(chan struct{})
267	var result [][]byte
268	m.TransactAsync(cmd, expectedNumResults, args, func(r [][]byte) error {
269		result = r
270		close(done)
271		return nil
272	})
273
274	if err := m.flush(); err != nil {
275		return nil, err
276	}
277
278	select {
279	case <-done:
280		return result, nil
281	case <-m.readerFinished:
282		panic("was still waiting for a result when the reader finished")
283	}
284}
285
286func (m *Subprocess) readerRoutine() {
287	defer close(m.readerFinished)
288
289	for pendingRead := range m.pendingReads {
290		if pendingRead.barrierCallback != nil {
291			pendingRead.barrierCallback()
292		}
293
294		if pendingRead.callback == nil {
295			continue
296		}
297
298		result, err := m.readResult(pendingRead.cmd, pendingRead.expectedNumResults)
299		if err != nil {
300			panic(fmt.Errorf("failed to read from subprocess: %w", err))
301		}
302
303		if err := pendingRead.callback(result); err != nil {
304			panic(fmt.Errorf("result from subprocess was rejected: %w", err))
305		}
306	}
307}
308
309func (m *Subprocess) readResult(cmd string, expectedNumResults int) ([][]byte, error) {
310	buf := make([]byte, 4)
311
312	if _, err := io.ReadFull(m.stdout, buf); err != nil {
313		return nil, err
314	}
315
316	numResults := binary.LittleEndian.Uint32(buf)
317	if int(numResults) != expectedNumResults {
318		return nil, fmt.Errorf("expected %d results from %q but got %d", expectedNumResults, cmd, numResults)
319	}
320
321	buf = make([]byte, 4*numResults)
322	if _, err := io.ReadFull(m.stdout, buf); err != nil {
323		return nil, err
324	}
325
326	var resultsLength uint64
327	for i := uint32(0); i < numResults; i++ {
328		resultsLength += uint64(binary.LittleEndian.Uint32(buf[4*i:]))
329	}
330
331	if resultsLength > (1 << 30) {
332		return nil, fmt.Errorf("results too large (%d bytes)", resultsLength)
333	}
334
335	results := make([]byte, resultsLength)
336	if _, err := io.ReadFull(m.stdout, results); err != nil {
337		return nil, err
338	}
339
340	ret := make([][]byte, 0, numResults)
341	var offset int
342	for i := uint32(0); i < numResults; i++ {
343		length := binary.LittleEndian.Uint32(buf[4*i:])
344		ret = append(ret, results[offset:offset+int(length)])
345		offset += int(length)
346	}
347
348	return ret, nil
349}
350
351// Config returns a JSON blob that describes the supported primitives. The
352// format of the blob is defined by ACVP. See
353// http://usnistgov.github.io/ACVP/artifacts/draft-fussell-acvp-spec-00.html#rfc.section.11.15.2.1
354func (m *Subprocess) Config() ([]byte, error) {
355	results, err := m.Transact("getConfig", 1)
356	if err != nil {
357		return nil, err
358	}
359	var config []struct {
360		Algorithm string   `json:"algorithm"`
361		Features  []string `json:"features"`
362	}
363	if err := json.Unmarshal(results[0], &config); err != nil {
364		return nil, errors.New("failed to parse config response from wrapper: " + err.Error())
365	}
366	for _, algo := range config {
367		if algo.Algorithm == "acvptool" {
368			for _, feature := range algo.Features {
369				switch feature {
370				case "batch":
371					m.supportsFlush = true
372				}
373			}
374		}
375	}
376
377	return results[0], nil
378}
379
380// Process runs a set of test vectors and returns the result.
381func (m *Subprocess) Process(algorithm string, vectorSet []byte) (any, error) {
382	prim, ok := m.primitives[algorithm]
383	if !ok {
384		return nil, fmt.Errorf("unknown algorithm %q", algorithm)
385	}
386	ret, err := prim.Process(vectorSet, m)
387	if err != nil {
388		return nil, err
389	}
390	return ret, nil
391}
392
393type primitive interface {
394	Process(vectorSet []byte, t Transactable) (any, error)
395}
396
397func uint32le(n uint32) []byte {
398	var ret [4]byte
399	binary.LittleEndian.PutUint32(ret[:], n)
400	return ret[:]
401}
402