diff options
-rw-r--r-- | src/cmd/compile/internal/inline/inl.go | 2 | ||||
-rw-r--r-- | src/cmd/compile/internal/noder/expr.go | 11 | ||||
-rw-r--r-- | src/cmd/compile/internal/noder/stencil.go | 182 | ||||
-rw-r--r-- | test/typeparam/combine.go | 65 | ||||
-rw-r--r-- | test/typeparam/settable.go | 27 |
5 files changed, 236 insertions, 51 deletions
diff --git a/src/cmd/compile/internal/inline/inl.go b/src/cmd/compile/internal/inline/inl.go index e961b10844..1d049298d7 100644 --- a/src/cmd/compile/internal/inline/inl.go +++ b/src/cmd/compile/internal/inline/inl.go @@ -354,7 +354,7 @@ func (v *hairyVisitor) doNode(n ir.Node) bool { return true case ir.OCLOSURE: - if base.Debug.InlFuncsWithClosures == 0 || base.Flag.G > 0 { + if base.Debug.InlFuncsWithClosures == 0 { v.reason = "not inlining functions with closures" return true } diff --git a/src/cmd/compile/internal/noder/expr.go b/src/cmd/compile/internal/noder/expr.go index b166d34ead..3fded144dc 100644 --- a/src/cmd/compile/internal/noder/expr.go +++ b/src/cmd/compile/internal/noder/expr.go @@ -325,19 +325,22 @@ func (g *irgen) compLit(typ types2.Type, lit *syntax.CompositeLit) ir.Node { return typecheck.Expr(ir.NewCompLitExpr(g.pos(lit), ir.OCOMPLIT, ir.TypeNode(g.typ(typ)), exprs)) } -func (g *irgen) funcLit(typ types2.Type, expr *syntax.FuncLit) ir.Node { +func (g *irgen) funcLit(typ2 types2.Type, expr *syntax.FuncLit) ir.Node { fn := ir.NewFunc(g.pos(expr)) fn.SetIsHiddenClosure(ir.CurFunc != nil) fn.Nname = ir.NewNameAt(g.pos(expr), typecheck.ClosureName(ir.CurFunc)) ir.MarkFunc(fn.Nname) - fn.Nname.SetType(g.typ(typ)) + typ := g.typ(typ2) fn.Nname.Func = fn fn.Nname.Defn = fn + // Set Ntype for now to be compatible with later parts of compile, remove later. + fn.Nname.Ntype = ir.TypeNode(typ) + typed(typ, fn.Nname) + fn.SetTypecheck(1) fn.OClosure = ir.NewClosureExpr(g.pos(expr), fn) - fn.OClosure.SetType(fn.Nname.Type()) - fn.OClosure.SetTypecheck(1) + typed(typ, fn.OClosure) g.funcBody(fn, nil, expr.Type, expr.Body) diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 69461a8190..fb1bbfedc8 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -27,19 +27,37 @@ func (g *irgen) stencil() { // functions calling other generic functions. for i := 0; i < len(g.target.Decls); i++ { decl := g.target.Decls[i] - if decl.Op() != ir.ODCLFUNC || decl.Type().NumTParams() > 0 { - // Skip any non-function declarations and skip generic functions + + // Look for function instantiations in bodies of non-generic + // functions or in global assignments (ignore global type and + // constant declarations). + switch decl.Op() { + case ir.ODCLFUNC: + if decl.Type().HasTParam() { + // Skip any generic functions + continue + } + + case ir.OAS: + + case ir.OAS2: + + default: continue } - // For each non-generic function, search for any function calls using - // generic function instantiations. (We don't yet handle generic - // function instantiations that are not immediately called.) - // Then create the needed instantiated function if it hasn't been - // created yet, and change to calling that function directly. - f := decl.(*ir.Func) + // For all non-generic code, search for any function calls using + // generic function instantiations. Then create the needed + // instantiated function if it hasn't been created yet, and change + // to calling that function directly. modified := false - ir.VisitList(f.Body, func(n ir.Node) { + foundFuncInst := false + ir.Visit(decl, func(n ir.Node) { + if n.Op() == ir.OFUNCINST { + // We found a function instantiation that is not + // immediately called. + foundFuncInst = true + } if n.Op() != ir.OCALLFUNC || n.(*ir.CallExpr).X.Op() != ir.OFUNCINST { return } @@ -47,19 +65,7 @@ func (g *irgen) stencil() { // instantiation. call := n.(*ir.CallExpr) inst := call.X.(*ir.InstExpr) - sym := makeInstName(inst) - //fmt.Printf("Found generic func call in %v to %v\n", f, s) - st := g.target.Stencils[sym] - if st == nil { - // If instantiation doesn't exist yet, create it and add - // to the list of decls. - st = genericSubst(sym, inst) - g.target.Stencils[sym] = st - g.target.Decls = append(g.target.Decls, st) - if base.Flag.W > 1 { - ir.Dump(fmt.Sprintf("\nstenciled %v", st), st) - } - } + st := g.getInstantiation(inst) // Replace the OFUNCINST with a direct reference to the // new stenciled function call.X = st.Nname @@ -76,6 +82,26 @@ func (g *irgen) stencil() { } modified = true }) + + // If we found an OFUNCINST without a corresponding call in the + // above decl, then traverse the nodes of decl again (with + // EditChildren rather than Visit), where we actually change the + // OFUNCINST node to an ONAME for the instantiated function. + // EditChildren is more expensive than Visit, so we only do this + // in the infrequent case of an OFUNCINSt without a corresponding + // call. + if foundFuncInst { + var edit func(ir.Node) ir.Node + edit = func(x ir.Node) ir.Node { + if x.Op() == ir.OFUNCINST { + st := g.getInstantiation(x.(*ir.InstExpr)) + return st.Nname + } + ir.EditChildren(x, edit) + return x + } + edit(decl) + } if base.Flag.W > 1 && modified { ir.Dump(fmt.Sprintf("\nmodified %v", decl), decl) } @@ -83,18 +109,39 @@ func (g *irgen) stencil() { } -// makeInstName makes the unique name for a stenciled generic function, based on -// the name of the function and the types of the type params. -func makeInstName(inst *ir.InstExpr) *types.Sym { - b := bytes.NewBufferString("#") +// getInstantiation gets the instantiated function corresponding to inst. If the +// instantiated function is not already cached, then it calls genericStub to +// create the new instantiation. +func (g *irgen) getInstantiation(inst *ir.InstExpr) *ir.Func { + var sym *types.Sym if meth, ok := inst.X.(*ir.SelectorExpr); ok { // Write the name of the generic method, including receiver type - b.WriteString(meth.Selection.Nname.Sym().Name) + sym = makeInstName(meth.Selection.Nname.Sym(), inst.Targs) } else { - b.WriteString(inst.X.(*ir.Name).Name().Sym().Name) + sym = makeInstName(inst.X.(*ir.Name).Name().Sym(), inst.Targs) } + //fmt.Printf("Found generic func call in %v to %v\n", f, s) + st := g.target.Stencils[sym] + if st == nil { + // If instantiation doesn't exist yet, create it and add + // to the list of decls. + st = g.genericSubst(sym, inst) + g.target.Stencils[sym] = st + g.target.Decls = append(g.target.Decls, st) + if base.Flag.W > 1 { + ir.Dump(fmt.Sprintf("\nstenciled %v", st), st) + } + } + return st +} + +// makeInstName makes the unique name for a stenciled generic function, based on +// the name of the function and the targs. +func makeInstName(fnsym *types.Sym, targs []ir.Node) *types.Sym { + b := bytes.NewBufferString("#") + b.WriteString(fnsym.Name) b.WriteString("[") - for i, targ := range inst.Targs { + for i, targ := range targs { if i > 0 { b.WriteString(",") } @@ -107,6 +154,7 @@ func makeInstName(inst *ir.InstExpr) *types.Sym { // Struct containing info needed for doing the substitution as we create the // instantiation of a generic function with specified type arguments. type subster struct { + g *irgen newf *ir.Func // Func node for the new stenciled function tparams []*types.Field targs []ir.Node @@ -121,7 +169,7 @@ type subster struct { // inst. For a method with a generic receiver, it returns an instantiated function // type where the receiver becomes the first parameter. Otherwise the instantiated // method would still need to be transformed by later compiler phases. -func genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func { +func (g *irgen) genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func { var nameNode *ir.Name var tparams []*types.Field if selExpr, ok := inst.X.(*ir.SelectorExpr); ok { @@ -148,6 +196,7 @@ func genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func { name.Def = newf.Nname subst := &subster{ + g: g, newf: newf, tparams: tparams, targs: inst.Targs, @@ -198,6 +247,9 @@ func (subst *subster) node(n ir.Node) ir.Node { return v } m := ir.NewNameAt(name.Pos(), name.Sym()) + if name.IsClosureVar() { + m.SetIsClosureVar(true) + } t := x.Type() newt := subst.typ(t) m.SetType(newt) @@ -219,10 +271,12 @@ func (subst *subster) node(n ir.Node) ir.Node { // t can be nil only if this is a call that has no // return values, so allow that and otherwise give // an error. - if _, isCallExpr := m.(*ir.CallExpr); !isCallExpr { + _, isCallExpr := m.(*ir.CallExpr) + _, isStructKeyExpr := m.(*ir.StructKeyExpr) + if !isCallExpr && !isStructKeyExpr { base.Fatalf(fmt.Sprintf("Nil type for %v", x)) } - } else { + } else if x.Op() != ir.OCLOSURE { m.SetType(subst.typ(x.Type())) } } @@ -270,14 +324,27 @@ func (subst *subster) node(n ir.Node) ir.Node { if oldfn.ClosureCalled() { newfn.SetClosureCalled(true) } + newfn.SetIsHiddenClosure(true) m.(*ir.ClosureExpr).Func = newfn - newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), oldfn.Nname.Sym()) - newfn.Nname.SetType(oldfn.Nname.Type()) - newfn.Nname.Ntype = subst.node(oldfn.Nname.Ntype).(ir.Ntype) + newsym := makeInstName(oldfn.Nname.Sym(), subst.targs) + newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), newsym) + newfn.Nname.Func = newfn + newfn.Nname.Defn = newfn + ir.MarkFunc(newfn.Nname) + newfn.OClosure = m.(*ir.ClosureExpr) + + saveNewf := subst.newf + subst.newf = newfn + newfn.Dcl = subst.namelist(oldfn.Dcl) + newfn.ClosureVars = subst.namelist(oldfn.ClosureVars) newfn.Body = subst.list(oldfn.Body) - // Make shallow copy of the Dcl and ClosureVar slices - newfn.Dcl = append([]*ir.Name(nil), oldfn.Dcl...) - newfn.ClosureVars = append([]*ir.Name(nil), oldfn.ClosureVars...) + subst.newf = saveNewf + + // Set Ntype for now to be compatible with later parts of compiler + newfn.Nname.Ntype = subst.node(oldfn.Nname.Ntype).(ir.Ntype) + typed(subst.typ(oldfn.Nname.Type()), newfn.Nname) + newfn.SetTypecheck(1) + subst.g.target.Decls = append(subst.g.target.Decls, newfn) } return m } @@ -285,6 +352,20 @@ func (subst *subster) node(n ir.Node) ir.Node { return edit(n) } +func (subst *subster) namelist(l []*ir.Name) []*ir.Name { + s := make([]*ir.Name, len(l)) + for i, n := range l { + s[i] = subst.node(n).(*ir.Name) + if n.Defn != nil { + s[i].Defn = subst.node(n.Defn) + } + if n.Outer != nil { + s[i].Outer = subst.node(n.Outer).(*ir.Name) + } + } + return s +} + func (subst *subster) list(l []ir.Node) []ir.Node { s := make([]ir.Node, len(l)) for i, n := range l { @@ -293,7 +374,9 @@ func (subst *subster) list(l []ir.Node) []ir.Node { return s } -// tstruct substitutes type params in a structure type +// tstruct substitutes type params in types of the fields of a structure type. For +// each field, if Nname is set, tstruct also translates the Nname using subst.vars, if +// Nname is in subst.vars. func (subst *subster) tstruct(t *types.Type) *types.Type { if t.NumFields() == 0 { return t @@ -301,7 +384,7 @@ func (subst *subster) tstruct(t *types.Type) *types.Type { var newfields []*types.Field for i, f := range t.Fields().Slice() { t2 := subst.typ(f.Type) - if t2 != f.Type && newfields == nil { + if (t2 != f.Type || f.Nname != nil) && newfields == nil { newfields = make([]*types.Field, t.NumFields()) for j := 0; j < i; j++ { newfields[j] = t.Field(j) @@ -309,6 +392,12 @@ func (subst *subster) tstruct(t *types.Type) *types.Type { } if newfields != nil { newfields[i] = types.NewField(f.Pos, f.Sym, t2) + if f.Nname != nil { + // f.Nname may not be in subst.vars[] if this is + // a function name or a function instantiation type + // that we are translating + newfields[i].Nname = subst.vars[f.Nname.(*ir.Name)] + } } } if newfields != nil { @@ -319,14 +408,14 @@ func (subst *subster) tstruct(t *types.Type) *types.Type { } // instTypeName creates a name for an instantiated type, based on the type args -func instTypeName(name string, targs []ir.Node) string { +func instTypeName(name string, targs []*types.Type) string { b := bytes.NewBufferString(name) b.WriteByte('[') for i, targ := range targs { if i > 0 { b.WriteByte(',') } - b.WriteString(targ.Type().String()) + b.WriteString(targ.String()) } b.WriteByte(']') return b.String() @@ -415,10 +504,17 @@ func (subst *subster) typ(t *types.Type) *types.Type { // Since we've substituted types, we also need to change // the defined name of the type, by removing the old types // (in brackets) from the name, and adding the new types. + + // Translate the type params for this type according to + // the tparam/targs mapping of the function. + neededTargs := make([]*types.Type, len(t.RParams)) + for i, rparam := range t.RParams { + neededTargs[i] = subst.typ(rparam) + } oldname := t.Sym().Name i := strings.Index(oldname, "[") oldname = oldname[:i] - sym := t.Sym().Pkg.Lookup(instTypeName(oldname, subst.targs)) + sym := t.Sym().Pkg.Lookup(instTypeName(oldname, neededTargs)) if sym.Def != nil { // We've already created this instantiated defined type. return sym.Def.Type() diff --git a/test/typeparam/combine.go b/test/typeparam/combine.go new file mode 100644 index 0000000000..d4a2988a7b --- /dev/null +++ b/test/typeparam/combine.go @@ -0,0 +1,65 @@ +// run -gcflags=-G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" +) + +type _Gen[A any] func() (A, bool) + +func combine[T1, T2, T any](g1 _Gen[T1], g2 _Gen[T2], join func(T1, T2) T) _Gen[T] { + return func() (T, bool) { + var t T + t1, ok := g1() + if !ok { + return t, false + } + t2, ok := g2() + if !ok { + return t, false + } + return join(t1, t2), true + } +} + +type _Pair[A, B any] struct { + A A + B B +} + +func _NewPair[A, B any](a A, b B) _Pair[A, B] { + return _Pair[A, B]{a, b} +} + +func _Combine2[A, B any](ga _Gen[A], gb _Gen[B]) _Gen[_Pair[A, B]] { + return combine(ga, gb, _NewPair[A, B]) +} + +func main() { + var g1 _Gen[int] = func() (int, bool) { return 3, true } + var g2 _Gen[string] = func() (string, bool) { return "x", false } + var g3 _Gen[string] = func() (string, bool) { return "y", true } + + gc := combine(g1, g2, _NewPair[int, string]) + if got, ok := gc(); ok { + panic(fmt.Sprintf("got %v, %v, wanted -/false", got, ok)) + } + gc2 := _Combine2(g1, g2) + if got, ok := gc2(); ok { + panic(fmt.Sprintf("got %v, %v, wanted -/false", got, ok)) + } + + gc3 := combine(g1, g3, _NewPair[int, string]) + if got, ok := gc3(); !ok || got.A != 3 || got.B != "y" { + panic(fmt.Sprintf("got %v, %v, wanted {3, y}, true", got, ok)) + } + gc4 := _Combine2(g1, g3) + if got, ok := gc4(); !ok || got.A != 3 || got.B != "y" { + panic (fmt.Sprintf("got %v, %v, wanted {3, y}, true", got, ok)) + } +} diff --git a/test/typeparam/settable.go b/test/typeparam/settable.go index 3bd141f784..7532953a77 100644 --- a/test/typeparam/settable.go +++ b/test/typeparam/settable.go @@ -11,7 +11,24 @@ import ( "strconv" ) -func fromStrings3[T any](s []string, set func(*T, string)) []T { +type Setter[B any] interface { + Set(string) + type *B +} + +func fromStrings1[T any, PT Setter[T]](s []string) []T { + result := make([]T, len(s)) + for i, v := range s { + // The type of &result[i] is *T which is in the type list + // of Setter, so we can convert it to PT. + p := PT(&result[i]) + // PT has a Set method. + p.Set(v) + } + return result +} + +func fromStrings2[T any](s []string, set func(*T, string)) []T { results := make([]T, len(s)) for i, v := range s { set(&results[i], v) @@ -30,8 +47,12 @@ func (p *Settable) Set(s string) { } func main() { - s := fromStrings3([]string{"1"}, - func(p *Settable, s string) { p.Set(s) }) + s := fromStrings1[Settable, *Settable]([]string{"1"}) + if len(s) != 1 || s[0] != 1 { + panic(fmt.Sprintf("got %v, want %v", s, []int{1})) + } + + s = fromStrings2([]string{"1"}, func(p *Settable, s string) { p.Set(s) }) if len(s) != 1 || s[0] != 1 { panic(fmt.Sprintf("got %v, want %v", s, []int{1})) } |