aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cmd/compile/internal/syntax/printer.go57
-rw-r--r--src/cmd/compile/internal/syntax/printer_test.go37
2 files changed, 79 insertions, 15 deletions
diff --git a/src/cmd/compile/internal/syntax/printer.go b/src/cmd/compile/internal/syntax/printer.go
index 9109ce2363..e557f5d924 100644
--- a/src/cmd/compile/internal/syntax/printer.go
+++ b/src/cmd/compile/internal/syntax/printer.go
@@ -481,10 +481,10 @@ func (p *printer) printRawNode(n Node) {
if len(n.FieldList) > 0 {
if p.linebreaks {
p.print(newline, indent)
- p.printFieldList(n.FieldList, n.TagList)
+ p.printFieldList(n.FieldList, n.TagList, _Semi)
p.print(outdent, newline)
} else {
- p.printFieldList(n.FieldList, n.TagList)
+ p.printFieldList(n.FieldList, n.TagList, _Semi)
}
}
p.print(_Rbrace)
@@ -494,20 +494,40 @@ func (p *printer) printRawNode(n Node) {
p.printSignature(n)
case *InterfaceType:
+ // separate type list and method list
+ var types []Expr
+ var methods []*Field
+ for _, f := range n.MethodList {
+ if f.Name != nil && f.Name.Value == "type" {
+ types = append(types, f.Type)
+ } else {
+ // method or embedded interface
+ methods = append(methods, f)
+ }
+ }
+
+ multiLine := len(n.MethodList) > 0 && p.linebreaks
p.print(_Interface)
- if len(n.MethodList) > 0 && p.linebreaks {
+ if multiLine {
p.print(blank)
}
p.print(_Lbrace)
- if len(n.MethodList) > 0 {
- if p.linebreaks {
- p.print(newline, indent)
- p.printMethodList(n.MethodList)
- p.print(outdent, newline)
- } else {
- p.printMethodList(n.MethodList)
+ if multiLine {
+ p.print(newline, indent)
+ }
+ if len(types) > 0 {
+ p.print(_Type, blank)
+ p.printExprList(types)
+ if len(methods) > 0 {
+ p.print(_Semi, blank)
}
}
+ if len(methods) > 0 {
+ p.printMethodList(methods)
+ }
+ if multiLine {
+ p.print(outdent, newline)
+ }
p.print(_Rbrace)
case *MapType:
@@ -667,7 +687,13 @@ func (p *printer) printRawNode(n Node) {
if n.Group == nil {
p.print(_Type, blank)
}
- p.print(n.Name, blank)
+ p.print(n.Name)
+ if n.TParamList != nil {
+ p.print(_Lbrack)
+ p.printFieldList(n.TParamList, nil, _Comma)
+ p.print(_Rbrack)
+ }
+ p.print(blank)
if n.Alias {
p.print(_Assign, blank)
}
@@ -696,6 +722,11 @@ func (p *printer) printRawNode(n Node) {
p.print(_Rparen, blank)
}
p.print(n.Name)
+ if n.TParamList != nil {
+ p.print(_Lbrack)
+ p.printFieldList(n.TParamList, nil, _Comma)
+ p.print(_Rbrack)
+ }
p.printSignature(n.Type)
if n.Body != nil {
p.print(blank, n.Body)
@@ -746,14 +777,14 @@ func (p *printer) printFields(fields []*Field, tags []*BasicLit, i, j int) {
}
}
-func (p *printer) printFieldList(fields []*Field, tags []*BasicLit) {
+func (p *printer) printFieldList(fields []*Field, tags []*BasicLit, sep token) {
i0 := 0
var typ Expr
for i, f := range fields {
if f.Name == nil || f.Type != typ {
if i0 < i {
p.printFields(fields, tags, i0, i)
- p.print(_Semi, newline)
+ p.print(sep, newline)
i0 = i
}
typ = f.Type
diff --git a/src/cmd/compile/internal/syntax/printer_test.go b/src/cmd/compile/internal/syntax/printer_test.go
index bcae815a46..4890327595 100644
--- a/src/cmd/compile/internal/syntax/printer_test.go
+++ b/src/cmd/compile/internal/syntax/printer_test.go
@@ -61,6 +61,21 @@ var stringTests = []string{
"package p",
"package p; type _ int; type T1 = struct{}; type ( _ *struct{}; T2 = float32 )",
+ // generic type declarations
+ "package p; type _[T any] struct{}",
+ "package p; type _[A, B, C interface{m()}] struct{}",
+ "package p; type _[T any, A, B, C interface{m()}, X, Y, Z interface{type int}] struct{}",
+
+ // generic function declarations
+ "package p; func _[T any]()",
+ "package p; func _[A, B, C interface{m()}]()",
+ "package p; func _[T any, A, B, C interface{m()}, X, Y, Z interface{type int}]()",
+
+ // methods with generic receiver types
+ "package p; func (R[T]) _()",
+ "package p; func (*R[A, B, C]) _()",
+ "package p; func (_ *R[A, B, C]) _()",
+
// channels
"package p; type _ chan chan int",
"package p; type _ chan (<-chan int)",
@@ -79,7 +94,7 @@ var stringTests = []string{
func TestPrintString(t *testing.T) {
for _, want := range stringTests {
- ast, err := Parse(nil, strings.NewReader(want), nil, nil, 0)
+ ast, err := Parse(nil, strings.NewReader(want), nil, nil, AllowGenerics)
if err != nil {
t.Error(err)
continue
@@ -116,6 +131,24 @@ var exprTests = [][2]string{
{"func(x int) complex128 { return 0 }", "func(x int) complex128 {…}"},
{"[]int{1, 2, 3}", "[]int{…}"},
+ // type expressions
+ dup("[1 << 10]byte"),
+ dup("[]int"),
+ dup("*int"),
+ dup("struct{x int}"),
+ dup("func()"),
+ dup("func(int, float32) string"),
+ dup("interface{m()}"),
+ dup("interface{m() string; n(x int)}"),
+ dup("interface{type int}"),
+ dup("interface{type int, float64, string}"),
+ dup("interface{type int; m()}"),
+ dup("interface{type int, float64, string; m() string; n(x int)}"),
+ dup("map[string]int"),
+ dup("chan E"),
+ dup("<-chan E"),
+ dup("chan<- E"),
+
// non-type expressions
dup("(x)"),
dup("x.f"),
@@ -172,7 +205,7 @@ var exprTests = [][2]string{
func TestShortString(t *testing.T) {
for _, test := range exprTests {
src := "package p; var _ = " + test[0]
- ast, err := Parse(nil, strings.NewReader(src), nil, nil, 0)
+ ast, err := Parse(nil, strings.NewReader(src), nil, nil, AllowGenerics)
if err != nil {
t.Errorf("%s: %s", test[0], err)
continue