...

Source file src/math/big/nat.go

Documentation: math/big

		 1  // Copyright 2009 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  // This file implements unsigned multi-precision integers (natural
		 6  // numbers). They are the building blocks for the implementation
		 7  // of signed integers, rationals, and floating-point numbers.
		 8  //
		 9  // Caution: This implementation relies on the function "alias"
		10  //					which assumes that (nat) slice capacities are never
		11  //					changed (no 3-operand slice expressions). If that
		12  //					changes, alias needs to be updated for correctness.
		13  
		14  package big
		15  
		16  import (
		17  	"encoding/binary"
		18  	"math/bits"
		19  	"math/rand"
		20  	"sync"
		21  )
		22  
		23  // An unsigned integer x of the form
		24  //
		25  //	 x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
		26  //
		27  // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
		28  // with the digits x[i] as the slice elements.
		29  //
		30  // A number is normalized if the slice contains no leading 0 digits.
		31  // During arithmetic operations, denormalized values may occur but are
		32  // always normalized before returning the final result. The normalized
		33  // representation of 0 is the empty or nil slice (length = 0).
		34  //
		35  type nat []Word
		36  
		37  var (
		38  	natOne	= nat{1}
		39  	natTwo	= nat{2}
		40  	natFive = nat{5}
		41  	natTen	= nat{10}
		42  )
		43  
		44  func (z nat) clear() {
		45  	for i := range z {
		46  		z[i] = 0
		47  	}
		48  }
		49  
		50  func (z nat) norm() nat {
		51  	i := len(z)
		52  	for i > 0 && z[i-1] == 0 {
		53  		i--
		54  	}
		55  	return z[0:i]
		56  }
		57  
		58  func (z nat) make(n int) nat {
		59  	if n <= cap(z) {
		60  		return z[:n] // reuse z
		61  	}
		62  	if n == 1 {
		63  		// Most nats start small and stay that way; don't over-allocate.
		64  		return make(nat, 1)
		65  	}
		66  	// Choosing a good value for e has significant performance impact
		67  	// because it increases the chance that a value can be reused.
		68  	const e = 4 // extra capacity
		69  	return make(nat, n, n+e)
		70  }
		71  
		72  func (z nat) setWord(x Word) nat {
		73  	if x == 0 {
		74  		return z[:0]
		75  	}
		76  	z = z.make(1)
		77  	z[0] = x
		78  	return z
		79  }
		80  
		81  func (z nat) setUint64(x uint64) nat {
		82  	// single-word value
		83  	if w := Word(x); uint64(w) == x {
		84  		return z.setWord(w)
		85  	}
		86  	// 2-word value
		87  	z = z.make(2)
		88  	z[1] = Word(x >> 32)
		89  	z[0] = Word(x)
		90  	return z
		91  }
		92  
		93  func (z nat) set(x nat) nat {
		94  	z = z.make(len(x))
		95  	copy(z, x)
		96  	return z
		97  }
		98  
		99  func (z nat) add(x, y nat) nat {
	 100  	m := len(x)
	 101  	n := len(y)
	 102  
	 103  	switch {
	 104  	case m < n:
	 105  		return z.add(y, x)
	 106  	case m == 0:
	 107  		// n == 0 because m >= n; result is 0
	 108  		return z[:0]
	 109  	case n == 0:
	 110  		// result is x
	 111  		return z.set(x)
	 112  	}
	 113  	// m > 0
	 114  
	 115  	z = z.make(m + 1)
	 116  	c := addVV(z[0:n], x, y)
	 117  	if m > n {
	 118  		c = addVW(z[n:m], x[n:], c)
	 119  	}
	 120  	z[m] = c
	 121  
	 122  	return z.norm()
	 123  }
	 124  
	 125  func (z nat) sub(x, y nat) nat {
	 126  	m := len(x)
	 127  	n := len(y)
	 128  
	 129  	switch {
	 130  	case m < n:
	 131  		panic("underflow")
	 132  	case m == 0:
	 133  		// n == 0 because m >= n; result is 0
	 134  		return z[:0]
	 135  	case n == 0:
	 136  		// result is x
	 137  		return z.set(x)
	 138  	}
	 139  	// m > 0
	 140  
	 141  	z = z.make(m)
	 142  	c := subVV(z[0:n], x, y)
	 143  	if m > n {
	 144  		c = subVW(z[n:], x[n:], c)
	 145  	}
	 146  	if c != 0 {
	 147  		panic("underflow")
	 148  	}
	 149  
	 150  	return z.norm()
	 151  }
	 152  
	 153  func (x nat) cmp(y nat) (r int) {
	 154  	m := len(x)
	 155  	n := len(y)
	 156  	if m != n || m == 0 {
	 157  		switch {
	 158  		case m < n:
	 159  			r = -1
	 160  		case m > n:
	 161  			r = 1
	 162  		}
	 163  		return
	 164  	}
	 165  
	 166  	i := m - 1
	 167  	for i > 0 && x[i] == y[i] {
	 168  		i--
	 169  	}
	 170  
	 171  	switch {
	 172  	case x[i] < y[i]:
	 173  		r = -1
	 174  	case x[i] > y[i]:
	 175  		r = 1
	 176  	}
	 177  	return
	 178  }
	 179  
	 180  func (z nat) mulAddWW(x nat, y, r Word) nat {
	 181  	m := len(x)
	 182  	if m == 0 || y == 0 {
	 183  		return z.setWord(r) // result is r
	 184  	}
	 185  	// m > 0
	 186  
	 187  	z = z.make(m + 1)
	 188  	z[m] = mulAddVWW(z[0:m], x, y, r)
	 189  
	 190  	return z.norm()
	 191  }
	 192  
	 193  // basicMul multiplies x and y and leaves the result in z.
	 194  // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
	 195  func basicMul(z, x, y nat) {
	 196  	z[0 : len(x)+len(y)].clear() // initialize z
	 197  	for i, d := range y {
	 198  		if d != 0 {
	 199  			z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
	 200  		}
	 201  	}
	 202  }
	 203  
	 204  // montgomery computes z mod m = x*y*2**(-n*_W) mod m,
	 205  // assuming k = -1/m mod 2**_W.
	 206  // z is used for storing the result which is returned;
	 207  // z must not alias x, y or m.
	 208  // See Gueron, "Efficient Software Implementations of Modular Exponentiation".
	 209  // https://eprint.iacr.org/2011/239.pdf
	 210  // In the terminology of that paper, this is an "Almost Montgomery Multiplication":
	 211  // x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
	 212  // z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
	 213  func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
	 214  	// This code assumes x, y, m are all the same length, n.
	 215  	// (required by addMulVVW and the for loop).
	 216  	// It also assumes that x, y are already reduced mod m,
	 217  	// or else the result will not be properly reduced.
	 218  	if len(x) != n || len(y) != n || len(m) != n {
	 219  		panic("math/big: mismatched montgomery number lengths")
	 220  	}
	 221  	z = z.make(n * 2)
	 222  	z.clear()
	 223  	var c Word
	 224  	for i := 0; i < n; i++ {
	 225  		d := y[i]
	 226  		c2 := addMulVVW(z[i:n+i], x, d)
	 227  		t := z[i] * k
	 228  		c3 := addMulVVW(z[i:n+i], m, t)
	 229  		cx := c + c2
	 230  		cy := cx + c3
	 231  		z[n+i] = cy
	 232  		if cx < c2 || cy < c3 {
	 233  			c = 1
	 234  		} else {
	 235  			c = 0
	 236  		}
	 237  	}
	 238  	if c != 0 {
	 239  		subVV(z[:n], z[n:], m)
	 240  	} else {
	 241  		copy(z[:n], z[n:])
	 242  	}
	 243  	return z[:n]
	 244  }
	 245  
	 246  // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
	 247  // Factored out for readability - do not use outside karatsuba.
	 248  func karatsubaAdd(z, x nat, n int) {
	 249  	if c := addVV(z[0:n], z, x); c != 0 {
	 250  		addVW(z[n:n+n>>1], z[n:], c)
	 251  	}
	 252  }
	 253  
	 254  // Like karatsubaAdd, but does subtract.
	 255  func karatsubaSub(z, x nat, n int) {
	 256  	if c := subVV(z[0:n], z, x); c != 0 {
	 257  		subVW(z[n:n+n>>1], z[n:], c)
	 258  	}
	 259  }
	 260  
	 261  // Operands that are shorter than karatsubaThreshold are multiplied using
	 262  // "grade school" multiplication; for longer operands the Karatsuba algorithm
	 263  // is used.
	 264  var karatsubaThreshold = 40 // computed by calibrate_test.go
	 265  
	 266  // karatsuba multiplies x and y and leaves the result in z.
	 267  // Both x and y must have the same length n and n must be a
	 268  // power of 2. The result vector z must have len(z) >= 6*n.
	 269  // The (non-normalized) result is placed in z[0 : 2*n].
	 270  func karatsuba(z, x, y nat) {
	 271  	n := len(y)
	 272  
	 273  	// Switch to basic multiplication if numbers are odd or small.
	 274  	// (n is always even if karatsubaThreshold is even, but be
	 275  	// conservative)
	 276  	if n&1 != 0 || n < karatsubaThreshold || n < 2 {
	 277  		basicMul(z, x, y)
	 278  		return
	 279  	}
	 280  	// n&1 == 0 && n >= karatsubaThreshold && n >= 2
	 281  
	 282  	// Karatsuba multiplication is based on the observation that
	 283  	// for two numbers x and y with:
	 284  	//
	 285  	//	 x = x1*b + x0
	 286  	//	 y = y1*b + y0
	 287  	//
	 288  	// the product x*y can be obtained with 3 products z2, z1, z0
	 289  	// instead of 4:
	 290  	//
	 291  	//	 x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
	 292  	//			 =		z2*b*b +							z1*b +		z0
	 293  	//
	 294  	// with:
	 295  	//
	 296  	//	 xd = x1 - x0
	 297  	//	 yd = y0 - y1
	 298  	//
	 299  	//	 z1 =			xd*yd										+ z2 + z0
	 300  	//			= (x1-x0)*(y0 - y1)						 + z2 + z0
	 301  	//			= x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
	 302  	//			= x1*y0 -		z2 -		z0 + x0*y1 + z2 + z0
	 303  	//			= x1*y0								 + x0*y1
	 304  
	 305  	// split x, y into "digits"
	 306  	n2 := n >> 1							// n2 >= 1
	 307  	x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
	 308  	y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
	 309  
	 310  	// z is used for the result and temporary storage:
	 311  	//
	 312  	//	 6*n		 5*n		 4*n		 3*n		 2*n		 1*n		 0*n
	 313  	// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
	 314  	//
	 315  	// For each recursive call of karatsuba, an unused slice of
	 316  	// z is passed in that has (at least) half the length of the
	 317  	// caller's z.
	 318  
	 319  	// compute z0 and z2 with the result "in place" in z
	 320  	karatsuba(z, x0, y0)		 // z0 = x0*y0
	 321  	karatsuba(z[n:], x1, y1) // z2 = x1*y1
	 322  
	 323  	// compute xd (or the negative value if underflow occurs)
	 324  	s := 1 // sign of product xd*yd
	 325  	xd := z[2*n : 2*n+n2]
	 326  	if subVV(xd, x1, x0) != 0 { // x1-x0
	 327  		s = -s
	 328  		subVV(xd, x0, x1) // x0-x1
	 329  	}
	 330  
	 331  	// compute yd (or the negative value if underflow occurs)
	 332  	yd := z[2*n+n2 : 3*n]
	 333  	if subVV(yd, y0, y1) != 0 { // y0-y1
	 334  		s = -s
	 335  		subVV(yd, y1, y0) // y1-y0
	 336  	}
	 337  
	 338  	// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
	 339  	// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
	 340  	p := z[n*3:]
	 341  	karatsuba(p, xd, yd)
	 342  
	 343  	// save original z2:z0
	 344  	// (ok to use upper half of z since we're done recursing)
	 345  	r := z[n*4:]
	 346  	copy(r, z[:n*2])
	 347  
	 348  	// add up all partial products
	 349  	//
	 350  	//	 2*n		 n		 0
	 351  	// z = [ z2	| z0	]
	 352  	//	 +		[ z0	]
	 353  	//	 +		[ z2	]
	 354  	//	 +		[	p	]
	 355  	//
	 356  	karatsubaAdd(z[n2:], r, n)
	 357  	karatsubaAdd(z[n2:], r[n:], n)
	 358  	if s > 0 {
	 359  		karatsubaAdd(z[n2:], p, n)
	 360  	} else {
	 361  		karatsubaSub(z[n2:], p, n)
	 362  	}
	 363  }
	 364  
	 365  // alias reports whether x and y share the same base array.
	 366  // Note: alias assumes that the capacity of underlying arrays
	 367  //			 is never changed for nat values; i.e. that there are
	 368  //			 no 3-operand slice expressions in this code (or worse,
	 369  //			 reflect-based operations to the same effect).
	 370  func alias(x, y nat) bool {
	 371  	return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
	 372  }
	 373  
	 374  // addAt implements z += x<<(_W*i); z must be long enough.
	 375  // (we don't use nat.add because we need z to stay the same
	 376  // slice, and we don't need to normalize z after each addition)
	 377  func addAt(z, x nat, i int) {
	 378  	if n := len(x); n > 0 {
	 379  		if c := addVV(z[i:i+n], z[i:], x); c != 0 {
	 380  			j := i + n
	 381  			if j < len(z) {
	 382  				addVW(z[j:], z[j:], c)
	 383  			}
	 384  		}
	 385  	}
	 386  }
	 387  
	 388  func max(x, y int) int {
	 389  	if x > y {
	 390  		return x
	 391  	}
	 392  	return y
	 393  }
	 394  
	 395  // karatsubaLen computes an approximation to the maximum k <= n such that
	 396  // k = p<<i for a number p <= threshold and an i >= 0. Thus, the
	 397  // result is the largest number that can be divided repeatedly by 2 before
	 398  // becoming about the value of threshold.
	 399  func karatsubaLen(n, threshold int) int {
	 400  	i := uint(0)
	 401  	for n > threshold {
	 402  		n >>= 1
	 403  		i++
	 404  	}
	 405  	return n << i
	 406  }
	 407  
	 408  func (z nat) mul(x, y nat) nat {
	 409  	m := len(x)
	 410  	n := len(y)
	 411  
	 412  	switch {
	 413  	case m < n:
	 414  		return z.mul(y, x)
	 415  	case m == 0 || n == 0:
	 416  		return z[:0]
	 417  	case n == 1:
	 418  		return z.mulAddWW(x, y[0], 0)
	 419  	}
	 420  	// m >= n > 1
	 421  
	 422  	// determine if z can be reused
	 423  	if alias(z, x) || alias(z, y) {
	 424  		z = nil // z is an alias for x or y - cannot reuse
	 425  	}
	 426  
	 427  	// use basic multiplication if the numbers are small
	 428  	if n < karatsubaThreshold {
	 429  		z = z.make(m + n)
	 430  		basicMul(z, x, y)
	 431  		return z.norm()
	 432  	}
	 433  	// m >= n && n >= karatsubaThreshold && n >= 2
	 434  
	 435  	// determine Karatsuba length k such that
	 436  	//
	 437  	//	 x = xh*b + x0	(0 <= x0 < b)
	 438  	//	 y = yh*b + y0	(0 <= y0 < b)
	 439  	//	 b = 1<<(_W*k)	("base" of digits xi, yi)
	 440  	//
	 441  	k := karatsubaLen(n, karatsubaThreshold)
	 442  	// k <= n
	 443  
	 444  	// multiply x0 and y0 via Karatsuba
	 445  	x0 := x[0:k]							// x0 is not normalized
	 446  	y0 := y[0:k]							// y0 is not normalized
	 447  	z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
	 448  	karatsuba(z, x0, y0)
	 449  	z = z[0 : m+n]	// z has final length but may be incomplete
	 450  	z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
	 451  
	 452  	// If xh != 0 or yh != 0, add the missing terms to z. For
	 453  	//
	 454  	//	 xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
	 455  	//	 yh =												 y1*b (0 <= y1 < b)
	 456  	//
	 457  	// the missing terms are
	 458  	//
	 459  	//	 x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
	 460  	//
	 461  	// since all the yi for i > 1 are 0 by choice of k: If any of them
	 462  	// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
	 463  	// be a larger valid threshold contradicting the assumption about k.
	 464  	//
	 465  	if k < n || m != n {
	 466  		tp := getNat(3 * k)
	 467  		t := *tp
	 468  
	 469  		// add x0*y1*b
	 470  		x0 := x0.norm()
	 471  		y1 := y[k:]			 // y1 is normalized because y is
	 472  		t = t.mul(x0, y1) // update t so we don't lose t's underlying array
	 473  		addAt(z, t, k)
	 474  
	 475  		// add xi*y0<<i, xi*y1*b<<(i+k)
	 476  		y0 := y0.norm()
	 477  		for i := k; i < len(x); i += k {
	 478  			xi := x[i:]
	 479  			if len(xi) > k {
	 480  				xi = xi[:k]
	 481  			}
	 482  			xi = xi.norm()
	 483  			t = t.mul(xi, y0)
	 484  			addAt(z, t, i)
	 485  			t = t.mul(xi, y1)
	 486  			addAt(z, t, i+k)
	 487  		}
	 488  
	 489  		putNat(tp)
	 490  	}
	 491  
	 492  	return z.norm()
	 493  }
	 494  
	 495  // basicSqr sets z = x*x and is asymptotically faster than basicMul
	 496  // by about a factor of 2, but slower for small arguments due to overhead.
	 497  // Requirements: len(x) > 0, len(z) == 2*len(x)
	 498  // The (non-normalized) result is placed in z.
	 499  func basicSqr(z, x nat) {
	 500  	n := len(x)
	 501  	tp := getNat(2 * n)
	 502  	t := *tp // temporary variable to hold the products
	 503  	t.clear()
	 504  	z[1], z[0] = mulWW(x[0], x[0]) // the initial square
	 505  	for i := 1; i < n; i++ {
	 506  		d := x[i]
	 507  		// z collects the squares x[i] * x[i]
	 508  		z[2*i+1], z[2*i] = mulWW(d, d)
	 509  		// t collects the products x[i] * x[j] where j < i
	 510  		t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
	 511  	}
	 512  	t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
	 513  	addVV(z, z, t)															// combine the result
	 514  	putNat(tp)
	 515  }
	 516  
	 517  // karatsubaSqr squares x and leaves the result in z.
	 518  // len(x) must be a power of 2 and len(z) >= 6*len(x).
	 519  // The (non-normalized) result is placed in z[0 : 2*len(x)].
	 520  //
	 521  // The algorithm and the layout of z are the same as for karatsuba.
	 522  func karatsubaSqr(z, x nat) {
	 523  	n := len(x)
	 524  
	 525  	if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
	 526  		basicSqr(z[:2*n], x)
	 527  		return
	 528  	}
	 529  
	 530  	n2 := n >> 1
	 531  	x1, x0 := x[n2:], x[0:n2]
	 532  
	 533  	karatsubaSqr(z, x0)
	 534  	karatsubaSqr(z[n:], x1)
	 535  
	 536  	// s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
	 537  	xd := z[2*n : 2*n+n2]
	 538  	if subVV(xd, x1, x0) != 0 {
	 539  		subVV(xd, x0, x1)
	 540  	}
	 541  
	 542  	p := z[n*3:]
	 543  	karatsubaSqr(p, xd)
	 544  
	 545  	r := z[n*4:]
	 546  	copy(r, z[:n*2])
	 547  
	 548  	karatsubaAdd(z[n2:], r, n)
	 549  	karatsubaAdd(z[n2:], r[n:], n)
	 550  	karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0
	 551  }
	 552  
	 553  // Operands that are shorter than basicSqrThreshold are squared using
	 554  // "grade school" multiplication; for operands longer than karatsubaSqrThreshold
	 555  // we use the Karatsuba algorithm optimized for x == y.
	 556  var basicSqrThreshold = 20			// computed by calibrate_test.go
	 557  var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
	 558  
	 559  // z = x*x
	 560  func (z nat) sqr(x nat) nat {
	 561  	n := len(x)
	 562  	switch {
	 563  	case n == 0:
	 564  		return z[:0]
	 565  	case n == 1:
	 566  		d := x[0]
	 567  		z = z.make(2)
	 568  		z[1], z[0] = mulWW(d, d)
	 569  		return z.norm()
	 570  	}
	 571  
	 572  	if alias(z, x) {
	 573  		z = nil // z is an alias for x - cannot reuse
	 574  	}
	 575  
	 576  	if n < basicSqrThreshold {
	 577  		z = z.make(2 * n)
	 578  		basicMul(z, x, x)
	 579  		return z.norm()
	 580  	}
	 581  	if n < karatsubaSqrThreshold {
	 582  		z = z.make(2 * n)
	 583  		basicSqr(z, x)
	 584  		return z.norm()
	 585  	}
	 586  
	 587  	// Use Karatsuba multiplication optimized for x == y.
	 588  	// The algorithm and layout of z are the same as for mul.
	 589  
	 590  	// z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2
	 591  
	 592  	k := karatsubaLen(n, karatsubaSqrThreshold)
	 593  
	 594  	x0 := x[0:k]
	 595  	z = z.make(max(6*k, 2*n))
	 596  	karatsubaSqr(z, x0) // z = x0^2
	 597  	z = z[0 : 2*n]
	 598  	z[2*k:].clear()
	 599  
	 600  	if k < n {
	 601  		tp := getNat(2 * k)
	 602  		t := *tp
	 603  		x0 := x0.norm()
	 604  		x1 := x[k:]
	 605  		t = t.mul(x0, x1)
	 606  		addAt(z, t, k)
	 607  		addAt(z, t, k) // z = 2*x1*x0*b + x0^2
	 608  		t = t.sqr(x1)
	 609  		addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
	 610  		putNat(tp)
	 611  	}
	 612  
	 613  	return z.norm()
	 614  }
	 615  
	 616  // mulRange computes the product of all the unsigned integers in the
	 617  // range [a, b] inclusively. If a > b (empty range), the result is 1.
	 618  func (z nat) mulRange(a, b uint64) nat {
	 619  	switch {
	 620  	case a == 0:
	 621  		// cut long ranges short (optimization)
	 622  		return z.setUint64(0)
	 623  	case a > b:
	 624  		return z.setUint64(1)
	 625  	case a == b:
	 626  		return z.setUint64(a)
	 627  	case a+1 == b:
	 628  		return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
	 629  	}
	 630  	m := (a + b) / 2
	 631  	return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
	 632  }
	 633  
	 634  // getNat returns a *nat of len n. The contents may not be zero.
	 635  // The pool holds *nat to avoid allocation when converting to interface{}.
	 636  func getNat(n int) *nat {
	 637  	var z *nat
	 638  	if v := natPool.Get(); v != nil {
	 639  		z = v.(*nat)
	 640  	}
	 641  	if z == nil {
	 642  		z = new(nat)
	 643  	}
	 644  	*z = z.make(n)
	 645  	return z
	 646  }
	 647  
	 648  func putNat(x *nat) {
	 649  	natPool.Put(x)
	 650  }
	 651  
	 652  var natPool sync.Pool
	 653  
	 654  // Length of x in bits. x must be normalized.
	 655  func (x nat) bitLen() int {
	 656  	if i := len(x) - 1; i >= 0 {
	 657  		return i*_W + bits.Len(uint(x[i]))
	 658  	}
	 659  	return 0
	 660  }
	 661  
	 662  // trailingZeroBits returns the number of consecutive least significant zero
	 663  // bits of x.
	 664  func (x nat) trailingZeroBits() uint {
	 665  	if len(x) == 0 {
	 666  		return 0
	 667  	}
	 668  	var i uint
	 669  	for x[i] == 0 {
	 670  		i++
	 671  	}
	 672  	// x[i] != 0
	 673  	return i*_W + uint(bits.TrailingZeros(uint(x[i])))
	 674  }
	 675  
	 676  func same(x, y nat) bool {
	 677  	return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0]
	 678  }
	 679  
	 680  // z = x << s
	 681  func (z nat) shl(x nat, s uint) nat {
	 682  	if s == 0 {
	 683  		if same(z, x) {
	 684  			return z
	 685  		}
	 686  		if !alias(z, x) {
	 687  			return z.set(x)
	 688  		}
	 689  	}
	 690  
	 691  	m := len(x)
	 692  	if m == 0 {
	 693  		return z[:0]
	 694  	}
	 695  	// m > 0
	 696  
	 697  	n := m + int(s/_W)
	 698  	z = z.make(n + 1)
	 699  	z[n] = shlVU(z[n-m:n], x, s%_W)
	 700  	z[0 : n-m].clear()
	 701  
	 702  	return z.norm()
	 703  }
	 704  
	 705  // z = x >> s
	 706  func (z nat) shr(x nat, s uint) nat {
	 707  	if s == 0 {
	 708  		if same(z, x) {
	 709  			return z
	 710  		}
	 711  		if !alias(z, x) {
	 712  			return z.set(x)
	 713  		}
	 714  	}
	 715  
	 716  	m := len(x)
	 717  	n := m - int(s/_W)
	 718  	if n <= 0 {
	 719  		return z[:0]
	 720  	}
	 721  	// n > 0
	 722  
	 723  	z = z.make(n)
	 724  	shrVU(z, x[m-n:], s%_W)
	 725  
	 726  	return z.norm()
	 727  }
	 728  
	 729  func (z nat) setBit(x nat, i uint, b uint) nat {
	 730  	j := int(i / _W)
	 731  	m := Word(1) << (i % _W)
	 732  	n := len(x)
	 733  	switch b {
	 734  	case 0:
	 735  		z = z.make(n)
	 736  		copy(z, x)
	 737  		if j >= n {
	 738  			// no need to grow
	 739  			return z
	 740  		}
	 741  		z[j] &^= m
	 742  		return z.norm()
	 743  	case 1:
	 744  		if j >= n {
	 745  			z = z.make(j + 1)
	 746  			z[n:].clear()
	 747  		} else {
	 748  			z = z.make(n)
	 749  		}
	 750  		copy(z, x)
	 751  		z[j] |= m
	 752  		// no need to normalize
	 753  		return z
	 754  	}
	 755  	panic("set bit is not 0 or 1")
	 756  }
	 757  
	 758  // bit returns the value of the i'th bit, with lsb == bit 0.
	 759  func (x nat) bit(i uint) uint {
	 760  	j := i / _W
	 761  	if j >= uint(len(x)) {
	 762  		return 0
	 763  	}
	 764  	// 0 <= j < len(x)
	 765  	return uint(x[j] >> (i % _W) & 1)
	 766  }
	 767  
	 768  // sticky returns 1 if there's a 1 bit within the
	 769  // i least significant bits, otherwise it returns 0.
	 770  func (x nat) sticky(i uint) uint {
	 771  	j := i / _W
	 772  	if j >= uint(len(x)) {
	 773  		if len(x) == 0 {
	 774  			return 0
	 775  		}
	 776  		return 1
	 777  	}
	 778  	// 0 <= j < len(x)
	 779  	for _, x := range x[:j] {
	 780  		if x != 0 {
	 781  			return 1
	 782  		}
	 783  	}
	 784  	if x[j]<<(_W-i%_W) != 0 {
	 785  		return 1
	 786  	}
	 787  	return 0
	 788  }
	 789  
	 790  func (z nat) and(x, y nat) nat {
	 791  	m := len(x)
	 792  	n := len(y)
	 793  	if m > n {
	 794  		m = n
	 795  	}
	 796  	// m <= n
	 797  
	 798  	z = z.make(m)
	 799  	for i := 0; i < m; i++ {
	 800  		z[i] = x[i] & y[i]
	 801  	}
	 802  
	 803  	return z.norm()
	 804  }
	 805  
	 806  func (z nat) andNot(x, y nat) nat {
	 807  	m := len(x)
	 808  	n := len(y)
	 809  	if n > m {
	 810  		n = m
	 811  	}
	 812  	// m >= n
	 813  
	 814  	z = z.make(m)
	 815  	for i := 0; i < n; i++ {
	 816  		z[i] = x[i] &^ y[i]
	 817  	}
	 818  	copy(z[n:m], x[n:m])
	 819  
	 820  	return z.norm()
	 821  }
	 822  
	 823  func (z nat) or(x, y nat) nat {
	 824  	m := len(x)
	 825  	n := len(y)
	 826  	s := x
	 827  	if m < n {
	 828  		n, m = m, n
	 829  		s = y
	 830  	}
	 831  	// m >= n
	 832  
	 833  	z = z.make(m)
	 834  	for i := 0; i < n; i++ {
	 835  		z[i] = x[i] | y[i]
	 836  	}
	 837  	copy(z[n:m], s[n:m])
	 838  
	 839  	return z.norm()
	 840  }
	 841  
	 842  func (z nat) xor(x, y nat) nat {
	 843  	m := len(x)
	 844  	n := len(y)
	 845  	s := x
	 846  	if m < n {
	 847  		n, m = m, n
	 848  		s = y
	 849  	}
	 850  	// m >= n
	 851  
	 852  	z = z.make(m)
	 853  	for i := 0; i < n; i++ {
	 854  		z[i] = x[i] ^ y[i]
	 855  	}
	 856  	copy(z[n:m], s[n:m])
	 857  
	 858  	return z.norm()
	 859  }
	 860  
	 861  // random creates a random integer in [0..limit), using the space in z if
	 862  // possible. n is the bit length of limit.
	 863  func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
	 864  	if alias(z, limit) {
	 865  		z = nil // z is an alias for limit - cannot reuse
	 866  	}
	 867  	z = z.make(len(limit))
	 868  
	 869  	bitLengthOfMSW := uint(n % _W)
	 870  	if bitLengthOfMSW == 0 {
	 871  		bitLengthOfMSW = _W
	 872  	}
	 873  	mask := Word((1 << bitLengthOfMSW) - 1)
	 874  
	 875  	for {
	 876  		switch _W {
	 877  		case 32:
	 878  			for i := range z {
	 879  				z[i] = Word(rand.Uint32())
	 880  			}
	 881  		case 64:
	 882  			for i := range z {
	 883  				z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
	 884  			}
	 885  		default:
	 886  			panic("unknown word size")
	 887  		}
	 888  		z[len(limit)-1] &= mask
	 889  		if z.cmp(limit) < 0 {
	 890  			break
	 891  		}
	 892  	}
	 893  
	 894  	return z.norm()
	 895  }
	 896  
	 897  // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
	 898  // otherwise it sets z to x**y. The result is the value of z.
	 899  func (z nat) expNN(x, y, m nat) nat {
	 900  	if alias(z, x) || alias(z, y) {
	 901  		// We cannot allow in-place modification of x or y.
	 902  		z = nil
	 903  	}
	 904  
	 905  	// x**y mod 1 == 0
	 906  	if len(m) == 1 && m[0] == 1 {
	 907  		return z.setWord(0)
	 908  	}
	 909  	// m == 0 || m > 1
	 910  
	 911  	// x**0 == 1
	 912  	if len(y) == 0 {
	 913  		return z.setWord(1)
	 914  	}
	 915  	// y > 0
	 916  
	 917  	// x**1 mod m == x mod m
	 918  	if len(y) == 1 && y[0] == 1 && len(m) != 0 {
	 919  		_, z = nat(nil).div(z, x, m)
	 920  		return z
	 921  	}
	 922  	// y > 1
	 923  
	 924  	if len(m) != 0 {
	 925  		// We likely end up being as long as the modulus.
	 926  		z = z.make(len(m))
	 927  	}
	 928  	z = z.set(x)
	 929  
	 930  	// If the base is non-trivial and the exponent is large, we use
	 931  	// 4-bit, windowed exponentiation. This involves precomputing 14 values
	 932  	// (x^2...x^15) but then reduces the number of multiply-reduces by a
	 933  	// third. Even for a 32-bit exponent, this reduces the number of
	 934  	// operations. Uses Montgomery method for odd moduli.
	 935  	if x.cmp(natOne) > 0 && len(y) > 1 && len(m) > 0 {
	 936  		if m[0]&1 == 1 {
	 937  			return z.expNNMontgomery(x, y, m)
	 938  		}
	 939  		return z.expNNWindowed(x, y, m)
	 940  	}
	 941  
	 942  	v := y[len(y)-1] // v > 0 because y is normalized and y > 0
	 943  	shift := nlz(v) + 1
	 944  	v <<= shift
	 945  	var q nat
	 946  
	 947  	const mask = 1 << (_W - 1)
	 948  
	 949  	// We walk through the bits of the exponent one by one. Each time we
	 950  	// see a bit, we square, thus doubling the power. If the bit is a one,
	 951  	// we also multiply by x, thus adding one to the power.
	 952  
	 953  	w := _W - int(shift)
	 954  	// zz and r are used to avoid allocating in mul and div as
	 955  	// otherwise the arguments would alias.
	 956  	var zz, r nat
	 957  	for j := 0; j < w; j++ {
	 958  		zz = zz.sqr(z)
	 959  		zz, z = z, zz
	 960  
	 961  		if v&mask != 0 {
	 962  			zz = zz.mul(z, x)
	 963  			zz, z = z, zz
	 964  		}
	 965  
	 966  		if len(m) != 0 {
	 967  			zz, r = zz.div(r, z, m)
	 968  			zz, r, q, z = q, z, zz, r
	 969  		}
	 970  
	 971  		v <<= 1
	 972  	}
	 973  
	 974  	for i := len(y) - 2; i >= 0; i-- {
	 975  		v = y[i]
	 976  
	 977  		for j := 0; j < _W; j++ {
	 978  			zz = zz.sqr(z)
	 979  			zz, z = z, zz
	 980  
	 981  			if v&mask != 0 {
	 982  				zz = zz.mul(z, x)
	 983  				zz, z = z, zz
	 984  			}
	 985  
	 986  			if len(m) != 0 {
	 987  				zz, r = zz.div(r, z, m)
	 988  				zz, r, q, z = q, z, zz, r
	 989  			}
	 990  
	 991  			v <<= 1
	 992  		}
	 993  	}
	 994  
	 995  	return z.norm()
	 996  }
	 997  
	 998  // expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
	 999  func (z nat) expNNWindowed(x, y, m nat) nat {
	1000  	// zz and r are used to avoid allocating in mul and div as otherwise
	1001  	// the arguments would alias.
	1002  	var zz, r nat
	1003  
	1004  	const n = 4
	1005  	// powers[i] contains x^i.
	1006  	var powers [1 << n]nat
	1007  	powers[0] = natOne
	1008  	powers[1] = x
	1009  	for i := 2; i < 1<<n; i += 2 {
	1010  		p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
	1011  		*p = p.sqr(*p2)
	1012  		zz, r = zz.div(r, *p, m)
	1013  		*p, r = r, *p
	1014  		*p1 = p1.mul(*p, x)
	1015  		zz, r = zz.div(r, *p1, m)
	1016  		*p1, r = r, *p1
	1017  	}
	1018  
	1019  	z = z.setWord(1)
	1020  
	1021  	for i := len(y) - 1; i >= 0; i-- {
	1022  		yi := y[i]
	1023  		for j := 0; j < _W; j += n {
	1024  			if i != len(y)-1 || j != 0 {
	1025  				// Unrolled loop for significant performance
	1026  				// gain. Use go test -bench=".*" in crypto/rsa
	1027  				// to check performance before making changes.
	1028  				zz = zz.sqr(z)
	1029  				zz, z = z, zz
	1030  				zz, r = zz.div(r, z, m)
	1031  				z, r = r, z
	1032  
	1033  				zz = zz.sqr(z)
	1034  				zz, z = z, zz
	1035  				zz, r = zz.div(r, z, m)
	1036  				z, r = r, z
	1037  
	1038  				zz = zz.sqr(z)
	1039  				zz, z = z, zz
	1040  				zz, r = zz.div(r, z, m)
	1041  				z, r = r, z
	1042  
	1043  				zz = zz.sqr(z)
	1044  				zz, z = z, zz
	1045  				zz, r = zz.div(r, z, m)
	1046  				z, r = r, z
	1047  			}
	1048  
	1049  			zz = zz.mul(z, powers[yi>>(_W-n)])
	1050  			zz, z = z, zz
	1051  			zz, r = zz.div(r, z, m)
	1052  			z, r = r, z
	1053  
	1054  			yi <<= n
	1055  		}
	1056  	}
	1057  
	1058  	return z.norm()
	1059  }
	1060  
	1061  // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
	1062  // Uses Montgomery representation.
	1063  func (z nat) expNNMontgomery(x, y, m nat) nat {
	1064  	numWords := len(m)
	1065  
	1066  	// We want the lengths of x and m to be equal.
	1067  	// It is OK if x >= m as long as len(x) == len(m).
	1068  	if len(x) > numWords {
	1069  		_, x = nat(nil).div(nil, x, m)
	1070  		// Note: now len(x) <= numWords, not guaranteed ==.
	1071  	}
	1072  	if len(x) < numWords {
	1073  		rr := make(nat, numWords)
	1074  		copy(rr, x)
	1075  		x = rr
	1076  	}
	1077  
	1078  	// Ideally the precomputations would be performed outside, and reused
	1079  	// k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
	1080  	// Iteration for Multiplicative Inverses Modulo Prime Powers".
	1081  	k0 := 2 - m[0]
	1082  	t := m[0] - 1
	1083  	for i := 1; i < _W; i <<= 1 {
	1084  		t *= t
	1085  		k0 *= (t + 1)
	1086  	}
	1087  	k0 = -k0
	1088  
	1089  	// RR = 2**(2*_W*len(m)) mod m
	1090  	RR := nat(nil).setWord(1)
	1091  	zz := nat(nil).shl(RR, uint(2*numWords*_W))
	1092  	_, RR = nat(nil).div(RR, zz, m)
	1093  	if len(RR) < numWords {
	1094  		zz = zz.make(numWords)
	1095  		copy(zz, RR)
	1096  		RR = zz
	1097  	}
	1098  	// one = 1, with equal length to that of m
	1099  	one := make(nat, numWords)
	1100  	one[0] = 1
	1101  
	1102  	const n = 4
	1103  	// powers[i] contains x^i
	1104  	var powers [1 << n]nat
	1105  	powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
	1106  	powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
	1107  	for i := 2; i < 1<<n; i++ {
	1108  		powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
	1109  	}
	1110  
	1111  	// initialize z = 1 (Montgomery 1)
	1112  	z = z.make(numWords)
	1113  	copy(z, powers[0])
	1114  
	1115  	zz = zz.make(numWords)
	1116  
	1117  	// same windowed exponent, but with Montgomery multiplications
	1118  	for i := len(y) - 1; i >= 0; i-- {
	1119  		yi := y[i]
	1120  		for j := 0; j < _W; j += n {
	1121  			if i != len(y)-1 || j != 0 {
	1122  				zz = zz.montgomery(z, z, m, k0, numWords)
	1123  				z = z.montgomery(zz, zz, m, k0, numWords)
	1124  				zz = zz.montgomery(z, z, m, k0, numWords)
	1125  				z = z.montgomery(zz, zz, m, k0, numWords)
	1126  			}
	1127  			zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
	1128  			z, zz = zz, z
	1129  			yi <<= n
	1130  		}
	1131  	}
	1132  	// convert to regular number
	1133  	zz = zz.montgomery(z, one, m, k0, numWords)
	1134  
	1135  	// One last reduction, just in case.
	1136  	// See golang.org/issue/13907.
	1137  	if zz.cmp(m) >= 0 {
	1138  		// Common case is m has high bit set; in that case,
	1139  		// since zz is the same length as m, there can be just
	1140  		// one multiple of m to remove. Just subtract.
	1141  		// We think that the subtract should be sufficient in general,
	1142  		// so do that unconditionally, but double-check,
	1143  		// in case our beliefs are wrong.
	1144  		// The div is not expected to be reached.
	1145  		zz = zz.sub(zz, m)
	1146  		if zz.cmp(m) >= 0 {
	1147  			_, zz = nat(nil).div(nil, zz, m)
	1148  		}
	1149  	}
	1150  
	1151  	return zz.norm()
	1152  }
	1153  
	1154  // bytes writes the value of z into buf using big-endian encoding.
	1155  // The value of z is encoded in the slice buf[i:]. If the value of z
	1156  // cannot be represented in buf, bytes panics. The number i of unused
	1157  // bytes at the beginning of buf is returned as result.
	1158  func (z nat) bytes(buf []byte) (i int) {
	1159  	i = len(buf)
	1160  	for _, d := range z {
	1161  		for j := 0; j < _S; j++ {
	1162  			i--
	1163  			if i >= 0 {
	1164  				buf[i] = byte(d)
	1165  			} else if byte(d) != 0 {
	1166  				panic("math/big: buffer too small to fit value")
	1167  			}
	1168  			d >>= 8
	1169  		}
	1170  	}
	1171  
	1172  	if i < 0 {
	1173  		i = 0
	1174  	}
	1175  	for i < len(buf) && buf[i] == 0 {
	1176  		i++
	1177  	}
	1178  
	1179  	return
	1180  }
	1181  
	1182  // bigEndianWord returns the contents of buf interpreted as a big-endian encoded Word value.
	1183  func bigEndianWord(buf []byte) Word {
	1184  	if _W == 64 {
	1185  		return Word(binary.BigEndian.Uint64(buf))
	1186  	}
	1187  	return Word(binary.BigEndian.Uint32(buf))
	1188  }
	1189  
	1190  // setBytes interprets buf as the bytes of a big-endian unsigned
	1191  // integer, sets z to that value, and returns z.
	1192  func (z nat) setBytes(buf []byte) nat {
	1193  	z = z.make((len(buf) + _S - 1) / _S)
	1194  
	1195  	i := len(buf)
	1196  	for k := 0; i >= _S; k++ {
	1197  		z[k] = bigEndianWord(buf[i-_S : i])
	1198  		i -= _S
	1199  	}
	1200  	if i > 0 {
	1201  		var d Word
	1202  		for s := uint(0); i > 0; s += 8 {
	1203  			d |= Word(buf[i-1]) << s
	1204  			i--
	1205  		}
	1206  		z[len(z)-1] = d
	1207  	}
	1208  
	1209  	return z.norm()
	1210  }
	1211  
	1212  // sqrt sets z = ⌊√x⌋
	1213  func (z nat) sqrt(x nat) nat {
	1214  	if x.cmp(natOne) <= 0 {
	1215  		return z.set(x)
	1216  	}
	1217  	if alias(z, x) {
	1218  		z = nil
	1219  	}
	1220  
	1221  	// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
	1222  	// See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
	1223  	// https://members.loria.fr/PZimmermann/mca/pub226.html
	1224  	// If x is one less than a perfect square, the sequence oscillates between the correct z and z+1;
	1225  	// otherwise it converges to the correct z and stays there.
	1226  	var z1, z2 nat
	1227  	z1 = z
	1228  	z1 = z1.setUint64(1)
	1229  	z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x
	1230  	for n := 0; ; n++ {
	1231  		z2, _ = z2.div(nil, x, z1)
	1232  		z2 = z2.add(z2, z1)
	1233  		z2 = z2.shr(z2, 1)
	1234  		if z2.cmp(z1) >= 0 {
	1235  			// z1 is answer.
	1236  			// Figure out whether z1 or z2 is currently aliased to z by looking at loop count.
	1237  			if n&1 == 0 {
	1238  				return z1
	1239  			}
	1240  			return z.set(z1)
	1241  		}
	1242  		z1, z2 = z2, z1
	1243  	}
	1244  }
	1245  

View as plain text