From 0027ed1872cdec08defe3b097c7123eaaf149e30 Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Wed, 9 Dec 2015 11:49:53 -0500 Subject: [release-branch.go1.5] math/big: fix carry propagation in Int.Exp Montgomery code Fixes #13515. Change-Id: I7dd5fbc816e5ea135f7d81f6735e7601f636fe4f Reviewed-on: https://go-review.googlesource.com/17672 Reviewed-by: Robert Griesemer Reviewed-on: https://go-review.googlesource.com/18585 --- src/math/big/nat.go | 29 +++++++++++---- src/math/big/nat_test.go | 95 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 103 insertions(+), 21 deletions(-) diff --git a/src/math/big/nat.go b/src/math/big/nat.go index 6545bc17ed..c7362e679b 100644 --- a/src/math/big/nat.go +++ b/src/math/big/nat.go @@ -216,23 +216,36 @@ func basicMul(z, x, y nat) { } } -// montgomery computes x*y*2^(-n*_W) mod m, -// assuming k = -1/m mod 2^_W. +// montgomery computes z mod m = x*y*2**(-n*_W) mod m, +// assuming k = -1/m mod 2**_W. // z is used for storing the result which is returned; // z must not alias x, y or m. +// See Gueron, "Efficient Software Implementations of Modular Exponentiation". +// https://eprint.iacr.org/2011/239.pdf +// In the terminology of that paper, this is an "Almost Montgomery Multiplication": +// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result +// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m. func (z nat) montgomery(x, y, m nat, k Word, n int) nat { - var c1, c2 Word + // This code assumes x, y, m are all the same length, n. + // (required by addMulVVW and the for loop). + // It also assumes that x, y are already reduced mod m, + // or else the result will not be properly reduced. + if len(x) != n || len(y) != n || len(m) != n { + panic("math/big: mismatched montgomery number lengths") + } + var c1, c2, c3 Word z = z.make(n) z.clear() for i := 0; i < n; i++ { d := y[i] - c1 += addMulVVW(z, x, d) + c2 = addMulVVW(z, x, d) t := z[0] * k - c2 = addMulVVW(z, m, t) - + c3 = addMulVVW(z, m, t) copy(z, z[1:]) - z[n-1] = c1 + c2 - if z[n-1] < c1 { + cx := c1 + c2 + cy := cx + c3 + z[n-1] = cy + if cx < c2 || cy < c3 { c1 = 1 } else { c1 = 0 diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go index 7ac3cb8a84..dce7de5dad 100644 --- a/src/math/big/nat_test.go +++ b/src/math/big/nat_test.go @@ -341,25 +341,57 @@ var montgomeryTests = []struct { "0xffffffffffffffffffffffffffffffffffffffffffffffffe", "0xffffffffffffffffffffffffffffffffffffffffffffffffe", "0xfffffffffffffffffffffffffffffffffffffffffffffffff", - 0x0000000000000000, - "0xffffffffffffffffffffffffffffffffffffffffff", - "0xffffffffffffffffffffffffffffffffff", + 1, + "0x1000000000000000000000000000000000000000000", + "0x10000000000000000000000000000000000", }, { - "0x0000000080000000", - "0x00000000ffffffff", + "0x000000000ffffff5", + "0x000000000ffffff0", "0x0000000010000001", 0xff0000000fffffff, - "0x0000000088000000", - "0x0000000007800001", + "0x000000000bfffff4", + "0x0000000003400001", + }, + { + "0x0000000080000000", + "0x00000000ffffffff", + "0x1000000000000001", + 0xfffffffffffffff, + "0x0800000008000001", + "0x0800000008000001", }, { - "0xffffffffffffffffffffffffffffffff00000000000022222223333333333444444444", - "0xffffffffffffffffffffffffffffffff999999999999999aaabbbbbbbbcccccccccccc", + "0x0000000080000000", + "0x0000000080000000", + "0xffffffff00000001", + 0xfffffffeffffffff, + "0xbfffffff40000001", + "0xbfffffff40000001", + }, + { + "0x0000000080000000", + "0x0000000080000000", + "0x00ffffff00000001", + 0xfffffeffffffff, + "0xbfffff40000001", + "0xbfffff40000001", + }, + { + "0x0000000080000000", + "0x0000000080000000", + "0x0000ffff00000001", + 0xfffeffffffff, + "0xbfff40000001", + "0xbfff40000001", + }, + { + "0x3321ffffffffffffffffffffffffffff00000000000022222623333333332bbbb888c0", + "0x3321ffffffffffffffffffffffffffff00000000000022222623333333332bbbb888c0", "0x33377fffffffffffffffffffffffffffffffffffffffffffff0000000000022222eee1", 0xdecc8f1249812adf, - "0x22bb05b6d95eaaeca2bb7c05e51f807bce9064b5fbad177161695e4558f9474e91cd79", - "0x14beb58d230f85b6d95eaaeca2bb7c05e51f807bce9064b5fb45669afa695f228e48cd", + "0x04eb0e11d72329dc0915f86784820fc403275bf2f6620a20e0dd344c5cd0875e50deb5", + "0x0d7144739a7d8e11d72329dc0915f86784820fc403275bf2f61ed96f35dd34dbb3d6a0", }, { "0x10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000ffffffffffffffffffffffffffffffff00000000000022222223333333333444444444", @@ -372,10 +404,27 @@ var montgomeryTests = []struct { } func TestMontgomery(t *testing.T) { + one := NewInt(1) + _B := new(Int).Lsh(one, _W) for i, test := range montgomeryTests { x := natFromString(test.x) y := natFromString(test.y) m := natFromString(test.m) + for len(x) < len(m) { + x = append(x, 0) + } + for len(y) < len(m) { + y = append(y, 0) + } + + if x.cmp(m) > 0 { + _, r := nat(nil).div(nil, x, m) + t.Errorf("#%d: x > m (0x%s > 0x%s; use 0x%s)", i, x.utoa(16), m.utoa(16), r.utoa(16)) + } + if y.cmp(m) > 0 { + _, r := nat(nil).div(nil, x, m) + t.Errorf("#%d: y > m (0x%s > 0x%s; use 0x%s)", i, y.utoa(16), m.utoa(16), r.utoa(16)) + } var out nat if _W == 32 { @@ -384,11 +433,31 @@ func TestMontgomery(t *testing.T) { out = natFromString(test.out64) } - k0 := Word(test.k0 & _M) // mask k0 to ensure that it fits for 32-bit systems. + // t.Logf("#%d: len=%d\n", i, len(m)) + + // check output in table + xi := &Int{abs: x} + yi := &Int{abs: y} + mi := &Int{abs: m} + p := new(Int).Mod(new(Int).Mul(xi, new(Int).Mul(yi, new(Int).ModInverse(new(Int).Lsh(one, uint(len(m))*_W), mi))), mi) + if out.cmp(p.abs.norm()) != 0 { + t.Errorf("#%d: out in table=0x%s, computed=0x%s", i, out.utoa(16), p.abs.norm().utoa(16)) + } + + // check k0 in table + k := new(Int).Mod(&Int{abs: m}, _B) + k = new(Int).Sub(_B, k) + k = new(Int).Mod(k, _B) + k0 := Word(new(Int).ModInverse(k, _B).Uint64()) + if k0 != Word(test.k0) { + t.Errorf("#%d: k0 in table=%#x, computed=%#x\n", i, test.k0, k0) + } + + // check montgomery with correct k0 produces correct output z := nat(nil).montgomery(x, y, m, k0, len(m)) z = z.norm() if z.cmp(out) != 0 { - t.Errorf("#%d got %s want %s", i, z.decimalString(), out.decimalString()) + t.Errorf("#%d: got 0x%s want 0x%s", i, z.utoa(16), out.utoa(16)) } } } -- cgit v1.2.3-54-g00ecf