1// Copyright 2016 The BoringSSL Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//go:build ignore
16
17package main
18
19import (
20	"bufio"
21	"errors"
22	"fmt"
23	"io"
24	"math/big"
25	"os"
26	"strings"
27)
28
29type test struct {
30	LineNumber int
31	Type       string
32	Values     map[string]*big.Int
33}
34
35type testScanner struct {
36	scanner *bufio.Scanner
37	lineNo  int
38	err     error
39	test    test
40}
41
42func newTestScanner(r io.Reader) *testScanner {
43	return &testScanner{scanner: bufio.NewScanner(r)}
44}
45
46func (s *testScanner) scanLine() bool {
47	if !s.scanner.Scan() {
48		return false
49	}
50	s.lineNo++
51	return true
52}
53
54func (s *testScanner) addAttribute(line string) (key string, ok bool) {
55	fields := strings.SplitN(line, "=", 2)
56	if len(fields) != 2 {
57		s.setError(errors.New("invalid syntax"))
58		return "", false
59	}
60
61	key = strings.TrimSpace(fields[0])
62	value := strings.TrimSpace(fields[1])
63
64	valueInt, ok := new(big.Int).SetString(value, 16)
65	if !ok {
66		s.setError(fmt.Errorf("could not parse %q", value))
67		return "", false
68	}
69	if _, dup := s.test.Values[key]; dup {
70		s.setError(fmt.Errorf("duplicate key %q", key))
71		return "", false
72	}
73	s.test.Values[key] = valueInt
74	return key, true
75}
76
77func (s *testScanner) Scan() bool {
78	s.test = test{
79		Values: make(map[string]*big.Int),
80	}
81
82	// Scan until the first attribute.
83	for {
84		if !s.scanLine() {
85			return false
86		}
87		if len(s.scanner.Text()) != 0 && s.scanner.Text()[0] != '#' {
88			break
89		}
90	}
91
92	var ok bool
93	s.test.Type, ok = s.addAttribute(s.scanner.Text())
94	if !ok {
95		return false
96	}
97	s.test.LineNumber = s.lineNo
98
99	for s.scanLine() {
100		if len(s.scanner.Text()) == 0 {
101			break
102		}
103
104		if s.scanner.Text()[0] == '#' {
105			continue
106		}
107
108		if _, ok := s.addAttribute(s.scanner.Text()); !ok {
109			return false
110		}
111	}
112	return s.scanner.Err() == nil
113}
114
115func (s *testScanner) Test() test {
116	return s.test
117}
118
119func (s *testScanner) Err() error {
120	if s.err != nil {
121		return s.err
122	}
123	return s.scanner.Err()
124}
125
126func (s *testScanner) setError(err error) {
127	s.err = fmt.Errorf("line %d: %s", s.lineNo, err)
128}
129
130func checkKeys(t test, keys ...string) bool {
131	var foundErrors bool
132
133	for _, k := range keys {
134		if _, ok := t.Values[k]; !ok {
135			fmt.Fprintf(os.Stderr, "Line %d: missing key %q.\n", t.LineNumber, k)
136			foundErrors = true
137		}
138	}
139
140	for k := range t.Values {
141		var found bool
142		for _, k2 := range keys {
143			if k == k2 {
144				found = true
145				break
146			}
147		}
148		if !found {
149			fmt.Fprintf(os.Stderr, "Line %d: unexpected key %q.\n", t.LineNumber, k)
150			foundErrors = true
151		}
152	}
153
154	return !foundErrors
155}
156
157func checkResult(t test, expr, key string, r *big.Int) {
158	if t.Values[key].Cmp(r) != 0 {
159		fmt.Fprintf(os.Stderr, "Line %d: %s did not match %s.\n\tGot %s\n", t.LineNumber, expr, key, r.Text(16))
160	}
161}
162
163func main() {
164	if len(os.Args) != 2 {
165		fmt.Fprintf(os.Stderr, "Usage: %s bn_tests.txt\n", os.Args[0])
166		os.Exit(1)
167	}
168
169	in, err := os.Open(os.Args[1])
170	if err != nil {
171		fmt.Fprintf(os.Stderr, "Error opening %s: %s.\n", os.Args[0], err)
172		os.Exit(1)
173	}
174	defer in.Close()
175
176	scanner := newTestScanner(in)
177	for scanner.Scan() {
178		test := scanner.Test()
179		switch test.Type {
180		case "Sum":
181			if checkKeys(test, "A", "B", "Sum") {
182				r := new(big.Int).Add(test.Values["A"], test.Values["B"])
183				checkResult(test, "A + B", "Sum", r)
184			}
185		case "LShift1":
186			if checkKeys(test, "A", "LShift1") {
187				r := new(big.Int).Add(test.Values["A"], test.Values["A"])
188				checkResult(test, "A + A", "LShift1", r)
189			}
190		case "LShift":
191			if checkKeys(test, "A", "N", "LShift") {
192				r := new(big.Int).Lsh(test.Values["A"], uint(test.Values["N"].Uint64()))
193				checkResult(test, "A << N", "LShift", r)
194			}
195		case "RShift":
196			if checkKeys(test, "A", "N", "RShift") {
197				r := new(big.Int).Rsh(test.Values["A"], uint(test.Values["N"].Uint64()))
198				checkResult(test, "A >> N", "RShift", r)
199			}
200		case "Square":
201			if checkKeys(test, "A", "Square") {
202				r := new(big.Int).Mul(test.Values["A"], test.Values["A"])
203				checkResult(test, "A * A", "Square", r)
204			}
205		case "Product":
206			if checkKeys(test, "A", "B", "Product") {
207				r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
208				checkResult(test, "A * B", "Product", r)
209			}
210		case "Quotient":
211			if checkKeys(test, "A", "B", "Quotient", "Remainder") {
212				q, r := new(big.Int).QuoRem(test.Values["A"], test.Values["B"], new(big.Int))
213				checkResult(test, "A / B", "Quotient", q)
214				checkResult(test, "A % B", "Remainder", r)
215			}
216		case "ModMul":
217			if checkKeys(test, "A", "B", "M", "ModMul") {
218				r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
219				r = r.Mod(r, test.Values["M"])
220				checkResult(test, "A * B (mod M)", "ModMul", r)
221			}
222		case "ModExp":
223			if checkKeys(test, "A", "E", "M", "ModExp") {
224				r := new(big.Int).Exp(test.Values["A"], test.Values["E"], test.Values["M"])
225				checkResult(test, "A ^ E (mod M)", "ModExp", r)
226			}
227		case "Exp":
228			if checkKeys(test, "A", "E", "Exp") {
229				r := new(big.Int).Exp(test.Values["A"], test.Values["E"], nil)
230				checkResult(test, "A ^ E", "Exp", r)
231			}
232		case "ModSqrt":
233			bigOne := big.NewInt(1)
234			bigTwo := big.NewInt(2)
235
236			if checkKeys(test, "A", "P", "ModSqrt") {
237				test.Values["A"].Mod(test.Values["A"], test.Values["P"])
238
239				r := new(big.Int).Mul(test.Values["ModSqrt"], test.Values["ModSqrt"])
240				r = r.Mod(r, test.Values["P"])
241				checkResult(test, "ModSqrt ^ 2 (mod P)", "A", r)
242
243				if test.Values["P"].Cmp(bigTwo) > 0 {
244					pMinus1Over2 := new(big.Int).Sub(test.Values["P"], bigOne)
245					pMinus1Over2.Rsh(pMinus1Over2, 1)
246
247					if test.Values["ModSqrt"].Cmp(pMinus1Over2) > 0 {
248						fmt.Fprintf(os.Stderr, "Line %d: ModSqrt should be minimal.\n", test.LineNumber)
249					}
250				}
251			}
252		case "ModInv":
253			if checkKeys(test, "A", "M", "ModInv") {
254				a := test.Values["A"]
255				m := test.Values["M"]
256				var r *big.Int
257				if a.Sign() == 0 && m.IsInt64() && m.Int64() == 1 {
258					// OpenSSL says 0^(-1) mod (1) is 0, while Go says the
259					// inverse does not exist.
260					r = big.NewInt(0)
261				} else {
262					r = new(big.Int).ModInverse(a, m)
263				}
264				if r == nil {
265					fmt.Fprintf(os.Stderr, "Line %d: A has no inverse mod M.\n", test.LineNumber)
266				} else {
267					checkResult(test, "A ^ -1 (mod M)", "ModInv", r)
268				}
269			}
270		case "ModSquare":
271			if checkKeys(test, "A", "M", "ModSquare") {
272				r := new(big.Int).Mul(test.Values["A"], test.Values["A"])
273				r = r.Mod(r, test.Values["M"])
274				checkResult(test, "A * A (mod M)", "ModSquare", r)
275			}
276		case "NotModSquare":
277			if checkKeys(test, "P", "NotModSquare") {
278				if new(big.Int).ModSqrt(test.Values["NotModSquare"], test.Values["P"]) != nil {
279					fmt.Fprintf(os.Stderr, "Line %d: value was a square.\n", test.LineNumber)
280				}
281			}
282		case "GCD":
283			if checkKeys(test, "A", "B", "GCD", "LCM") {
284				a := test.Values["A"]
285				b := test.Values["B"]
286				// Go's GCD function does not accept zero, unlike OpenSSL.
287				var g *big.Int
288				if a.Sign() == 0 {
289					g = b
290				} else if b.Sign() == 0 {
291					g = a
292				} else {
293					g = new(big.Int).GCD(nil, nil, a, b)
294				}
295				checkResult(test, "GCD(A, B)", "GCD", g)
296				if g.Sign() != 0 {
297					lcm := new(big.Int).Mul(a, b)
298					lcm = lcm.Div(lcm, g)
299					checkResult(test, "LCM(A, B)", "LCM", lcm)
300				}
301			}
302		default:
303			fmt.Fprintf(os.Stderr, "Line %d: unknown test type %q.\n", test.LineNumber, test.Type)
304		}
305	}
306	if scanner.Err() != nil {
307		fmt.Fprintf(os.Stderr, "Error reading tests: %s.\n", scanner.Err())
308	}
309}
310