aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/crypto/x509/cert_pool.go28
-rw-r--r--src/crypto/x509/verify.go86
-rw-r--r--src/crypto/x509/verify_test.go119
3 files changed, 176 insertions, 57 deletions
diff --git a/src/crypto/x509/cert_pool.go b/src/crypto/x509/cert_pool.go
index 71ffbdf0e0..86e8cbe869 100644
--- a/src/crypto/x509/cert_pool.go
+++ b/src/crypto/x509/cert_pool.go
@@ -38,32 +38,16 @@ func SystemCertPool() (*CertPool, error) {
return loadSystemRoots()
}
-// findVerifiedParents attempts to find certificates in s which have signed the
-// given certificate. If any candidates were rejected then errCert will be set
-// to one of them, arbitrarily, and err will contain the reason that it was
-// rejected.
-func (s *CertPool) findVerifiedParents(cert *Certificate) (parents []int, errCert *Certificate, err error) {
+// findPotentialParents returns the indexes of certificates in s which might
+// have signed cert. The caller must not modify the returned slice.
+func (s *CertPool) findPotentialParents(cert *Certificate) []int {
if s == nil {
- return
+ return nil
}
- var candidates []int
-
if len(cert.AuthorityKeyId) > 0 {
- candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)]
- }
- if len(candidates) == 0 {
- candidates = s.byName[string(cert.RawIssuer)]
+ return s.bySubjectKeyId[string(cert.AuthorityKeyId)]
}
-
- for _, c := range candidates {
- if err = cert.CheckSignatureFrom(s.certs[c]); err == nil {
- parents = append(parents, c)
- } else {
- errCert = s.certs[c]
- }
- }
-
- return
+ return s.byName[string(cert.RawIssuer)]
}
func (s *CertPool) contains(cert *Certificate) bool {
diff --git a/src/crypto/x509/verify.go b/src/crypto/x509/verify.go
index 60e415b7ec..b70a97c952 100644
--- a/src/crypto/x509/verify.go
+++ b/src/crypto/x509/verify.go
@@ -765,7 +765,7 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
if opts.Roots.contains(c) {
candidateChains = append(candidateChains, []*Certificate{c})
} else {
- if candidateChains, err = c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts); err != nil {
+ if candidateChains, err = c.buildChains(nil, []*Certificate{c}, nil, &opts); err != nil {
return nil, err
}
}
@@ -802,58 +802,74 @@ func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate
return n
}
-func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err error) {
- possibleRoots, failedRoot, rootErr := opts.Roots.findVerifiedParents(c)
-nextRoot:
- for _, rootNum := range possibleRoots {
- root := opts.Roots.certs[rootNum]
+// maxChainSignatureChecks is the maximum number of CheckSignatureFrom calls
+// that an invocation of buildChains will (tranistively) make. Most chains are
+// less than 15 certificates long, so this leaves space for multiple chains and
+// for failed checks due to different intermediates having the same Subject.
+const maxChainSignatureChecks = 100
+func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, currentChain []*Certificate, sigChecks *int, opts *VerifyOptions) (chains [][]*Certificate, err error) {
+ var (
+ hintErr error
+ hintCert *Certificate
+ )
+
+ considerCandidate := func(certType int, candidate *Certificate) {
for _, cert := range currentChain {
- if cert.Equal(root) {
- continue nextRoot
+ if cert.Equal(candidate) {
+ return
}
}
- err = root.isValid(rootCertificate, currentChain, opts)
- if err != nil {
- continue
+ if sigChecks == nil {
+ sigChecks = new(int)
+ }
+ *sigChecks++
+ if *sigChecks > maxChainSignatureChecks {
+ err = errors.New("x509: signature check attempts limit reached while verifying certificate chain")
+ return
}
- chains = append(chains, appendToFreshChain(currentChain, root))
- }
- possibleIntermediates, failedIntermediate, intermediateErr := opts.Intermediates.findVerifiedParents(c)
-nextIntermediate:
- for _, intermediateNum := range possibleIntermediates {
- intermediate := opts.Intermediates.certs[intermediateNum]
- for _, cert := range currentChain {
- if cert.Equal(intermediate) {
- continue nextIntermediate
+ if err := c.CheckSignatureFrom(candidate); err != nil {
+ if hintErr == nil {
+ hintErr = err
+ hintCert = candidate
}
+ return
}
- err = intermediate.isValid(intermediateCertificate, currentChain, opts)
+
+ err = candidate.isValid(certType, currentChain, opts)
if err != nil {
- continue
+ return
}
- var childChains [][]*Certificate
- childChains, ok := cache[intermediateNum]
- if !ok {
- childChains, err = intermediate.buildChains(cache, appendToFreshChain(currentChain, intermediate), opts)
- cache[intermediateNum] = childChains
+
+ switch certType {
+ case rootCertificate:
+ chains = append(chains, appendToFreshChain(currentChain, candidate))
+ case intermediateCertificate:
+ if cache == nil {
+ cache = make(map[*Certificate][][]*Certificate)
+ }
+ childChains, ok := cache[candidate]
+ if !ok {
+ childChains, err = candidate.buildChains(cache, appendToFreshChain(currentChain, candidate), sigChecks, opts)
+ cache[candidate] = childChains
+ }
+ chains = append(chains, childChains...)
}
- chains = append(chains, childChains...)
+ }
+
+ for _, rootNum := range opts.Roots.findPotentialParents(c) {
+ considerCandidate(rootCertificate, opts.Roots.certs[rootNum])
+ }
+ for _, intermediateNum := range opts.Intermediates.findPotentialParents(c) {
+ considerCandidate(intermediateCertificate, opts.Intermediates.certs[intermediateNum])
}
if len(chains) > 0 {
err = nil
}
-
if len(chains) == 0 && err == nil {
- hintErr := rootErr
- hintCert := failedRoot
- if hintErr == nil {
- hintErr = intermediateErr
- hintCert = failedIntermediate
- }
err = UnknownAuthorityError{c, hintErr, hintCert}
}
diff --git a/src/crypto/x509/verify_test.go b/src/crypto/x509/verify_test.go
index bd3df47907..2fc760d812 100644
--- a/src/crypto/x509/verify_test.go
+++ b/src/crypto/x509/verify_test.go
@@ -5,10 +5,15 @@
package x509
import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
+ "math/big"
"runtime"
"strings"
"testing"
@@ -1707,3 +1712,117 @@ UNhY4JhezH9gQYqvDMWrWDAbBgNVHSMEFDASgBArF29S5Bnqw7de8GzGA1nfMAoG
CCqGSM49BAMCA0gAMEUCIQClA3d4tdrDu9Eb5ZBpgyC+fU1xTZB0dKQHz6M5fPZA
2AIgN96lM+CPGicwhN24uQI6flOsO3H0TJ5lNzBYLtnQtlc=
-----END CERTIFICATE-----`
+
+func generateCert(cn string, isCA bool, issuer *Certificate, issuerKey crypto.PrivateKey) (*Certificate, crypto.PrivateKey, error) {
+ priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+ serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit)
+
+ template := &Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{CommonName: cn},
+ NotBefore: time.Now().Add(-1 * time.Hour),
+ NotAfter: time.Now().Add(24 * time.Hour),
+
+ KeyUsage: KeyUsageKeyEncipherment | KeyUsageDigitalSignature | KeyUsageCertSign,
+ ExtKeyUsage: []ExtKeyUsage{ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ IsCA: isCA,
+ }
+ if issuer == nil {
+ issuer = template
+ issuerKey = priv
+ }
+
+ derBytes, err := CreateCertificate(rand.Reader, template, issuer, priv.Public(), issuerKey)
+ if err != nil {
+ return nil, nil, err
+ }
+ cert, err := ParseCertificate(derBytes)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return cert, priv, nil
+}
+
+func TestPathologicalChain(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping generation of a long chain of certificates in short mode")
+ }
+
+ // Build a chain where all intermediates share the same subject, to hit the
+ // path building worst behavior.
+ roots, intermediates := NewCertPool(), NewCertPool()
+
+ parent, parentKey, err := generateCert("Root CA", true, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ roots.AddCert(parent)
+
+ for i := 1; i < 100; i++ {
+ parent, parentKey, err = generateCert("Intermediate CA", true, parent, parentKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ intermediates.AddCert(parent)
+ }
+
+ leaf, _, err := generateCert("Leaf", false, parent, parentKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ start := time.Now()
+ _, err = leaf.Verify(VerifyOptions{
+ Roots: roots,
+ Intermediates: intermediates,
+ })
+ t.Logf("verification took %v", time.Since(start))
+
+ if err == nil || !strings.Contains(err.Error(), "signature check attempts limit") {
+ t.Errorf("expected verification to fail with a signature checks limit error; got %v", err)
+ }
+}
+
+func TestLongChain(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping generation of a long chain of certificates in short mode")
+ }
+
+ roots, intermediates := NewCertPool(), NewCertPool()
+
+ parent, parentKey, err := generateCert("Root CA", true, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ roots.AddCert(parent)
+
+ for i := 1; i < 15; i++ {
+ name := fmt.Sprintf("Intermediate CA #%d", i)
+ parent, parentKey, err = generateCert(name, true, parent, parentKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ intermediates.AddCert(parent)
+ }
+
+ leaf, _, err := generateCert("Leaf", false, parent, parentKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ start := time.Now()
+ if _, err := leaf.Verify(VerifyOptions{
+ Roots: roots,
+ Intermediates: intermediates,
+ }); err != nil {
+ t.Error(err)
+ }
+ t.Logf("verification took %v", time.Since(start))
+}