aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDan Scales <danscales@google.com>2021-07-26 19:25:40 -0700
committerDan Scales <danscales@google.com>2021-07-29 21:58:06 +0000
commit600b7b431bd546841c04a27d4ac73af1e4f2fcc4 (patch)
tree2f3b007590862af07f7eedc69e237d9111fb1b61
parent5ecbd811b54f478244b7e54a621f32b5b8e3ea95 (diff)
downloadgo-600b7b431bd546841c04a27d4ac73af1e4f2fcc4.tar.gz
go-600b7b431bd546841c04a27d4ac73af1e4f2fcc4.zip
[dev.typeparams] cmd/compile: handle meth expressions on typeparams
Rewrite a method expression such as 'T.String' (where T is type param and String is part of its type bound Stringer) as: func(rcvr T, other params...) { return Stringer(rcvr).String(other params...) } New function buildClosure2 to create the needed closure. The conversion Stringer(rcvr) uses the dictionary in the outer function. For a method expression like 'Test[T].finish' (where finish is a method of Test[T]), we can already deal with this in buildClosure(). We just need fix transformDot() to allow the method lookup to fail, since shapes have no methods on them. That's fine, since for any instantiated receiver type, we always use the methods on the generic base type. Also removed the OMETHEXPR case in the main switch of node(), which isn't needed any (and removes one more potential unshapify). Also, fixed two small bugs with handling closures that have generic params or generic captured variables. Need to set the instInfo for the closure in the subst struct when descending into a closure during genericSubst() and was missing initializing the startItabConv and gfInfo fields in the closure info. Change-Id: I6dadedd1378477936a27c9c544c014cd2083cfb7 Reviewed-on: https://go-review.googlesource.com/c/go/+/338129 Trust: Dan Scales <danscales@google.com> Run-TryBot: Dan Scales <danscales@google.com> TryBot-Result: Go Bot <gobot@golang.org> Reviewed-by: Keith Randall <khr@golang.org>
-rw-r--r--src/cmd/compile/internal/noder/stencil.go214
-rw-r--r--src/cmd/compile/internal/noder/transform.go6
-rw-r--r--test/typeparam/boundmethod.go54
3 files changed, 198 insertions, 76 deletions
diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go
index 70a2c7b97f..83abee1dd2 100644
--- a/src/cmd/compile/internal/noder/stencil.go
+++ b/src/cmd/compile/internal/noder/stencil.go
@@ -357,8 +357,7 @@ func (g *irgen) buildClosure(outer *ir.Func, x ir.Node) ir.Node {
// }
// Make a new internal function.
- fn := ir.NewClosureFunc(pos, outer != nil)
- ir.NameClosure(fn.OClosure, outer)
+ fn, formalParams, formalResults := startClosure(pos, outer, typ)
// This is the dictionary we want to use.
// It may be a constant, or it may be a dictionary acquired from the outer function's dictionary.
@@ -395,38 +394,6 @@ func (g *irgen) buildClosure(outer *ir.Func, x ir.Node) ir.Node {
outer.Dcl = append(outer.Dcl, rcvrVar)
}
- // Build formal argument and return lists.
- var formalParams []*types.Field // arguments of closure
- var formalResults []*types.Field // returns of closure
- for i := 0; i < typ.NumParams(); i++ {
- t := typ.Params().Field(i).Type
- arg := ir.NewNameAt(pos, typecheck.LookupNum("a", i))
- arg.Class = ir.PPARAM
- typed(t, arg)
- arg.Curfn = fn
- fn.Dcl = append(fn.Dcl, arg)
- f := types.NewField(pos, arg.Sym(), t)
- f.Nname = arg
- formalParams = append(formalParams, f)
- }
- for i := 0; i < typ.NumResults(); i++ {
- t := typ.Results().Field(i).Type
- result := ir.NewNameAt(pos, typecheck.LookupNum("r", i)) // TODO: names not needed?
- result.Class = ir.PPARAMOUT
- typed(t, result)
- result.Curfn = fn
- fn.Dcl = append(fn.Dcl, result)
- f := types.NewField(pos, result.Sym(), t)
- f.Nname = result
- formalResults = append(formalResults, f)
- }
-
- // Build an internal function with the right signature.
- closureType := types.NewSignature(x.Type().Pkg(), nil, nil, formalParams, formalResults)
- typed(closureType, fn.Nname)
- typed(x.Type(), fn.OClosure)
- fn.SetTypecheck(1)
-
// Build body of closure. This involves just calling the wrapped function directly
// with the additional dictionary argument.
@@ -1092,14 +1059,15 @@ func getDictionaryEntry(pos src.XPos, dict *ir.Name, i int, size int) ir.Node {
return r
}
-// getDictionaryType returns a *runtime._type from the dictionary entry i
-// (which refers to a type param or a derived type that uses type params).
-func (subst *subster) getDictionaryType(pos src.XPos, i int) ir.Node {
- if i < 0 || i >= subst.info.startSubDict {
+// getDictionaryType returns a *runtime._type from the dictionary entry i (which
+// refers to a type param or a derived type that uses type params). It uses the
+// specified dictionary dictParam, rather than the one in info.dictParam.
+func getDictionaryType(info *instInfo, dictParam *ir.Name, pos src.XPos, i int) ir.Node {
+ if i < 0 || i >= info.startSubDict {
base.Fatalf(fmt.Sprintf("bad dict index %d", i))
}
- r := getDictionaryEntry(pos, subst.info.dictParam, i, subst.info.startSubDict)
+ r := getDictionaryEntry(pos, info.dictParam, i, info.startSubDict)
// change type of retrieved dictionary entry to *byte, which is the
// standard typing of a *runtime._type in the compiler
typed(types.Types[types.TUINT8].PtrTo(), r)
@@ -1235,9 +1203,9 @@ func (subst *subster) node(n ir.Node) ir.Node {
// The only dot on a shape type value are methods.
if mse.X.Op() == ir.OTYPE {
// Method expression T.M
- // Fall back from shape type to concrete type.
- src = subst.unshapifyTyp(src)
- mse.X = ir.TypeNode(src)
+ m = subst.g.buildClosure2(subst.newf, subst.info, m, x)
+ // No need for transformDot - buildClosure2 has already
+ // transformed to OCALLINTER/ODOTINTER.
} else {
// Implement x.M as a conversion-to-bound-interface
// 1) convert x to the bound interface
@@ -1247,12 +1215,11 @@ func (subst *subster) node(n ir.Node) ir.Node {
if dst.HasTParam() {
dst = subst.ts.Typ(dst)
}
- mse.X = subst.convertUsingDictionary(m.Pos(), mse.X, x, dst, gsrc)
+ mse.X = convertUsingDictionary(subst.info, subst.info.dictParam, m.Pos(), mse.X, x, dst, gsrc)
+ transformDot(mse, false)
}
- }
- transformDot(mse, false)
- if mse.Op() == ir.OMETHEXPR && mse.X.Type().HasShape() {
- mse.X = ir.TypeNodeAt(mse.X.Pos(), subst.unshapifyTyp(mse.X.Type()))
+ } else {
+ transformDot(mse, false)
}
m.SetTypecheck(1)
@@ -1341,11 +1308,14 @@ func (subst *subster) node(n ir.Node) ir.Node {
// outer function. Since the dictionary is shared, use the
// same entries for startSubDict, dictLen, dictEntryMap.
cinfo := &instInfo{
- fun: newfn,
- dictParam: cdict,
- startSubDict: subst.info.startSubDict,
- dictLen: subst.info.dictLen,
- dictEntryMap: subst.info.dictEntryMap,
+ fun: newfn,
+ dictParam: cdict,
+ gf: subst.info.gf,
+ gfInfo: subst.info.gfInfo,
+ startSubDict: subst.info.startSubDict,
+ startItabConv: subst.info.startItabConv,
+ dictLen: subst.info.dictLen,
+ dictEntryMap: subst.info.dictEntryMap,
}
subst.g.instInfoMap[newfn.Nname.Sym()] = cinfo
@@ -1353,8 +1323,11 @@ func (subst *subster) node(n ir.Node) ir.Node {
typed(newfn.Nname.Type(), newfn.OClosure)
newfn.SetTypecheck(1)
+ outerinfo := subst.info
+ subst.info = cinfo
// Make sure type of closure function is set before doing body.
newfn.Body = subst.list(oldfn.Body)
+ subst.info = outerinfo
subst.newf = saveNewf
ir.CurFunc = saveNewf
@@ -1366,15 +1339,15 @@ func (subst *subster) node(n ir.Node) ir.Node {
// Note: x's argument is still typed as a type parameter.
// m's argument now has an instantiated type.
if x.X.Type().HasTParam() {
- m = subst.convertUsingDictionary(m.Pos(), m.(*ir.ConvExpr).X, x, m.Type(), x.X.Type())
+ m = convertUsingDictionary(subst.info, subst.info.dictParam, m.Pos(), m.(*ir.ConvExpr).X, x, m.Type(), x.X.Type())
}
case ir.ODOTTYPE, ir.ODOTTYPE2:
dt := m.(*ir.TypeAssertExpr)
var rt ir.Node
if dt.Type().IsInterface() || dt.X.Type().IsEmptyInterface() {
- ix := subst.findDictType(x.Type())
+ ix := findDictType(subst.info, x.Type())
assert(ix >= 0)
- rt = subst.getDictionaryType(dt.Pos(), ix)
+ rt = getDictionaryType(subst.info, subst.info.dictParam, dt.Pos(), ix)
} else {
// nonempty interface to noninterface. Need an itab.
ix := -1
@@ -1395,9 +1368,6 @@ func (subst *subster) node(n ir.Node) ir.Node {
m.SetType(dt.Type())
m.SetTypecheck(1)
- case ir.OMETHEXPR:
- se := m.(*ir.SelectorExpr)
- se.X = ir.TypeNodeAt(se.X.Pos(), subst.unshapifyTyp(se.X.Type()))
case ir.OFUNCINST:
inst := m.(*ir.InstExpr)
targs2 := make([]ir.Node, len(inst.Targs))
@@ -1414,17 +1384,17 @@ func (subst *subster) node(n ir.Node) ir.Node {
}
// findDictType looks for type t in the typeparams or derived types in the generic
-// function info subst.info.gfInfo. This will indicate the dictionary entry with the
+// function info.gfInfo. This will indicate the dictionary entry with the
// correct concrete type for the associated instantiated function.
-func (subst *subster) findDictType(t *types.Type) int {
- for i, dt := range subst.info.gfInfo.tparams {
+func findDictType(info *instInfo, t *types.Type) int {
+ for i, dt := range info.gfInfo.tparams {
if dt == t {
return i
}
}
- for i, dt := range subst.info.gfInfo.derivedTypes {
+ for i, dt := range info.gfInfo.derivedTypes {
if types.Identical(dt, t) {
- return i + len(subst.info.gfInfo.tparams)
+ return i + len(info.gfInfo.tparams)
}
}
return -1
@@ -1435,7 +1405,7 @@ func (subst *subster) findDictType(t *types.Type) int {
// is the generic (not shape) type, and gn is the original generic node of the
// CONVIFACE node or XDOT node (for a bound method call) that is causing the
// conversion.
-func (subst *subster) convertUsingDictionary(pos src.XPos, v ir.Node, gn ir.Node, dst, src *types.Type) ir.Node {
+func convertUsingDictionary(info *instInfo, dictParam *ir.Name, pos src.XPos, v ir.Node, gn ir.Node, dst, src *types.Type) ir.Node {
assert(src.HasTParam())
assert(dst.IsInterface())
@@ -1445,19 +1415,19 @@ func (subst *subster) convertUsingDictionary(pos src.XPos, v ir.Node, gn ir.Node
// will be more efficient than converting to an empty interface first
// and then type asserting to dst.
ix := -1
- for i, ic := range subst.info.gfInfo.itabConvs {
+ for i, ic := range info.gfInfo.itabConvs {
if ic == gn {
- ix = subst.info.startItabConv + i
+ ix = info.startItabConv + i
break
}
}
assert(ix >= 0)
- rt = getDictionaryEntry(pos, subst.info.dictParam, ix, subst.info.dictLen)
+ rt = getDictionaryEntry(pos, dictParam, ix, info.dictLen)
} else {
- ix := subst.findDictType(src)
+ ix := findDictType(info, src)
assert(ix >= 0)
// Load the actual runtime._type of the type parameter from the dictionary.
- rt = subst.getDictionaryType(pos, ix)
+ rt = getDictionaryType(info, dictParam, pos, ix)
}
// Figure out what the data field of the interface will be.
@@ -1670,7 +1640,7 @@ func (g *irgen) getDictionarySym(gf *ir.Name, targs []*types.Type, isMeth bool)
case ir.OXDOT:
selExpr := n.(*ir.SelectorExpr)
- subtargs := selExpr.X.Type().RParams()
+ subtargs := deref(selExpr.X.Type()).RParams()
s2targs := make([]*types.Type, len(subtargs))
for i, t := range subtargs {
s2targs[i] = subst.Typ(t)
@@ -1842,7 +1812,7 @@ func (g *irgen) getGfInfo(gn *ir.Name) *gfInfo {
}
} else if n.Op() == ir.OXDOT && !n.(*ir.SelectorExpr).Implicit() &&
n.(*ir.SelectorExpr).Selection != nil &&
- len(n.(*ir.SelectorExpr).X.Type().RParams()) > 0 {
+ len(deref(n.(*ir.SelectorExpr).X.Type()).RParams()) > 0 {
if n.(*ir.SelectorExpr).X.Op() == ir.OTYPE {
infoPrint(" Closure&subdictionary required at generic meth expr %v\n", n)
} else {
@@ -2017,3 +1987,105 @@ func parameterizedBy1(t *types.Type, params []*types.Type, visited map[*types.Ty
return true
}
}
+
+// startClosures starts creation of a closure that has the function type typ. It
+// creates all the formal params and results according to the type typ. On return,
+// the body and closure variables of the closure must still be filled in, and
+// ir.UseClosure() called.
+func startClosure(pos src.XPos, outer *ir.Func, typ *types.Type) (*ir.Func, []*types.Field, []*types.Field) {
+ // Make a new internal function.
+ fn := ir.NewClosureFunc(pos, outer != nil)
+ ir.NameClosure(fn.OClosure, outer)
+
+ // Build formal argument and return lists.
+ var formalParams []*types.Field // arguments of closure
+ var formalResults []*types.Field // returns of closure
+ for i := 0; i < typ.NumParams(); i++ {
+ t := typ.Params().Field(i).Type
+ arg := ir.NewNameAt(pos, typecheck.LookupNum("a", i))
+ arg.Class = ir.PPARAM
+ typed(t, arg)
+ arg.Curfn = fn
+ fn.Dcl = append(fn.Dcl, arg)
+ f := types.NewField(pos, arg.Sym(), t)
+ f.Nname = arg
+ formalParams = append(formalParams, f)
+ }
+ for i := 0; i < typ.NumResults(); i++ {
+ t := typ.Results().Field(i).Type
+ result := ir.NewNameAt(pos, typecheck.LookupNum("r", i)) // TODO: names not needed?
+ result.Class = ir.PPARAMOUT
+ typed(t, result)
+ result.Curfn = fn
+ fn.Dcl = append(fn.Dcl, result)
+ f := types.NewField(pos, result.Sym(), t)
+ f.Nname = result
+ formalResults = append(formalResults, f)
+ }
+
+ // Build an internal function with the right signature.
+ closureType := types.NewSignature(typ.Pkg(), nil, nil, formalParams, formalResults)
+ typed(closureType, fn.Nname)
+ typed(typ, fn.OClosure)
+ fn.SetTypecheck(1)
+ return fn, formalParams, formalResults
+
+}
+
+// buildClosure2 makes a closure to implement a method expression m (generic form x)
+// which has a shape type as receiver. If the receiver is exactly a shape (i.e. from
+// a typeparam), then the body of the closure converts the first argument (the
+// receiver) to the interface bound type, and makes an interface call with the
+// remaining arguments.
+//
+// The returned closure is fully substituted and has already has any needed
+// transformations done.
+func (g *irgen) buildClosure2(outer *ir.Func, info *instInfo, m, x ir.Node) ir.Node {
+ pos := m.Pos()
+ typ := m.Type() // type of the closure
+
+ fn, formalParams, formalResults := startClosure(pos, outer, typ)
+
+ // Capture dictionary calculated in the outer function
+ dictVar := ir.CaptureName(pos, fn, info.dictParam)
+ typed(types.Types[types.TUINTPTR], dictVar)
+
+ // Build arguments to call inside the closure.
+ var args []ir.Node
+ for i := 0; i < typ.NumParams(); i++ {
+ args = append(args, formalParams[i].Nname.(*ir.Name))
+ }
+
+ // Build call itself. This involves converting the first argument to the
+ // bound type (an interface) using the dictionary, and then making an
+ // interface call with the remaining arguments.
+ var innerCall ir.Node
+ rcvr := args[0]
+ args = args[1:]
+ assert(m.(*ir.SelectorExpr).X.Type().IsShape())
+ rcvr = convertUsingDictionary(info, dictVar, pos, rcvr, x, x.(*ir.SelectorExpr).X.Type().Bound(), x.(*ir.SelectorExpr).X.Type())
+ dot := ir.NewSelectorExpr(pos, ir.ODOTINTER, rcvr, x.(*ir.SelectorExpr).Sel)
+ dot.Selection = typecheck.Lookdot1(dot, dot.Sel, dot.X.Type(), dot.X.Type().AllMethods(), 1)
+
+ typed(x.(*ir.SelectorExpr).Selection.Type, dot)
+ innerCall = ir.NewCallExpr(pos, ir.OCALLINTER, dot, args)
+ t := m.Type()
+ if t.NumResults() == 0 {
+ innerCall.SetTypecheck(1)
+ } else if t.NumResults() == 1 {
+ typed(t.Results().Field(0).Type, innerCall)
+ } else {
+ typed(t.Results(), innerCall)
+ }
+ if len(formalResults) > 0 {
+ innerCall = ir.NewReturnStmt(pos, []ir.Node{innerCall})
+ innerCall.SetTypecheck(1)
+ }
+ fn.Body = []ir.Node{innerCall}
+
+ // We're all done with the captured dictionary
+ ir.FinishCaptureNames(pos, outer, fn)
+
+ // Do final checks on closure and return it.
+ return ir.UseClosure(fn.OClosure, g.target)
+}
diff --git a/src/cmd/compile/internal/noder/transform.go b/src/cmd/compile/internal/noder/transform.go
index 2fe55a6852..9c791d8a7b 100644
--- a/src/cmd/compile/internal/noder/transform.go
+++ b/src/cmd/compile/internal/noder/transform.go
@@ -664,7 +664,11 @@ func transformMethodExpr(n *ir.SelectorExpr) (res ir.Node) {
s := n.Sel
m := typecheck.Lookdot1(n, s, t, ms, 0)
- assert(m != nil)
+ if !t.HasShape() {
+ // It's OK to not find the method if t is instantiated by shape types,
+ // because we will use the methods on the generic type anyway.
+ assert(m != nil)
+ }
n.SetOp(ir.OMETHEXPR)
n.Selection = m
diff --git a/test/typeparam/boundmethod.go b/test/typeparam/boundmethod.go
index 3deabbcdce..22f416422d 100644
--- a/test/typeparam/boundmethod.go
+++ b/test/typeparam/boundmethod.go
@@ -29,32 +29,78 @@ type Stringer interface {
func stringify[T Stringer](s []T) (ret []string) {
for _, v := range s {
+ // Test normal bounds method call on type param
+ x1 := v.String()
+
+ // Test converting type param to its bound interface first
+ v1 := Stringer(v)
+ x2 := v1.String()
+
+ // Test method expression with type param type
+ f1 := T.String
+ x3 := f1(v)
+
+ // Test creating and calling closure equivalent to the method expression
+ f2 := func(v1 T) string {
+ return Stringer(v1).String()
+ }
+ x4 := f2(v)
+
+ if x1 != x2 || x2 != x3 || x3 != x4 {
+ panic(fmt.Sprintf("Mismatched values %v, %v, %v, %v\n", x1, x2, x3, x4))
+ }
+
ret = append(ret, v.String())
}
return ret
}
-type StringInt[T any] T
+type Ints interface {
+ ~int32 | ~int
+}
+
+type StringInt[T Ints] T
//go:noinline
func (m StringInt[T]) String() string {
- return "aa"
+ return strconv.Itoa(int(m))
+}
+
+type StringStruct[T Ints] struct {
+ f T
+}
+
+func (m StringStruct[T]) String() string {
+ return strconv.Itoa(int(m.f))
}
func main() {
x := []myint{myint(1), myint(2), myint(3)}
+ // stringify on a normal type, whose bound method is associated with the base type.
got := stringify(x)
want := []string{"1", "2", "3"}
if !reflect.DeepEqual(got, want) {
panic(fmt.Sprintf("got %s, want %s", got, want))
}
- x2 := []StringInt[myint]{StringInt[myint](1), StringInt[myint](2), StringInt[myint](3)}
+ x2 := []StringInt[myint]{StringInt[myint](5), StringInt[myint](7), StringInt[myint](6)}
+ // stringify on an instantiated type, whose bound method is associated with
+ // the generic type StringInt[T], which maps directly to T.
got2 := stringify(x2)
- want2 := []string{"aa", "aa", "aa"}
+ want2 := []string{ "5", "7", "6" }
if !reflect.DeepEqual(got2, want2) {
panic(fmt.Sprintf("got %s, want %s", got2, want2))
}
+
+ // stringify on an instantiated type, whose bound method is associated with
+ // the generic type StringStruct[T], which maps to a struct containing T.
+ x3 := []StringStruct[myint]{StringStruct[myint]{f: 11}, StringStruct[myint]{f: 10}, StringStruct[myint]{f: 9}}
+
+ got3 := stringify(x3)
+ want3 := []string{ "11", "10", "9" }
+ if !reflect.DeepEqual(got3, want3) {
+ panic(fmt.Sprintf("got %s, want %s", got3, want3))
+ }
}