aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobert Griesemer <gri@golang.org>2021-03-03 18:33:45 -0800
committerRobert Griesemer <gri@golang.org>2021-03-04 22:20:29 +0000
commita99ff24a26939f30440dd0f06dce426ed5e638ee (patch)
tree130202788ccb64a05f53351f985780cc03d71b3c
parent9d3718e834fcf5b602b84539364606445cfc8a1a (diff)
downloadgo-a99ff24a26939f30440dd0f06dce426ed5e638ee.tar.gz
go-a99ff24a26939f30440dd0f06dce426ed5e638ee.zip
cmd/compile/internal/syntax: print type parameters and type lists
types2 uses the syntax printer to print expressions (for tracing or error messages), so we need to (at least) print type lists in interfaces. While at it, also implement the printing of type parameter lists. Fixes #44766. Change-Id: I36a4a7152d9bef7251af264b5c7890aca88d8dc3 Reviewed-on: https://go-review.googlesource.com/c/go/+/298549 Trust: Robert Griesemer <gri@golang.org> Reviewed-by: Robert Findley <rfindley@google.com>
-rw-r--r--src/cmd/compile/internal/syntax/printer.go57
-rw-r--r--src/cmd/compile/internal/syntax/printer_test.go37
2 files changed, 79 insertions, 15 deletions
diff --git a/src/cmd/compile/internal/syntax/printer.go b/src/cmd/compile/internal/syntax/printer.go
index 9109ce2363..e557f5d924 100644
--- a/src/cmd/compile/internal/syntax/printer.go
+++ b/src/cmd/compile/internal/syntax/printer.go
@@ -481,10 +481,10 @@ func (p *printer) printRawNode(n Node) {
if len(n.FieldList) > 0 {
if p.linebreaks {
p.print(newline, indent)
- p.printFieldList(n.FieldList, n.TagList)
+ p.printFieldList(n.FieldList, n.TagList, _Semi)
p.print(outdent, newline)
} else {
- p.printFieldList(n.FieldList, n.TagList)
+ p.printFieldList(n.FieldList, n.TagList, _Semi)
}
}
p.print(_Rbrace)
@@ -494,20 +494,40 @@ func (p *printer) printRawNode(n Node) {
p.printSignature(n)
case *InterfaceType:
+ // separate type list and method list
+ var types []Expr
+ var methods []*Field
+ for _, f := range n.MethodList {
+ if f.Name != nil && f.Name.Value == "type" {
+ types = append(types, f.Type)
+ } else {
+ // method or embedded interface
+ methods = append(methods, f)
+ }
+ }
+
+ multiLine := len(n.MethodList) > 0 && p.linebreaks
p.print(_Interface)
- if len(n.MethodList) > 0 && p.linebreaks {
+ if multiLine {
p.print(blank)
}
p.print(_Lbrace)
- if len(n.MethodList) > 0 {
- if p.linebreaks {
- p.print(newline, indent)
- p.printMethodList(n.MethodList)
- p.print(outdent, newline)
- } else {
- p.printMethodList(n.MethodList)
+ if multiLine {
+ p.print(newline, indent)
+ }
+ if len(types) > 0 {
+ p.print(_Type, blank)
+ p.printExprList(types)
+ if len(methods) > 0 {
+ p.print(_Semi, blank)
}
}
+ if len(methods) > 0 {
+ p.printMethodList(methods)
+ }
+ if multiLine {
+ p.print(outdent, newline)
+ }
p.print(_Rbrace)
case *MapType:
@@ -667,7 +687,13 @@ func (p *printer) printRawNode(n Node) {
if n.Group == nil {
p.print(_Type, blank)
}
- p.print(n.Name, blank)
+ p.print(n.Name)
+ if n.TParamList != nil {
+ p.print(_Lbrack)
+ p.printFieldList(n.TParamList, nil, _Comma)
+ p.print(_Rbrack)
+ }
+ p.print(blank)
if n.Alias {
p.print(_Assign, blank)
}
@@ -696,6 +722,11 @@ func (p *printer) printRawNode(n Node) {
p.print(_Rparen, blank)
}
p.print(n.Name)
+ if n.TParamList != nil {
+ p.print(_Lbrack)
+ p.printFieldList(n.TParamList, nil, _Comma)
+ p.print(_Rbrack)
+ }
p.printSignature(n.Type)
if n.Body != nil {
p.print(blank, n.Body)
@@ -746,14 +777,14 @@ func (p *printer) printFields(fields []*Field, tags []*BasicLit, i, j int) {
}
}
-func (p *printer) printFieldList(fields []*Field, tags []*BasicLit) {
+func (p *printer) printFieldList(fields []*Field, tags []*BasicLit, sep token) {
i0 := 0
var typ Expr
for i, f := range fields {
if f.Name == nil || f.Type != typ {
if i0 < i {
p.printFields(fields, tags, i0, i)
- p.print(_Semi, newline)
+ p.print(sep, newline)
i0 = i
}
typ = f.Type
diff --git a/src/cmd/compile/internal/syntax/printer_test.go b/src/cmd/compile/internal/syntax/printer_test.go
index bcae815a46..4890327595 100644
--- a/src/cmd/compile/internal/syntax/printer_test.go
+++ b/src/cmd/compile/internal/syntax/printer_test.go
@@ -61,6 +61,21 @@ var stringTests = []string{
"package p",
"package p; type _ int; type T1 = struct{}; type ( _ *struct{}; T2 = float32 )",
+ // generic type declarations
+ "package p; type _[T any] struct{}",
+ "package p; type _[A, B, C interface{m()}] struct{}",
+ "package p; type _[T any, A, B, C interface{m()}, X, Y, Z interface{type int}] struct{}",
+
+ // generic function declarations
+ "package p; func _[T any]()",
+ "package p; func _[A, B, C interface{m()}]()",
+ "package p; func _[T any, A, B, C interface{m()}, X, Y, Z interface{type int}]()",
+
+ // methods with generic receiver types
+ "package p; func (R[T]) _()",
+ "package p; func (*R[A, B, C]) _()",
+ "package p; func (_ *R[A, B, C]) _()",
+
// channels
"package p; type _ chan chan int",
"package p; type _ chan (<-chan int)",
@@ -79,7 +94,7 @@ var stringTests = []string{
func TestPrintString(t *testing.T) {
for _, want := range stringTests {
- ast, err := Parse(nil, strings.NewReader(want), nil, nil, 0)
+ ast, err := Parse(nil, strings.NewReader(want), nil, nil, AllowGenerics)
if err != nil {
t.Error(err)
continue
@@ -116,6 +131,24 @@ var exprTests = [][2]string{
{"func(x int) complex128 { return 0 }", "func(x int) complex128 {…}"},
{"[]int{1, 2, 3}", "[]int{…}"},
+ // type expressions
+ dup("[1 << 10]byte"),
+ dup("[]int"),
+ dup("*int"),
+ dup("struct{x int}"),
+ dup("func()"),
+ dup("func(int, float32) string"),
+ dup("interface{m()}"),
+ dup("interface{m() string; n(x int)}"),
+ dup("interface{type int}"),
+ dup("interface{type int, float64, string}"),
+ dup("interface{type int; m()}"),
+ dup("interface{type int, float64, string; m() string; n(x int)}"),
+ dup("map[string]int"),
+ dup("chan E"),
+ dup("<-chan E"),
+ dup("chan<- E"),
+
// non-type expressions
dup("(x)"),
dup("x.f"),
@@ -172,7 +205,7 @@ var exprTests = [][2]string{
func TestShortString(t *testing.T) {
for _, test := range exprTests {
src := "package p; var _ = " + test[0]
- ast, err := Parse(nil, strings.NewReader(src), nil, nil, 0)
+ ast, err := Parse(nil, strings.NewReader(src), nil, nil, AllowGenerics)
if err != nil {
t.Errorf("%s: %s", test[0], err)
continue