aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobert Griesemer <gri@golang.org>2010-04-30 21:25:48 -0700
committerRobert Griesemer <gri@golang.org>2010-04-30 21:25:48 -0700
commit58e77990ba20494b8a4837f3668f8f03281e5406 (patch)
treeb4dc96c2c2b707063c3557bc6b2e92645dfb6f62
parent161b44c76a78ea2b2f04d8c3ba8c8292edac54b4 (diff)
downloadgo-58e77990ba20494b8a4837f3668f8f03281e5406.tar.gz
go-58e77990ba20494b8a4837f3668f8f03281e5406.zip
big: use fast shift routines
- fixed a couple of bugs in the process (shift right was incorrect for negative numbers) - added more tests and made some tests more robust - changed pidigits back to using shifts to multiply by 2 instead of add This improves pidigit -s -n 10000 by approx. 5%: user 0m6.496s (old) user 0m6.156s (new) R=rsc CC=golang-dev https://golang.org/cl/963044
-rw-r--r--src/pkg/big/int.go31
-rw-r--r--src/pkg/big/int_test.go8
-rw-r--r--src/pkg/big/nat.go62
-rw-r--r--src/pkg/big/nat_test.go36
-rw-r--r--test/bench/pidigits.go4
5 files changed, 95 insertions, 46 deletions
diff --git a/src/pkg/big/int.go b/src/pkg/big/int.go
index e5e589a852..2b7a628052 100644
--- a/src/pkg/big/int.go
+++ b/src/pkg/big/int.go
@@ -216,13 +216,16 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
if scanned != len(s) {
goto Error
}
+ if len(z.abs) == 0 {
+ z.neg = false // 0 has no sign
+ }
return z, true
Error:
z.neg = false
z.abs = nil
- return nil, false
+ return z, false
}
@@ -384,26 +387,24 @@ func ProbablyPrime(z *Int, n int) bool { return !z.neg && z.abs.probablyPrime(n)
// Lsh sets z = x << n and returns z.
func (z *Int) Lsh(x *Int, n uint) *Int {
- addedWords := int(n) / _W
- // Don't assign z.abs yet, in case z == x
- znew := z.abs.make(len(x.abs) + addedWords + 1)
z.neg = x.neg
- znew[addedWords:].shiftLeft(x.abs, n%_W)
- for i := range znew[0:addedWords] {
- znew[i] = 0
- }
- z.abs = znew.norm()
+ z.abs = z.abs.shl(x.abs, n)
return z
}
// Rsh sets z = x >> n and returns z.
func (z *Int) Rsh(x *Int, n uint) *Int {
- removedWords := int(n) / _W
- // Don't assign z.abs yet, in case z == x
- znew := z.abs.make(len(x.abs) - removedWords)
- z.neg = x.neg
- znew.shiftRight(x.abs[removedWords:], n%_W)
- z.abs = znew.norm()
+ if x.neg {
+ // (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1)
+ z.neg = true
+ t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0
+ t = t.shr(t, n)
+ z.abs = t.add(t, natOne)
+ return z
+ }
+
+ z.neg = false
+ z.abs = z.abs.shr(x.abs, n)
return z
}
diff --git a/src/pkg/big/int_test.go b/src/pkg/big/int_test.go
index cdcd28eac7..ceb31e069e 100644
--- a/src/pkg/big/int_test.go
+++ b/src/pkg/big/int_test.go
@@ -562,6 +562,7 @@ type intShiftTest struct {
var rshTests = []intShiftTest{
intShiftTest{"0", 0, "0"},
+ intShiftTest{"-0", 0, "0"},
intShiftTest{"0", 1, "0"},
intShiftTest{"0", 2, "0"},
intShiftTest{"1", 0, "1"},
@@ -569,7 +570,12 @@ var rshTests = []intShiftTest{
intShiftTest{"1", 2, "0"},
intShiftTest{"2", 0, "2"},
intShiftTest{"2", 1, "1"},
- intShiftTest{"2", 2, "0"},
+ intShiftTest{"-1", 0, "-1"},
+ intShiftTest{"-1", 1, "-1"},
+ intShiftTest{"-1", 10, "-1"},
+ intShiftTest{"-100", 2, "-25"},
+ intShiftTest{"-100", 3, "-13"},
+ intShiftTest{"-100", 100, "-1"},
intShiftTest{"4294967296", 0, "4294967296"},
intShiftTest{"4294967296", 1, "2147483648"},
intShiftTest{"4294967296", 2, "1073741824"},
diff --git a/src/pkg/big/nat.go b/src/pkg/big/nat.go
index 2db9e59f8e..ff8e806b24 100644
--- a/src/pkg/big/nat.go
+++ b/src/pkg/big/nat.go
@@ -554,8 +554,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
// D1.
shift := uint(leadingZeroBits(v[n-1]))
- v.shiftLeft(v, shift)
- u.shiftLeft(uIn, shift)
+ v.shiftLeftDeprecated(v, shift)
+ u.shiftLeftDeprecated(uIn, shift)
u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift))
// D2.
@@ -597,8 +597,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
}
q = q.norm()
- u.shiftRight(u, shift)
- v.shiftRight(v, shift)
+ u.shiftRightDeprecated(u, shift)
+ v.shiftRightDeprecated(v, shift)
r = u.norm()
return q, r
@@ -780,12 +780,56 @@ func trailingZeroBits(x Word) int {
}
-// TODO(gri) Make the shift routines faster.
-// Use pidigits.go benchmark as a test case.
+// z = x << s
+func (z nat) shl(x nat, s uint) nat {
+ m := len(x)
+ if m == 0 {
+ return z.make(0)
+ }
+ // m > 0
+
+ // determine if z can be reused
+ // TODO(gri) change shlVW so we don't need this
+ if len(z) > 0 && alias(z, x) {
+ z = nil // z is an alias for x - cannot reuse
+ }
+
+ n := m + int(s/_W)
+ z = z.make(n + 1)
+ z[n] = shlVW(&z[n-m], &x[0], Word(s%_W), m)
+
+ return z.norm()
+}
+
+
+// z = x >> s
+func (z nat) shr(x nat, s uint) nat {
+ m := len(x)
+ n := m - int(s/_W)
+ if n <= 0 {
+ return z.make(0)
+ }
+ // n > 0
+ // determine if z can be reused
+ // TODO(gri) change shrVW so we don't need this
+ if len(z) > 0 && alias(z, x) {
+ z = nil // z is an alias for x - cannot reuse
+ }
+
+ z = z.make(n)
+ shrVW(&z[0], &x[m-n], Word(s%_W), m)
+
+ return z.norm()
+}
+
+
+// TODO(gri) Remove these shift functions once shlVW and shrVW can be
+// used directly in divLarge and powersOfTwoDecompose
+//
// To avoid losing the top n bits, z should be sized so that
// len(z) == len(x) + 1.
-func (z nat) shiftLeft(x nat, n uint) nat {
+func (z nat) shiftLeftDeprecated(x nat, n uint) nat {
if len(x) == 0 {
return x
}
@@ -805,7 +849,7 @@ func (z nat) shiftLeft(x nat, n uint) nat {
}
-func (z nat) shiftRight(x nat, n uint) nat {
+func (z nat) shiftRightDeprecated(x nat, n uint) nat {
if len(x) == 0 {
return x
}
@@ -850,7 +894,7 @@ func (n nat) powersOfTwoDecompose() (q nat, k Word) {
x := trailingZeroBits(n[zeroWords])
q = q.make(len(n) - zeroWords)
- q.shiftRight(n[zeroWords:], uint(x))
+ q.shiftRightDeprecated(n[zeroWords:], uint(x))
q = q.norm()
k = Word(_W*zeroWords + x)
diff --git a/src/pkg/big/nat_test.go b/src/pkg/big/nat_test.go
index e1039c48a1..bf637b0daa 100644
--- a/src/pkg/big/nat_test.go
+++ b/src/pkg/big/nat_test.go
@@ -230,9 +230,8 @@ type shiftTest struct {
var leftShiftTests = []shiftTest{
shiftTest{nil, 0, nil},
shiftTest{nil, 1, nil},
- shiftTest{nat{0}, 0, nat{0}},
- shiftTest{nat{1}, 0, nat{1}},
- shiftTest{nat{1}, 1, nat{2}},
+ shiftTest{natOne, 0, natOne},
+ shiftTest{natOne, 1, natTwo},
shiftTest{nat{1 << (_W - 1)}, 1, nat{0}},
shiftTest{nat{1 << (_W - 1), 0}, 1, nat{0, 1}},
}
@@ -240,11 +239,11 @@ var leftShiftTests = []shiftTest{
func TestShiftLeft(t *testing.T) {
for i, test := range leftShiftTests {
- dst := make(nat, len(test.out))
- dst.shiftLeft(test.in, test.shift)
- for j, v := range dst {
- if test.out[j] != v {
- t.Errorf("#%d: got: %v want: %v", i, dst, test.out)
+ var z nat
+ z = z.shl(test.in, test.shift)
+ for j, d := range test.out {
+ if j >= len(z) || z[j] != d {
+ t.Errorf("#%d: got: %v want: %v", i, z, test.out)
break
}
}
@@ -255,22 +254,21 @@ func TestShiftLeft(t *testing.T) {
var rightShiftTests = []shiftTest{
shiftTest{nil, 0, nil},
shiftTest{nil, 1, nil},
- shiftTest{nat{0}, 0, nat{0}},
- shiftTest{nat{1}, 0, nat{1}},
- shiftTest{nat{1}, 1, nat{0}},
- shiftTest{nat{2}, 1, nat{1}},
- shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1), 0}},
- shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1), 0}},
+ shiftTest{natOne, 0, natOne},
+ shiftTest{natOne, 1, nil},
+ shiftTest{natTwo, 1, natOne},
+ shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1)}},
+ shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}},
}
func TestShiftRight(t *testing.T) {
for i, test := range rightShiftTests {
- dst := make(nat, len(test.out))
- dst.shiftRight(test.in, test.shift)
- for j, v := range dst {
- if test.out[j] != v {
- t.Errorf("#%d: got: %v want: %v", i, dst, test.out)
+ var z nat
+ z = z.shr(test.in, test.shift)
+ for j, d := range test.out {
+ if j >= len(z) || z[j] != d {
+ t.Errorf("#%d: got: %v want: %v", i, z, test.out)
break
}
}
diff --git a/test/bench/pidigits.go b/test/bench/pidigits.go
index 3e455dc838..a05515028a 100644
--- a/test/bench/pidigits.go
+++ b/test/bench/pidigits.go
@@ -63,7 +63,7 @@ func extract_digit() int64 {
}
// Compute (numer * 3 + accum) / denom
- tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1)
+ tmp1.Lsh(numer, 1)
tmp1.Add(tmp1, numer)
tmp1.Add(tmp1, accum)
tmp1.DivMod(tmp1, denom, tmp2)
@@ -84,7 +84,7 @@ func next_term(k int64) {
y2.New(k*2 + 1)
bigk.New(k)
- tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1)
+ tmp1.Lsh(numer, 1)
accum.Add(accum, tmp1)
accum.Mul(accum, y2)
numer.Mul(numer, bigk)