diff options
author | Robin Jarry <robin@jarry.cc> | 2023-10-03 15:48:15 +0200 |
---|---|---|
committer | Robin Jarry <robin@jarry.cc> | 2023-10-03 22:34:39 +0200 |
commit | ef69ca8707eb7072241c987527166b254378c1b5 (patch) | |
tree | 9b7ad5848830c9d83a8353c34517182b20e824d4 | |
parent | bbe5e81538a537a3c65d9b94e0dbb1d4cd1604df (diff) | |
download | aerc-ef69ca8707eb7072241c987527166b254378c1b5.tar.gz aerc-ef69ca8707eb7072241c987527166b254378c1b5.zip |
lib: add opt parser based on struct tags
Signed-off-by: Robin Jarry <robin@jarry.cc>
-rw-r--r-- | lib/opt/opt.go | 28 | ||||
-rw-r--r-- | lib/opt/spec.go | 475 | ||||
-rw-r--r-- | lib/opt/spec_test.go | 87 |
3 files changed, 590 insertions, 0 deletions
diff --git a/lib/opt/opt.go b/lib/opt/opt.go new file mode 100644 index 00000000..a671394c --- /dev/null +++ b/lib/opt/opt.go @@ -0,0 +1,28 @@ +package opt + +import ( + "errors" + "fmt" +) + +func CmdlineToStruct(cmdline string, v any) error { + args, err := SplitArgs(cmdline) + if err != nil { + return err + } + return ArgsToStruct(args, v) +} + +func ArgsToStruct(args *Args, v any) error { + if args.Count() == 0 { + return errors.New("empty command") + } + + cmd := NewCmdSpec(args.Arg(0), v) + err := cmd.ParseArgs(args) + if err != nil { + return fmt.Errorf("%w. %s", err, cmd.Usage()) + } + + return nil +} diff --git a/lib/opt/spec.go b/lib/opt/spec.go new file mode 100644 index 00000000..2cdb1be6 --- /dev/null +++ b/lib/opt/spec.go @@ -0,0 +1,475 @@ +package opt + +import ( + "fmt" + "reflect" + "regexp" + "strconv" + "strings" +) + +type optKind int + +const ( + unset optKind = iota + // flag without a value: "-f" or "--foo-baz" + flag + // flag with a value: "-f bar", "--foo-baz bar" or "--foo-baz=bar" + option + // positional argument after interpreting shell quotes + positional + // remaining positional arguments after interpreting shell quotes + remainderSplit + // remaining positional arguments without interpreting shell quotes + remainder +) + +// Command line options specifier +type CmdSpec struct { + // first argument, for Usage() generation + name string + // list of option specs extracted from struct tags + opts []optSpec + // indexes of short options in the list for quick access + shortOpts map[string]int + // indexes of long options in the list for quick access + longOpts map[string]int + // indexes of positional arguments + positionals []int +} + +// Option or argument specifier +type optSpec struct { + // kind of option/argument + kind optKind + // "f", "foo-baz" (only when kind is flag or option) + short, long string + // name of option/argument value in usage help + metavar string + // default string value before interpretation + defval string + // argument is required + required bool + // argument was processed + seen bool + // custom parse method + parse reflect.Value + // destination struct field + dest reflect.Value +} + +var ( + shortOptRe = regexp.MustCompile(`^-([a-zA-Z0-9])$`) + longOptRe = regexp.MustCompile(`^--([a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9])$`) + positionalRe = regexp.MustCompile(`^([a-zA-Z]\w*)$`) +) + +// Interpret all struct fields to a list of option specs +func NewCmdSpec(name string, v any) *CmdSpec { + typ := reflect.TypeOf(v) + val := reflect.ValueOf(v) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + val = val.Elem() + } else { + panic("NewCmdSpec requires a pointer") + } + if typ.Kind() != reflect.Struct { + panic("NewCmdSpec requires a pointer to a struct") + } + + cmd := &CmdSpec{ + name: name, + opts: make([]optSpec, 0, typ.NumField()), + shortOpts: make(map[string]int), + longOpts: make(map[string]int), + } + + allPositionals := false + for i := 0; i < typ.NumField(); i++ { + var spec optSpec + spec.parseField(reflect.ValueOf(v), typ.Field(i)) + if spec.kind == unset { + // ignored field + continue + } + spec.dest = val.Field(i) + + if allPositionals { + panic(`opt:"..." must be last`) + } + + switch spec.kind { + case flag, option: + if spec.short != "" { + cmd.shortOpts[spec.short] = len(cmd.opts) + } + if spec.long != "" { + cmd.longOpts[spec.long] = len(cmd.opts) + } + case remainder, remainderSplit: + allPositionals = true + fallthrough + default: + cmd.positionals = append(cmd.positionals, len(cmd.opts)) + } + cmd.opts = append(cmd.opts, spec) + } + + return cmd +} + +func (spec *optSpec) parseField(struc reflect.Value, t reflect.StructField) { + abort := func(msg string, args ...any) { + msg = fmt.Sprintf(msg, args...) + panic(fmt.Sprintf("%s.%s: %s", struc.Type(), t.Name, msg)) + } + + // check what kind of argument this field maps to + opt := t.Tag.Get("opt") + + switch { + case opt == "" || opt == "-": + // ignored field + return + + case opt == "...": + // remainder + switch t.Type.Kind() { + case reflect.Slice: + if t.Type.Elem().Kind() != reflect.String { + abort("'...' only works with []string") + } + spec.kind = remainderSplit + case reflect.String: + spec.kind = remainder + default: + abort("'...' only works with string or []string") + } + spec.metavar = strings.ToUpper(t.Name) + spec.required = true + + case strings.Contains(opt, "-"): + // flag or option + for _, flag := range strings.Split(opt, ",") { + m := longOptRe.FindStringSubmatch(flag) + if m != nil { + spec.long = m[1] + continue + } + m = shortOptRe.FindStringSubmatch(flag) + if m != nil { + spec.short = m[1] + continue + } + abort("invalid opt tag: %q", opt) + } + if t.Type.Kind() == reflect.Bool { + spec.kind = flag + } else { + spec.kind = option + spec.metavar = strings.ToUpper(t.Name) + } + if spec.short == "" && spec.long == "" { + abort("invalid opt tag: %q", opt) + } + + case positionalRe.MatchString(opt): + // named positional + spec.kind = positional + spec.metavar = strings.ToUpper(opt) + spec.required = true + + default: + abort("invalid opt tag: %q", opt) + } + + if metavar, hasMetavar := t.Tag.Lookup("metavar"); hasMetavar { + // explicit metavar for the generated usage + spec.metavar = metavar + } + + spec.defval = t.Tag.Get("default") + + switch t.Tag.Get("required") { + case "true": + spec.required = true + case "false": + spec.required = false + case "": + if spec.defval != "" { + spec.required = false + } + default: + abort("invalid required value") + } + + if methodName, found := t.Tag.Lookup("parse"); found { + method := struc.MethodByName(methodName) + if !method.IsValid() { + abort("parse method not found: (*%s).%s", struc, methodName) + } + + ok := method.Type().NumIn() == 1 + ok = ok && method.Type().In(0).Kind() == reflect.String + ok = ok && method.Type().NumOut() == 1 + ok = ok && method.Type().Out(0).Kind() == reflect.Interface + ok = ok && method.Type().Out(0).Name() == "error" + if !ok { + abort("(*%s).%s: invalid signature, expected func(string) (error)", + struc.Elem().Type().Name(), methodName, + t.Type.Kind()) + } + spec.parse = method + } + + if !spec.parse.IsValid() { + switch t.Type.Kind() { + case reflect.String: + fallthrough + case reflect.Bool: + fallthrough + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fallthrough + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + fallthrough + case reflect.Float32, reflect.Float64: + break + case reflect.Slice: + break + default: + abort("unsupported field type: %s", t.Type.Kind()) + } + } +} + +func (s *optSpec) Usage() string { + var usage string + + switch s.kind { + case flag: + if s.short != "" { + usage = "-" + s.short + } else { + usage = "--" + s.long + } + + case option: + if s.short != "" { + usage = "-" + s.short + } else { + usage = "--" + s.long + } + usage += " " + s.metavar + + case positional: + usage = s.metavar + + case remainder, remainderSplit: + usage = s.metavar + "..." + } + + if s.required { + return usage + } + return "[" + usage + "]" +} + +func (s *optSpec) Flag() string { + var usage string + + switch s.kind { + case flag, option: + if s.short != "" { + usage = "-" + s.short + } else { + usage = "--" + s.long + } + default: + usage = s.metavar + } + + return usage +} + +func (c *CmdSpec) Usage() string { + args := make([]string, len(c.opts)+2) + args[0] = "Usage:" + args[1] = c.name + for i := 0; i < len(c.opts); i++ { + args[i+2] = c.opts[i].Usage() + } + return strings.Join(args, " ") +} + +func (c *CmdSpec) ParseArgs(args *Args) error { + var spec *optSpec + var i int + + for i = 1; i < args.Count(); i++ { + arg := args.Arg(i) + + switch { + case len(arg) > 2 && strings.HasPrefix(arg, "--"): + // --long-flag + if spec != nil { + return fmt.Errorf("%s requires a value", spec.Flag()) + } + arg = arg[2:] + if o, ok := c.longOpts[arg]; ok { + spec = &c.opts[o] + if spec.kind == flag { + spec.dest.SetBool(true) + spec.seen = true + spec = nil + } + continue + } + if strings.Contains(arg, "=") { + arg, val, _ := strings.Cut(arg, "=") + if o, ok := c.longOpts[arg]; ok { + spec = &c.opts[o] + if spec.kind == flag { + return fmt.Errorf("%s does not take a value", spec.Flag()) + } + err := spec.parseValue(val) + if err != nil { + return fmt.Errorf("%s: %w", spec.Flag(), err) + } + spec.seen = true + spec = nil + continue + } + } + return fmt.Errorf("unknown flag --%s", arg) + + case len(arg) > 1 && strings.HasPrefix(arg, "-"): + // -x (short flag) + if spec != nil { + return fmt.Errorf("%s requires a value", spec.Flag()) + } + arg = arg[1:] + for len(arg) > 0 { + f := arg[:1] + arg = arg[1:] + if o, ok := c.shortOpts[f]; ok { + spec = &c.opts[o] + if spec.kind == flag { + spec.dest.SetBool(true) + spec.seen = true + spec = nil + } else if len(arg) > 0 { + err := spec.parseValue(arg) + if err != nil { + return fmt.Errorf("%s: %w", spec.Flag(), err) + } + spec.seen = true + spec = nil + arg = "" + } + } else { + return fmt.Errorf("unknown flag -%s", f) + } + } + + default: + // positional + if spec != nil { + if err := spec.parseValue(arg); err != nil { + return fmt.Errorf("%s: %w", spec.Flag(), err) + } + spec.seen = true + spec = nil + continue + } + if len(c.positionals) == 0 { + return fmt.Errorf("unexpected argument %q", arg) + } + spec = &c.opts[c.positionals[0]] + c.positionals = c.positionals[1:] + + switch spec.kind { + case remainder: + args.Shift(i) + spec.dest.SetString(args.String()) + i = args.Count() + + case remainderSplit: + args.Shift(i) + spec.dest.Set(reflect.ValueOf(args.Args())) + i = args.Count() + + default: + err := spec.parseValue(arg) + if err != nil { + return fmt.Errorf("%s: %w", spec.Flag(), err) + } + } + + spec.seen = true + spec = nil + } + } + + if spec != nil { + return fmt.Errorf("%s requires a value", spec.Flag()) + } + + for i = 0; i < len(c.opts); i++ { + spec := &c.opts[i] + if !spec.seen && spec.defval != "" { + err := spec.parseValue(spec.defval) + if err != nil { + return fmt.Errorf("%s: %w", spec.Flag(), err) + } + spec.seen = true + } + if spec.required && !spec.seen { + return fmt.Errorf("%s is required", spec.Flag()) + } + } + + return nil +} + +func (s *optSpec) parseValue(arg string) error { + if s.parse.IsValid() { + in := []reflect.Value{reflect.ValueOf(arg)} + out := s.parse.Call(in) + err, _ := out[0].Interface().(error) + return err + } + + switch s.dest.Type().Kind() { + case reflect.String: + s.dest.SetString(arg) + case reflect.Bool: + if b, err := strconv.ParseBool(arg); err == nil { + s.dest.SetBool(b) + } else { + return err + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if i, err := strconv.ParseInt(arg, 10, 64); err == nil { + s.dest.SetInt(i) + } else { + return err + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if u, err := strconv.ParseUint(arg, 10, 64); err == nil { + s.dest.SetUint(u) + } else { + return err + } + case reflect.Float32, reflect.Float64: + if f, err := strconv.ParseFloat(arg, 64); err == nil { + s.dest.SetFloat(f) + } else { + return err + } + default: + return fmt.Errorf("unsupported type: %s", s.dest) + } + + return nil +} diff --git a/lib/opt/spec_test.go b/lib/opt/spec_test.go new file mode 100644 index 00000000..eec9db4f --- /dev/null +++ b/lib/opt/spec_test.go @@ -0,0 +1,87 @@ +package opt + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type OptionStruct struct { + Jobs int `opt:"-j,--jobs" required:"true"` + Delay float64 `opt:"--delay" default:"0.5"` + Zero bool `opt:"-z"` + Backoff bool `opt:"-B,--backoff"` + Name string `opt:"NAME"` +} + +func TestArgsToStructErrors(t *testing.T) { + vectors := []struct { + cmdline string + err string + }{ + {"foo", "-j is required"}, + {"foo -j", "-j requires a value"}, + {"foo --delay -B", "--delay requires a value"}, + {"foo -j4", "NAME is required"}, + {"foo -j f", `strconv.ParseInt: parsing "f": invalid syntax.`}, + {"foo --delay=m", `strconv.ParseFloat: parsing "m": invalid syntax.`}, + {"foo --jobs 8 bar baz", `unexpected argument "baz"`}, + {"foo -u8", "unknown flag -u"}, + {"foo --remove", "unknown flag --remove"}, + } + + for _, v := range vectors { + t.Run(v.cmdline, func(t *testing.T) { + args, err := SplitArgs(v.cmdline) + if err != nil { + t.Fatal(err) + } + err = ArgsToStruct(args, new(OptionStruct)) + assert.ErrorContains(t, err, v.err) + }) + } +} + +func TestArgsToStruct(t *testing.T) { + vectors := []struct { + cmdline string + expected OptionStruct + }{ + { + cmdline: "foo -j4 bar", + expected: OptionStruct{ + Jobs: 4, + Delay: 0.5, + Name: "bar", + }, + }, + { + cmdline: "foo --delay 0.1 -zBj 8 bar", + expected: OptionStruct{ + Jobs: 8, + Delay: 0.1, + Zero: true, + Backoff: true, + Name: "bar", + }, + }, + { + cmdline: "foo -Bz --delay=0.1 --jobs=8 bar", + expected: OptionStruct{ + Jobs: 8, + Delay: 0.1, + Zero: true, + Backoff: true, + Name: "bar", + }, + }, + } + + for _, v := range vectors { + t.Run(v.cmdline, func(t *testing.T) { + var s OptionStruct + assert.Nil(t, CmdlineToStruct(v.cmdline, &s)) + assert.Equal(t, v.expected, s) + }) + } +} |