aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/crypto/rsa/pss.go173
-rw-r--r--src/crypto/rsa/rsa.go9
2 files changed, 96 insertions, 86 deletions
diff --git a/src/crypto/rsa/pss.go b/src/crypto/rsa/pss.go
index 3ff0c2f4d0..f9844d8732 100644
--- a/src/crypto/rsa/pss.go
+++ b/src/crypto/rsa/pss.go
@@ -4,9 +4,7 @@
package rsa
-// This file implements the PSS signature scheme [1].
-//
-// [1] https://www.emc.com/collateral/white-papers/h11300-pkcs-1v2-2-rsa-cryptography-standard-wp.pdf
+// This file implements the RSASSA-PSS signature scheme according to RFC 8017.
import (
"bytes"
@@ -17,8 +15,22 @@ import (
"math/big"
)
+// Per RFC 8017, Section 9.1
+//
+// EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc
+//
+// where
+//
+// DB = PS || 0x01 || salt
+//
+// and PS can be empty so
+//
+// emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2
+//
+
func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
- // See [1], section 9.1.1
+ // See RFC 8017, Section 9.1.1.
+
hLen := hash.Size()
sLen := len(salt)
emLen := (emBits + 7) / 8
@@ -30,7 +42,7 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
// 2. Let mHash = Hash(M), an octet string of length hLen.
if len(mHash) != hLen {
- return nil, errors.New("crypto/rsa: input must be hashed message")
+ return nil, errors.New("crypto/rsa: input must be hashed with given hash")
}
// 3. If emLen < hLen + sLen + 2, output "encoding error" and stop.
@@ -40,8 +52,9 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
}
em := make([]byte, emLen)
- db := em[:emLen-sLen-hLen-2+1+sLen]
- h := em[emLen-sLen-hLen-2+1+sLen : emLen-1]
+ psLen := emLen - sLen - hLen - 2
+ db := em[:psLen+1+sLen]
+ h := em[psLen+1+sLen : emLen-1]
// 4. Generate a random octet string salt of length sLen; if sLen = 0,
// then salt is the empty string.
@@ -69,8 +82,8 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
// 8. Let DB = PS || 0x01 || salt; DB is an octet string of length
// emLen - hLen - 1.
- db[emLen-sLen-hLen-2] = 0x01
- copy(db[emLen-sLen-hLen-1:], salt)
+ db[psLen] = 0x01
+ copy(db[psLen+1:], salt)
// 9. Let dbMask = MGF(H, emLen - hLen - 1).
//
@@ -81,47 +94,57 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
// maskedDB to zero.
- db[0] &= (0xFF >> uint(8*emLen-emBits))
+ db[0] &= 0xff >> (8*emLen - emBits)
// 12. Let EM = maskedDB || H || 0xbc.
- em[emLen-1] = 0xBC
+ em[emLen-1] = 0xbc
// 13. Output EM.
return em, nil
}
func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
+ // See RFC 8017, Section 9.1.2.
+
+ hLen := hash.Size()
+ if sLen == PSSSaltLengthEqualsHash {
+ sLen = hLen
+ }
+ emLen := (emBits + 7) / 8
+ if emLen != len(em) {
+ return errors.New("rsa: internal error: inconsistent length")
+ }
+
// 1. If the length of M is greater than the input limitation for the
// hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
// and stop.
//
// 2. Let mHash = Hash(M), an octet string of length hLen.
- hLen := hash.Size()
if hLen != len(mHash) {
return ErrVerification
}
// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
- emLen := (emBits + 7) / 8
if emLen < hLen+sLen+2 {
return ErrVerification
}
// 4. If the rightmost octet of EM does not have hexadecimal value
// 0xbc, output "inconsistent" and stop.
- if em[len(em)-1] != 0xBC {
+ if em[emLen-1] != 0xbc {
return ErrVerification
}
// 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
// let H be the next hLen octets.
db := em[:emLen-hLen-1]
- h := em[emLen-hLen-1 : len(em)-1]
+ h := em[emLen-hLen-1 : emLen-1]
// 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in
// maskedDB are not all equal to zero, output "inconsistent" and
// stop.
- if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 {
+ var bitMask byte = 0xff >> (8*emLen - emBits)
+ if em[0] & ^bitMask != 0 {
return ErrVerification
}
@@ -132,37 +155,30 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
// 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
// to zero.
- db[0] &= (0xFF >> uint(8*emLen-emBits))
+ db[0] &= bitMask
+ // If we don't know the salt length, look for the 0x01 delimiter.
if sLen == PSSSaltLengthAuto {
- FindSaltLength:
- for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- {
- switch db[emLen-hLen-sLen-2] {
- case 1:
- break FindSaltLength
- case 0:
- continue
- default:
- return ErrVerification
- }
- }
- if sLen < 0 {
+ psLen := bytes.IndexByte(db, 0x01)
+ if psLen < 0 {
return ErrVerification
}
- } else {
- // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
- // or if the octet at position emLen - hLen - sLen - 1 (the leftmost
- // position is "position 1") does not have hexadecimal value 0x01,
- // output "inconsistent" and stop.
- for _, e := range db[:emLen-hLen-sLen-2] {
- if e != 0x00 {
- return ErrVerification
- }
- }
- if db[emLen-hLen-sLen-2] != 0x01 {
+ sLen = len(db) - psLen - 1
+ }
+
+ // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
+ // or if the octet at position emLen - hLen - sLen - 1 (the leftmost
+ // position is "position 1") does not have hexadecimal value 0x01,
+ // output "inconsistent" and stop.
+ psLen := emLen - hLen - sLen - 2
+ for _, e := range db[:psLen] {
+ if e != 0x00 {
return ErrVerification
}
}
+ if db[psLen] != 0x01 {
+ return ErrVerification
+ }
// 11. Let salt be the last sLen octets of DB.
salt := db[len(db)-sLen:]
@@ -181,19 +197,19 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
h0 := hash.Sum(nil)
// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
- if !bytes.Equal(h0, h) {
+ if !bytes.Equal(h0, h) { // TODO: constant time?
return ErrVerification
}
return nil
}
-// signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt.
+// signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
// Note that hashed must be the result of hashing the input message using the
// given hash function. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
- nBits := priv.N.BitLen()
- em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New())
+ emBits := priv.N.BitLen() - 1
+ em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
if err != nil {
return
}
@@ -202,7 +218,7 @@ func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed,
if err != nil {
return
}
- s = make([]byte, (nBits+7)/8)
+ s = make([]byte, priv.Size())
copyWithLeftPad(s, c.Bytes())
return
}
@@ -223,16 +239,15 @@ type PSSOptions struct {
// PSSSaltLength constants.
SaltLength int
- // Hash, if not zero, overrides the hash function passed to SignPSS.
- // This is the only way to specify the hash function when using the
- // crypto.Signer interface.
+ // Hash is the hash function used to generate the message digest. If not
+ // zero, it overrides the hash function passed to SignPSS. It's required
+ // when using PrivateKey.Sign.
Hash crypto.Hash
}
-// HashFunc returns pssOpts.Hash so that PSSOptions implements
-// crypto.SignerOpts.
-func (pssOpts *PSSOptions) HashFunc() crypto.Hash {
- return pssOpts.Hash
+// HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts.
+func (opts *PSSOptions) HashFunc() crypto.Hash {
+ return opts.Hash
}
func (opts *PSSOptions) saltLength() int {
@@ -242,56 +257,50 @@ func (opts *PSSOptions) saltLength() int {
return opts.SaltLength
}
-// SignPSS calculates the signature of hashed using RSASSA-PSS [1].
-// Note that hashed must be the result of hashing the input message using the
-// given hash function. The opts argument may be nil, in which case sensible
-// defaults are used.
-func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) ([]byte, error) {
+// SignPSS calculates the signature of digest using PSS.
+//
+// digest must be the result of hashing the input message using the given hash
+// function. The opts argument may be nil, in which case sensible defaults are
+// used. If opts.Hash is set, it overrides hash.
+func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
+ if opts != nil && opts.Hash != 0 {
+ hash = opts.Hash
+ }
+
saltLength := opts.saltLength()
switch saltLength {
case PSSSaltLengthAuto:
- saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size()
+ saltLength = priv.Size() - 2 - hash.Size()
case PSSSaltLengthEqualsHash:
saltLength = hash.Size()
}
- if opts != nil && opts.Hash != 0 {
- hash = opts.Hash
- }
-
salt := make([]byte, saltLength)
if _, err := io.ReadFull(rand, salt); err != nil {
return nil, err
}
- return signPSSWithSalt(rand, priv, hash, hashed, salt)
+ return signPSSWithSalt(rand, priv, hash, digest, salt)
}
// VerifyPSS verifies a PSS signature.
-// hashed is the result of hashing the input message using the given hash
-// function and sig is the signature. A valid signature is indicated by
-// returning a nil error. The opts argument may be nil, in which case sensible
-// defaults are used.
-func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error {
- return verifyPSS(pub, hash, hashed, sig, opts.saltLength())
-}
-
-// verifyPSS verifies a PSS signature with the given salt length.
-func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
- nBits := pub.N.BitLen()
- if len(sig) != (nBits+7)/8 {
+//
+// A valid signature is indicated by returning a nil error. digest must be the
+// result of hashing the input message using the given hash function. The opts
+// argument may be nil, in which case sensible defaults are used. opts.Hash is
+// ignored.
+func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
+ if len(sig) != pub.Size() {
return ErrVerification
}
s := new(big.Int).SetBytes(sig)
m := encrypt(new(big.Int), pub, s)
- emBits := nBits - 1
+ emBits := pub.N.BitLen() - 1
emLen := (emBits + 7) / 8
- if emLen < len(m.Bytes()) {
+ emBytes := m.Bytes()
+ if emLen < len(emBytes) {
return ErrVerification
}
em := make([]byte, emLen)
- copyWithLeftPad(em, m.Bytes())
- if saltLen == PSSSaltLengthEqualsHash {
- saltLen = hash.Size()
- }
- return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New())
+ copyWithLeftPad(em, emBytes)
+ return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
}
diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go
index 5a42990640..b4bfa13def 100644
--- a/src/crypto/rsa/rsa.go
+++ b/src/crypto/rsa/rsa.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package rsa implements RSA encryption as specified in PKCS#1.
+// Package rsa implements RSA encryption as specified in PKCS#1 and RFC 8017.
//
// RSA is a single, fundamental operation that is used in this package to
// implement either public-key encryption or public-key signatures.
@@ -10,13 +10,13 @@
// The original specification for encryption and signatures with RSA is PKCS#1
// and the terms "RSA encryption" and "RSA signatures" by default refer to
// PKCS#1 version 1.5. However, that specification has flaws and new designs
-// should use version two, usually called by just OAEP and PSS, where
+// should use version 2, usually called by just OAEP and PSS, where
// possible.
//
// Two sets of interfaces are included in this package. When a more abstract
// interface isn't necessary, there are functions for encrypting/decrypting
// with v1.5/OAEP and signing/verifying with v1.5/PSS. If one needs to abstract
-// over the public-key primitive, the PrivateKey struct implements the
+// over the public key primitive, the PrivateKey type implements the
// Decrypter and Signer interfaces from the crypto package.
//
// The RSA operations in this package are not implemented using constant-time algorithms.
@@ -111,7 +111,8 @@ func (priv *PrivateKey) Public() crypto.PublicKey {
// Sign signs digest with priv, reading randomness from rand. If opts is a
// *PSSOptions then the PSS algorithm will be used, otherwise PKCS#1 v1.5 will
-// be used.
+// be used. digest must be the result of hashing the input message using
+// opts.HashFunc().
//
// This method implements crypto.Signer, which is an interface to support keys
// where the private part is kept in, for example, a hardware module. Common