diff options
-rw-r--r-- | src/cmd/compile/internal/arm64/ssa.go | 20 | ||||
-rw-r--r-- | src/cmd/compile/internal/ssa/gen/ARM64Ops.go | 32 | ||||
-rw-r--r-- | src/cmd/compile/internal/ssa/opGen.go | 40 | ||||
-rw-r--r-- | src/cmd/compile/internal/ssa/rewrite.go | 31 | ||||
-rw-r--r-- | test/fixedbugs/issue43619.go | 119 |
5 files changed, 215 insertions, 27 deletions
diff --git a/src/cmd/compile/internal/arm64/ssa.go b/src/cmd/compile/internal/arm64/ssa.go index 22b28a9308..43588511ab 100644 --- a/src/cmd/compile/internal/arm64/ssa.go +++ b/src/cmd/compile/internal/arm64/ssa.go @@ -1054,7 +1054,11 @@ func ssaGenValue(s *gc.SSAGenState, v *ssa.Value) { ssa.OpARM64LessThanF, ssa.OpARM64LessEqualF, ssa.OpARM64GreaterThanF, - ssa.OpARM64GreaterEqualF: + ssa.OpARM64GreaterEqualF, + ssa.OpARM64NotLessThanF, + ssa.OpARM64NotLessEqualF, + ssa.OpARM64NotGreaterThanF, + ssa.OpARM64NotGreaterEqualF: // generate boolean values using CSET p := s.Prog(arm64.ACSET) p.From.Type = obj.TYPE_REG // assembler encodes conditional bits in Reg @@ -1098,10 +1102,16 @@ var condBits = map[ssa.Op]int16{ ssa.OpARM64GreaterThanU: arm64.COND_HI, ssa.OpARM64GreaterEqual: arm64.COND_GE, ssa.OpARM64GreaterEqualU: arm64.COND_HS, - ssa.OpARM64LessThanF: arm64.COND_MI, - ssa.OpARM64LessEqualF: arm64.COND_LS, - ssa.OpARM64GreaterThanF: arm64.COND_GT, - ssa.OpARM64GreaterEqualF: arm64.COND_GE, + ssa.OpARM64LessThanF: arm64.COND_MI, // Less than + ssa.OpARM64LessEqualF: arm64.COND_LS, // Less than or equal to + ssa.OpARM64GreaterThanF: arm64.COND_GT, // Greater than + ssa.OpARM64GreaterEqualF: arm64.COND_GE, // Greater than or equal to + + // The following condition codes have unordered to handle comparisons related to NaN. + ssa.OpARM64NotLessThanF: arm64.COND_PL, // Greater than, equal to, or unordered + ssa.OpARM64NotLessEqualF: arm64.COND_HI, // Greater than or unordered + ssa.OpARM64NotGreaterThanF: arm64.COND_LE, // Less than, equal to or unordered + ssa.OpARM64NotGreaterEqualF: arm64.COND_LT, // Less than or unordered } var blockJump = map[ssa.BlockKind]struct { diff --git a/src/cmd/compile/internal/ssa/gen/ARM64Ops.go b/src/cmd/compile/internal/ssa/gen/ARM64Ops.go index 87db2b7c9d..b0bc9c78ff 100644 --- a/src/cmd/compile/internal/ssa/gen/ARM64Ops.go +++ b/src/cmd/compile/internal/ssa/gen/ARM64Ops.go @@ -478,20 +478,24 @@ func init() { // pseudo-ops {name: "LoweredNilCheck", argLength: 2, reg: regInfo{inputs: []regMask{gpg}}, nilCheck: true, faultOnNilArg0: true}, // panic if arg0 is nil. arg1=mem. - {name: "Equal", argLength: 1, reg: readflags}, // bool, true flags encode x==y false otherwise. - {name: "NotEqual", argLength: 1, reg: readflags}, // bool, true flags encode x!=y false otherwise. - {name: "LessThan", argLength: 1, reg: readflags}, // bool, true flags encode signed x<y false otherwise. - {name: "LessEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x<=y false otherwise. - {name: "GreaterThan", argLength: 1, reg: readflags}, // bool, true flags encode signed x>y false otherwise. - {name: "GreaterEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x>=y false otherwise. - {name: "LessThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x<y false otherwise. - {name: "LessEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x<=y false otherwise. - {name: "GreaterThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>y false otherwise. - {name: "GreaterEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>=y false otherwise. - {name: "LessThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<y false otherwise. - {name: "LessEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<=y false otherwise. - {name: "GreaterThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>y false otherwise. - {name: "GreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>=y false otherwise. + {name: "Equal", argLength: 1, reg: readflags}, // bool, true flags encode x==y false otherwise. + {name: "NotEqual", argLength: 1, reg: readflags}, // bool, true flags encode x!=y false otherwise. + {name: "LessThan", argLength: 1, reg: readflags}, // bool, true flags encode signed x<y false otherwise. + {name: "LessEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x<=y false otherwise. + {name: "GreaterThan", argLength: 1, reg: readflags}, // bool, true flags encode signed x>y false otherwise. + {name: "GreaterEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x>=y false otherwise. + {name: "LessThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x<y false otherwise. + {name: "LessEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x<=y false otherwise. + {name: "GreaterThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>y false otherwise. + {name: "GreaterEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>=y false otherwise. + {name: "LessThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<y false otherwise. + {name: "LessEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<=y false otherwise. + {name: "GreaterThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>y false otherwise. + {name: "GreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>=y false otherwise. + {name: "NotLessThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>=y || x is unordered with y, false otherwise. + {name: "NotLessEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>y || x is unordered with y, false otherwise. + {name: "NotGreaterThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<=y || x is unordered with y, false otherwise. + {name: "NotGreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<y || x is unordered with y, false otherwise. // duffzero // arg0 = address of memory to zero // arg1 = mem diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 83d35cf7e1..e590f6ba5d 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -1564,6 +1564,10 @@ const ( OpARM64LessEqualF OpARM64GreaterThanF OpARM64GreaterEqualF + OpARM64NotLessThanF + OpARM64NotLessEqualF + OpARM64NotGreaterThanF + OpARM64NotGreaterEqualF OpARM64DUFFZERO OpARM64LoweredZero OpARM64DUFFCOPY @@ -20799,6 +20803,42 @@ var opcodeTable = [...]opInfo{ }, }, { + name: "NotLessThanF", + argLen: 1, + reg: regInfo{ + outputs: []outputInfo{ + {0, 670826495}, // R0 R1 R2 R3 R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R19 R20 R21 R22 R23 R24 R25 R26 R30 + }, + }, + }, + { + name: "NotLessEqualF", + argLen: 1, + reg: regInfo{ + outputs: []outputInfo{ + {0, 670826495}, // R0 R1 R2 R3 R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R19 R20 R21 R22 R23 R24 R25 R26 R30 + }, + }, + }, + { + name: "NotGreaterThanF", + argLen: 1, + reg: regInfo{ + outputs: []outputInfo{ + {0, 670826495}, // R0 R1 R2 R3 R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R19 R20 R21 R22 R23 R24 R25 R26 R30 + }, + }, + }, + { + name: "NotGreaterEqualF", + argLen: 1, + reg: regInfo{ + outputs: []outputInfo{ + {0, 670826495}, // R0 R1 R2 R3 R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R19 R20 R21 R22 R23 R24 R25 R26 R30 + }, + }, + }, + { name: "DUFFZERO", auxType: auxInt64, argLen: 2, diff --git a/src/cmd/compile/internal/ssa/rewrite.go b/src/cmd/compile/internal/ssa/rewrite.go index 24efd38fb7..f5d1a7889f 100644 --- a/src/cmd/compile/internal/ssa/rewrite.go +++ b/src/cmd/compile/internal/ssa/rewrite.go @@ -974,9 +974,10 @@ func flagArg(v *Value) *Value { } // arm64Negate finds the complement to an ARM64 condition code, -// for example Equal -> NotEqual or LessThan -> GreaterEqual +// for example !Equal -> NotEqual or !LessThan -> GreaterEqual // -// TODO: add floating-point conditions +// For floating point, it's more subtle because NaN is unordered. We do +// !LessThanF -> NotLessThanF, the latter takes care of NaNs. func arm64Negate(op Op) Op { switch op { case OpARM64LessThan: @@ -1000,13 +1001,21 @@ func arm64Negate(op Op) Op { case OpARM64NotEqual: return OpARM64Equal case OpARM64LessThanF: - return OpARM64GreaterEqualF - case OpARM64GreaterThanF: - return OpARM64LessEqualF + return OpARM64NotLessThanF + case OpARM64NotLessThanF: + return OpARM64LessThanF case OpARM64LessEqualF: + return OpARM64NotLessEqualF + case OpARM64NotLessEqualF: + return OpARM64LessEqualF + case OpARM64GreaterThanF: + return OpARM64NotGreaterThanF + case OpARM64NotGreaterThanF: return OpARM64GreaterThanF case OpARM64GreaterEqualF: - return OpARM64LessThanF + return OpARM64NotGreaterEqualF + case OpARM64NotGreaterEqualF: + return OpARM64GreaterEqualF default: panic("unreachable") } @@ -1017,8 +1026,6 @@ func arm64Negate(op Op) Op { // that the same result would be produced if the arguments // to the flag-generating instruction were reversed, e.g. // (InvertFlags (CMP x y)) -> (CMP y x) -// -// TODO: add floating-point conditions func arm64Invert(op Op) Op { switch op { case OpARM64LessThan: @@ -1047,6 +1054,14 @@ func arm64Invert(op Op) Op { return OpARM64GreaterEqualF case OpARM64GreaterEqualF: return OpARM64LessEqualF + case OpARM64NotLessThanF: + return OpARM64NotGreaterThanF + case OpARM64NotGreaterThanF: + return OpARM64NotLessThanF + case OpARM64NotLessEqualF: + return OpARM64NotGreaterEqualF + case OpARM64NotGreaterEqualF: + return OpARM64NotLessEqualF default: panic("unreachable") } diff --git a/test/fixedbugs/issue43619.go b/test/fixedbugs/issue43619.go new file mode 100644 index 0000000000..3e667851a4 --- /dev/null +++ b/test/fixedbugs/issue43619.go @@ -0,0 +1,119 @@ +// run + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "math" +) + +//go:noinline +func fcmplt(a, b float64, x uint64) uint64 { + if a < b { + x = 0 + } + return x +} + +//go:noinline +func fcmple(a, b float64, x uint64) uint64 { + if a <= b { + x = 0 + } + return x +} + +//go:noinline +func fcmpgt(a, b float64, x uint64) uint64 { + if a > b { + x = 0 + } + return x +} + +//go:noinline +func fcmpge(a, b float64, x uint64) uint64 { + if a >= b { + x = 0 + } + return x +} + +//go:noinline +func fcmpeq(a, b float64, x uint64) uint64 { + if a == b { + x = 0 + } + return x +} + +//go:noinline +func fcmpne(a, b float64, x uint64) uint64 { + if a != b { + x = 0 + } + return x +} + +func main() { + type fn func(a, b float64, x uint64) uint64 + + type testCase struct { + f fn + a, b float64 + x, want uint64 + } + NaN := math.NaN() + for _, t := range []testCase{ + {fcmplt, 1.0, 1.0, 123, 123}, + {fcmple, 1.0, 1.0, 123, 0}, + {fcmpgt, 1.0, 1.0, 123, 123}, + {fcmpge, 1.0, 1.0, 123, 0}, + {fcmpeq, 1.0, 1.0, 123, 0}, + {fcmpne, 1.0, 1.0, 123, 123}, + + {fcmplt, 1.0, 2.0, 123, 0}, + {fcmple, 1.0, 2.0, 123, 0}, + {fcmpgt, 1.0, 2.0, 123, 123}, + {fcmpge, 1.0, 2.0, 123, 123}, + {fcmpeq, 1.0, 2.0, 123, 123}, + {fcmpne, 1.0, 2.0, 123, 0}, + + {fcmplt, 2.0, 1.0, 123, 123}, + {fcmple, 2.0, 1.0, 123, 123}, + {fcmpgt, 2.0, 1.0, 123, 0}, + {fcmpge, 2.0, 1.0, 123, 0}, + {fcmpeq, 2.0, 1.0, 123, 123}, + {fcmpne, 2.0, 1.0, 123, 0}, + + {fcmplt, 1.0, NaN, 123, 123}, + {fcmple, 1.0, NaN, 123, 123}, + {fcmpgt, 1.0, NaN, 123, 123}, + {fcmpge, 1.0, NaN, 123, 123}, + {fcmpeq, 1.0, NaN, 123, 123}, + {fcmpne, 1.0, NaN, 123, 0}, + + {fcmplt, NaN, 1.0, 123, 123}, + {fcmple, NaN, 1.0, 123, 123}, + {fcmpgt, NaN, 1.0, 123, 123}, + {fcmpge, NaN, 1.0, 123, 123}, + {fcmpeq, NaN, 1.0, 123, 123}, + {fcmpne, NaN, 1.0, 123, 0}, + + {fcmplt, NaN, NaN, 123, 123}, + {fcmple, NaN, NaN, 123, 123}, + {fcmpgt, NaN, NaN, 123, 123}, + {fcmpge, NaN, NaN, 123, 123}, + {fcmpeq, NaN, NaN, 123, 123}, + {fcmpne, NaN, NaN, 123, 0}, + } { + got := t.f(t.a, t.b, t.x) + if got != t.want { + panic(fmt.Sprintf("want %v, got %v", t.want, got)) + } + } +} |