aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/emersion/go-sasl/oauthbearer.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/emersion/go-sasl/oauthbearer.go')
-rw-r--r--vendor/github.com/emersion/go-sasl/oauthbearer.go191
1 files changed, 191 insertions, 0 deletions
diff --git a/vendor/github.com/emersion/go-sasl/oauthbearer.go b/vendor/github.com/emersion/go-sasl/oauthbearer.go
new file mode 100644
index 0000000..a0639b1
--- /dev/null
+++ b/vendor/github.com/emersion/go-sasl/oauthbearer.go
@@ -0,0 +1,191 @@
+package sasl
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+// The OAUTHBEARER mechanism name.
+const OAuthBearer = "OAUTHBEARER"
+
+type OAuthBearerError struct {
+ Status string `json:"status"`
+ Schemes string `json:"schemes"`
+ Scope string `json:"scope"`
+}
+
+type OAuthBearerOptions struct {
+ Username string
+ Token string
+ Host string
+ Port int
+}
+
+// Implements error
+func (err *OAuthBearerError) Error() string {
+ return fmt.Sprintf("OAUTHBEARER authentication error (%v)", err.Status)
+}
+
+type oauthBearerClient struct {
+ OAuthBearerOptions
+}
+
+func (a *oauthBearerClient) Start() (mech string, ir []byte, err error) {
+ mech = OAuthBearer
+ var str = "n,a=" + a.Username + ","
+
+ if a.Host != "" {
+ str += "\x01host=" + a.Host
+ }
+
+ if a.Port != 0 {
+ str += "\x01port=" + strconv.Itoa(a.Port)
+ }
+ str += "\x01auth=Bearer " + a.Token + "\x01\x01"
+ ir = []byte(str)
+ return
+}
+
+func (a *oauthBearerClient) Next(challenge []byte) ([]byte, error) {
+ authBearerErr := &OAuthBearerError{}
+ if err := json.Unmarshal(challenge, authBearerErr); err != nil {
+ return nil, err
+ } else {
+ return nil, authBearerErr
+ }
+}
+
+// An implementation of the OAUTHBEARER authentication mechanism, as
+// described in RFC 7628.
+func NewOAuthBearerClient(opt *OAuthBearerOptions) Client {
+ return &oauthBearerClient{*opt}
+}
+
+type OAuthBearerAuthenticator func(opts OAuthBearerOptions) *OAuthBearerError
+
+type oauthBearerServer struct {
+ done bool
+ failErr error
+ authenticate OAuthBearerAuthenticator
+}
+
+func (a *oauthBearerServer) fail(descr string) ([]byte, bool, error) {
+ blob, err := json.Marshal(OAuthBearerError{
+ Status: "invalid_request",
+ Schemes: "bearer",
+ })
+ if err != nil {
+ panic(err) // wtf
+ }
+ a.failErr = errors.New(descr)
+ return blob, false, nil
+}
+
+func (a *oauthBearerServer) Next(response []byte) (challenge []byte, done bool, err error) {
+ // Per RFC, we cannot just send an error, we need to return JSON-structured
+ // value as a challenge and then after getting dummy response from the
+ // client stop the exchange.
+ if a.failErr != nil {
+ // Server libraries (go-smtp, go-imap) will not call Next on
+ // protocol-specific SASL cancel response ('*'). However, GS2 (and
+ // indirectly OAUTHBEARER) defines a protocol-independent way to do so
+ // using 0x01.
+ if len(response) != 1 && response[0] != 0x01 {
+ return nil, true, errors.New("unexpected response")
+ }
+ return nil, true, a.failErr
+ }
+
+ if a.done {
+ err = ErrUnexpectedClientResponse
+ return
+ }
+
+ // Generate empty challenge.
+ if response == nil {
+ return []byte{}, false, nil
+ }
+
+ a.done = true
+
+ // Cut n,a=username,\x01host=...\x01auth=...
+ // into
+ // n
+ // a=username
+ // \x01host=...\x01auth=...\x01\x01
+ parts := bytes.SplitN(response, []byte{','}, 3)
+ if len(parts) != 3 {
+ return a.fail("Invalid response")
+ }
+ if !bytes.Equal(parts[0], []byte{'n'}) {
+ return a.fail("Invalid response, missing 'n'")
+ }
+ opts := OAuthBearerOptions{}
+ if !bytes.HasPrefix(parts[1], []byte("a=")) {
+ return a.fail("Invalid response, missing 'a'")
+ }
+ opts.Username = string(bytes.TrimPrefix(parts[1], []byte("a=")))
+
+ // Cut \x01host=...\x01auth=...\x01\x01
+ // into
+ // *empty*
+ // host=...
+ // auth=...
+ // *empty*
+ //
+ // Note that this code does not do a lot of checks to make sure the input
+ // follows the exact format specified by RFC.
+ params := bytes.Split(parts[2], []byte{0x01})
+ for _, p := range params {
+ // Skip empty fields (one at start and end).
+ if len(p) == 0 {
+ continue
+ }
+
+ pParts := bytes.SplitN(p, []byte{'='}, 2)
+ if len(pParts) != 2 {
+ return a.fail("Invalid response, missing '='")
+ }
+
+ switch string(pParts[0]) {
+ case "host":
+ opts.Host = string(pParts[1])
+ case "port":
+ port, err := strconv.ParseUint(string(pParts[1]), 10, 16)
+ if err != nil {
+ return a.fail("Invalid response, malformed 'port' value")
+ }
+ opts.Port = int(port)
+ case "auth":
+ const prefix = "bearer "
+ strValue := string(pParts[1])
+ // Token type is case-insensitive.
+ if !strings.HasPrefix(strings.ToLower(strValue), prefix) {
+ return a.fail("Unsupported token type")
+ }
+ opts.Token = strValue[len(prefix):]
+ default:
+ return a.fail("Invalid response, unknown parameter: " + string(pParts[0]))
+ }
+ }
+
+ authzErr := a.authenticate(opts)
+ if authzErr != nil {
+ blob, err := json.Marshal(authzErr)
+ if err != nil {
+ panic(err) // wtf
+ }
+ a.failErr = authzErr
+ return blob, false, nil
+ }
+
+ return nil, true, nil
+}
+
+func NewOAuthBearerServer(auth OAuthBearerAuthenticator) Server {
+ return &oauthBearerServer{authenticate: auth}
+}