diff options
author | Robert Griesemer <gri@golang.org> | 2010-04-30 21:25:48 -0700 |
---|---|---|
committer | Robert Griesemer <gri@golang.org> | 2010-04-30 21:25:48 -0700 |
commit | 58e77990ba20494b8a4837f3668f8f03281e5406 (patch) | |
tree | b4dc96c2c2b707063c3557bc6b2e92645dfb6f62 | |
parent | 161b44c76a78ea2b2f04d8c3ba8c8292edac54b4 (diff) | |
download | go-58e77990ba20494b8a4837f3668f8f03281e5406.tar.gz go-58e77990ba20494b8a4837f3668f8f03281e5406.zip |
big: use fast shift routines
- fixed a couple of bugs in the process
(shift right was incorrect for negative numbers)
- added more tests and made some tests more robust
- changed pidigits back to using shifts to multiply
by 2 instead of add
This improves pidigit -s -n 10000 by approx. 5%:
user 0m6.496s (old)
user 0m6.156s (new)
R=rsc
CC=golang-dev
https://golang.org/cl/963044
-rw-r--r-- | src/pkg/big/int.go | 31 | ||||
-rw-r--r-- | src/pkg/big/int_test.go | 8 | ||||
-rw-r--r-- | src/pkg/big/nat.go | 62 | ||||
-rw-r--r-- | src/pkg/big/nat_test.go | 36 | ||||
-rw-r--r-- | test/bench/pidigits.go | 4 |
5 files changed, 95 insertions, 46 deletions
diff --git a/src/pkg/big/int.go b/src/pkg/big/int.go index e5e589a852..2b7a628052 100644 --- a/src/pkg/big/int.go +++ b/src/pkg/big/int.go @@ -216,13 +216,16 @@ func (z *Int) SetString(s string, base int) (*Int, bool) { if scanned != len(s) { goto Error } + if len(z.abs) == 0 { + z.neg = false // 0 has no sign + } return z, true Error: z.neg = false z.abs = nil - return nil, false + return z, false } @@ -384,26 +387,24 @@ func ProbablyPrime(z *Int, n int) bool { return !z.neg && z.abs.probablyPrime(n) // Lsh sets z = x << n and returns z. func (z *Int) Lsh(x *Int, n uint) *Int { - addedWords := int(n) / _W - // Don't assign z.abs yet, in case z == x - znew := z.abs.make(len(x.abs) + addedWords + 1) z.neg = x.neg - znew[addedWords:].shiftLeft(x.abs, n%_W) - for i := range znew[0:addedWords] { - znew[i] = 0 - } - z.abs = znew.norm() + z.abs = z.abs.shl(x.abs, n) return z } // Rsh sets z = x >> n and returns z. func (z *Int) Rsh(x *Int, n uint) *Int { - removedWords := int(n) / _W - // Don't assign z.abs yet, in case z == x - znew := z.abs.make(len(x.abs) - removedWords) - z.neg = x.neg - znew.shiftRight(x.abs[removedWords:], n%_W) - z.abs = znew.norm() + if x.neg { + // (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1) + z.neg = true + t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0 + t = t.shr(t, n) + z.abs = t.add(t, natOne) + return z + } + + z.neg = false + z.abs = z.abs.shr(x.abs, n) return z } diff --git a/src/pkg/big/int_test.go b/src/pkg/big/int_test.go index cdcd28eac7..ceb31e069e 100644 --- a/src/pkg/big/int_test.go +++ b/src/pkg/big/int_test.go @@ -562,6 +562,7 @@ type intShiftTest struct { var rshTests = []intShiftTest{ intShiftTest{"0", 0, "0"}, + intShiftTest{"-0", 0, "0"}, intShiftTest{"0", 1, "0"}, intShiftTest{"0", 2, "0"}, intShiftTest{"1", 0, "1"}, @@ -569,7 +570,12 @@ var rshTests = []intShiftTest{ intShiftTest{"1", 2, "0"}, intShiftTest{"2", 0, "2"}, intShiftTest{"2", 1, "1"}, - intShiftTest{"2", 2, "0"}, + intShiftTest{"-1", 0, "-1"}, + intShiftTest{"-1", 1, "-1"}, + intShiftTest{"-1", 10, "-1"}, + intShiftTest{"-100", 2, "-25"}, + intShiftTest{"-100", 3, "-13"}, + intShiftTest{"-100", 100, "-1"}, intShiftTest{"4294967296", 0, "4294967296"}, intShiftTest{"4294967296", 1, "2147483648"}, intShiftTest{"4294967296", 2, "1073741824"}, diff --git a/src/pkg/big/nat.go b/src/pkg/big/nat.go index 2db9e59f8e..ff8e806b24 100644 --- a/src/pkg/big/nat.go +++ b/src/pkg/big/nat.go @@ -554,8 +554,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { // D1. shift := uint(leadingZeroBits(v[n-1])) - v.shiftLeft(v, shift) - u.shiftLeft(uIn, shift) + v.shiftLeftDeprecated(v, shift) + u.shiftLeftDeprecated(uIn, shift) u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift)) // D2. @@ -597,8 +597,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { } q = q.norm() - u.shiftRight(u, shift) - v.shiftRight(v, shift) + u.shiftRightDeprecated(u, shift) + v.shiftRightDeprecated(v, shift) r = u.norm() return q, r @@ -780,12 +780,56 @@ func trailingZeroBits(x Word) int { } -// TODO(gri) Make the shift routines faster. -// Use pidigits.go benchmark as a test case. +// z = x << s +func (z nat) shl(x nat, s uint) nat { + m := len(x) + if m == 0 { + return z.make(0) + } + // m > 0 + + // determine if z can be reused + // TODO(gri) change shlVW so we don't need this + if len(z) > 0 && alias(z, x) { + z = nil // z is an alias for x - cannot reuse + } + + n := m + int(s/_W) + z = z.make(n + 1) + z[n] = shlVW(&z[n-m], &x[0], Word(s%_W), m) + + return z.norm() +} + + +// z = x >> s +func (z nat) shr(x nat, s uint) nat { + m := len(x) + n := m - int(s/_W) + if n <= 0 { + return z.make(0) + } + // n > 0 + // determine if z can be reused + // TODO(gri) change shrVW so we don't need this + if len(z) > 0 && alias(z, x) { + z = nil // z is an alias for x - cannot reuse + } + + z = z.make(n) + shrVW(&z[0], &x[m-n], Word(s%_W), m) + + return z.norm() +} + + +// TODO(gri) Remove these shift functions once shlVW and shrVW can be +// used directly in divLarge and powersOfTwoDecompose +// // To avoid losing the top n bits, z should be sized so that // len(z) == len(x) + 1. -func (z nat) shiftLeft(x nat, n uint) nat { +func (z nat) shiftLeftDeprecated(x nat, n uint) nat { if len(x) == 0 { return x } @@ -805,7 +849,7 @@ func (z nat) shiftLeft(x nat, n uint) nat { } -func (z nat) shiftRight(x nat, n uint) nat { +func (z nat) shiftRightDeprecated(x nat, n uint) nat { if len(x) == 0 { return x } @@ -850,7 +894,7 @@ func (n nat) powersOfTwoDecompose() (q nat, k Word) { x := trailingZeroBits(n[zeroWords]) q = q.make(len(n) - zeroWords) - q.shiftRight(n[zeroWords:], uint(x)) + q.shiftRightDeprecated(n[zeroWords:], uint(x)) q = q.norm() k = Word(_W*zeroWords + x) diff --git a/src/pkg/big/nat_test.go b/src/pkg/big/nat_test.go index e1039c48a1..bf637b0daa 100644 --- a/src/pkg/big/nat_test.go +++ b/src/pkg/big/nat_test.go @@ -230,9 +230,8 @@ type shiftTest struct { var leftShiftTests = []shiftTest{ shiftTest{nil, 0, nil}, shiftTest{nil, 1, nil}, - shiftTest{nat{0}, 0, nat{0}}, - shiftTest{nat{1}, 0, nat{1}}, - shiftTest{nat{1}, 1, nat{2}}, + shiftTest{natOne, 0, natOne}, + shiftTest{natOne, 1, natTwo}, shiftTest{nat{1 << (_W - 1)}, 1, nat{0}}, shiftTest{nat{1 << (_W - 1), 0}, 1, nat{0, 1}}, } @@ -240,11 +239,11 @@ var leftShiftTests = []shiftTest{ func TestShiftLeft(t *testing.T) { for i, test := range leftShiftTests { - dst := make(nat, len(test.out)) - dst.shiftLeft(test.in, test.shift) - for j, v := range dst { - if test.out[j] != v { - t.Errorf("#%d: got: %v want: %v", i, dst, test.out) + var z nat + z = z.shl(test.in, test.shift) + for j, d := range test.out { + if j >= len(z) || z[j] != d { + t.Errorf("#%d: got: %v want: %v", i, z, test.out) break } } @@ -255,22 +254,21 @@ func TestShiftLeft(t *testing.T) { var rightShiftTests = []shiftTest{ shiftTest{nil, 0, nil}, shiftTest{nil, 1, nil}, - shiftTest{nat{0}, 0, nat{0}}, - shiftTest{nat{1}, 0, nat{1}}, - shiftTest{nat{1}, 1, nat{0}}, - shiftTest{nat{2}, 1, nat{1}}, - shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1), 0}}, - shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1), 0}}, + shiftTest{natOne, 0, natOne}, + shiftTest{natOne, 1, nil}, + shiftTest{natTwo, 1, natOne}, + shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1)}}, + shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}}, } func TestShiftRight(t *testing.T) { for i, test := range rightShiftTests { - dst := make(nat, len(test.out)) - dst.shiftRight(test.in, test.shift) - for j, v := range dst { - if test.out[j] != v { - t.Errorf("#%d: got: %v want: %v", i, dst, test.out) + var z nat + z = z.shr(test.in, test.shift) + for j, d := range test.out { + if j >= len(z) || z[j] != d { + t.Errorf("#%d: got: %v want: %v", i, z, test.out) break } } diff --git a/test/bench/pidigits.go b/test/bench/pidigits.go index 3e455dc838..a05515028a 100644 --- a/test/bench/pidigits.go +++ b/test/bench/pidigits.go @@ -63,7 +63,7 @@ func extract_digit() int64 { } // Compute (numer * 3 + accum) / denom - tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1) + tmp1.Lsh(numer, 1) tmp1.Add(tmp1, numer) tmp1.Add(tmp1, accum) tmp1.DivMod(tmp1, denom, tmp2) @@ -84,7 +84,7 @@ func next_term(k int64) { y2.New(k*2 + 1) bigk.New(k) - tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1) + tmp1.Lsh(numer, 1) accum.Add(accum, tmp1) accum.Mul(accum, y2) numer.Mul(numer, bigk) |