1// Copyright 2016 The BoringSSL Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15//go:build ignore 16 17package main 18 19import ( 20 "bufio" 21 "errors" 22 "fmt" 23 "io" 24 "math/big" 25 "os" 26 "strings" 27) 28 29type test struct { 30 LineNumber int 31 Type string 32 Values map[string]*big.Int 33} 34 35type testScanner struct { 36 scanner *bufio.Scanner 37 lineNo int 38 err error 39 test test 40} 41 42func newTestScanner(r io.Reader) *testScanner { 43 return &testScanner{scanner: bufio.NewScanner(r)} 44} 45 46func (s *testScanner) scanLine() bool { 47 if !s.scanner.Scan() { 48 return false 49 } 50 s.lineNo++ 51 return true 52} 53 54func (s *testScanner) addAttribute(line string) (key string, ok bool) { 55 fields := strings.SplitN(line, "=", 2) 56 if len(fields) != 2 { 57 s.setError(errors.New("invalid syntax")) 58 return "", false 59 } 60 61 key = strings.TrimSpace(fields[0]) 62 value := strings.TrimSpace(fields[1]) 63 64 valueInt, ok := new(big.Int).SetString(value, 16) 65 if !ok { 66 s.setError(fmt.Errorf("could not parse %q", value)) 67 return "", false 68 } 69 if _, dup := s.test.Values[key]; dup { 70 s.setError(fmt.Errorf("duplicate key %q", key)) 71 return "", false 72 } 73 s.test.Values[key] = valueInt 74 return key, true 75} 76 77func (s *testScanner) Scan() bool { 78 s.test = test{ 79 Values: make(map[string]*big.Int), 80 } 81 82 // Scan until the first attribute. 83 for { 84 if !s.scanLine() { 85 return false 86 } 87 if len(s.scanner.Text()) != 0 && s.scanner.Text()[0] != '#' { 88 break 89 } 90 } 91 92 var ok bool 93 s.test.Type, ok = s.addAttribute(s.scanner.Text()) 94 if !ok { 95 return false 96 } 97 s.test.LineNumber = s.lineNo 98 99 for s.scanLine() { 100 if len(s.scanner.Text()) == 0 { 101 break 102 } 103 104 if s.scanner.Text()[0] == '#' { 105 continue 106 } 107 108 if _, ok := s.addAttribute(s.scanner.Text()); !ok { 109 return false 110 } 111 } 112 return s.scanner.Err() == nil 113} 114 115func (s *testScanner) Test() test { 116 return s.test 117} 118 119func (s *testScanner) Err() error { 120 if s.err != nil { 121 return s.err 122 } 123 return s.scanner.Err() 124} 125 126func (s *testScanner) setError(err error) { 127 s.err = fmt.Errorf("line %d: %s", s.lineNo, err) 128} 129 130func checkKeys(t test, keys ...string) bool { 131 var foundErrors bool 132 133 for _, k := range keys { 134 if _, ok := t.Values[k]; !ok { 135 fmt.Fprintf(os.Stderr, "Line %d: missing key %q.\n", t.LineNumber, k) 136 foundErrors = true 137 } 138 } 139 140 for k := range t.Values { 141 var found bool 142 for _, k2 := range keys { 143 if k == k2 { 144 found = true 145 break 146 } 147 } 148 if !found { 149 fmt.Fprintf(os.Stderr, "Line %d: unexpected key %q.\n", t.LineNumber, k) 150 foundErrors = true 151 } 152 } 153 154 return !foundErrors 155} 156 157func checkResult(t test, expr, key string, r *big.Int) { 158 if t.Values[key].Cmp(r) != 0 { 159 fmt.Fprintf(os.Stderr, "Line %d: %s did not match %s.\n\tGot %s\n", t.LineNumber, expr, key, r.Text(16)) 160 } 161} 162 163func main() { 164 if len(os.Args) != 2 { 165 fmt.Fprintf(os.Stderr, "Usage: %s bn_tests.txt\n", os.Args[0]) 166 os.Exit(1) 167 } 168 169 in, err := os.Open(os.Args[1]) 170 if err != nil { 171 fmt.Fprintf(os.Stderr, "Error opening %s: %s.\n", os.Args[0], err) 172 os.Exit(1) 173 } 174 defer in.Close() 175 176 scanner := newTestScanner(in) 177 for scanner.Scan() { 178 test := scanner.Test() 179 switch test.Type { 180 case "Sum": 181 if checkKeys(test, "A", "B", "Sum") { 182 r := new(big.Int).Add(test.Values["A"], test.Values["B"]) 183 checkResult(test, "A + B", "Sum", r) 184 } 185 case "LShift1": 186 if checkKeys(test, "A", "LShift1") { 187 r := new(big.Int).Add(test.Values["A"], test.Values["A"]) 188 checkResult(test, "A + A", "LShift1", r) 189 } 190 case "LShift": 191 if checkKeys(test, "A", "N", "LShift") { 192 r := new(big.Int).Lsh(test.Values["A"], uint(test.Values["N"].Uint64())) 193 checkResult(test, "A << N", "LShift", r) 194 } 195 case "RShift": 196 if checkKeys(test, "A", "N", "RShift") { 197 r := new(big.Int).Rsh(test.Values["A"], uint(test.Values["N"].Uint64())) 198 checkResult(test, "A >> N", "RShift", r) 199 } 200 case "Square": 201 if checkKeys(test, "A", "Square") { 202 r := new(big.Int).Mul(test.Values["A"], test.Values["A"]) 203 checkResult(test, "A * A", "Square", r) 204 } 205 case "Product": 206 if checkKeys(test, "A", "B", "Product") { 207 r := new(big.Int).Mul(test.Values["A"], test.Values["B"]) 208 checkResult(test, "A * B", "Product", r) 209 } 210 case "Quotient": 211 if checkKeys(test, "A", "B", "Quotient", "Remainder") { 212 q, r := new(big.Int).QuoRem(test.Values["A"], test.Values["B"], new(big.Int)) 213 checkResult(test, "A / B", "Quotient", q) 214 checkResult(test, "A % B", "Remainder", r) 215 } 216 case "ModMul": 217 if checkKeys(test, "A", "B", "M", "ModMul") { 218 r := new(big.Int).Mul(test.Values["A"], test.Values["B"]) 219 r = r.Mod(r, test.Values["M"]) 220 checkResult(test, "A * B (mod M)", "ModMul", r) 221 } 222 case "ModExp": 223 if checkKeys(test, "A", "E", "M", "ModExp") { 224 r := new(big.Int).Exp(test.Values["A"], test.Values["E"], test.Values["M"]) 225 checkResult(test, "A ^ E (mod M)", "ModExp", r) 226 } 227 case "Exp": 228 if checkKeys(test, "A", "E", "Exp") { 229 r := new(big.Int).Exp(test.Values["A"], test.Values["E"], nil) 230 checkResult(test, "A ^ E", "Exp", r) 231 } 232 case "ModSqrt": 233 bigOne := big.NewInt(1) 234 bigTwo := big.NewInt(2) 235 236 if checkKeys(test, "A", "P", "ModSqrt") { 237 test.Values["A"].Mod(test.Values["A"], test.Values["P"]) 238 239 r := new(big.Int).Mul(test.Values["ModSqrt"], test.Values["ModSqrt"]) 240 r = r.Mod(r, test.Values["P"]) 241 checkResult(test, "ModSqrt ^ 2 (mod P)", "A", r) 242 243 if test.Values["P"].Cmp(bigTwo) > 0 { 244 pMinus1Over2 := new(big.Int).Sub(test.Values["P"], bigOne) 245 pMinus1Over2.Rsh(pMinus1Over2, 1) 246 247 if test.Values["ModSqrt"].Cmp(pMinus1Over2) > 0 { 248 fmt.Fprintf(os.Stderr, "Line %d: ModSqrt should be minimal.\n", test.LineNumber) 249 } 250 } 251 } 252 case "ModInv": 253 if checkKeys(test, "A", "M", "ModInv") { 254 a := test.Values["A"] 255 m := test.Values["M"] 256 var r *big.Int 257 if a.Sign() == 0 && m.IsInt64() && m.Int64() == 1 { 258 // OpenSSL says 0^(-1) mod (1) is 0, while Go says the 259 // inverse does not exist. 260 r = big.NewInt(0) 261 } else { 262 r = new(big.Int).ModInverse(a, m) 263 } 264 if r == nil { 265 fmt.Fprintf(os.Stderr, "Line %d: A has no inverse mod M.\n", test.LineNumber) 266 } else { 267 checkResult(test, "A ^ -1 (mod M)", "ModInv", r) 268 } 269 } 270 case "ModSquare": 271 if checkKeys(test, "A", "M", "ModSquare") { 272 r := new(big.Int).Mul(test.Values["A"], test.Values["A"]) 273 r = r.Mod(r, test.Values["M"]) 274 checkResult(test, "A * A (mod M)", "ModSquare", r) 275 } 276 case "NotModSquare": 277 if checkKeys(test, "P", "NotModSquare") { 278 if new(big.Int).ModSqrt(test.Values["NotModSquare"], test.Values["P"]) != nil { 279 fmt.Fprintf(os.Stderr, "Line %d: value was a square.\n", test.LineNumber) 280 } 281 } 282 case "GCD": 283 if checkKeys(test, "A", "B", "GCD", "LCM") { 284 a := test.Values["A"] 285 b := test.Values["B"] 286 // Go's GCD function does not accept zero, unlike OpenSSL. 287 var g *big.Int 288 if a.Sign() == 0 { 289 g = b 290 } else if b.Sign() == 0 { 291 g = a 292 } else { 293 g = new(big.Int).GCD(nil, nil, a, b) 294 } 295 checkResult(test, "GCD(A, B)", "GCD", g) 296 if g.Sign() != 0 { 297 lcm := new(big.Int).Mul(a, b) 298 lcm = lcm.Div(lcm, g) 299 checkResult(test, "LCM(A, B)", "LCM", lcm) 300 } 301 } 302 default: 303 fmt.Fprintf(os.Stderr, "Line %d: unknown test type %q.\n", test.LineNumber, test.Type) 304 } 305 } 306 if scanner.Err() != nil { 307 fmt.Fprintf(os.Stderr, "Error reading tests: %s.\n", scanner.Err()) 308 } 309} 310