aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobin Jarry <robin@jarry.cc>2023-10-03 15:48:15 +0200
committerRobin Jarry <robin@jarry.cc>2023-10-03 22:34:39 +0200
commitef69ca8707eb7072241c987527166b254378c1b5 (patch)
tree9b7ad5848830c9d83a8353c34517182b20e824d4
parentbbe5e81538a537a3c65d9b94e0dbb1d4cd1604df (diff)
downloadaerc-ef69ca8707eb7072241c987527166b254378c1b5.tar.gz
aerc-ef69ca8707eb7072241c987527166b254378c1b5.zip
lib: add opt parser based on struct tags
Signed-off-by: Robin Jarry <robin@jarry.cc>
-rw-r--r--lib/opt/opt.go28
-rw-r--r--lib/opt/spec.go475
-rw-r--r--lib/opt/spec_test.go87
3 files changed, 590 insertions, 0 deletions
diff --git a/lib/opt/opt.go b/lib/opt/opt.go
new file mode 100644
index 00000000..a671394c
--- /dev/null
+++ b/lib/opt/opt.go
@@ -0,0 +1,28 @@
+package opt
+
+import (
+ "errors"
+ "fmt"
+)
+
+func CmdlineToStruct(cmdline string, v any) error {
+ args, err := SplitArgs(cmdline)
+ if err != nil {
+ return err
+ }
+ return ArgsToStruct(args, v)
+}
+
+func ArgsToStruct(args *Args, v any) error {
+ if args.Count() == 0 {
+ return errors.New("empty command")
+ }
+
+ cmd := NewCmdSpec(args.Arg(0), v)
+ err := cmd.ParseArgs(args)
+ if err != nil {
+ return fmt.Errorf("%w. %s", err, cmd.Usage())
+ }
+
+ return nil
+}
diff --git a/lib/opt/spec.go b/lib/opt/spec.go
new file mode 100644
index 00000000..2cdb1be6
--- /dev/null
+++ b/lib/opt/spec.go
@@ -0,0 +1,475 @@
+package opt
+
+import (
+ "fmt"
+ "reflect"
+ "regexp"
+ "strconv"
+ "strings"
+)
+
+type optKind int
+
+const (
+ unset optKind = iota
+ // flag without a value: "-f" or "--foo-baz"
+ flag
+ // flag with a value: "-f bar", "--foo-baz bar" or "--foo-baz=bar"
+ option
+ // positional argument after interpreting shell quotes
+ positional
+ // remaining positional arguments after interpreting shell quotes
+ remainderSplit
+ // remaining positional arguments without interpreting shell quotes
+ remainder
+)
+
+// Command line options specifier
+type CmdSpec struct {
+ // first argument, for Usage() generation
+ name string
+ // list of option specs extracted from struct tags
+ opts []optSpec
+ // indexes of short options in the list for quick access
+ shortOpts map[string]int
+ // indexes of long options in the list for quick access
+ longOpts map[string]int
+ // indexes of positional arguments
+ positionals []int
+}
+
+// Option or argument specifier
+type optSpec struct {
+ // kind of option/argument
+ kind optKind
+ // "f", "foo-baz" (only when kind is flag or option)
+ short, long string
+ // name of option/argument value in usage help
+ metavar string
+ // default string value before interpretation
+ defval string
+ // argument is required
+ required bool
+ // argument was processed
+ seen bool
+ // custom parse method
+ parse reflect.Value
+ // destination struct field
+ dest reflect.Value
+}
+
+var (
+ shortOptRe = regexp.MustCompile(`^-([a-zA-Z0-9])$`)
+ longOptRe = regexp.MustCompile(`^--([a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9])$`)
+ positionalRe = regexp.MustCompile(`^([a-zA-Z]\w*)$`)
+)
+
+// Interpret all struct fields to a list of option specs
+func NewCmdSpec(name string, v any) *CmdSpec {
+ typ := reflect.TypeOf(v)
+ val := reflect.ValueOf(v)
+ if typ.Kind() == reflect.Ptr {
+ typ = typ.Elem()
+ val = val.Elem()
+ } else {
+ panic("NewCmdSpec requires a pointer")
+ }
+ if typ.Kind() != reflect.Struct {
+ panic("NewCmdSpec requires a pointer to a struct")
+ }
+
+ cmd := &CmdSpec{
+ name: name,
+ opts: make([]optSpec, 0, typ.NumField()),
+ shortOpts: make(map[string]int),
+ longOpts: make(map[string]int),
+ }
+
+ allPositionals := false
+ for i := 0; i < typ.NumField(); i++ {
+ var spec optSpec
+ spec.parseField(reflect.ValueOf(v), typ.Field(i))
+ if spec.kind == unset {
+ // ignored field
+ continue
+ }
+ spec.dest = val.Field(i)
+
+ if allPositionals {
+ panic(`opt:"..." must be last`)
+ }
+
+ switch spec.kind {
+ case flag, option:
+ if spec.short != "" {
+ cmd.shortOpts[spec.short] = len(cmd.opts)
+ }
+ if spec.long != "" {
+ cmd.longOpts[spec.long] = len(cmd.opts)
+ }
+ case remainder, remainderSplit:
+ allPositionals = true
+ fallthrough
+ default:
+ cmd.positionals = append(cmd.positionals, len(cmd.opts))
+ }
+ cmd.opts = append(cmd.opts, spec)
+ }
+
+ return cmd
+}
+
+func (spec *optSpec) parseField(struc reflect.Value, t reflect.StructField) {
+ abort := func(msg string, args ...any) {
+ msg = fmt.Sprintf(msg, args...)
+ panic(fmt.Sprintf("%s.%s: %s", struc.Type(), t.Name, msg))
+ }
+
+ // check what kind of argument this field maps to
+ opt := t.Tag.Get("opt")
+
+ switch {
+ case opt == "" || opt == "-":
+ // ignored field
+ return
+
+ case opt == "...":
+ // remainder
+ switch t.Type.Kind() {
+ case reflect.Slice:
+ if t.Type.Elem().Kind() != reflect.String {
+ abort("'...' only works with []string")
+ }
+ spec.kind = remainderSplit
+ case reflect.String:
+ spec.kind = remainder
+ default:
+ abort("'...' only works with string or []string")
+ }
+ spec.metavar = strings.ToUpper(t.Name)
+ spec.required = true
+
+ case strings.Contains(opt, "-"):
+ // flag or option
+ for _, flag := range strings.Split(opt, ",") {
+ m := longOptRe.FindStringSubmatch(flag)
+ if m != nil {
+ spec.long = m[1]
+ continue
+ }
+ m = shortOptRe.FindStringSubmatch(flag)
+ if m != nil {
+ spec.short = m[1]
+ continue
+ }
+ abort("invalid opt tag: %q", opt)
+ }
+ if t.Type.Kind() == reflect.Bool {
+ spec.kind = flag
+ } else {
+ spec.kind = option
+ spec.metavar = strings.ToUpper(t.Name)
+ }
+ if spec.short == "" && spec.long == "" {
+ abort("invalid opt tag: %q", opt)
+ }
+
+ case positionalRe.MatchString(opt):
+ // named positional
+ spec.kind = positional
+ spec.metavar = strings.ToUpper(opt)
+ spec.required = true
+
+ default:
+ abort("invalid opt tag: %q", opt)
+ }
+
+ if metavar, hasMetavar := t.Tag.Lookup("metavar"); hasMetavar {
+ // explicit metavar for the generated usage
+ spec.metavar = metavar
+ }
+
+ spec.defval = t.Tag.Get("default")
+
+ switch t.Tag.Get("required") {
+ case "true":
+ spec.required = true
+ case "false":
+ spec.required = false
+ case "":
+ if spec.defval != "" {
+ spec.required = false
+ }
+ default:
+ abort("invalid required value")
+ }
+
+ if methodName, found := t.Tag.Lookup("parse"); found {
+ method := struc.MethodByName(methodName)
+ if !method.IsValid() {
+ abort("parse method not found: (*%s).%s", struc, methodName)
+ }
+
+ ok := method.Type().NumIn() == 1
+ ok = ok && method.Type().In(0).Kind() == reflect.String
+ ok = ok && method.Type().NumOut() == 1
+ ok = ok && method.Type().Out(0).Kind() == reflect.Interface
+ ok = ok && method.Type().Out(0).Name() == "error"
+ if !ok {
+ abort("(*%s).%s: invalid signature, expected func(string) (error)",
+ struc.Elem().Type().Name(), methodName,
+ t.Type.Kind())
+ }
+ spec.parse = method
+ }
+
+ if !spec.parse.IsValid() {
+ switch t.Type.Kind() {
+ case reflect.String:
+ fallthrough
+ case reflect.Bool:
+ fallthrough
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ fallthrough
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ fallthrough
+ case reflect.Float32, reflect.Float64:
+ break
+ case reflect.Slice:
+ break
+ default:
+ abort("unsupported field type: %s", t.Type.Kind())
+ }
+ }
+}
+
+func (s *optSpec) Usage() string {
+ var usage string
+
+ switch s.kind {
+ case flag:
+ if s.short != "" {
+ usage = "-" + s.short
+ } else {
+ usage = "--" + s.long
+ }
+
+ case option:
+ if s.short != "" {
+ usage = "-" + s.short
+ } else {
+ usage = "--" + s.long
+ }
+ usage += " " + s.metavar
+
+ case positional:
+ usage = s.metavar
+
+ case remainder, remainderSplit:
+ usage = s.metavar + "..."
+ }
+
+ if s.required {
+ return usage
+ }
+ return "[" + usage + "]"
+}
+
+func (s *optSpec) Flag() string {
+ var usage string
+
+ switch s.kind {
+ case flag, option:
+ if s.short != "" {
+ usage = "-" + s.short
+ } else {
+ usage = "--" + s.long
+ }
+ default:
+ usage = s.metavar
+ }
+
+ return usage
+}
+
+func (c *CmdSpec) Usage() string {
+ args := make([]string, len(c.opts)+2)
+ args[0] = "Usage:"
+ args[1] = c.name
+ for i := 0; i < len(c.opts); i++ {
+ args[i+2] = c.opts[i].Usage()
+ }
+ return strings.Join(args, " ")
+}
+
+func (c *CmdSpec) ParseArgs(args *Args) error {
+ var spec *optSpec
+ var i int
+
+ for i = 1; i < args.Count(); i++ {
+ arg := args.Arg(i)
+
+ switch {
+ case len(arg) > 2 && strings.HasPrefix(arg, "--"):
+ // --long-flag
+ if spec != nil {
+ return fmt.Errorf("%s requires a value", spec.Flag())
+ }
+ arg = arg[2:]
+ if o, ok := c.longOpts[arg]; ok {
+ spec = &c.opts[o]
+ if spec.kind == flag {
+ spec.dest.SetBool(true)
+ spec.seen = true
+ spec = nil
+ }
+ continue
+ }
+ if strings.Contains(arg, "=") {
+ arg, val, _ := strings.Cut(arg, "=")
+ if o, ok := c.longOpts[arg]; ok {
+ spec = &c.opts[o]
+ if spec.kind == flag {
+ return fmt.Errorf("%s does not take a value", spec.Flag())
+ }
+ err := spec.parseValue(val)
+ if err != nil {
+ return fmt.Errorf("%s: %w", spec.Flag(), err)
+ }
+ spec.seen = true
+ spec = nil
+ continue
+ }
+ }
+ return fmt.Errorf("unknown flag --%s", arg)
+
+ case len(arg) > 1 && strings.HasPrefix(arg, "-"):
+ // -x (short flag)
+ if spec != nil {
+ return fmt.Errorf("%s requires a value", spec.Flag())
+ }
+ arg = arg[1:]
+ for len(arg) > 0 {
+ f := arg[:1]
+ arg = arg[1:]
+ if o, ok := c.shortOpts[f]; ok {
+ spec = &c.opts[o]
+ if spec.kind == flag {
+ spec.dest.SetBool(true)
+ spec.seen = true
+ spec = nil
+ } else if len(arg) > 0 {
+ err := spec.parseValue(arg)
+ if err != nil {
+ return fmt.Errorf("%s: %w", spec.Flag(), err)
+ }
+ spec.seen = true
+ spec = nil
+ arg = ""
+ }
+ } else {
+ return fmt.Errorf("unknown flag -%s", f)
+ }
+ }
+
+ default:
+ // positional
+ if spec != nil {
+ if err := spec.parseValue(arg); err != nil {
+ return fmt.Errorf("%s: %w", spec.Flag(), err)
+ }
+ spec.seen = true
+ spec = nil
+ continue
+ }
+ if len(c.positionals) == 0 {
+ return fmt.Errorf("unexpected argument %q", arg)
+ }
+ spec = &c.opts[c.positionals[0]]
+ c.positionals = c.positionals[1:]
+
+ switch spec.kind {
+ case remainder:
+ args.Shift(i)
+ spec.dest.SetString(args.String())
+ i = args.Count()
+
+ case remainderSplit:
+ args.Shift(i)
+ spec.dest.Set(reflect.ValueOf(args.Args()))
+ i = args.Count()
+
+ default:
+ err := spec.parseValue(arg)
+ if err != nil {
+ return fmt.Errorf("%s: %w", spec.Flag(), err)
+ }
+ }
+
+ spec.seen = true
+ spec = nil
+ }
+ }
+
+ if spec != nil {
+ return fmt.Errorf("%s requires a value", spec.Flag())
+ }
+
+ for i = 0; i < len(c.opts); i++ {
+ spec := &c.opts[i]
+ if !spec.seen && spec.defval != "" {
+ err := spec.parseValue(spec.defval)
+ if err != nil {
+ return fmt.Errorf("%s: %w", spec.Flag(), err)
+ }
+ spec.seen = true
+ }
+ if spec.required && !spec.seen {
+ return fmt.Errorf("%s is required", spec.Flag())
+ }
+ }
+
+ return nil
+}
+
+func (s *optSpec) parseValue(arg string) error {
+ if s.parse.IsValid() {
+ in := []reflect.Value{reflect.ValueOf(arg)}
+ out := s.parse.Call(in)
+ err, _ := out[0].Interface().(error)
+ return err
+ }
+
+ switch s.dest.Type().Kind() {
+ case reflect.String:
+ s.dest.SetString(arg)
+ case reflect.Bool:
+ if b, err := strconv.ParseBool(arg); err == nil {
+ s.dest.SetBool(b)
+ } else {
+ return err
+ }
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if i, err := strconv.ParseInt(arg, 10, 64); err == nil {
+ s.dest.SetInt(i)
+ } else {
+ return err
+ }
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ if u, err := strconv.ParseUint(arg, 10, 64); err == nil {
+ s.dest.SetUint(u)
+ } else {
+ return err
+ }
+ case reflect.Float32, reflect.Float64:
+ if f, err := strconv.ParseFloat(arg, 64); err == nil {
+ s.dest.SetFloat(f)
+ } else {
+ return err
+ }
+ default:
+ return fmt.Errorf("unsupported type: %s", s.dest)
+ }
+
+ return nil
+}
diff --git a/lib/opt/spec_test.go b/lib/opt/spec_test.go
new file mode 100644
index 00000000..eec9db4f
--- /dev/null
+++ b/lib/opt/spec_test.go
@@ -0,0 +1,87 @@
+package opt
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type OptionStruct struct {
+ Jobs int `opt:"-j,--jobs" required:"true"`
+ Delay float64 `opt:"--delay" default:"0.5"`
+ Zero bool `opt:"-z"`
+ Backoff bool `opt:"-B,--backoff"`
+ Name string `opt:"NAME"`
+}
+
+func TestArgsToStructErrors(t *testing.T) {
+ vectors := []struct {
+ cmdline string
+ err string
+ }{
+ {"foo", "-j is required"},
+ {"foo -j", "-j requires a value"},
+ {"foo --delay -B", "--delay requires a value"},
+ {"foo -j4", "NAME is required"},
+ {"foo -j f", `strconv.ParseInt: parsing "f": invalid syntax.`},
+ {"foo --delay=m", `strconv.ParseFloat: parsing "m": invalid syntax.`},
+ {"foo --jobs 8 bar baz", `unexpected argument "baz"`},
+ {"foo -u8", "unknown flag -u"},
+ {"foo --remove", "unknown flag --remove"},
+ }
+
+ for _, v := range vectors {
+ t.Run(v.cmdline, func(t *testing.T) {
+ args, err := SplitArgs(v.cmdline)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = ArgsToStruct(args, new(OptionStruct))
+ assert.ErrorContains(t, err, v.err)
+ })
+ }
+}
+
+func TestArgsToStruct(t *testing.T) {
+ vectors := []struct {
+ cmdline string
+ expected OptionStruct
+ }{
+ {
+ cmdline: "foo -j4 bar",
+ expected: OptionStruct{
+ Jobs: 4,
+ Delay: 0.5,
+ Name: "bar",
+ },
+ },
+ {
+ cmdline: "foo --delay 0.1 -zBj 8 bar",
+ expected: OptionStruct{
+ Jobs: 8,
+ Delay: 0.1,
+ Zero: true,
+ Backoff: true,
+ Name: "bar",
+ },
+ },
+ {
+ cmdline: "foo -Bz --delay=0.1 --jobs=8 bar",
+ expected: OptionStruct{
+ Jobs: 8,
+ Delay: 0.1,
+ Zero: true,
+ Backoff: true,
+ Name: "bar",
+ },
+ },
+ }
+
+ for _, v := range vectors {
+ t.Run(v.cmdline, func(t *testing.T) {
+ var s OptionStruct
+ assert.Nil(t, CmdlineToStruct(v.cmdline, &s))
+ assert.Equal(t, v.expected, s)
+ })
+ }
+}