1// Copyright 2021 The BoringSSL Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// testmodulewrapper is a modulewrapper binary that works with acvptool and
16// implements the primitives that BoringSSL's modulewrapper doesn't, so that
17// we have something that can exercise all the code in avcptool.
18
19package main
20
21import (
22	"bytes"
23	"crypto"
24	"crypto/aes"
25	"crypto/cipher"
26	"crypto/ed25519"
27	"crypto/hkdf"
28	"crypto/hmac"
29	"crypto/pbkdf2"
30	"crypto/rand"
31	"crypto/sha256"
32	"crypto/sha3"
33	"crypto/sha512"
34	"encoding/binary"
35	"errors"
36	"fmt"
37	"hash"
38	"io"
39	"os"
40
41	"filippo.io/edwards25519"
42
43	"golang.org/x/crypto/xts"
44)
45
46var (
47	output       io.Writer
48	outputBuffer *bytes.Buffer
49)
50
51var handlers = map[string]func([][]byte) error{
52	"flush":                    flush,
53	"getConfig":                getConfig,
54	"KDF-counter":              kdfCounter,
55	"AES-XTS/encrypt":          xtsEncrypt,
56	"AES-XTS/decrypt":          xtsDecrypt,
57	"HKDF/SHA2-256":            hkdfMAC,
58	"hmacDRBG-reseed/SHA2-256": hmacDRBGReseed,
59	"hmacDRBG-pr/SHA2-256":     hmacDRBGPredictionResistance,
60	"AES-CBC-CS3/encrypt":      ctsEncrypt,
61	"AES-CBC-CS3/decrypt":      ctsDecrypt,
62	"PBKDF":                    pbkdf,
63	"EDDSA/keyGen":             eddsaKeyGen,
64	"EDDSA/keyVer":             eddsaKeyVer,
65	"EDDSA/sigGen":             eddsaSigGen,
66	"EDDSA/sigVer":             eddsaSigVer,
67	"SHAKE-128":                shakeAftVot(sha3.NewSHAKE128),
68	"SHAKE-128/VOT":            shakeAftVot(sha3.NewSHAKE128),
69	"SHAKE-128/MCT":            shakeMct(sha3.NewSHAKE128),
70	"SHAKE-256":                shakeAftVot(sha3.NewSHAKE256),
71	"SHAKE-256/VOT":            shakeAftVot(sha3.NewSHAKE256),
72	"SHAKE-256/MCT":            shakeMct(sha3.NewSHAKE256),
73	"cSHAKE-128":               cShakeAft(sha3.NewCSHAKE128),
74	"cSHAKE-128/MCT":           cShakeMct(sha3.NewCSHAKE128),
75	"cSHAKE-256":               cShakeAft(sha3.NewCSHAKE256),
76	"cSHAKE-256/MCT":           cShakeMct(sha3.NewCSHAKE256),
77}
78
79func flush(args [][]byte) error {
80	if outputBuffer == nil {
81		return nil
82	}
83
84	if _, err := os.Stdout.Write(outputBuffer.Bytes()); err != nil {
85		return err
86	}
87	outputBuffer = new(bytes.Buffer)
88	output = outputBuffer
89	return nil
90}
91
92func getConfig(args [][]byte) error {
93	if len(args) != 0 {
94		return fmt.Errorf("getConfig received %d args", len(args))
95	}
96
97	if err := reply([]byte(`[
98	{
99		"algorithm": "acvptool",
100		"features": ["batch"]
101	}, {
102		"algorithm": "KDF",
103		"revision": "1.0",
104		"capabilities": [{
105			"kdfMode": "counter",
106			"macMode": [
107				"HMAC-SHA2-256"
108			],
109			"supportedLengths": [{
110				"min": 8,
111				"max": 4096,
112				"increment": 8
113			}],
114			"fixedDataOrder": [
115				"before fixed data"
116			],
117			"counterLength": [
118				32
119			]
120		}]
121	}, {
122		"algorithm": "ACVP-AES-XTS",
123		"revision": "1.0",
124		"direction": [
125		  "encrypt",
126		  "decrypt"
127		],
128		"keyLen": [
129		  128,
130		  256
131		],
132		"payloadLen": [
133		  1024
134		],
135		"tweakMode": [
136		  "number"
137		]
138	}, {
139		"algorithm": "KDA",
140		"mode": "HKDF",
141		"revision": "Sp800-56Cr1",
142		"fixedInfoPattern": "uPartyInfo||vPartyInfo",
143		"encoding": [
144			"concatenation"
145		],
146		"hmacAlg": [
147			"SHA2-256"
148		],
149		"macSaltMethods": [
150			"default",
151			"random"
152		],
153		"l": 256,
154		"z": [256, 384]
155	}, {
156		"algorithm": "hmacDRBG",
157		"revision": "1.0",
158		"predResistanceEnabled": [false, true],
159		"reseedImplemented": true,
160		"capabilities": [{
161			"mode": "SHA2-256",
162			"derFuncEnabled": false,
163			"entropyInputLen": [
164				256
165			],
166			"nonceLen": [
167				128
168			],
169			"persoStringLen": [
170				256
171			],
172			"additionalInputLen": [
173				256
174			],
175			"returnedBitsLen": 256
176		}]
177	}, {
178		"algorithm": "ACVP-AES-CBC-CS3",
179		"revision": "1.0",
180		"payloadLen": [{
181			"min": 128,
182			"max": 2048,
183			"increment": 8
184		}],
185		"direction": [
186		  "encrypt",
187		  "decrypt"
188		],
189		"keyLen": [
190		  128,
191		  256
192		]
193	}, {
194		"algorithm": "PBKDF",
195		"revision":"1.0",
196		"capabilities": [{
197			"iterationCount":[{
198				"min":1,
199				"max":10000,
200				"increment":1
201			}],
202			"keyLen": [{
203				"min":112,
204				"max":4096,
205				"increment":8
206			}],
207			"passwordLen":[{
208				"min":8,
209				"max":64,
210				"increment":1
211			}],
212			"saltLen":[{
213				"min":128,
214				"max":512,
215				"increment":8
216			}],
217			"hmacAlg":[
218				"SHA2-224",
219				"SHA2-256",
220				"SHA2-384",
221				"SHA2-512",
222				"SHA2-512/224",
223				"SHA2-512/256",
224				"SHA3-224",
225				"SHA3-256",
226				"SHA3-384",
227				"SHA3-512"
228			]
229		}]
230	}, {
231		"algorithm": "EDDSA",
232		"mode": "keyVer",
233		"revision": "1.0",
234		"curve": ["ED-25519"]
235	}, {
236		"algorithm": "EDDSA",
237		"mode": "sigVer",
238		"revision": "1.0",
239		"pure": true,
240		"preHash": true,
241		"curve": ["ED-25519"]
242	}, {
243		"algorithm": "SHAKE-128",
244		"inBit": false,
245		"outBit": false,
246		"inEmpty": false,
247		"outputLen": [{
248			"min": 128,
249			"max": 4096,
250			"increment": 8
251		}],
252		"revision": "1.0"
253	}, {
254		"algorithm": "SHAKE-256",
255		"inBit": false,
256		"outBit": false,
257		"inEmpty": false,
258		"outputLen": [{
259			"min": 128,
260			"max": 4096,
261			"increment": 8
262		}],
263		"revision": "1.0"
264	}, {
265		"algorithm": "cSHAKE-128",
266		"hexCustomization": false,
267		"outputLen": [{
268			"min": 16,
269			"max": 65536,
270			"increment": 8
271		}],
272		"msgLen": [{
273			"min": 0,
274			"max": 65536,
275			"increment": 8
276		}],
277		"revision": "1.0"
278	}, {
279		"algorithm": "cSHAKE-256",
280		"hexCustomization": false,
281		"outputLen": [{
282			"min": 16,
283			"max": 65536,
284			"increment": 8
285		}],
286		"msgLen": [{
287			"min": 0,
288			"max": 65536,
289			"increment": 8
290		}],
291		"revision": "1.0"
292	}
293]`)); err != nil {
294		return err
295	}
296
297	return flush(nil)
298}
299
300func kdfCounter(args [][]byte) error {
301	if len(args) != 5 {
302		return fmt.Errorf("KDF received %d args", len(args))
303	}
304
305	outputBytes32, prf, counterLocation, key, counterBits32 := args[0], args[1], args[2], args[3], args[4]
306	outputBytes := binary.LittleEndian.Uint32(outputBytes32)
307	counterBits := binary.LittleEndian.Uint32(counterBits32)
308
309	if !bytes.Equal(prf, []byte("HMAC-SHA2-256")) {
310		return fmt.Errorf("KDF received unsupported PRF %q", string(prf))
311	}
312	if !bytes.Equal(counterLocation, []byte("before fixed data")) {
313		return fmt.Errorf("KDF received unsupported counter location %q", counterLocation)
314	}
315	if counterBits != 32 {
316		return fmt.Errorf("KDF received unsupported counter length %d", counterBits)
317	}
318
319	if len(key) == 0 {
320		key = make([]byte, 32)
321		rand.Reader.Read(key)
322	}
323
324	// See https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-108.pdf section 5.1
325	if outputBytes+31 < outputBytes {
326		return fmt.Errorf("KDF received excessive output length %d", outputBytes)
327	}
328
329	n := (outputBytes + 31) / 32
330	result := make([]byte, 0, 32*n)
331	mac := hmac.New(sha256.New, key)
332	var input [4 + 8]byte
333	var digest []byte
334	rand.Reader.Read(input[4:])
335	for i := uint32(1); i <= n; i++ {
336		mac.Reset()
337		binary.BigEndian.PutUint32(input[:4], i)
338		mac.Write(input[:])
339		digest = mac.Sum(digest[:0])
340		result = append(result, digest...)
341	}
342
343	return reply(key, input[4:], result[:outputBytes])
344}
345
346func reply(responses ...[]byte) error {
347	if len(responses) > maxArgs {
348		return fmt.Errorf("%d responses is too many", len(responses))
349	}
350
351	var lengths [4 * (1 + maxArgs)]byte
352	binary.LittleEndian.PutUint32(lengths[:4], uint32(len(responses)))
353	for i, response := range responses {
354		binary.LittleEndian.PutUint32(lengths[4*(i+1):4*(i+2)], uint32(len(response)))
355	}
356
357	lengthsLength := (1 + len(responses)) * 4
358	if n, err := output.Write(lengths[:lengthsLength]); n != lengthsLength || err != nil {
359		return fmt.Errorf("write failed: %s", err)
360	}
361
362	for _, response := range responses {
363		if n, err := output.Write(response); n != len(response) || err != nil {
364			return fmt.Errorf("write failed: %s", err)
365		}
366	}
367
368	return nil
369}
370
371func xtsEncrypt(args [][]byte) error {
372	return doXTS(args, false)
373}
374
375func xtsDecrypt(args [][]byte) error {
376	return doXTS(args, true)
377}
378
379func doXTS(args [][]byte, decrypt bool) error {
380	if len(args) != 3 {
381		return fmt.Errorf("XTS received %d args, wanted 3", len(args))
382	}
383	key := args[0]
384	msg := args[1]
385	tweak := args[2]
386
387	if len(msg)%16 != 0 {
388		return fmt.Errorf("XTS received %d-byte msg, need multiple of 16", len(msg))
389	}
390	if len(tweak) != 16 {
391		return fmt.Errorf("XTS received %d-byte tweak, wanted 16", len(tweak))
392	}
393
394	var zeros [8]byte
395	if !bytes.Equal(tweak[8:], zeros[:]) {
396		return errors.New("XTS received tweak with invalid structure. Ensure that configuration specifies a 'number' tweak")
397	}
398
399	sectorNum := binary.LittleEndian.Uint64(tweak[:8])
400
401	c, err := xts.NewCipher(aes.NewCipher, key)
402	if err != nil {
403		return err
404	}
405
406	if decrypt {
407		c.Decrypt(msg, msg, sectorNum)
408	} else {
409		c.Encrypt(msg, msg, sectorNum)
410	}
411
412	return reply(msg)
413}
414
415func hkdfMAC(args [][]byte) error {
416	if len(args) != 4 {
417		return fmt.Errorf("HKDF received %d args, wanted 4", len(args))
418	}
419
420	key := args[0]
421	salt := args[1]
422	info := args[2]
423	lengthBytes := args[3]
424
425	if len(lengthBytes) != 4 {
426		return fmt.Errorf("uint32 length was %d bytes long", len(lengthBytes))
427	}
428
429	length := binary.LittleEndian.Uint32(lengthBytes)
430
431	ret, err := hkdf.Key(sha256.New, key, salt, string(info), int(length))
432	if err != nil {
433		return err
434	}
435
436	return reply(ret)
437}
438
439func hmacDRBGReseed(args [][]byte) error {
440	if len(args) != 8 {
441		return fmt.Errorf("hmacDRBG received %d args, wanted 8", len(args))
442	}
443
444	outLenBytes, entropy, personalisation, reseedAdditionalData, reseedEntropy, additionalData1, additionalData2, nonce := args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]
445
446	if len(outLenBytes) != 4 {
447		return fmt.Errorf("uint32 length was %d bytes long", len(outLenBytes))
448	}
449	outLen := binary.LittleEndian.Uint32(outLenBytes)
450	out := make([]byte, outLen)
451
452	drbg := NewHMACDRBG(entropy, nonce, personalisation)
453	drbg.Reseed(reseedEntropy, reseedAdditionalData)
454	drbg.Generate(out, additionalData1)
455	drbg.Generate(out, additionalData2)
456
457	return reply(out)
458}
459
460func hmacDRBGPredictionResistance(args [][]byte) error {
461	if len(args) != 8 {
462		return fmt.Errorf("hmacDRBG received %d args, wanted 8", len(args))
463	}
464
465	outLenBytes, entropy, personalisation, additionalData1, entropy1, additionalData2, entropy2, nonce := args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]
466
467	if len(outLenBytes) != 4 {
468		return fmt.Errorf("uint32 length was %d bytes long", len(outLenBytes))
469	}
470	outLen := binary.LittleEndian.Uint32(outLenBytes)
471	out := make([]byte, outLen)
472
473	drbg := NewHMACDRBG(entropy, nonce, personalisation)
474	drbg.Reseed(entropy1, additionalData1)
475	drbg.Generate(out, nil)
476	drbg.Reseed(entropy2, additionalData2)
477	drbg.Generate(out, nil)
478
479	return reply(out)
480}
481
482func swapFinalTwoAESBlocks(d []byte) {
483	var blockNMinus1 [aes.BlockSize]byte
484	copy(blockNMinus1[:], d[len(d)-2*aes.BlockSize:])
485	copy(d[len(d)-2*aes.BlockSize:], d[len(d)-aes.BlockSize:])
486	copy(d[len(d)-aes.BlockSize:], blockNMinus1[:])
487}
488
489func roundUp(n, m int) int {
490	return n + (m-(n%m))%m
491}
492
493func doCTSEncrypt(key, origPlaintext, iv []byte) []byte {
494	// https://nvlpubs.nist.gov/nistpubs/legacy/sp/nistspecialpublication800-38a-add.pdf
495	if len(origPlaintext) < aes.BlockSize {
496		panic("input too small")
497	}
498
499	plaintext := make([]byte, roundUp(len(origPlaintext), aes.BlockSize))
500	copy(plaintext, origPlaintext)
501
502	block, err := aes.NewCipher(key)
503	if err != nil {
504		panic(err)
505	}
506	cbcEncryptor := cipher.NewCBCEncrypter(block, iv)
507	cbcEncryptor.CryptBlocks(plaintext, plaintext)
508	ciphertext := plaintext
509
510	if len(origPlaintext) > aes.BlockSize {
511		swapFinalTwoAESBlocks(ciphertext)
512
513		if len(origPlaintext)%16 != 0 {
514			// Truncate the ciphertext
515			ciphertext = ciphertext[:len(ciphertext)-aes.BlockSize+(len(origPlaintext)%aes.BlockSize)]
516		}
517	}
518
519	if len(ciphertext) != len(origPlaintext) {
520		panic("internal error")
521	}
522
523	return ciphertext
524}
525
526func doCTSDecrypt(key, origCiphertext, iv []byte) []byte {
527	if len(origCiphertext) < aes.BlockSize {
528		panic("input too small")
529	}
530
531	ciphertext := make([]byte, roundUp(len(origCiphertext), aes.BlockSize))
532	copy(ciphertext, origCiphertext)
533
534	if len(ciphertext) > aes.BlockSize {
535		swapFinalTwoAESBlocks(ciphertext)
536	}
537
538	block, err := aes.NewCipher(key)
539	if err != nil {
540		panic(err)
541	}
542	cbcDecrypter := cipher.NewCBCDecrypter(block, iv)
543
544	var plaintext []byte
545	if overhang := len(origCiphertext) % aes.BlockSize; overhang == 0 {
546		cbcDecrypter.CryptBlocks(ciphertext, ciphertext)
547		plaintext = ciphertext
548	} else {
549		ciphertext, finalBlock := ciphertext[:len(ciphertext)-aes.BlockSize], ciphertext[len(ciphertext)-aes.BlockSize:]
550		var plaintextFinalBlock [aes.BlockSize]byte
551		block.Decrypt(plaintextFinalBlock[:], finalBlock)
552		copy(ciphertext[len(ciphertext)-aes.BlockSize+overhang:], plaintextFinalBlock[overhang:])
553		plaintext = make([]byte, len(origCiphertext))
554		cbcDecrypter.CryptBlocks(plaintext, ciphertext)
555		for i := 0; i < overhang; i++ {
556			plaintextFinalBlock[i] ^= ciphertext[len(ciphertext)-aes.BlockSize+i]
557		}
558		copy(plaintext[len(ciphertext):], plaintextFinalBlock[:overhang])
559	}
560
561	return plaintext
562}
563
564func ctsEncrypt(args [][]byte) error {
565	if len(args) != 4 {
566		return fmt.Errorf("ctsEncrypt received %d args, wanted 4", len(args))
567	}
568
569	key, plaintext, iv, numIterations32 := args[0], args[1], args[2], args[3]
570	if len(numIterations32) != 4 || binary.LittleEndian.Uint32(numIterations32) != 1 {
571		return errors.New("only a single iteration supported for ctsEncrypt")
572	}
573
574	if len(plaintext) < aes.BlockSize {
575		return fmt.Errorf("ctsEncrypt plaintext too short: %d bytes", len(plaintext))
576	}
577
578	return reply(doCTSEncrypt(key, plaintext, iv))
579}
580
581func ctsDecrypt(args [][]byte) error {
582	if len(args) != 4 {
583		return fmt.Errorf("ctsDecrypt received %d args, wanted 4", len(args))
584	}
585
586	key, ciphertext, iv, numIterations32 := args[0], args[1], args[2], args[3]
587	if len(numIterations32) != 4 || binary.LittleEndian.Uint32(numIterations32) != 1 {
588		return errors.New("only a single iteration supported for ctsDecrypt")
589	}
590
591	if len(ciphertext) < aes.BlockSize {
592		return errors.New("ctsDecrypt ciphertext too short")
593	}
594
595	return reply(doCTSDecrypt(key, ciphertext, iv))
596}
597
598func pbkdf(args [][]byte) error {
599	if len(args) != 5 {
600		return fmt.Errorf("pbkdf received %d args, wanted 5", len(args))
601	}
602
603	hmacName := args[0]
604	var h func() hash.Hash
605	switch string(hmacName) {
606	case "SHA2-224":
607		h = sha256.New224
608	case "SHA2-256":
609		h = sha256.New
610	case "SHA2-384":
611		h = sha512.New384
612	case "SHA2-512":
613		h = sha512.New
614	case "SHA2-512/224":
615		h = sha512.New512_224
616	case "SHA2-512/256":
617		h = sha512.New512_256
618	case "SHA3-224":
619		h = func() hash.Hash { return sha3.New224() }
620	case "SHA3-256":
621		h = func() hash.Hash { return sha3.New256() }
622	case "SHA3-384":
623		h = func() hash.Hash { return sha3.New384() }
624	case "SHA3-512":
625		h = func() hash.Hash { return sha3.New512() }
626	default:
627		return fmt.Errorf("pbkdf unknown HMAC algorithm: %q", hmacName)
628	}
629	keyLen := binary.LittleEndian.Uint32(args[1]) / 8
630	salt, password := args[2], args[3]
631	iterationCount := binary.LittleEndian.Uint32(args[4])
632
633	derivedKey, err := pbkdf2.Key(h, string(password), salt, int(iterationCount), int(keyLen))
634	if err != nil {
635		return err
636	}
637
638	return reply(derivedKey)
639}
640
641func eddsaKeyGen(args [][]byte) error {
642	if string(args[0]) != "ED-25519" {
643		return fmt.Errorf("unsupported EDDSA curve: %q", args[0])
644	}
645
646	pk, sk, err := ed25519.GenerateKey(nil)
647	if err != nil {
648		return fmt.Errorf("generating EDDSA keypair: %w", err)
649	}
650
651	// EDDSA/keyGen/AFT responses are d & q, described[0] as:
652	//   d	The encoded private key point
653	//   q	The encoded public key point
654	//
655	// Contrary to the description of a "point", d is the private key
656	// seed bytes per FIPS.186-5[1] A.2.3.
657	//
658	// [0]: https://pages.nist.gov/ACVP/draft-celi-acvp-eddsa.html#section-9.1
659	// [1]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-5.pdf
660	return reply(sk.Seed(), pk)
661}
662
663func eddsaKeyVer(args [][]byte) error {
664	if string(args[0]) != "ED-25519" {
665		return fmt.Errorf("unsupported EDDSA curve: %q", args[0])
666	}
667
668	if len(args[1]) != ed25519.PublicKeySize {
669		return reply([]byte{0})
670	}
671
672	// Verify the point is on the curve. The higher-level ed25519 API does
673	// this at signature verification time so we have to use the lower-level
674	// edwards25519 package to do it here in absence of a signature to verify.
675	if _, err := new(edwards25519.Point).SetBytes(args[1]); err != nil {
676		return reply([]byte{0})
677	}
678
679	return reply([]byte{1})
680}
681
682func eddsaSigGen(args [][]byte) error {
683	if string(args[0]) != "ED-25519" {
684		return fmt.Errorf("unsupported EDDSA curve: %q", args[0])
685	}
686
687	sk := ed25519.NewKeyFromSeed(args[1])
688	msg := args[2]
689	prehash := args[3]
690	context := string(args[4])
691
692	var opts ed25519.Options
693	if prehash[0] == 1 {
694		opts.Hash = crypto.SHA512
695		h := sha512.New()
696		h.Write(msg)
697		msg = h.Sum(nil)
698		// With ed25519 the context is only specified for sigGen tests when using prehashing.
699		// See https://pages.nist.gov/ACVP/draft-celi-acvp-eddsa.html#section-8.6
700		opts.Context = context
701	}
702
703	sig, err := sk.Sign(nil, msg, &opts)
704	if err != nil {
705		return fmt.Errorf("error signing message: %w", err)
706	}
707
708	return reply(sig)
709}
710
711func eddsaSigVer(args [][]byte) error {
712	if string(args[0]) != "ED-25519" {
713		return fmt.Errorf("unsupported EDDSA curve: %q", args[0])
714	}
715
716	msg := args[1]
717	pk := ed25519.PublicKey(args[2])
718	sig := args[3]
719	prehash := args[4]
720
721	var opts ed25519.Options
722	if prehash[0] == 1 {
723		opts.Hash = crypto.SHA512
724		h := sha512.New()
725		h.Write(msg)
726		msg = h.Sum(nil)
727		// Context is only specified for sigGen, not sigVer.
728		// See https://pages.nist.gov/ACVP/draft-celi-acvp-eddsa.html#section-8.6
729	}
730
731	if err := ed25519.VerifyWithOptions(pk, msg, sig, &opts); err != nil {
732		return reply([]byte{0})
733	}
734
735	return reply([]byte{1})
736}
737
738func shakeAftVot(digestFn func() *sha3.SHAKE) func([][]byte) error {
739	return func(args [][]byte) error {
740		if len(args) != 2 {
741			return fmt.Errorf("shakeAftVot received %d args, wanted 2", len(args))
742		}
743
744		msg := args[0]
745		outLenBytes := binary.LittleEndian.Uint32(args[1])
746
747		h := digestFn()
748		h.Write(msg)
749		digest := make([]byte, outLenBytes)
750		h.Read(digest)
751
752		return reply(digest)
753	}
754}
755
756func shakeMct(digestFn func() *sha3.SHAKE) func([][]byte) error {
757	return func(args [][]byte) error {
758		if len(args) != 4 {
759			return fmt.Errorf("shakeMct received %d args, wanted 4", len(args))
760		}
761
762		md := args[0]
763		minOutBytes := binary.LittleEndian.Uint32(args[1])
764		maxOutBytes := binary.LittleEndian.Uint32(args[2])
765
766		outputLenBytes := binary.LittleEndian.Uint32(args[3])
767		if outputLenBytes < 2 {
768			return fmt.Errorf("invalid output length: %d", outputLenBytes)
769		}
770
771		if maxOutBytes < minOutBytes {
772			return fmt.Errorf("invalid maxOutBytes and minOutBytes: %d, %d", maxOutBytes, minOutBytes)
773		}
774
775		rangeBytes := maxOutBytes - minOutBytes + 1
776
777		for i := 0; i < 1000; i++ {
778			// "The MSG[i] input to SHAKE MUST always contain at least 128 bits. If this is not the case
779			// as the previous digest was too short, append empty bits to the rightmost side of the digest."
780			boundary := min(len(md), 16)
781			msg := make([]byte, 16)
782			copy(msg, md[:boundary])
783
784			//  MD[i] = SHAKE(MSG[i], OutputLen * 8)
785			h := digestFn()
786			h.Write(msg)
787			digest := make([]byte, outputLenBytes)
788			h.Read(digest)
789			md = digest
790
791			// RightmostOutputBits = 16 rightmost bits of MD[i] as an integer
792			// OutputLen = minOutBytes + (RightmostOutputBits % Range)
793			rightmostOutput := uint32(md[outputLenBytes-2])<<8 | uint32(md[outputLenBytes-1])
794			outputLenBytes = minOutBytes + (rightmostOutput % rangeBytes)
795		}
796
797		encodedOutputLenBytes := make([]byte, 4)
798		binary.LittleEndian.PutUint32(encodedOutputLenBytes, outputLenBytes)
799
800		return reply(md, encodedOutputLenBytes)
801	}
802}
803
804func cShakeAft(hFn func(N, S []byte) *sha3.SHAKE) func([][]byte) error {
805	return func(args [][]byte) error {
806		if len(args) != 4 {
807			return fmt.Errorf("cShakeAft received %d args, wanted 4", len(args))
808		}
809
810		msg := args[0]
811		outLenBytes := binary.LittleEndian.Uint32(args[1])
812		functionName := args[2]
813		customization := args[3]
814
815		h := hFn(functionName, customization)
816		h.Write(msg)
817		digest := make([]byte, outLenBytes)
818		h.Read(digest)
819
820		return reply(digest)
821	}
822}
823
824func cShakeMct(hFn func(N, S []byte) *sha3.SHAKE) func([][]byte) error {
825	return func(args [][]byte) error {
826		if len(args) != 6 {
827			return fmt.Errorf("cShakeMct received %d args, wanted 6", len(args))
828		}
829
830		message := args[0]
831		minOutLenBytes := binary.LittleEndian.Uint32(args[1])
832		maxOutLenBytes := binary.LittleEndian.Uint32(args[2])
833		outputLenBytes := binary.LittleEndian.Uint32(args[3])
834		incrementBytes := binary.LittleEndian.Uint32(args[4])
835		customization := args[5]
836
837		if outputLenBytes < 2 {
838			return fmt.Errorf("invalid output length: %d", outputLenBytes)
839		}
840
841		rangeBits := (maxOutLenBytes*8 - minOutLenBytes*8) + 1
842		if rangeBits == 0 {
843			return fmt.Errorf("invalid maxOutLenBytes and minOutLenBytes: %d, %d", maxOutLenBytes, minOutLenBytes)
844		}
845
846		// cSHAKE Monte Carlo test inner loop:
847		//   https://pages.nist.gov/ACVP/draft-celi-acvp-xof.html#section-6.2.1
848		for i := 0; i < 1000; i++ {
849			// InnerMsg = Left(Output[i-1] || ZeroBits(128), 128);
850			boundary := min(len(message), 16)
851			innerMsg := make([]byte, 16)
852			copy(innerMsg, message[:boundary])
853
854			// Output[i] = CSHAKE(InnerMsg, OutputLen, FunctionName, Customization);
855			h := hFn(nil, customization) // Note: function name fixed to "" for MCT.
856			h.Write(innerMsg)
857			digest := make([]byte, outputLenBytes)
858			h.Read(digest)
859			message = digest
860
861			// Rightmost_Output_bits = Right(Output[i], 16);
862			rightmostOutput := digest[outputLenBytes-2:]
863			// IMPORTANT: the specification says:
864			//   NOTE: For the "Rightmost_Output_bits % Range" operation, the Rightmost_Output_bits bit string
865			//   should be interpreted as a little endian-encoded number.
866			// This is **a lie**! It has to be interpreted as a big-endian number.
867			rightmostOutputBE := binary.BigEndian.Uint16(rightmostOutput)
868
869			// OutputLen = MinOutLen + (floor((Rightmost_Output_bits % Range) / OutLenIncrement) * OutLenIncrement);
870			incrementBits := incrementBytes * 8
871			outputLenBits := (minOutLenBytes * 8) + (((uint32)(rightmostOutputBE)%rangeBits)/incrementBits)*incrementBits
872			outputLenBytes = outputLenBits / 8
873
874			// Customization = BitsToString(InnerMsg || Rightmost_Output_bits);
875			msgWithBits := append(innerMsg, rightmostOutput...)
876			customization = make([]byte, len(msgWithBits))
877			for i, b := range msgWithBits {
878				customization[i] = (b % 26) + 65
879			}
880		}
881
882		encodedOutputLenBytes := make([]byte, 4)
883		binary.LittleEndian.PutUint32(encodedOutputLenBytes, outputLenBytes)
884
885		return reply(message, encodedOutputLenBytes, customization)
886	}
887}
888
889const (
890	maxArgs       = 9
891	maxArgLength  = 1 << 20
892	maxNameLength = 30
893)
894
895func main() {
896	if err := do(); err != nil {
897		fmt.Fprintf(os.Stderr, "%s.\n", err)
898		os.Exit(1)
899	}
900}
901
902func do() error {
903	// In order to exercise pipelining, all output is buffered until a "flush".
904	outputBuffer = new(bytes.Buffer)
905	output = outputBuffer
906
907	var nums [4 * (1 + maxArgs)]byte
908	var argLengths [maxArgs]uint32
909	var args [maxArgs][]byte
910	var argsData []byte
911
912	for {
913		if _, err := io.ReadFull(os.Stdin, nums[:8]); err != nil {
914			return err
915		}
916
917		numArgs := binary.LittleEndian.Uint32(nums[:4])
918		if numArgs == 0 {
919			return errors.New("Invalid, zero-argument operation requested")
920		} else if numArgs > maxArgs {
921			return fmt.Errorf("Operation requested with %d args, but %d is the limit", numArgs, maxArgs)
922		}
923
924		if numArgs > 1 {
925			if _, err := io.ReadFull(os.Stdin, nums[8:4+4*numArgs]); err != nil {
926				return err
927			}
928		}
929
930		input := nums[4:]
931		var need uint64
932		for i := uint32(0); i < numArgs; i++ {
933			argLength := binary.LittleEndian.Uint32(input[:4])
934			if i == 0 && argLength > maxNameLength {
935				return fmt.Errorf("Operation with name of length %d exceeded limit of %d", argLength, maxNameLength)
936			} else if argLength > maxArgLength {
937				return fmt.Errorf("Operation with argument of length %d exceeded limit of %d", argLength, maxArgLength)
938			}
939			need += uint64(argLength)
940			argLengths[i] = argLength
941			input = input[4:]
942		}
943
944		if need > uint64(cap(argsData)) {
945			argsData = make([]byte, need)
946		} else {
947			argsData = argsData[:need]
948		}
949
950		if _, err := io.ReadFull(os.Stdin, argsData); err != nil {
951			return err
952		}
953
954		input = argsData
955		for i := uint32(0); i < numArgs; i++ {
956			args[i] = input[:argLengths[i]]
957			input = input[argLengths[i]:]
958		}
959
960		name := string(args[0])
961		if handler, ok := handlers[name]; !ok {
962			return fmt.Errorf("unknown operation %q", name)
963		} else {
964			if err := handler(args[1:numArgs]); err != nil {
965				return err
966			}
967		}
968	}
969}
970