diff options
-rw-r--r-- | src/cmd/compile/internal/syntax/printer.go | 57 | ||||
-rw-r--r-- | src/cmd/compile/internal/syntax/printer_test.go | 37 |
2 files changed, 79 insertions, 15 deletions
diff --git a/src/cmd/compile/internal/syntax/printer.go b/src/cmd/compile/internal/syntax/printer.go index 9109ce2363..e557f5d924 100644 --- a/src/cmd/compile/internal/syntax/printer.go +++ b/src/cmd/compile/internal/syntax/printer.go @@ -481,10 +481,10 @@ func (p *printer) printRawNode(n Node) { if len(n.FieldList) > 0 { if p.linebreaks { p.print(newline, indent) - p.printFieldList(n.FieldList, n.TagList) + p.printFieldList(n.FieldList, n.TagList, _Semi) p.print(outdent, newline) } else { - p.printFieldList(n.FieldList, n.TagList) + p.printFieldList(n.FieldList, n.TagList, _Semi) } } p.print(_Rbrace) @@ -494,20 +494,40 @@ func (p *printer) printRawNode(n Node) { p.printSignature(n) case *InterfaceType: + // separate type list and method list + var types []Expr + var methods []*Field + for _, f := range n.MethodList { + if f.Name != nil && f.Name.Value == "type" { + types = append(types, f.Type) + } else { + // method or embedded interface + methods = append(methods, f) + } + } + + multiLine := len(n.MethodList) > 0 && p.linebreaks p.print(_Interface) - if len(n.MethodList) > 0 && p.linebreaks { + if multiLine { p.print(blank) } p.print(_Lbrace) - if len(n.MethodList) > 0 { - if p.linebreaks { - p.print(newline, indent) - p.printMethodList(n.MethodList) - p.print(outdent, newline) - } else { - p.printMethodList(n.MethodList) + if multiLine { + p.print(newline, indent) + } + if len(types) > 0 { + p.print(_Type, blank) + p.printExprList(types) + if len(methods) > 0 { + p.print(_Semi, blank) } } + if len(methods) > 0 { + p.printMethodList(methods) + } + if multiLine { + p.print(outdent, newline) + } p.print(_Rbrace) case *MapType: @@ -667,7 +687,13 @@ func (p *printer) printRawNode(n Node) { if n.Group == nil { p.print(_Type, blank) } - p.print(n.Name, blank) + p.print(n.Name) + if n.TParamList != nil { + p.print(_Lbrack) + p.printFieldList(n.TParamList, nil, _Comma) + p.print(_Rbrack) + } + p.print(blank) if n.Alias { p.print(_Assign, blank) } @@ -696,6 +722,11 @@ func (p *printer) printRawNode(n Node) { p.print(_Rparen, blank) } p.print(n.Name) + if n.TParamList != nil { + p.print(_Lbrack) + p.printFieldList(n.TParamList, nil, _Comma) + p.print(_Rbrack) + } p.printSignature(n.Type) if n.Body != nil { p.print(blank, n.Body) @@ -746,14 +777,14 @@ func (p *printer) printFields(fields []*Field, tags []*BasicLit, i, j int) { } } -func (p *printer) printFieldList(fields []*Field, tags []*BasicLit) { +func (p *printer) printFieldList(fields []*Field, tags []*BasicLit, sep token) { i0 := 0 var typ Expr for i, f := range fields { if f.Name == nil || f.Type != typ { if i0 < i { p.printFields(fields, tags, i0, i) - p.print(_Semi, newline) + p.print(sep, newline) i0 = i } typ = f.Type diff --git a/src/cmd/compile/internal/syntax/printer_test.go b/src/cmd/compile/internal/syntax/printer_test.go index bcae815a46..4890327595 100644 --- a/src/cmd/compile/internal/syntax/printer_test.go +++ b/src/cmd/compile/internal/syntax/printer_test.go @@ -61,6 +61,21 @@ var stringTests = []string{ "package p", "package p; type _ int; type T1 = struct{}; type ( _ *struct{}; T2 = float32 )", + // generic type declarations + "package p; type _[T any] struct{}", + "package p; type _[A, B, C interface{m()}] struct{}", + "package p; type _[T any, A, B, C interface{m()}, X, Y, Z interface{type int}] struct{}", + + // generic function declarations + "package p; func _[T any]()", + "package p; func _[A, B, C interface{m()}]()", + "package p; func _[T any, A, B, C interface{m()}, X, Y, Z interface{type int}]()", + + // methods with generic receiver types + "package p; func (R[T]) _()", + "package p; func (*R[A, B, C]) _()", + "package p; func (_ *R[A, B, C]) _()", + // channels "package p; type _ chan chan int", "package p; type _ chan (<-chan int)", @@ -79,7 +94,7 @@ var stringTests = []string{ func TestPrintString(t *testing.T) { for _, want := range stringTests { - ast, err := Parse(nil, strings.NewReader(want), nil, nil, 0) + ast, err := Parse(nil, strings.NewReader(want), nil, nil, AllowGenerics) if err != nil { t.Error(err) continue @@ -116,6 +131,24 @@ var exprTests = [][2]string{ {"func(x int) complex128 { return 0 }", "func(x int) complex128 {…}"}, {"[]int{1, 2, 3}", "[]int{…}"}, + // type expressions + dup("[1 << 10]byte"), + dup("[]int"), + dup("*int"), + dup("struct{x int}"), + dup("func()"), + dup("func(int, float32) string"), + dup("interface{m()}"), + dup("interface{m() string; n(x int)}"), + dup("interface{type int}"), + dup("interface{type int, float64, string}"), + dup("interface{type int; m()}"), + dup("interface{type int, float64, string; m() string; n(x int)}"), + dup("map[string]int"), + dup("chan E"), + dup("<-chan E"), + dup("chan<- E"), + // non-type expressions dup("(x)"), dup("x.f"), @@ -172,7 +205,7 @@ var exprTests = [][2]string{ func TestShortString(t *testing.T) { for _, test := range exprTests { src := "package p; var _ = " + test[0] - ast, err := Parse(nil, strings.NewReader(src), nil, nil, 0) + ast, err := Parse(nil, strings.NewReader(src), nil, nil, AllowGenerics) if err != nil { t.Errorf("%s: %s", test[0], err) continue |