diff options
author | KimMachineGun <geon0250@gmail.com> | 2021-03-10 12:55:56 +0000 |
---|---|---|
committer | Emmanuel Odeke <emmanuel@orijtech.com> | 2021-03-10 18:33:52 +0000 |
commit | 5ce51ea74177513b473e2da05d0e4e9b7affb3c4 (patch) | |
tree | ba8940f62fc8d5552e90987faa67f9713b0baa0d /src/flag | |
parent | d0b79e3513a29628f3599dc8860666b6eed75372 (diff) | |
download | go-5ce51ea74177513b473e2da05d0e4e9b7affb3c4.tar.gz go-5ce51ea74177513b473e2da05d0e4e9b7affb3c4.zip |
flag: panic if flag name begins with - or contains =
Fixes #41792
Change-Id: I9b4aae8a899e3c3ac9532d27932d275cfb1fab48
GitHub-Last-Rev: f06b1e17674bf77bdc2d3e798df4c4379748c8d2
GitHub-Pull-Request: golang/go#42737
Reviewed-on: https://go-review.googlesource.com/c/go/+/271788
Reviewed-by: Emmanuel Odeke <emmanuel@orijtech.com>
Reviewed-by: Rob Pike <r@golang.org>
Run-TryBot: Emmanuel Odeke <emmanuel@orijtech.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Rob Pike <r@golang.org>
Diffstat (limited to 'src/flag')
-rw-r--r-- | src/flag/flag.go | 24 | ||||
-rw-r--r-- | src/flag/flag_test.go | 83 |
2 files changed, 101 insertions, 6 deletions
diff --git a/src/flag/flag.go b/src/flag/flag.go index a8485f034f..f7598a6758 100644 --- a/src/flag/flag.go +++ b/src/flag/flag.go @@ -857,17 +857,23 @@ func Func(name, usage string, fn func(string) error) { // of strings by giving the slice the methods of Value; in particular, Set would // decompose the comma-separated string into the slice. func (f *FlagSet) Var(value Value, name string, usage string) { + // Flag must not begin "-" or contain "=". + if strings.HasPrefix(name, "-") { + panic(f.sprintf("flag %q begins with -", name)) + } else if strings.Contains(name, "=") { + panic(f.sprintf("flag %q contains =", name)) + } + // Remember the default value as a string; it won't change. flag := &Flag{name, usage, value, value.String()} _, alreadythere := f.formal[name] if alreadythere { var msg string if f.name == "" { - msg = fmt.Sprintf("flag redefined: %s", name) + msg = f.sprintf("flag redefined: %s", name) } else { - msg = fmt.Sprintf("%s flag redefined: %s", f.name, name) + msg = f.sprintf("%s flag redefined: %s", f.name, name) } - fmt.Fprintln(f.Output(), msg) panic(msg) // Happens only if flags are declared with identical names } if f.formal == nil { @@ -886,13 +892,19 @@ func Var(value Value, name string, usage string) { CommandLine.Var(value, name, usage) } +// sprintf formats the message, prints it to output, and returns it. +func (f *FlagSet) sprintf(format string, a ...interface{}) string { + msg := fmt.Sprintf(format, a...) + fmt.Fprintln(f.Output(), msg) + return msg +} + // failf prints to standard error a formatted error and usage message and // returns the error. func (f *FlagSet) failf(format string, a ...interface{}) error { - err := fmt.Errorf(format, a...) - fmt.Fprintln(f.Output(), err) + msg := f.sprintf(format, a...) f.usage() - return err + return errors.New(msg) } // usage calls the Usage method for the flag set if one is specified, diff --git a/src/flag/flag_test.go b/src/flag/flag_test.go index 06cab79405..5835fcf22c 100644 --- a/src/flag/flag_test.go +++ b/src/flag/flag_test.go @@ -655,3 +655,86 @@ func TestExitCode(t *testing.T) { } } } + +func mustPanic(t *testing.T, testName string, expected string, f func()) { + t.Helper() + defer func() { + switch msg := recover().(type) { + case nil: + t.Errorf("%s\n: expected panic(%q), but did not panic", testName, expected) + case string: + if msg != expected { + t.Errorf("%s\n: expected panic(%q), but got panic(%q)", testName, expected, msg) + } + default: + t.Errorf("%s\n: expected panic(%q), but got panic(%T%v)", testName, expected, msg, msg) + } + }() + f() +} + +func TestInvalidFlags(t *testing.T) { + tests := []struct { + flag string + errorMsg string + }{ + { + flag: "-foo", + errorMsg: "flag \"-foo\" begins with -", + }, + { + flag: "foo=bar", + errorMsg: "flag \"foo=bar\" contains =", + }, + } + + for _, test := range tests { + testName := fmt.Sprintf("FlagSet.Var(&v, %q, \"\")", test.flag) + + fs := NewFlagSet("", ContinueOnError) + buf := bytes.NewBuffer(nil) + fs.SetOutput(buf) + + mustPanic(t, testName, test.errorMsg, func() { + var v flagVar + fs.Var(&v, test.flag, "") + }) + if msg := test.errorMsg + "\n"; msg != buf.String() { + t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf) + } + } +} + +func TestRedefinedFlags(t *testing.T) { + tests := []struct { + flagSetName string + errorMsg string + }{ + { + flagSetName: "", + errorMsg: "flag redefined: foo", + }, + { + flagSetName: "fs", + errorMsg: "fs flag redefined: foo", + }, + } + + for _, test := range tests { + testName := fmt.Sprintf("flag redefined in FlagSet(%q)", test.flagSetName) + + fs := NewFlagSet(test.flagSetName, ContinueOnError) + buf := bytes.NewBuffer(nil) + fs.SetOutput(buf) + + var v flagVar + fs.Var(&v, "foo", "") + + mustPanic(t, testName, test.errorMsg, func() { + fs.Var(&v, "foo", "") + }) + if msg := test.errorMsg + "\n"; msg != buf.String() { + t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf) + } + } +} |