1// Copyright 2020 The BoringSSL Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package hpke
16
17import (
18	"bytes"
19	_ "crypto/sha256"
20	_ "crypto/sha512"
21	_ "embed"
22	"encoding/hex"
23	"encoding/json"
24	"errors"
25	"fmt"
26	"testing"
27)
28
29const (
30	exportOnlyAEAD uint16 = 0xffff
31)
32
33//go:embed testdata/test-vectors.json
34var testVectorsJSON []byte
35
36// Simple round-trip test for fixed inputs.
37func TestRoundTrip(t *testing.T) {
38	publicKeyR, secretKeyR, err := GenerateKeyPairX25519()
39	if err != nil {
40		t.Errorf("failed to generate key pair: %s", err)
41		return
42	}
43
44	// Set up the sender and receiver contexts.
45	senderContext, enc, err := SetupBaseSenderX25519(HKDFSHA256, AES256GCM, publicKeyR, nil, nil)
46	if err != nil {
47		t.Errorf("failed to set up sender: %s", err)
48		return
49	}
50	receiverContext, err := SetupBaseReceiverX25519(HKDFSHA256, AES256GCM, enc, secretKeyR, nil)
51	if err != nil {
52		t.Errorf("failed to set up receiver: %s", err)
53		return
54	}
55
56	// Seal() our plaintext with the sender context, then Open() the
57	// ciphertext with the receiver context.
58	plaintext := []byte("foobar")
59	ciphertext := senderContext.Seal(plaintext, nil)
60	decrypted, err := receiverContext.Open(ciphertext, nil)
61	if err != nil {
62		t.Errorf("encryption round trip failed: %s", err)
63		return
64	}
65	checkBytesEqual(t, "decrypted", decrypted, plaintext)
66}
67
68// HpkeTestVector defines the subset of test-vectors.json that we read.
69type HpkeTestVector struct {
70	KEM         uint16                 `json:"kem_id"`
71	Mode        uint8                  `json:"mode"`
72	KDF         uint16                 `json:"kdf_id"`
73	AEAD        uint16                 `json:"aead_id"`
74	Info        HexString              `json:"info"`
75	PSK         HexString              `json:"psk"`
76	PSKID       HexString              `json:"psk_id"`
77	SecretKeyR  HexString              `json:"skRm"`
78	SecretKeyE  HexString              `json:"skEm"`
79	PublicKeyR  HexString              `json:"pkRm"`
80	PublicKeyE  HexString              `json:"pkEm"`
81	Enc         HexString              `json:"enc"`
82	Encryptions []EncryptionTestVector `json:"encryptions"`
83	Exports     []ExportTestVector     `json:"exports"`
84}
85type EncryptionTestVector struct {
86	Plaintext      HexString `json:"pt"`
87	AdditionalData HexString `json:"aad"`
88	Ciphertext     HexString `json:"ct"`
89}
90type ExportTestVector struct {
91	ExportContext HexString `json:"exporter_context"`
92	ExportLength  int       `json:"L"`
93	ExportValue   HexString `json:"exported_value"`
94}
95
96// TestVectors checks all relevant test vectors in test-vectors.json.
97func TestVectors(t *testing.T) {
98	var testVectors []HpkeTestVector
99	if err := json.Unmarshal(testVectorsJSON, &testVectors); err != nil {
100		t.Errorf("error parsing test vectors: %s", err)
101		return
102	}
103
104	var numSkippedTests = 0
105
106	for testNum, testVec := range testVectors {
107		// Skip this vector if it specifies an unsupported parameter.
108		if testVec.KEM != X25519WithHKDFSHA256 ||
109			(testVec.Mode != hpkeModeBase && testVec.Mode != hpkeModePSK) ||
110			testVec.AEAD == exportOnlyAEAD {
111			numSkippedTests++
112			continue
113		}
114
115		testVec := testVec // capture the range variable
116		t.Run(fmt.Sprintf("test%d,Mode=%d,KDF=%d,AEAD=%d", testNum, testVec.Mode, testVec.KDF, testVec.AEAD), func(t *testing.T) {
117			var senderContext *Context
118			var receiverContext *Context
119			var enc []byte
120			var err error
121
122			switch testVec.Mode {
123			case hpkeModeBase:
124				senderContext, enc, err = SetupBaseSenderX25519(testVec.KDF, testVec.AEAD, testVec.PublicKeyR, testVec.Info,
125					func() ([]byte, []byte, error) {
126						return testVec.PublicKeyE, testVec.SecretKeyE, nil
127					})
128				if err != nil {
129					t.Errorf("failed to set up sender: %s", err)
130					return
131				}
132				checkBytesEqual(t, "sender enc", enc, testVec.Enc)
133
134				receiverContext, err = SetupBaseReceiverX25519(testVec.KDF, testVec.AEAD, enc, testVec.SecretKeyR, testVec.Info)
135				if err != nil {
136					t.Errorf("failed to set up receiver: %s", err)
137					return
138				}
139			case hpkeModePSK:
140				senderContext, enc, err = SetupPSKSenderX25519(testVec.KDF, testVec.AEAD, testVec.PublicKeyR, testVec.Info, testVec.PSK, testVec.PSKID,
141					func() ([]byte, []byte, error) {
142						return testVec.PublicKeyE, testVec.SecretKeyE, nil
143					})
144				if err != nil {
145					t.Errorf("failed to set up sender: %s", err)
146					return
147				}
148				checkBytesEqual(t, "sender enc", enc, testVec.Enc)
149
150				receiverContext, err = SetupPSKReceiverX25519(testVec.KDF, testVec.AEAD, enc, testVec.SecretKeyR, testVec.Info, testVec.PSK, testVec.PSKID)
151				if err != nil {
152					t.Errorf("failed to set up receiver: %s", err)
153					return
154				}
155			default:
156				panic("unsupported mode")
157			}
158
159			for encryptionNum, e := range testVec.Encryptions {
160				ciphertext := senderContext.Seal(e.Plaintext, e.AdditionalData)
161				checkBytesEqual(t, "ciphertext", ciphertext, e.Ciphertext)
162
163				decrypted, err := receiverContext.Open(ciphertext, e.AdditionalData)
164				if err != nil {
165					t.Errorf("decryption %d failed: %s", encryptionNum, err)
166					return
167				}
168				checkBytesEqual(t, "decrypted plaintext", decrypted, e.Plaintext)
169			}
170
171			for _, ex := range testVec.Exports {
172				exportValue := senderContext.Export(ex.ExportContext, ex.ExportLength)
173				checkBytesEqual(t, "exportValue", exportValue, ex.ExportValue)
174
175				exportValue = receiverContext.Export(ex.ExportContext, ex.ExportLength)
176				checkBytesEqual(t, "exportValue", exportValue, ex.ExportValue)
177			}
178		})
179	}
180
181	if numSkippedTests == len(testVectors) {
182		panic("no test vectors were used")
183	}
184}
185
186// HexString enables us to unmarshal JSON strings containing hex byte strings.
187type HexString []byte
188
189func (h *HexString) UnmarshalJSON(data []byte) error {
190	if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
191		return errors.New("missing double quotes")
192	}
193	var err error
194	*h, err = hex.DecodeString(string(data[1 : len(data)-1]))
195	return err
196}
197
198func checkBytesEqual(t *testing.T, name string, actual, expected []byte) {
199	if !bytes.Equal(actual, expected) {
200		t.Errorf("%s = %x; want %x", name, actual, expected)
201	}
202}
203