aboutsummaryrefslogtreecommitdiff
path: root/src/crypto/rsa/pss.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/rsa/pss.go')
-rw-r--r--src/crypto/rsa/pss.go35
1 files changed, 35 insertions, 0 deletions
diff --git a/src/crypto/rsa/pss.go b/src/crypto/rsa/pss.go
index 3ff0c2f4d0..e32cb7e0a7 100644
--- a/src/crypto/rsa/pss.go
+++ b/src/crypto/rsa/pss.go
@@ -17,6 +17,8 @@ import (
"math/big"
)
+import "crypto/internal/boring"
+
func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
// See [1], section 9.1.1
hLen := hash.Size()
@@ -197,6 +199,21 @@ func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed,
if err != nil {
return
}
+
+ if boring.Enabled {
+ bkey, err := boringPrivateKey(priv)
+ if err != nil {
+ return nil, err
+ }
+ // Note: BoringCrypto takes care of the "AndCheck" part of "decryptAndCheck".
+ // (It's not just decrypt.)
+ s, err := boring.DecryptRSANoPadding(bkey, em)
+ if err != nil {
+ return nil, err
+ }
+ return s, nil
+ }
+
m := new(big.Int).SetBytes(em)
c, err := decryptAndCheck(rand, priv, m)
if err != nil {
@@ -259,6 +276,14 @@ func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte,
hash = opts.Hash
}
+ if boring.Enabled && rand == boring.RandReader {
+ bkey, err := boringPrivateKey(priv)
+ if err != nil {
+ return nil, err
+ }
+ return boring.SignRSAPSS(bkey, hash, hashed, saltLength)
+ }
+
salt := make([]byte, saltLength)
if _, err := io.ReadFull(rand, salt); err != nil {
return nil, err
@@ -277,6 +302,16 @@ func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts
// verifyPSS verifies a PSS signature with the given salt length.
func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
+ if boring.Enabled {
+ bkey, err := boringPublicKey(pub)
+ if err != nil {
+ return err
+ }
+ if err := boring.VerifyRSAPSS(bkey, hash, hashed, sig, saltLen); err != nil {
+ return ErrVerification
+ }
+ return nil
+ }
nBits := pub.N.BitLen()
if len(sig) != (nBits+7)/8 {
return ErrVerification