1// Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <openssl/bn.h>
16
17#include <assert.h>
18
19#include <openssl/mem.h>
20
21#include "internal.h"
22#include "../../internal.h"
23
24
25static int bn_cmp_words_consttime(const BN_ULONG *a, size_t a_len,
26                                  const BN_ULONG *b, size_t b_len) {
27  static_assert(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
28                "crypto_word_t is too small");
29  int ret = 0;
30  // Process the common words in little-endian order.
31  size_t min = a_len < b_len ? a_len : b_len;
32  for (size_t i = 0; i < min; i++) {
33    crypto_word_t eq = constant_time_eq_w(a[i], b[i]);
34    crypto_word_t lt = constant_time_lt_w(a[i], b[i]);
35    ret =
36        constant_time_select_int(eq, ret, constant_time_select_int(lt, -1, 1));
37  }
38
39  // If |a| or |b| has non-zero words beyond |min|, they take precedence.
40  if (a_len < b_len) {
41    crypto_word_t mask = 0;
42    for (size_t i = a_len; i < b_len; i++) {
43      mask |= b[i];
44    }
45    ret = constant_time_select_int(constant_time_is_zero_w(mask), ret, -1);
46  } else if (b_len < a_len) {
47    crypto_word_t mask = 0;
48    for (size_t i = b_len; i < a_len; i++) {
49      mask |= a[i];
50    }
51    ret = constant_time_select_int(constant_time_is_zero_w(mask), ret, 1);
52  }
53
54  return ret;
55}
56
57int BN_ucmp(const BIGNUM *a, const BIGNUM *b) {
58  return bn_cmp_words_consttime(a->d, a->width, b->d, b->width);
59}
60
61int BN_cmp(const BIGNUM *a, const BIGNUM *b) {
62  if ((a == NULL) || (b == NULL)) {
63    if (a != NULL) {
64      return -1;
65    } else if (b != NULL) {
66      return 1;
67    } else {
68      return 0;
69    }
70  }
71
72  // We do not attempt to process the sign bit in constant time. Negative
73  // |BIGNUM|s should never occur in crypto, only calculators.
74  if (a->neg != b->neg) {
75    if (a->neg) {
76      return -1;
77    }
78    return 1;
79  }
80
81  int ret = BN_ucmp(a, b);
82  return a->neg ? -ret : ret;
83}
84
85int bn_less_than_words(const BN_ULONG *a, const BN_ULONG *b, size_t len) {
86  return bn_cmp_words_consttime(a, len, b, len) < 0;
87}
88
89int BN_abs_is_word(const BIGNUM *bn, BN_ULONG w) {
90  if (bn->width == 0) {
91    return w == 0;
92  }
93  BN_ULONG mask = bn->d[0] ^ w;
94  for (int i = 1; i < bn->width; i++) {
95    mask |= bn->d[i];
96  }
97  return mask == 0;
98}
99
100int BN_cmp_word(const BIGNUM *a, BN_ULONG b) {
101  BIGNUM b_bn;
102  BN_init(&b_bn);
103
104  b_bn.d = &b;
105  b_bn.width = b > 0;
106  b_bn.dmax = 1;
107  b_bn.flags = BN_FLG_STATIC_DATA;
108  return BN_cmp(a, &b_bn);
109}
110
111int BN_is_zero(const BIGNUM *bn) {
112  return bn_fits_in_words(bn, 0);
113}
114
115int BN_is_one(const BIGNUM *bn) {
116  return bn->neg == 0 && BN_abs_is_word(bn, 1);
117}
118
119int BN_is_word(const BIGNUM *bn, BN_ULONG w) {
120  return BN_abs_is_word(bn, w) && (w == 0 || bn->neg == 0);
121}
122
123int BN_is_odd(const BIGNUM *bn) {
124  return bn->width > 0 && (bn->d[0] & 1) == 1;
125}
126
127int BN_is_pow2(const BIGNUM *bn) {
128  int width = bn_minimal_width(bn);
129  if (width == 0 || bn->neg) {
130    return 0;
131  }
132
133  for (int i = 0; i < width - 1; i++) {
134    if (bn->d[i] != 0) {
135      return 0;
136    }
137  }
138
139  return 0 == (bn->d[width-1] & (bn->d[width-1] - 1));
140}
141
142int BN_equal_consttime(const BIGNUM *a, const BIGNUM *b) {
143  BN_ULONG mask = 0;
144  // If |a| or |b| has more words than the other, all those words must be zero.
145  for (int i = a->width; i < b->width; i++) {
146    mask |= b->d[i];
147  }
148  for (int i = b->width; i < a->width; i++) {
149    mask |= a->d[i];
150  }
151  // Common words must match.
152  int min = a->width < b->width ? a->width : b->width;
153  for (int i = 0; i < min; i++) {
154    mask |= (a->d[i] ^ b->d[i]);
155  }
156  // The sign bit must match.
157  mask |= (a->neg ^ b->neg);
158  return mask == 0;
159}
160