aboutsummaryrefslogtreecommitdiff
path: root/src/cmd/compile/internal/ssa/loopbce.go
blob: fd03efb417db73d916657e1a2f58902e6e3c63c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ssa

import (
	"fmt"
	"math"
)

type indVarFlags uint8

const (
	indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
	indVarMaxInc                         // maximum value is inclusive (default: exclusive)
)

type indVar struct {
	ind   *Value // induction variable
	min   *Value // minimum value, inclusive/exclusive depends on flags
	max   *Value // maximum value, inclusive/exclusive depends on flags
	entry *Block // entry block in the loop.
	flags indVarFlags
	// Invariant: for all blocks strictly dominated by entry:
	//	min <= ind <  max    [if flags == 0]
	//	min <  ind <  max    [if flags == indVarMinExc]
	//	min <= ind <= max    [if flags == indVarMaxInc]
	//	min <  ind <= max    [if flags == indVarMinExc|indVarMaxInc]
}

// parseIndVar checks whether the SSA value passed as argument is a valid induction
// variable, and, if so, extracts:
//   * the minimum bound
//   * the increment value
//   * the "next" value (SSA value that is Phi'd into the induction variable every loop)
// Currently, we detect induction variables that match (Phi min nxt),
// with nxt being (Add inc ind).
// If it can't parse the induction variable correctly, it returns (nil, nil, nil).
func parseIndVar(ind *Value) (min, inc, nxt *Value) {
	if ind.Op != OpPhi {
		return
	}

	if n := ind.Args[0]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
		min, nxt = ind.Args[1], n
	} else if n := ind.Args[1]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
		min, nxt = ind.Args[0], n
	} else {
		// Not a recognized induction variable.
		return
	}

	if nxt.Args[0] == ind { // nxt = ind + inc
		inc = nxt.Args[1]
	} else if nxt.Args[1] == ind { // nxt = inc + ind
		inc = nxt.Args[0]
	} else {
		panic("unreachable") // one of the cases must be true from the above.
	}

	return
}

