...

Source file src/crypto/elliptic/elliptic_test.go

Documentation: crypto/elliptic

		 1  // Copyright 2010 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  package elliptic
		 6  
		 7  import (
		 8  	"bytes"
		 9  	"crypto/rand"
		10  	"encoding/hex"
		11  	"math/big"
		12  	"testing"
		13  )
		14  
		15  // genericParamsForCurve returns the dereferenced CurveParams for
		16  // the specified curve. This is used to avoid the logic for
		17  // upgrading a curve to it's specific implementation, forcing
		18  // usage of the generic implementation. This is only relevant
		19  // for the P224, P256, and P521 curves.
		20  func genericParamsForCurve(c Curve) *CurveParams {
		21  	d := *(c.Params())
		22  	return &d
		23  }
		24  
		25  func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
		26  	tests := []struct {
		27  		name	string
		28  		curve Curve
		29  	}{
		30  		{"P256", P256()},
		31  		{"P256/Params", genericParamsForCurve(P256())},
		32  		{"P224", P224()},
		33  		{"P224/Params", genericParamsForCurve(P224())},
		34  		{"P384", P384()},
		35  		{"P384/Params", genericParamsForCurve(P384())},
		36  		{"P521", P521()},
		37  		{"P521/Params", genericParamsForCurve(P521())},
		38  	}
		39  	if testing.Short() {
		40  		tests = tests[:1]
		41  	}
		42  	for _, test := range tests {
		43  		curve := test.curve
		44  		t.Run(test.name, func(t *testing.T) {
		45  			t.Parallel()
		46  			f(t, curve)
		47  		})
		48  	}
		49  }
		50  
		51  func TestOnCurve(t *testing.T) {
		52  	testAllCurves(t, func(t *testing.T, curve Curve) {
		53  		if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
		54  			t.Error("basepoint is not on the curve")
		55  		}
		56  	})
		57  }
		58  
		59  func TestOffCurve(t *testing.T) {
		60  	testAllCurves(t, func(t *testing.T, curve Curve) {
		61  		x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
		62  		if curve.IsOnCurve(x, y) {
		63  			t.Errorf("point off curve is claimed to be on the curve")
		64  		}
		65  		b := Marshal(curve, x, y)
		66  		x1, y1 := Unmarshal(curve, b)
		67  		if x1 != nil || y1 != nil {
		68  			t.Errorf("unmarshaling a point not on the curve succeeded")
		69  		}
		70  	})
		71  }
		72  
		73  func TestInfinity(t *testing.T) {
		74  	testAllCurves(t, testInfinity)
		75  }
		76  
		77  func testInfinity(t *testing.T, curve Curve) {
		78  	_, x, y, _ := GenerateKey(curve, rand.Reader)
		79  	x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes())
		80  	if x.Sign() != 0 || y.Sign() != 0 {
		81  		t.Errorf("x^q != ∞")
		82  	}
		83  
		84  	x, y = curve.ScalarBaseMult([]byte{0})
		85  	if x.Sign() != 0 || y.Sign() != 0 {
		86  		t.Errorf("b^0 != ∞")
		87  		x.SetInt64(0)
		88  		y.SetInt64(0)
		89  	}
		90  
		91  	x2, y2 := curve.Double(x, y)
		92  	if x2.Sign() != 0 || y2.Sign() != 0 {
		93  		t.Errorf("2∞ != ∞")
		94  	}
		95  
		96  	baseX := curve.Params().Gx
		97  	baseY := curve.Params().Gy
		98  
		99  	x3, y3 := curve.Add(baseX, baseY, x, y)
	 100  	if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
	 101  		t.Errorf("x+∞ != x")
	 102  	}
	 103  
	 104  	x4, y4 := curve.Add(x, y, baseX, baseY)
	 105  	if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
	 106  		t.Errorf("∞+x != x")
	 107  	}
	 108  
	 109  	if curve.IsOnCurve(x, y) {
	 110  		t.Errorf("IsOnCurve(∞) == true")
	 111  	}
	 112  }
	 113  
	 114  func TestMarshal(t *testing.T) {
	 115  	testAllCurves(t, func(t *testing.T, curve Curve) {
	 116  		_, x, y, err := GenerateKey(curve, rand.Reader)
	 117  		if err != nil {
	 118  			t.Fatal(err)
	 119  		}
	 120  		serialized := Marshal(curve, x, y)
	 121  		xx, yy := Unmarshal(curve, serialized)
	 122  		if xx == nil {
	 123  			t.Fatal("failed to unmarshal")
	 124  		}
	 125  		if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
	 126  			t.Fatal("unmarshal returned different values")
	 127  		}
	 128  	})
	 129  }
	 130  
	 131  func TestUnmarshalToLargeCoordinates(t *testing.T) {
	 132  	// See https://golang.org/issues/20482.
	 133  	testAllCurves(t, testUnmarshalToLargeCoordinates)
	 134  }
	 135  
	 136  func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) {
	 137  	p := curve.Params().P
	 138  	byteLen := (p.BitLen() + 7) / 8
	 139  
	 140  	// Set x to be greater than curve's parameter P – specifically, to P+5.
	 141  	// Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the
	 142  	// curve.
	 143  	x := new(big.Int).Add(p, big.NewInt(5))
	 144  	y := curve.Params().polynomial(x)
	 145  	y.ModSqrt(y, p)
	 146  
	 147  	invalid := make([]byte, byteLen*2+1)
	 148  	invalid[0] = 4 // uncompressed encoding
	 149  	x.FillBytes(invalid[1 : 1+byteLen])
	 150  	y.FillBytes(invalid[1+byteLen:])
	 151  
	 152  	if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
	 153  		t.Errorf("Unmarshal accepts invalid X coordinate")
	 154  	}
	 155  
	 156  	if curve == p256 {
	 157  		// This is a point on the curve with a small y value, small enough that
	 158  		// we can add p and still be within 32 bytes.
	 159  		x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10)
	 160  		y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10)
	 161  		y.Add(y, p)
	 162  
	 163  		if p.Cmp(y) > 0 || y.BitLen() != 256 {
	 164  			t.Fatal("y not within expected range")
	 165  		}
	 166  
	 167  		// marshal
	 168  		x.FillBytes(invalid[1 : 1+byteLen])
	 169  		y.FillBytes(invalid[1+byteLen:])
	 170  
	 171  		if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
	 172  			t.Errorf("Unmarshal accepts invalid Y coordinate")
	 173  		}
	 174  	}
	 175  }
	 176  
	 177  // TestInvalidCoordinates tests big.Int values that are not valid field elements
	 178  // (negative or bigger than P). They are expected to return false from
	 179  // IsOnCurve, all other behavior is undefined.
	 180  func TestInvalidCoordinates(t *testing.T) {
	 181  	testAllCurves(t, testInvalidCoordinates)
	 182  }
	 183  
	 184  func testInvalidCoordinates(t *testing.T, curve Curve) {
	 185  	checkIsOnCurveFalse := func(name string, x, y *big.Int) {
	 186  		if curve.IsOnCurve(x, y) {
	 187  			t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
	 188  		}
	 189  	}
	 190  
	 191  	p := curve.Params().P
	 192  	_, x, y, _ := GenerateKey(curve, rand.Reader)
	 193  	xx, yy := new(big.Int), new(big.Int)
	 194  
	 195  	// Check if the sign is getting dropped.
	 196  	xx.Neg(x)
	 197  	checkIsOnCurveFalse("-x, y", xx, y)
	 198  	yy.Neg(y)
	 199  	checkIsOnCurveFalse("x, -y", x, yy)
	 200  
	 201  	// Check if negative values are reduced modulo P.
	 202  	xx.Sub(x, p)
	 203  	checkIsOnCurveFalse("x-P, y", xx, y)
	 204  	yy.Sub(y, p)
	 205  	checkIsOnCurveFalse("x, y-P", x, yy)
	 206  
	 207  	// Check if positive values are reduced modulo P.
	 208  	xx.Add(x, p)
	 209  	checkIsOnCurveFalse("x+P, y", xx, y)
	 210  	yy.Add(y, p)
	 211  	checkIsOnCurveFalse("x, y+P", x, yy)
	 212  
	 213  	// Check if the overflow is dropped.
	 214  	xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
	 215  	checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
	 216  	yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
	 217  	checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
	 218  
	 219  	// Check if P is treated like zero (if possible).
	 220  	// y^2 = x^3 - 3x + B
	 221  	// y = mod_sqrt(x^3 - 3x + B)
	 222  	// y = mod_sqrt(B) if x = 0
	 223  	// If there is no modsqrt, there is no point with x = 0, can't test x = P.
	 224  	if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil {
	 225  		if !curve.IsOnCurve(big.NewInt(0), yy) {
	 226  			t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
	 227  		}
	 228  		checkIsOnCurveFalse("P, y", p, yy)
	 229  	}
	 230  }
	 231  
	 232  func TestMarshalCompressed(t *testing.T) {
	 233  	t.Run("P-256/03", func(t *testing.T) {
	 234  		data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
	 235  		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
	 236  		y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
	 237  		testMarshalCompressed(t, P256(), x, y, data)
	 238  	})
	 239  	t.Run("P-256/02", func(t *testing.T) {
	 240  		data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
	 241  		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
	 242  		y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
	 243  		testMarshalCompressed(t, P256(), x, y, data)
	 244  	})
	 245  
	 246  	t.Run("Invalid", func(t *testing.T) {
	 247  		data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
	 248  		X, Y := UnmarshalCompressed(P256(), data)
	 249  		if X != nil || Y != nil {
	 250  			t.Error("expected an error for invalid encoding")
	 251  		}
	 252  	})
	 253  
	 254  	if testing.Short() {
	 255  		t.Skip("skipping other curves on short test")
	 256  	}
	 257  
	 258  	testAllCurves(t, func(t *testing.T, curve Curve) {
	 259  		_, x, y, err := GenerateKey(curve, rand.Reader)
	 260  		if err != nil {
	 261  			t.Fatal(err)
	 262  		}
	 263  		testMarshalCompressed(t, curve, x, y, nil)
	 264  	})
	 265  
	 266  }
	 267  
	 268  func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
	 269  	if !curve.IsOnCurve(x, y) {
	 270  		t.Fatal("invalid test point")
	 271  	}
	 272  	got := MarshalCompressed(curve, x, y)
	 273  	if want != nil && !bytes.Equal(got, want) {
	 274  		t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
	 275  	}
	 276  
	 277  	X, Y := UnmarshalCompressed(curve, got)
	 278  	if X == nil || Y == nil {
	 279  		t.Fatalf("UnmarshalCompressed failed unexpectedly")
	 280  	}
	 281  
	 282  	if !curve.IsOnCurve(X, Y) {
	 283  		t.Error("UnmarshalCompressed returned a point not on the curve")
	 284  	}
	 285  	if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
	 286  		t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
	 287  	}
	 288  }
	 289  
	 290  func benchmarkAllCurves(t *testing.B, f func(*testing.B, Curve)) {
	 291  	tests := []struct {
	 292  		name	string
	 293  		curve Curve
	 294  	}{
	 295  		{"P256", P256()},
	 296  		{"P224", P224()},
	 297  		{"P384", P384()},
	 298  		{"P521", P521()},
	 299  	}
	 300  	for _, test := range tests {
	 301  		curve := test.curve
	 302  		t.Run(test.name, func(t *testing.B) {
	 303  			f(t, curve)
	 304  		})
	 305  	}
	 306  }
	 307  
	 308  func BenchmarkScalarBaseMult(b *testing.B) {
	 309  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
	 310  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
	 311  		b.ReportAllocs()
	 312  		b.ResetTimer()
	 313  		for i := 0; i < b.N; i++ {
	 314  			x, _ := curve.ScalarBaseMult(priv)
	 315  			// Prevent the compiler from optimizing out the operation.
	 316  			priv[0] ^= byte(x.Bits()[0])
	 317  		}
	 318  	})
	 319  }
	 320  
	 321  func BenchmarkScalarMult(b *testing.B) {
	 322  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
	 323  		_, x, y, _ := GenerateKey(curve, rand.Reader)
	 324  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
	 325  		b.ReportAllocs()
	 326  		b.ResetTimer()
	 327  		for i := 0; i < b.N; i++ {
	 328  			x, y = curve.ScalarMult(x, y, priv)
	 329  		}
	 330  	})
	 331  }
	 332  

View as plain text