1
2
3
4
5 package elliptic
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "encoding/hex"
11 "math/big"
12 "testing"
13 )
14
15
16
17
18
19
20 func genericParamsForCurve(c Curve) *CurveParams {
21 d := *(c.Params())
22 return &d
23 }
24
25 func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
26 tests := []struct {
27 name string
28 curve Curve
29 }{
30 {"P256", P256()},
31 {"P256/Params", genericParamsForCurve(P256())},
32 {"P224", P224()},
33 {"P224/Params", genericParamsForCurve(P224())},
34 {"P384", P384()},
35 {"P384/Params", genericParamsForCurve(P384())},
36 {"P521", P521()},
37 {"P521/Params", genericParamsForCurve(P521())},
38 }
39 if testing.Short() {
40 tests = tests[:1]
41 }
42 for _, test := range tests {
43 curve := test.curve
44 t.Run(test.name, func(t *testing.T) {
45 t.Parallel()
46 f(t, curve)
47 })
48 }
49 }
50
51 func TestOnCurve(t *testing.T) {
52 testAllCurves(t, func(t *testing.T, curve Curve) {
53 if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
54 t.Error("basepoint is not on the curve")
55 }
56 })
57 }
58
59 func TestOffCurve(t *testing.T) {
60 testAllCurves(t, func(t *testing.T, curve Curve) {
61 x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
62 if curve.IsOnCurve(x, y) {
63 t.Errorf("point off curve is claimed to be on the curve")
64 }
65 b := Marshal(curve, x, y)
66 x1, y1 := Unmarshal(curve, b)
67 if x1 != nil || y1 != nil {
68 t.Errorf("unmarshaling a point not on the curve succeeded")
69 }
70 })
71 }
72
73 func TestInfinity(t *testing.T) {
74 testAllCurves(t, testInfinity)
75 }
76
77 func testInfinity(t *testing.T, curve Curve) {
78 _, x, y, _ := GenerateKey(curve, rand.Reader)
79 x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes())
80 if x.Sign() != 0 || y.Sign() != 0 {
81 t.Errorf("x^q != ∞")
82 }
83
84 x, y = curve.ScalarBaseMult([]byte{0})
85 if x.Sign() != 0 || y.Sign() != 0 {
86 t.Errorf("b^0 != ∞")
87 x.SetInt64(0)
88 y.SetInt64(0)
89 }
90
91 x2, y2 := curve.Double(x, y)
92 if x2.Sign() != 0 || y2.Sign() != 0 {
93 t.Errorf("2∞ != ∞")
94 }
95
96 baseX := curve.Params().Gx
97 baseY := curve.Params().Gy
98
99 x3, y3 := curve.Add(baseX, baseY, x, y)
100 if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
101 t.Errorf("x+∞ != x")
102 }
103
104 x4, y4 := curve.Add(x, y, baseX, baseY)
105 if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
106 t.Errorf("∞+x != x")
107 }
108
109 if curve.IsOnCurve(x, y) {
110 t.Errorf("IsOnCurve(∞) == true")
111 }
112 }
113
114 func TestMarshal(t *testing.T) {
115 testAllCurves(t, func(t *testing.T, curve Curve) {
116 _, x, y, err := GenerateKey(curve, rand.Reader)
117 if err != nil {
118 t.Fatal(err)
119 }
120 serialized := Marshal(curve, x, y)
121 xx, yy := Unmarshal(curve, serialized)
122 if xx == nil {
123 t.Fatal("failed to unmarshal")
124 }
125 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
126 t.Fatal("unmarshal returned different values")
127 }
128 })
129 }
130
131 func TestUnmarshalToLargeCoordinates(t *testing.T) {
132
133 testAllCurves(t, testUnmarshalToLargeCoordinates)
134 }
135
136 func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) {
137 p := curve.Params().P
138 byteLen := (p.BitLen() + 7) / 8
139
140
141
142
143 x := new(big.Int).Add(p, big.NewInt(5))
144 y := curve.Params().polynomial(x)
145 y.ModSqrt(y, p)
146
147 invalid := make([]byte, byteLen*2+1)
148 invalid[0] = 4
149 x.FillBytes(invalid[1 : 1+byteLen])
150 y.FillBytes(invalid[1+byteLen:])
151
152 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
153 t.Errorf("Unmarshal accepts invalid X coordinate")
154 }
155
156 if curve == p256 {
157
158
159 x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10)
160 y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10)
161 y.Add(y, p)
162
163 if p.Cmp(y) > 0 || y.BitLen() != 256 {
164 t.Fatal("y not within expected range")
165 }
166
167
168 x.FillBytes(invalid[1 : 1+byteLen])
169 y.FillBytes(invalid[1+byteLen:])
170
171 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
172 t.Errorf("Unmarshal accepts invalid Y coordinate")
173 }
174 }
175 }
176
177
178
179
180 func TestInvalidCoordinates(t *testing.T) {
181 testAllCurves(t, testInvalidCoordinates)
182 }
183
184 func testInvalidCoordinates(t *testing.T, curve Curve) {
185 checkIsOnCurveFalse := func(name string, x, y *big.Int) {
186 if curve.IsOnCurve(x, y) {
187 t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
188 }
189 }
190
191 p := curve.Params().P
192 _, x, y, _ := GenerateKey(curve, rand.Reader)
193 xx, yy := new(big.Int), new(big.Int)
194
195
196 xx.Neg(x)
197 checkIsOnCurveFalse("-x, y", xx, y)
198 yy.Neg(y)
199 checkIsOnCurveFalse("x, -y", x, yy)
200
201
202 xx.Sub(x, p)
203 checkIsOnCurveFalse("x-P, y", xx, y)
204 yy.Sub(y, p)
205 checkIsOnCurveFalse("x, y-P", x, yy)
206
207
208 xx.Add(x, p)
209 checkIsOnCurveFalse("x+P, y", xx, y)
210 yy.Add(y, p)
211 checkIsOnCurveFalse("x, y+P", x, yy)
212
213
214 xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
215 checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
216 yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
217 checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
218
219
220
221
222
223
224 if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil {
225 if !curve.IsOnCurve(big.NewInt(0), yy) {
226 t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
227 }
228 checkIsOnCurveFalse("P, y", p, yy)
229 }
230 }
231
232 func TestMarshalCompressed(t *testing.T) {
233 t.Run("P-256/03", func(t *testing.T) {
234 data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
235 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
236 y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
237 testMarshalCompressed(t, P256(), x, y, data)
238 })
239 t.Run("P-256/02", func(t *testing.T) {
240 data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
241 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
242 y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
243 testMarshalCompressed(t, P256(), x, y, data)
244 })
245
246 t.Run("Invalid", func(t *testing.T) {
247 data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
248 X, Y := UnmarshalCompressed(P256(), data)
249 if X != nil || Y != nil {
250 t.Error("expected an error for invalid encoding")
251 }
252 })
253
254 if testing.Short() {
255 t.Skip("skipping other curves on short test")
256 }
257
258 testAllCurves(t, func(t *testing.T, curve Curve) {
259 _, x, y, err := GenerateKey(curve, rand.Reader)
260 if err != nil {
261 t.Fatal(err)
262 }
263 testMarshalCompressed(t, curve, x, y, nil)
264 })
265
266 }
267
268 func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
269 if !curve.IsOnCurve(x, y) {
270 t.Fatal("invalid test point")
271 }
272 got := MarshalCompressed(curve, x, y)
273 if want != nil && !bytes.Equal(got, want) {
274 t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
275 }
276
277 X, Y := UnmarshalCompressed(curve, got)
278 if X == nil || Y == nil {
279 t.Fatalf("UnmarshalCompressed failed unexpectedly")
280 }
281
282 if !curve.IsOnCurve(X, Y) {
283 t.Error("UnmarshalCompressed returned a point not on the curve")
284 }
285 if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
286 t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
287 }
288 }
289
290 func benchmarkAllCurves(t *testing.B, f func(*testing.B, Curve)) {
291 tests := []struct {
292 name string
293 curve Curve
294 }{
295 {"P256", P256()},
296 {"P224", P224()},
297 {"P384", P384()},
298 {"P521", P521()},
299 }
300 for _, test := range tests {
301 curve := test.curve
302 t.Run(test.name, func(t *testing.B) {
303 f(t, curve)
304 })
305 }
306 }
307
308 func BenchmarkScalarBaseMult(b *testing.B) {
309 benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
310 priv, _, _, _ := GenerateKey(curve, rand.Reader)
311 b.ReportAllocs()
312 b.ResetTimer()
313 for i := 0; i < b.N; i++ {
314 x, _ := curve.ScalarBaseMult(priv)
315
316 priv[0] ^= byte(x.Bits()[0])
317 }
318 })
319 }
320
321 func BenchmarkScalarMult(b *testing.B) {
322 benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
323 _, x, y, _ := GenerateKey(curve, rand.Reader)
324 priv, _, _, _ := GenerateKey(curve, rand.Reader)
325 b.ReportAllocs()
326 b.ResetTimer()
327 for i := 0; i < b.N; i++ {
328 x, y = curve.ScalarMult(x, y, priv)
329 }
330 })
331 }
332
View as plain text