...

Source file src/crypto/rsa/pss.go

Documentation: crypto/rsa

		 1  // Copyright 2013 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  package rsa
		 6  
		 7  // This file implements the RSASSA-PSS signature scheme according to RFC 8017.
		 8  
		 9  import (
		10  	"bytes"
		11  	"crypto"
		12  	"errors"
		13  	"hash"
		14  	"io"
		15  	"math/big"
		16  )
		17  
		18  // Per RFC 8017, Section 9.1
		19  //
		20  //		 EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc
		21  //
		22  // where
		23  //
		24  //		 DB = PS || 0x01 || salt
		25  //
		26  // and PS can be empty so
		27  //
		28  //		 emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2
		29  //
		30  
		31  func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
		32  	// See RFC 8017, Section 9.1.1.
		33  
		34  	hLen := hash.Size()
		35  	sLen := len(salt)
		36  	emLen := (emBits + 7) / 8
		37  
		38  	// 1.	If the length of M is greater than the input limitation for the
		39  	//		 hash function (2^61 - 1 octets for SHA-1), output "message too
		40  	//		 long" and stop.
		41  	//
		42  	// 2.	Let mHash = Hash(M), an octet string of length hLen.
		43  
		44  	if len(mHash) != hLen {
		45  		return nil, errors.New("crypto/rsa: input must be hashed with given hash")
		46  	}
		47  
		48  	// 3.	If emLen < hLen + sLen + 2, output "encoding error" and stop.
		49  
		50  	if emLen < hLen+sLen+2 {
		51  		return nil, errors.New("crypto/rsa: key size too small for PSS signature")
		52  	}
		53  
		54  	em := make([]byte, emLen)
		55  	psLen := emLen - sLen - hLen - 2
		56  	db := em[:psLen+1+sLen]
		57  	h := em[psLen+1+sLen : emLen-1]
		58  
		59  	// 4.	Generate a random octet string salt of length sLen; if sLen = 0,
		60  	//		 then salt is the empty string.
		61  	//
		62  	// 5.	Let
		63  	//			 M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
		64  	//
		65  	//		 M' is an octet string of length 8 + hLen + sLen with eight
		66  	//		 initial zero octets.
		67  	//
		68  	// 6.	Let H = Hash(M'), an octet string of length hLen.
		69  
		70  	var prefix [8]byte
		71  
		72  	hash.Write(prefix[:])
		73  	hash.Write(mHash)
		74  	hash.Write(salt)
		75  
		76  	h = hash.Sum(h[:0])
		77  	hash.Reset()
		78  
		79  	// 7.	Generate an octet string PS consisting of emLen - sLen - hLen - 2
		80  	//		 zero octets. The length of PS may be 0.
		81  	//
		82  	// 8.	Let DB = PS || 0x01 || salt; DB is an octet string of length
		83  	//		 emLen - hLen - 1.
		84  
		85  	db[psLen] = 0x01
		86  	copy(db[psLen+1:], salt)
		87  
		88  	// 9.	Let dbMask = MGF(H, emLen - hLen - 1).
		89  	//
		90  	// 10. Let maskedDB = DB \xor dbMask.
		91  
		92  	mgf1XOR(db, hash, h)
		93  
		94  	// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
		95  	//		 maskedDB to zero.
		96  
		97  	db[0] &= 0xff >> (8*emLen - emBits)
		98  
		99  	// 12. Let EM = maskedDB || H || 0xbc.
	 100  	em[emLen-1] = 0xbc
	 101  
	 102  	// 13. Output EM.
	 103  	return em, nil
	 104  }
	 105  
	 106  func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
	 107  	// See RFC 8017, Section 9.1.2.
	 108  
	 109  	hLen := hash.Size()
	 110  	if sLen == PSSSaltLengthEqualsHash {
	 111  		sLen = hLen
	 112  	}
	 113  	emLen := (emBits + 7) / 8
	 114  	if emLen != len(em) {
	 115  		return errors.New("rsa: internal error: inconsistent length")
	 116  	}
	 117  
	 118  	// 1.	If the length of M is greater than the input limitation for the
	 119  	//		 hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
	 120  	//		 and stop.
	 121  	//
	 122  	// 2.	Let mHash = Hash(M), an octet string of length hLen.
	 123  	if hLen != len(mHash) {
	 124  		return ErrVerification
	 125  	}
	 126  
	 127  	// 3.	If emLen < hLen + sLen + 2, output "inconsistent" and stop.
	 128  	if emLen < hLen+sLen+2 {
	 129  		return ErrVerification
	 130  	}
	 131  
	 132  	// 4.	If the rightmost octet of EM does not have hexadecimal value
	 133  	//		 0xbc, output "inconsistent" and stop.
	 134  	if em[emLen-1] != 0xbc {
	 135  		return ErrVerification
	 136  	}
	 137  
	 138  	// 5.	Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
	 139  	//		 let H be the next hLen octets.
	 140  	db := em[:emLen-hLen-1]
	 141  	h := em[emLen-hLen-1 : emLen-1]
	 142  
	 143  	// 6.	If the leftmost 8 * emLen - emBits bits of the leftmost octet in
	 144  	//		 maskedDB are not all equal to zero, output "inconsistent" and
	 145  	//		 stop.
	 146  	var bitMask byte = 0xff >> (8*emLen - emBits)
	 147  	if em[0] & ^bitMask != 0 {
	 148  		return ErrVerification
	 149  	}
	 150  
	 151  	// 7.	Let dbMask = MGF(H, emLen - hLen - 1).
	 152  	//
	 153  	// 8.	Let DB = maskedDB \xor dbMask.
	 154  	mgf1XOR(db, hash, h)
	 155  
	 156  	// 9.	Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
	 157  	//		 to zero.
	 158  	db[0] &= bitMask
	 159  
	 160  	// If we don't know the salt length, look for the 0x01 delimiter.
	 161  	if sLen == PSSSaltLengthAuto {
	 162  		psLen := bytes.IndexByte(db, 0x01)
	 163  		if psLen < 0 {
	 164  			return ErrVerification
	 165  		}
	 166  		sLen = len(db) - psLen - 1
	 167  	}
	 168  
	 169  	// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
	 170  	//		 or if the octet at position emLen - hLen - sLen - 1 (the leftmost
	 171  	//		 position is "position 1") does not have hexadecimal value 0x01,
	 172  	//		 output "inconsistent" and stop.
	 173  	psLen := emLen - hLen - sLen - 2
	 174  	for _, e := range db[:psLen] {
	 175  		if e != 0x00 {
	 176  			return ErrVerification
	 177  		}
	 178  	}
	 179  	if db[psLen] != 0x01 {
	 180  		return ErrVerification
	 181  	}
	 182  
	 183  	// 11.	Let salt be the last sLen octets of DB.
	 184  	salt := db[len(db)-sLen:]
	 185  
	 186  	// 12.	Let
	 187  	//					M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
	 188  	//		 M' is an octet string of length 8 + hLen + sLen with eight
	 189  	//		 initial zero octets.
	 190  	//
	 191  	// 13. Let H' = Hash(M'), an octet string of length hLen.
	 192  	var prefix [8]byte
	 193  	hash.Write(prefix[:])
	 194  	hash.Write(mHash)
	 195  	hash.Write(salt)
	 196  
	 197  	h0 := hash.Sum(nil)
	 198  
	 199  	// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
	 200  	if !bytes.Equal(h0, h) { // TODO: constant time?
	 201  		return ErrVerification
	 202  	}
	 203  	return nil
	 204  }
	 205  
	 206  // signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
	 207  // Note that hashed must be the result of hashing the input message using the
	 208  // given hash function. salt is a random sequence of bytes whose length will be
	 209  // later used to verify the signature.
	 210  func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
	 211  	emBits := priv.N.BitLen() - 1
	 212  	em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
	 213  	if err != nil {
	 214  		return nil, err
	 215  	}
	 216  	m := new(big.Int).SetBytes(em)
	 217  	c, err := decryptAndCheck(rand, priv, m)
	 218  	if err != nil {
	 219  		return nil, err
	 220  	}
	 221  	s := make([]byte, priv.Size())
	 222  	return c.FillBytes(s), nil
	 223  }
	 224  
	 225  const (
	 226  	// PSSSaltLengthAuto causes the salt in a PSS signature to be as large
	 227  	// as possible when signing, and to be auto-detected when verifying.
	 228  	PSSSaltLengthAuto = 0
	 229  	// PSSSaltLengthEqualsHash causes the salt length to equal the length
	 230  	// of the hash used in the signature.
	 231  	PSSSaltLengthEqualsHash = -1
	 232  )
	 233  
	 234  // PSSOptions contains options for creating and verifying PSS signatures.
	 235  type PSSOptions struct {
	 236  	// SaltLength controls the length of the salt used in the PSS
	 237  	// signature. It can either be a number of bytes, or one of the special
	 238  	// PSSSaltLength constants.
	 239  	SaltLength int
	 240  
	 241  	// Hash is the hash function used to generate the message digest. If not
	 242  	// zero, it overrides the hash function passed to SignPSS. It's required
	 243  	// when using PrivateKey.Sign.
	 244  	Hash crypto.Hash
	 245  }
	 246  
	 247  // HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts.
	 248  func (opts *PSSOptions) HashFunc() crypto.Hash {
	 249  	return opts.Hash
	 250  }
	 251  
	 252  func (opts *PSSOptions) saltLength() int {
	 253  	if opts == nil {
	 254  		return PSSSaltLengthAuto
	 255  	}
	 256  	return opts.SaltLength
	 257  }
	 258  
	 259  // SignPSS calculates the signature of digest using PSS.
	 260  //
	 261  // digest must be the result of hashing the input message using the given hash
	 262  // function. The opts argument may be nil, in which case sensible defaults are
	 263  // used. If opts.Hash is set, it overrides hash.
	 264  func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
	 265  	if opts != nil && opts.Hash != 0 {
	 266  		hash = opts.Hash
	 267  	}
	 268  
	 269  	saltLength := opts.saltLength()
	 270  	switch saltLength {
	 271  	case PSSSaltLengthAuto:
	 272  		saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
	 273  	case PSSSaltLengthEqualsHash:
	 274  		saltLength = hash.Size()
	 275  	}
	 276  
	 277  	salt := make([]byte, saltLength)
	 278  	if _, err := io.ReadFull(rand, salt); err != nil {
	 279  		return nil, err
	 280  	}
	 281  	return signPSSWithSalt(rand, priv, hash, digest, salt)
	 282  }
	 283  
	 284  // VerifyPSS verifies a PSS signature.
	 285  //
	 286  // A valid signature is indicated by returning a nil error. digest must be the
	 287  // result of hashing the input message using the given hash function. The opts
	 288  // argument may be nil, in which case sensible defaults are used. opts.Hash is
	 289  // ignored.
	 290  func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
	 291  	if len(sig) != pub.Size() {
	 292  		return ErrVerification
	 293  	}
	 294  	s := new(big.Int).SetBytes(sig)
	 295  	m := encrypt(new(big.Int), pub, s)
	 296  	emBits := pub.N.BitLen() - 1
	 297  	emLen := (emBits + 7) / 8
	 298  	if m.BitLen() > emLen*8 {
	 299  		return ErrVerification
	 300  	}
	 301  	em := m.FillBytes(make([]byte, emLen))
	 302  	return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
	 303  }
	 304  

View as plain text