// findIndVar finds induction variables in a function.
//
// Look for variables and blocks that satisfy the following
//
// loop:
//   ind = (Phi min nxt),
//   if ind < max
//     then goto enter_loop
//     else goto exit_loop
//
//   enter_loop:
//	do something
//      nxt = inc + ind
//	goto loop
//
// exit_loop:
//
//
// TODO: handle 32 bit operations
func findIndVar(f *Func) []indVar {
	var iv []indVar
	sdom := f.Sdom()

	for _, b := range f.Blocks {
		if b.Kind != BlockIf || len(b.Preds) != 2 {
			continue
		}

		var flags indVarFlags
		var ind, max *Value // induction, and maximum

		// Check thet the control if it either ind </<= max or max >/>= ind.
		// TODO: Handle 32-bit comparisons.
		// TODO: Handle unsigned comparisons?
		c := b.Controls[0]
		switch c.Op {
		case OpLeq64:
			flags |= indVarMaxInc
			fallthrough
		case OpLess64:
			ind, max = c.Args[0], c.Args[1]
		default:
			continue
		}

		// See if this is really an induction variable
		less := true
		min, inc, nxt := parseIndVar(ind)
		if min == nil {
			// We failed to parse the induction variable. Before punting, we want to check
			// whether the control op was written with arguments in non-idiomatic order,
			// so that we believe being "max" (the upper bound) is actually the induction
			// variable itself. This would happen for code like:
			//     for i := 0; len(n) > i; i++
			min, inc, nxt = parseIndVar(max)
			if min == nil {
				// No recognied induction variable on either operand
				continue
			}

			// Ok, the arguments were reversed. Swap them, and remember that we're
			// looking at a ind >/>= loop (so the induction must be decrementing).
			ind, max = max, ind
			less = false
		}

		// Expect the increment to be a nonzero constant.
		if inc.Op != OpConst64 {
			continue
		}
		step := inc.AuxInt
		if step == 0 {
			continue
		}

		// Increment sign must match comparison direction.
		// When incrementing, the termination comparison must be ind </<= max.
		// When decrementing, the termination comparison must be ind >/>= max.
		// See issue 26116.
		if step > 0 && !less {
			continue
		}
		if step < 0 && less {
			continue
		}

		// If the increment is negative, swap min/max and their flags
		if step < 0 {
			min, max = max, min
			oldf := flags
			flags = indVarMaxInc
			if oldf&indVarMaxInc == 0 {
				flags |= indVarMinExc
			}
			step = -step
		}

		if flags&indVarMaxInc != 0 && max.Op == OpConst64 && max.AuxInt+step < max.AuxInt {
			// For a <= comparison, we need to make sure that a value equal to
			// max can be incremented without overflowing.
			// (For a < comparison, the %step check below ensures no overflow.)
			continue
		}

		// Up to now we extracted the induction variable (ind),
		// the increment delta (inc), the temporary sum (nxt),
		// the mininum value (min) and the maximum value (max).
		//
		// We also know that ind has the form (Phi min nxt) where
		// nxt is (Add inc nxt) which means: 1) inc dominates nxt
		// and 2) there is a loop starting at inc and containing nxt.
		//
		// We need to prove that the induction variable is incremented
		// only when it's smaller than the maximum value.
		// Two conditions must happen listed below to accept ind
		// as an induction variable.

		// First condition: loop entry has a single predecessor, which
		// is the header block.  This implies that b.Succs[0] is
		// reached iff ind < max.
		if len(b.Succs[0].b.Preds) != 1 {
			// b.Succs[1] must exit the loop.
			continue
		}

		// Second condition: b.Succs[0] dominates nxt so that
		// nxt is computed when inc < max, meaning nxt <= max.
		if !sdom.IsAncestorEq(b.Succs[0].b, nxt.Block) {
			// inc+ind can only be reached through the branch that enters the loop.
			continue
		}

		// We can only guarantee that the loop runs within limits of induction variable
		// if (one of)
		// (1) the increment is ±1
		// (2) the limits are constants
		// (3) loop is of the form k0 upto Known_not_negative-k inclusive, step <= k
		// (4) loop is of the form k0 upto Known_not_negative-k exclusive, step <= k+1
		// (5) loop is of the form Known_not_negative downto k0, minint+step < k0
		if step > 1 {
			ok := false
			if min.Op == OpConst64 && max.Op == OpConst64 {
				if max.AuxInt > min.AuxInt && max.AuxInt%step == min.AuxInt%step { // handle overflow
					ok = true
				}
			}
			// Handle induction variables of these forms.
			// KNN is known-not-negative.
			// SIGNED ARITHMETIC ONLY. (see switch on c above)
			// Possibilities for KNN are len and cap; perhaps we can infer others.
			// for i := 0; i <= KNN-k    ; i += k
			// for i := 0; i <  KNN-(k-1); i += k
			// Also handle decreasing.

			// "Proof" copied from https://go-review.googlesource.com/c/go/+/104041/10/src/cmd/compile/internal/ssa/loopbce.go#164
			//
			//	In the case of
			//	// PC is Positive Constant
			//	L := len(A)-PC
			//	for i := 0; i < L; i = i+PC
			//
			//	we know:
			//
			//	0 + PC does not over/underflow.
			//	len(A)-PC does not over/underflow
			//	maximum value for L is MaxInt-PC
			//	i < L <= MaxInt-PC means i + PC < MaxInt hence no overflow.

			// To match in SSA:
			// if  (a) min.Op == OpConst64(k0)
			// and (b) k0 >= MININT + step
			// and (c) max.Op == OpSubtract(Op{StringLen,SliceLen,SliceCap}, k)
			// or  (c) max.Op == OpAdd(Op{StringLen,SliceLen,SliceCap}, -k)
			// or  (c) max.Op == Op{StringLen,SliceLen,SliceCap}
			// and (d) if upto loop, require indVarMaxInc && step <= k or !indVarMaxInc && step-1 <= k

			if min.Op == OpConst64 && min.AuxInt >= step+math.MinInt64 {
				knn := max
				k := int64(0)
				var kArg *Value

				switch max.Op {
				case OpSub64:
					knn = max.Args[0]
					kArg = max.Args[1]

				case OpAdd64:
					knn = max.Args[0]
					kArg = max.Args[1]
					if knn.Op == OpConst64 {
						knn, kArg = kArg, knn
					}
				}
				switch knn.Op {
				case OpSliceLen, OpStringLen, OpSliceCap:
				default:
					knn = nil
				}

				if kArg != nil && kArg.Op == OpConst64 {
					k = kArg.AuxInt
					if max.Op == OpAdd64 {
						k = -k
					}
				}
				if k >= 0 && knn != nil {
					if inc.AuxInt > 0 { // increasing iteration
						// The concern for the relation between step and k is to ensure that iv never exceeds knn
						// i.e., iv < knn-(K-1) ==> iv + K <= knn; iv <= knn-K ==> iv +K < knn
						if step <= k || flags&indVarMaxInc == 0 && step-1 == k {
							ok = true
						}
					} else { // decreasing iteration
						// Will be decrementing from max towards min; max is knn-k; will only attempt decrement if
						// knn-k >[=] min; underflow is only a concern if min-step is not smaller than min.
						// This all assumes signed integer arithmetic
						// This is already assured by the test above: min.AuxInt >= step+math.MinInt64
						ok = true
					}
				}
			}

			// TODO: other unrolling idioms
			// for i := 0; i < KNN - KNN % k ; i += k
			// for i := 0; i < KNN&^(k-1) ; i += k // k a power of 2
			// for i := 0; i < KNN&(-k) ; i += k // k a power of 2

			if !ok {
				continue
			}
		}

		if f.pass.debug >= 1 {
			printIndVar(b, ind, min, max, step, flags)
		}

		iv = append(iv, indVar{
			ind:   ind,
			min:   min,
			max:   max,
			entry: b.Succs[0].b,
			flags: flags,
		})
		b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
	}

	return iv
}

func dropAdd64(v *Value) (*Value, int64) {
	if v.Op == OpAdd64 && v.Args[0].Op == OpConst64 {
		return v.Args[1], v.Args[0].AuxInt
	}
	if v.Op == OpAdd64 && v.Args[1].Op == OpConst64 {
		return v.Args[0], v.Args[1].AuxInt
	}
	return v, 0
}

func printIndVar(b *Block, i, min, max *Value, inc int64, flags indVarFlags) {
	mb1, mb2 := "[", "]"
	if flags&indVarMinExc != 0 {
		mb1 = "("
	}
	if flags&indVarMaxInc == 0 {
		mb2 = ")"
	}

	mlim1, mlim2 := fmt.Sprint(min.AuxInt), fmt.Sprint(max.AuxInt)
	if !min.isGenericIntConst() {
		if b.Func.pass.debug >= 2 {
			mlim1 = fmt.Sprint(min)
		} else {
			mlim1 = "?"
		}
	}
	if !max.isGenericIntConst() {
		if b.Func.pass.debug >= 2 {
			mlim2 = fmt.Sprint(max)
		} else {
			mlim2 = "?"
		}
	}
	extra := ""
	if b.Func.pass.debug >= 2 {
		extra = fmt.Sprintf(" (%s)", i)
	}
	b.Func.Warnl(b.Pos, "Induction variable: limits %v%v,%v%v, increment %d%s", mb1, mlim1, mlim2, mb2, inc, extra)
}