aboutsummaryrefslogtreecommitdiff
path: root/src/flag
diff options
context:
space:
mode:
authorKimMachineGun <geon0250@gmail.com>2021-03-10 12:55:56 +0000
committerEmmanuel Odeke <emmanuel@orijtech.com>2021-03-10 18:33:52 +0000
commit5ce51ea74177513b473e2da05d0e4e9b7affb3c4 (patch)
treeba8940f62fc8d5552e90987faa67f9713b0baa0d /src/flag
parentd0b79e3513a29628f3599dc8860666b6eed75372 (diff)
downloadgo-5ce51ea74177513b473e2da05d0e4e9b7affb3c4.tar.gz
go-5ce51ea74177513b473e2da05d0e4e9b7affb3c4.zip
flag: panic if flag name begins with - or contains =
Fixes #41792 Change-Id: I9b4aae8a899e3c3ac9532d27932d275cfb1fab48 GitHub-Last-Rev: f06b1e17674bf77bdc2d3e798df4c4379748c8d2 GitHub-Pull-Request: golang/go#42737 Reviewed-on: https://go-review.googlesource.com/c/go/+/271788 Reviewed-by: Emmanuel Odeke <emmanuel@orijtech.com> Reviewed-by: Rob Pike <r@golang.org> Run-TryBot: Emmanuel Odeke <emmanuel@orijtech.com> TryBot-Result: Go Bot <gobot@golang.org> Trust: Rob Pike <r@golang.org>
Diffstat (limited to 'src/flag')
-rw-r--r--src/flag/flag.go24
-rw-r--r--src/flag/flag_test.go83
2 files changed, 101 insertions, 6 deletions
diff --git a/src/flag/flag.go b/src/flag/flag.go
index a8485f034f..f7598a6758 100644
--- a/src/flag/flag.go
+++ b/src/flag/flag.go
@@ -857,17 +857,23 @@ func Func(name, usage string, fn func(string) error) {
// of strings by giving the slice the methods of Value; in particular, Set would
// decompose the comma-separated string into the slice.
func (f *FlagSet) Var(value Value, name string, usage string) {
+ // Flag must not begin "-" or contain "=".
+ if strings.HasPrefix(name, "-") {
+ panic(f.sprintf("flag %q begins with -", name))
+ } else if strings.Contains(name, "=") {
+ panic(f.sprintf("flag %q contains =", name))
+ }
+
// Remember the default value as a string; it won't change.
flag := &Flag{name, usage, value, value.String()}
_, alreadythere := f.formal[name]
if alreadythere {
var msg string
if f.name == "" {
- msg = fmt.Sprintf("flag redefined: %s", name)
+ msg = f.sprintf("flag redefined: %s", name)
} else {
- msg = fmt.Sprintf("%s flag redefined: %s", f.name, name)
+ msg = f.sprintf("%s flag redefined: %s", f.name, name)
}
- fmt.Fprintln(f.Output(), msg)
panic(msg) // Happens only if flags are declared with identical names
}
if f.formal == nil {
@@ -886,13 +892,19 @@ func Var(value Value, name string, usage string) {
CommandLine.Var(value, name, usage)
}
+// sprintf formats the message, prints it to output, and returns it.
+func (f *FlagSet) sprintf(format string, a ...interface{}) string {
+ msg := fmt.Sprintf(format, a...)
+ fmt.Fprintln(f.Output(), msg)
+ return msg
+}
+
// failf prints to standard error a formatted error and usage message and
// returns the error.
func (f *FlagSet) failf(format string, a ...interface{}) error {
- err := fmt.Errorf(format, a...)
- fmt.Fprintln(f.Output(), err)
+ msg := f.sprintf(format, a...)
f.usage()
- return err
+ return errors.New(msg)
}
// usage calls the Usage method for the flag set if one is specified,
diff --git a/src/flag/flag_test.go b/src/flag/flag_test.go
index 06cab79405..5835fcf22c 100644
--- a/src/flag/flag_test.go
+++ b/src/flag/flag_test.go
@@ -655,3 +655,86 @@ func TestExitCode(t *testing.T) {
}
}
}
+
+func mustPanic(t *testing.T, testName string, expected string, f func()) {
+ t.Helper()
+ defer func() {
+ switch msg := recover().(type) {
+ case nil:
+ t.Errorf("%s\n: expected panic(%q), but did not panic", testName, expected)
+ case string:
+ if msg != expected {
+ t.Errorf("%s\n: expected panic(%q), but got panic(%q)", testName, expected, msg)
+ }
+ default:
+ t.Errorf("%s\n: expected panic(%q), but got panic(%T%v)", testName, expected, msg, msg)
+ }
+ }()
+ f()
+}
+
+func TestInvalidFlags(t *testing.T) {
+ tests := []struct {
+ flag string
+ errorMsg string
+ }{
+ {
+ flag: "-foo",
+ errorMsg: "flag \"-foo\" begins with -",
+ },
+ {
+ flag: "foo=bar",
+ errorMsg: "flag \"foo=bar\" contains =",
+ },
+ }
+
+ for _, test := range tests {
+ testName := fmt.Sprintf("FlagSet.Var(&v, %q, \"\")", test.flag)
+
+ fs := NewFlagSet("", ContinueOnError)
+ buf := bytes.NewBuffer(nil)
+ fs.SetOutput(buf)
+
+ mustPanic(t, testName, test.errorMsg, func() {
+ var v flagVar
+ fs.Var(&v, test.flag, "")
+ })
+ if msg := test.errorMsg + "\n"; msg != buf.String() {
+ t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf)
+ }
+ }
+}
+
+func TestRedefinedFlags(t *testing.T) {
+ tests := []struct {
+ flagSetName string
+ errorMsg string
+ }{
+ {
+ flagSetName: "",
+ errorMsg: "flag redefined: foo",
+ },
+ {
+ flagSetName: "fs",
+ errorMsg: "fs flag redefined: foo",
+ },
+ }
+
+ for _, test := range tests {
+ testName := fmt.Sprintf("flag redefined in FlagSet(%q)", test.flagSetName)
+
+ fs := NewFlagSet(test.flagSetName, ContinueOnError)
+ buf := bytes.NewBuffer(nil)
+ fs.SetOutput(buf)
+
+ var v flagVar
+ fs.Var(&v, "foo", "")
+
+ mustPanic(t, testName, test.errorMsg, func() {
+ fs.Var(&v, "foo", "")
+ })
+ if msg := test.errorMsg + "\n"; msg != buf.String() {
+ t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf)
+ }
+ }
+}