aboutsummaryrefslogtreecommitdiff
path: root/src/math/big/calibrate_test.go
blob: 4fa663ff08331e850b0412edd336b388c7f35809 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Calibration used to determine thresholds for using
// different algorithms.  Ideally, this would be converted
// to go generate to create thresholds.go

// This file prints execution times for the Mul benchmark
// given different Karatsuba thresholds. The result may be
// used to manually fine-tune the threshold constant. The
// results are somewhat fragile; use repeated runs to get
// a clear picture.

// Calculates lower and upper thresholds for when basicSqr
// is faster than standard multiplication.

// Usage: go test -run=TestCalibrate -v -calibrate

package big

import (
	"flag"
	"fmt"
	"testing"
	"time"
)

var calibrate = flag.Bool("calibrate", false, "run calibration test")

const (
	sqrModeMul       = "mul(x, x)"
	sqrModeBasic     = "basicSqr(x)"
	sqrModeKaratsuba = "karatsubaSqr(x)"
)

func TestCalibrate(t *testing.T) {
	if !*calibrate {
		return
	}

	computeKaratsubaThresholds()

	// compute basicSqrThreshold where overhead becomes negligible
	minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic)
	// compute karatsubaSqrThreshold where karatsuba is faster
	maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba)
	if minSqr != 0 {
		fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
	} else {
		fmt.Println("no basicSqrThreshold found")
	}
	if maxSqr != 0 {
		fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
	} else {
		fmt.Println("no karatsubaSqrThreshold found")
	}
}

func karatsubaLoad(b *testing.B) {
	BenchmarkMul(b)
}

// measureKaratsuba returns the time to run a Karatsuba-relevant benchmark
// given Karatsuba threshold th.
func measureKaratsuba(th int) time.Duration {
	th, karatsubaThreshold = karatsubaThreshold, th
	res := testing.Benchmark(karatsubaLoad)
	karatsubaThreshold = th
	return time.Duration(res.NsPerOp())
}

func computeKaratsubaThresholds() {
	fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
	fmt.Printf("(run repeatedly for good results)\n")

	// determine Tk, the work load execution time using basic multiplication
	Tb := measureKaratsuba(1e9) // th == 1e9 => Karatsuba multiplication disabled
	fmt.Printf("Tb = %10s\n", Tb)

	// thresholds
	th := 4
	th1 := -1
	th2 := -1

	var deltaOld time.Duration
	for count := -1; count != 0 && th < 128; count-- {
		// determine Tk, the work load execution time using Karatsuba multiplication
		Tk := measureKaratsuba(th)

		// improvement over Tb
		delta := (Tb - Tk) * 100 / Tb

		fmt.Printf("th = %3d  Tk = %10s  %4d%%", th, Tk, delta)

		// determine break-even point
		if Tk < Tb && th1 < 0 {
			th1 = th
			fmt.Print("  break-even point")
		}

		// determine diminishing return
		if 0 < delta && delta < deltaOld && th2 < 0 {
			th2 = th
			fmt.Print("  diminishing return")
		}
		deltaOld = delta

		fmt.Println()

		// trigger counter
		if th1 >= 0 && th2 >= 0 && count < 0 {
			count = 10 // this many extra measurements after we got both thresholds
		}

		th++
	}
}

func measureSqr(words, nruns int, mode string) time.Duration {
	// more runs for better statistics
	initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold

	switch mode {
	case sqrModeMul:
		basicSqrThreshold = words + 1
	case sqrModeBasic:
		basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
	case sqrModeKaratsuba:
		karatsubaSqrThreshold = words - 1
	}

	var testval int64
	for i := 0; i < nruns; i++ {
		res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) })
		testval += res.NsPerOp()
	}
	testval /= int64(nruns)

	basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr

	return time.Duration(testval)
}

func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int {
	fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper)
	fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
	var initPos bool
	var threshold int
	for i := from; i <= to; i += step {
		baseline := measureSqr(i, nruns, lower)
		testval := measureSqr(i, nruns, upper)
		pos := baseline > testval
		delta := baseline - testval
		percent := delta * 100 / baseline
		fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos)
		if i == from {
			initPos = pos
		}
		if threshold == 0 && pos != initPos {
			threshold = i
			fmt.Printf("  threshold  found")
		}
		fmt.Println()

	}
	if threshold != 0 {
		fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to)
	} else {
		fmt.Printf("Found NO threshold between %d - %d\n", from, to)
	}
	return threshold
}