aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/emersion/go-smtp/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/emersion/go-smtp/conn.go')
-rw-r--r--vendor/github.com/emersion/go-smtp/conn.go986
1 files changed, 986 insertions, 0 deletions
diff --git a/vendor/github.com/emersion/go-smtp/conn.go b/vendor/github.com/emersion/go-smtp/conn.go
new file mode 100644
index 0000000..72a67d8
--- /dev/null
+++ b/vendor/github.com/emersion/go-smtp/conn.go
@@ -0,0 +1,986 @@
+package smtp
+
+import (
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/textproto"
+ "regexp"
+ "runtime/debug"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+// Number of errors we'll tolerate per connection before closing. Defaults to 3.
+const errThreshold = 3
+
+type Conn struct {
+ conn net.Conn
+ text *textproto.Conn
+ server *Server
+ helo string
+
+ // Number of errors witnessed on this connection
+ errCount int
+
+ session Session
+ locker sync.Mutex
+ binarymime bool
+
+ lineLimitReader *lineLimitReader
+ bdatPipe *io.PipeWriter
+ bdatStatus *statusCollector // used for BDAT on LMTP
+ dataResult chan error
+ bytesReceived int // counts total size of chunks when BDAT is used
+
+ fromReceived bool
+ recipients []string
+ didAuth bool
+}
+
+func newConn(c net.Conn, s *Server) *Conn {
+ sc := &Conn{
+ server: s,
+ conn: c,
+ }
+
+ sc.init()
+ return sc
+}
+
+func (c *Conn) init() {
+ c.lineLimitReader = &lineLimitReader{
+ R: c.conn,
+ LineLimit: c.server.MaxLineLength,
+ }
+ rwc := struct {
+ io.Reader
+ io.Writer
+ io.Closer
+ }{
+ Reader: c.lineLimitReader,
+ Writer: c.conn,
+ Closer: c.conn,
+ }
+
+ if c.server.Debug != nil {
+ rwc = struct {
+ io.Reader
+ io.Writer
+ io.Closer
+ }{
+ io.TeeReader(rwc.Reader, c.server.Debug),
+ io.MultiWriter(rwc.Writer, c.server.Debug),
+ rwc.Closer,
+ }
+ }
+
+ c.text = textproto.NewConn(rwc)
+}
+
+// Commands are dispatched to the appropriate handler functions.
+func (c *Conn) handle(cmd string, arg string) {
+ // If panic happens during command handling - send 421 response
+ // and close connection.
+ defer func() {
+ if err := recover(); err != nil {
+ c.writeResponse(421, EnhancedCode{4, 0, 0}, "Internal server error")
+ c.Close()
+
+ stack := debug.Stack()
+ c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.conn.RemoteAddr(), err, stack)
+ }
+ }()
+
+ if cmd == "" {
+ c.protocolError(500, EnhancedCode{5, 5, 2}, "Error: bad syntax")
+ return
+ }
+
+ cmd = strings.ToUpper(cmd)
+ switch cmd {
+ case "SEND", "SOML", "SAML", "EXPN", "HELP", "TURN":
+ // These commands are not implemented in any state
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, fmt.Sprintf("%v command not implemented", cmd))
+ case "HELO", "EHLO", "LHLO":
+ lmtp := cmd == "LHLO"
+ enhanced := lmtp || cmd == "EHLO"
+ if c.server.LMTP && !lmtp {
+ c.writeResponse(500, EnhancedCode{5, 5, 1}, "This is a LMTP server, use LHLO")
+ return
+ }
+ if !c.server.LMTP && lmtp {
+ c.writeResponse(500, EnhancedCode{5, 5, 1}, "This is not a LMTP server")
+ return
+ }
+ c.handleGreet(enhanced, arg)
+ case "MAIL":
+ c.handleMail(arg)
+ case "RCPT":
+ c.handleRcpt(arg)
+ case "VRFY":
+ c.writeResponse(252, EnhancedCode{2, 5, 0}, "Cannot VRFY user, but will accept message")
+ case "NOOP":
+ c.writeResponse(250, EnhancedCode{2, 0, 0}, "I have sucessfully done nothing")
+ case "RSET": // Reset session
+ c.reset()
+ c.writeResponse(250, EnhancedCode{2, 0, 0}, "Session reset")
+ case "BDAT":
+ c.handleBdat(arg)
+ case "DATA":
+ c.handleData(arg)
+ case "QUIT":
+ c.writeResponse(221, EnhancedCode{2, 0, 0}, "Bye")
+ c.Close()
+ case "AUTH":
+ if c.server.AuthDisabled {
+ c.protocolError(500, EnhancedCode{5, 5, 2}, "Syntax error, AUTH command unrecognized")
+ } else {
+ c.handleAuth(arg)
+ }
+ case "STARTTLS":
+ c.handleStartTLS()
+ default:
+ msg := fmt.Sprintf("Syntax errors, %v command unrecognized", cmd)
+ c.protocolError(500, EnhancedCode{5, 5, 2}, msg)
+ }
+}
+
+func (c *Conn) Server() *Server {
+ return c.server
+}
+
+func (c *Conn) Session() Session {
+ c.locker.Lock()
+ defer c.locker.Unlock()
+ return c.session
+}
+
+func (c *Conn) setSession(session Session) {
+ c.locker.Lock()
+ defer c.locker.Unlock()
+ c.session = session
+}
+
+func (c *Conn) Close() error {
+ c.locker.Lock()
+ defer c.locker.Unlock()
+
+ if c.bdatPipe != nil {
+ c.bdatPipe.CloseWithError(ErrDataReset)
+ c.bdatPipe = nil
+ }
+
+ if c.session != nil {
+ c.session.Logout()
+ c.session = nil
+ }
+
+ return c.conn.Close()
+}
+
+// TLSConnectionState returns the connection's TLS connection state.
+// Zero values are returned if the connection doesn't use TLS.
+func (c *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
+ tc, ok := c.conn.(*tls.Conn)
+ if !ok {
+ return
+ }
+ return tc.ConnectionState(), true
+}
+
+func (c *Conn) Hostname() string {
+ return c.helo
+}
+
+func (c *Conn) Conn() net.Conn {
+ return c.conn
+}
+
+func (c *Conn) authAllowed() bool {
+ _, isTLS := c.TLSConnectionState()
+ return !c.server.AuthDisabled && (isTLS || c.server.AllowInsecureAuth)
+}
+
+// protocolError writes errors responses and closes the connection once too many
+// have occurred.
+func (c *Conn) protocolError(code int, ec EnhancedCode, msg string) {
+ c.writeResponse(code, ec, msg)
+
+ c.errCount++
+ if c.errCount > errThreshold {
+ c.writeResponse(500, EnhancedCode{5, 5, 1}, "Too many errors. Quiting now")
+ c.Close()
+ }
+}
+
+// GREET state -> waiting for HELO
+func (c *Conn) handleGreet(enhanced bool, arg string) {
+ domain, err := parseHelloArgument(arg)
+ if err != nil {
+ c.writeResponse(501, EnhancedCode{5, 5, 2}, "Domain/address argument required for HELO")
+ return
+ }
+ c.helo = domain
+
+ sess, err := c.server.Backend.NewSession(c)
+ if err != nil {
+ if smtpErr, ok := err.(*SMTPError); ok {
+ c.writeResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message)
+ return
+ }
+ c.writeResponse(451, EnhancedCode{4, 0, 0}, err.Error())
+ return
+ }
+ c.setSession(sess)
+
+ if !enhanced {
+ c.writeResponse(250, EnhancedCode{2, 0, 0}, fmt.Sprintf("Hello %s", domain))
+ return
+ }
+
+ caps := []string{}
+ caps = append(caps, c.server.caps...)
+ if _, isTLS := c.TLSConnectionState(); c.server.TLSConfig != nil && !isTLS {
+ caps = append(caps, "STARTTLS")
+ }
+ if c.authAllowed() {
+ authCap := "AUTH"
+ for name := range c.server.auths {
+ authCap += " " + name
+ }
+
+ caps = append(caps, authCap)
+ }
+ if c.server.EnableSMTPUTF8 {
+ caps = append(caps, "SMTPUTF8")
+ }
+ if _, isTLS := c.TLSConnectionState(); isTLS && c.server.EnableREQUIRETLS {
+ caps = append(caps, "REQUIRETLS")
+ }
+ if c.server.EnableBINARYMIME {
+ caps = append(caps, "BINARYMIME")
+ }
+ if c.server.MaxMessageBytes > 0 {
+ caps = append(caps, fmt.Sprintf("SIZE %v", c.server.MaxMessageBytes))
+ } else {
+ caps = append(caps, "SIZE")
+ }
+
+ args := []string{"Hello " + domain}
+ args = append(args, caps...)
+ c.writeResponse(250, NoEnhancedCode, args...)
+}
+
+// READY state -> waiting for MAIL
+func (c *Conn) handleMail(arg string) {
+ if c.helo == "" {
+ c.writeResponse(502, EnhancedCode{2, 5, 1}, "Please introduce yourself first.")
+ return
+ }
+ if c.bdatPipe != nil {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "MAIL not allowed during message transfer")
+ return
+ }
+
+ if len(arg) < 6 || strings.ToUpper(arg[0:5]) != "FROM:" {
+ c.writeResponse(501, EnhancedCode{5, 5, 2}, "Was expecting MAIL arg syntax of FROM:<address>")
+ return
+ }
+ fromArgs := strings.Split(strings.Trim(arg[5:], " "), " ")
+ if c.server.Strict {
+ if !strings.HasPrefix(fromArgs[0], "<") || !strings.HasSuffix(fromArgs[0], ">") {
+ c.writeResponse(501, EnhancedCode{5, 5, 2}, "Was expecting MAIL arg syntax of FROM:<address>")
+ return
+ }
+ }
+ from := fromArgs[0]
+ if from == "" {
+ c.writeResponse(501, EnhancedCode{5, 5, 2}, "Was expecting MAIL arg syntax of FROM:<address>")
+ return
+ }
+ from = strings.Trim(from, "<>")
+
+ opts := &MailOptions{}
+
+ c.binarymime = false
+ // This is where the Conn may put BODY=8BITMIME, but we already
+ // read the DATA as bytes, so it does not effect our processing.
+ if len(fromArgs) > 1 {
+ args, err := parseArgs(fromArgs[1:])
+ if err != nil {
+ c.writeResponse(501, EnhancedCode{5, 5, 4}, "Unable to parse MAIL ESMTP parameters")
+ return
+ }
+
+ for key, value := range args {
+ switch key {
+ case "SIZE":
+ size, err := strconv.ParseInt(value, 10, 32)
+ if err != nil {
+ c.writeResponse(501, EnhancedCode{5, 5, 4}, "Unable to parse SIZE as an integer")
+ return
+ }
+
+ if c.server.MaxMessageBytes > 0 && int(size) > c.server.MaxMessageBytes {
+ c.writeResponse(552, EnhancedCode{5, 3, 4}, "Max message size exceeded")
+ return
+ }
+
+ opts.Size = int(size)
+ case "SMTPUTF8":
+ if !c.server.EnableSMTPUTF8 {
+ c.writeResponse(504, EnhancedCode{5, 5, 4}, "SMTPUTF8 is not implemented")
+ return
+ }
+ opts.UTF8 = true
+ case "REQUIRETLS":
+ if !c.server.EnableREQUIRETLS {
+ c.writeResponse(504, EnhancedCode{5, 5, 4}, "REQUIRETLS is not implemented")
+ return
+ }
+ opts.RequireTLS = true
+ case "BODY":
+ switch value {
+ case "BINARYMIME":
+ if !c.server.EnableBINARYMIME {
+ c.writeResponse(504, EnhancedCode{5, 5, 4}, "BINARYMIME is not implemented")
+ return
+ }
+ c.binarymime = true
+ case "7BIT", "8BITMIME":
+ default:
+ c.writeResponse(500, EnhancedCode{5, 5, 4}, "Unknown BODY value")
+ return
+ }
+ opts.Body = BodyType(value)
+ case "AUTH":
+ value, err := decodeXtext(value)
+ if err != nil {
+ c.writeResponse(500, EnhancedCode{5, 5, 4}, "Malformed AUTH parameter value")
+ return
+ }
+ if !strings.HasPrefix(value, "<") {
+ c.writeResponse(500, EnhancedCode{5, 5, 4}, "Missing opening angle bracket")
+ return
+ }
+ if !strings.HasSuffix(value, ">") {
+ c.writeResponse(500, EnhancedCode{5, 5, 4}, "Missing closing angle bracket")
+ return
+ }
+ decodedMbox := value[1 : len(value)-1]
+ opts.Auth = &decodedMbox
+ default:
+ c.writeResponse(500, EnhancedCode{5, 5, 4}, "Unknown MAIL FROM argument")
+ return
+ }
+ }
+ }
+
+ if err := c.Session().Mail(from, opts); err != nil {
+ if smtpErr, ok := err.(*SMTPError); ok {
+ c.writeResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message)
+ return
+ }
+ c.writeResponse(451, EnhancedCode{4, 0, 0}, err.Error())
+ return
+ }
+
+ c.writeResponse(250, EnhancedCode{2, 0, 0}, fmt.Sprintf("Roger, accepting mail from <%v>", from))
+ c.fromReceived = true
+}
+
+// This regexp matches 'hexchar' token defined in
+// https://tools.ietf.org/html/rfc4954#section-8 however it is intentionally
+// relaxed by requiring only '+' to be present. It allows us to detect
+// malformed values such as +A or +HH and report them appropriately.
+var hexcharRe = regexp.MustCompile(`\+[0-9A-F]?[0-9A-F]?`)
+
+func decodeXtext(val string) (string, error) {
+ if !strings.Contains(val, "+") {
+ return val, nil
+ }
+
+ var replaceErr error
+ decoded := hexcharRe.ReplaceAllStringFunc(val, func(match string) string {
+ if len(match) != 3 {
+ replaceErr = errors.New("incomplete hexchar")
+ return ""
+ }
+ char, err := strconv.ParseInt(match, 16, 8)
+ if err != nil {
+ replaceErr = err
+ return ""
+ }
+
+ return string(rune(char))
+ })
+ if replaceErr != nil {
+ return "", replaceErr
+ }
+
+ return decoded, nil
+}
+
+func encodeXtext(raw string) string {
+ var out strings.Builder
+ out.Grow(len(raw))
+
+ for _, ch := range raw {
+ if ch == '+' || ch == '=' {
+ out.WriteRune('+')
+ out.WriteString(strings.ToUpper(strconv.FormatInt(int64(ch), 16)))
+ }
+ if ch > '!' && ch < '~' { // printable non-space US-ASCII
+ out.WriteRune(ch)
+ }
+ // Non-ASCII.
+ out.WriteRune('+')
+ out.WriteString(strings.ToUpper(strconv.FormatInt(int64(ch), 16)))
+ }
+ return out.String()
+}
+
+// MAIL state -> waiting for RCPTs followed by DATA
+func (c *Conn) handleRcpt(arg string) {
+ if !c.fromReceived {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "Missing MAIL FROM command.")
+ return
+ }
+ if c.bdatPipe != nil {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "RCPT not allowed during message transfer")
+ return
+ }
+
+ if (len(arg) < 4) || (strings.ToUpper(arg[0:3]) != "TO:") {
+ c.writeResponse(501, EnhancedCode{5, 5, 2}, "Was expecting RCPT arg syntax of TO:<address>")
+ return
+ }
+
+ // TODO: This trim is probably too forgiving
+ recipient := strings.Trim(arg[3:], "<> ")
+
+ if c.server.MaxRecipients > 0 && len(c.recipients) >= c.server.MaxRecipients {
+ c.writeResponse(552, EnhancedCode{5, 5, 3}, fmt.Sprintf("Maximum limit of %v recipients reached", c.server.MaxRecipients))
+ return
+ }
+
+ if err := c.Session().Rcpt(recipient); err != nil {
+ if smtpErr, ok := err.(*SMTPError); ok {
+ c.writeResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message)
+ return
+ }
+ c.writeResponse(451, EnhancedCode{4, 0, 0}, err.Error())
+ return
+ }
+ c.recipients = append(c.recipients, recipient)
+ c.writeResponse(250, EnhancedCode{2, 0, 0}, fmt.Sprintf("I'll make sure <%v> gets this", recipient))
+}
+
+func (c *Conn) handleAuth(arg string) {
+ if c.helo == "" {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "Please introduce yourself first.")
+ return
+ }
+ if c.didAuth {
+ c.writeResponse(503, EnhancedCode{5, 5, 1}, "Already authenticated")
+ return
+ }
+
+ parts := strings.Fields(arg)
+ if len(parts) == 0 {
+ c.writeResponse(502, EnhancedCode{5, 5, 4}, "Missing parameter")
+ return
+ }
+
+ if _, isTLS := c.TLSConnectionState(); !isTLS && !c.server.AllowInsecureAuth {
+ c.writeResponse(523, EnhancedCode{5, 7, 10}, "TLS is required")
+ return
+ }
+
+ mechanism := strings.ToUpper(parts[0])
+
+ // Parse client initial response if there is one
+ var ir []byte
+ if len(parts) > 1 {
+ var err error
+ ir, err = base64.StdEncoding.DecodeString(parts[1])
+ if err != nil {
+ return
+ }
+ }
+
+ newSasl, ok := c.server.auths[mechanism]
+ if !ok {
+ c.writeResponse(504, EnhancedCode{5, 7, 4}, "Unsupported authentication mechanism")
+ return
+ }
+
+ sasl := newSasl(c)
+
+ response := ir
+ for {
+ challenge, done, err := sasl.Next(response)
+ if err != nil {
+ if smtpErr, ok := err.(*SMTPError); ok {
+ c.writeResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message)
+ return
+ }
+ c.writeResponse(454, EnhancedCode{4, 7, 0}, err.Error())
+ return
+ }
+
+ if done {
+ break
+ }
+
+ encoded := ""
+ if len(challenge) > 0 {
+ encoded = base64.StdEncoding.EncodeToString(challenge)
+ }
+ c.writeResponse(334, NoEnhancedCode, encoded)
+
+ encoded, err = c.readLine()
+ if err != nil {
+ return // TODO: error handling
+ }
+
+ if encoded == "*" {
+ // https://tools.ietf.org/html/rfc4954#page-4
+ c.writeResponse(501, EnhancedCode{5, 0, 0}, "Negotiation cancelled")
+ return
+ }
+
+ response, err = base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ c.writeResponse(454, EnhancedCode{4, 7, 0}, "Invalid base64 data")
+ return
+ }
+ }
+
+ c.writeResponse(235, EnhancedCode{2, 0, 0}, "Authentication succeeded")
+ c.didAuth = true
+}
+
+func (c *Conn) handleStartTLS() {
+ if _, isTLS := c.TLSConnectionState(); isTLS {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "Already running in TLS")
+ return
+ }
+
+ if c.server.TLSConfig == nil {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "TLS not supported")
+ return
+ }
+
+ c.writeResponse(220, EnhancedCode{2, 0, 0}, "Ready to start TLS")
+
+ // Upgrade to TLS
+ tlsConn := tls.Server(c.conn, c.server.TLSConfig)
+
+ if err := tlsConn.Handshake(); err != nil {
+ c.writeResponse(550, EnhancedCode{5, 0, 0}, "Handshake error")
+ return
+ }
+
+ c.conn = tlsConn
+ c.init()
+
+ // Reset all state and close the previous Session.
+ // This is different from just calling reset() since we want the Backend to
+ // be able to see the information about TLS connection in the
+ // ConnectionState object passed to it.
+ if session := c.Session(); session != nil {
+ session.Logout()
+ c.setSession(nil)
+ }
+ c.helo = ""
+ c.didAuth = false
+ c.reset()
+}
+
+// DATA
+func (c *Conn) handleData(arg string) {
+ if arg != "" {
+ c.writeResponse(501, EnhancedCode{5, 5, 4}, "DATA command should not have any arguments")
+ return
+ }
+ if c.bdatPipe != nil {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "DATA not allowed during message transfer")
+ return
+ }
+ if c.binarymime {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "DATA not allowed for BINARYMIME messages")
+ return
+ }
+
+ if !c.fromReceived || len(c.recipients) == 0 {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.")
+ return
+ }
+
+ // We have recipients, go to accept data
+ c.writeResponse(354, EnhancedCode{2, 0, 0}, "Go ahead. End your data with <CR><LF>.<CR><LF>")
+
+ defer c.reset()
+
+ if c.server.LMTP {
+ c.handleDataLMTP()
+ return
+ }
+
+ r := newDataReader(c)
+ code, enhancedCode, msg := toSMTPStatus(c.Session().Data(r))
+ r.limited = false
+ io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed
+ c.writeResponse(code, enhancedCode, msg)
+}
+
+func (c *Conn) handleBdat(arg string) {
+ args := strings.Fields(arg)
+ if len(args) == 0 {
+ c.writeResponse(501, EnhancedCode{5, 5, 4}, "Missing chunk size argument")
+ return
+ }
+ if len(args) > 2 {
+ c.writeResponse(501, EnhancedCode{5, 5, 4}, "Too many arguments")
+ return
+ }
+
+ if !c.fromReceived || len(c.recipients) == 0 {
+ c.writeResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.")
+ return
+ }
+
+ last := false
+ if len(args) == 2 {
+ if !strings.EqualFold(args[1], "LAST") {
+ c.writeResponse(501, EnhancedCode{5, 5, 4}, "Unknown BDAT argument")
+ return
+ }
+ last = true
+ }
+
+ // ParseUint instead of Atoi so we will not accept negative values.
+ size, err := strconv.ParseUint(args[0], 10, 32)
+ if err != nil {
+ c.writeResponse(501, EnhancedCode{5, 5, 4}, "Malformed size argument")
+ return
+ }
+
+ if c.server.MaxMessageBytes != 0 && c.bytesReceived+int(size) > c.server.MaxMessageBytes {
+ c.writeResponse(552, EnhancedCode{5, 3, 4}, "Max message size exceeded")
+
+ // Discard chunk itself without passing it to backend.
+ io.Copy(ioutil.Discard, io.LimitReader(c.text.R, int64(size)))
+
+ c.reset()
+ return
+ }
+
+ if c.bdatStatus == nil && c.server.LMTP {
+ c.bdatStatus = c.createStatusCollector()
+ }
+
+ if c.bdatPipe == nil {
+ var r *io.PipeReader
+ r, c.bdatPipe = io.Pipe()
+
+ c.dataResult = make(chan error, 1)
+
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ c.handlePanic(err, c.bdatStatus)
+
+ c.dataResult <- errPanic
+ r.CloseWithError(errPanic)
+ }
+ }()
+
+ var err error
+ if !c.server.LMTP {
+ err = c.Session().Data(r)
+ } else {
+ lmtpSession, ok := c.Session().(LMTPSession)
+ if !ok {
+ err = c.Session().Data(r)
+ for _, rcpt := range c.recipients {
+ c.bdatStatus.SetStatus(rcpt, err)
+ }
+ } else {
+ err = lmtpSession.LMTPData(r, c.bdatStatus)
+ }
+ }
+
+ c.dataResult <- err
+ r.CloseWithError(err)
+ }()
+ }
+
+ c.lineLimitReader.LineLimit = 0
+
+ chunk := io.LimitReader(c.text.R, int64(size))
+ _, err = io.Copy(c.bdatPipe, chunk)
+ if err != nil {
+ // Backend might return an error early using CloseWithError without consuming
+ // the whole chunk.
+ io.Copy(ioutil.Discard, chunk)
+
+ c.writeResponse(toSMTPStatus(err))
+
+ if err == errPanic {
+ c.Close()
+ }
+
+ c.reset()
+ c.lineLimitReader.LineLimit = c.server.MaxLineLength
+ return
+ }
+
+ c.bytesReceived += int(size)
+
+ if last {
+ c.lineLimitReader.LineLimit = c.server.MaxLineLength
+
+ c.bdatPipe.Close()
+
+ err := <-c.dataResult
+
+ if c.server.LMTP {
+ c.bdatStatus.fillRemaining(err)
+ for i, rcpt := range c.recipients {
+ code, enchCode, msg := toSMTPStatus(<-c.bdatStatus.status[i])
+ c.writeResponse(code, enchCode, "<"+rcpt+"> "+msg)
+ }
+ } else {
+ c.writeResponse(toSMTPStatus(err))
+ }
+
+ if err == errPanic {
+ c.Close()
+ return
+ }
+
+ c.reset()
+ } else {
+ c.writeResponse(250, EnhancedCode{2, 0, 0}, "Continue")
+ }
+}
+
+// ErrDataReset is returned by Reader pased to Data function if client does not
+// send another BDAT command and instead closes connection or issues RSET command.
+var ErrDataReset = errors.New("smtp: message transmission aborted")
+
+var errPanic = &SMTPError{
+ Code: 421,
+ EnhancedCode: EnhancedCode{4, 0, 0},
+ Message: "Internal server error",
+}
+
+func (c *Conn) handlePanic(err interface{}, status *statusCollector) {
+ if status != nil {
+ status.fillRemaining(errPanic)
+ }
+
+ stack := debug.Stack()
+ c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.conn.RemoteAddr(), err, stack)
+}
+
+func (c *Conn) createStatusCollector() *statusCollector {
+ rcptCounts := make(map[string]int, len(c.recipients))
+
+ status := &statusCollector{
+ statusMap: make(map[string]chan error, len(c.recipients)),
+ status: make([]chan error, 0, len(c.recipients)),
+ }
+ for _, rcpt := range c.recipients {
+ rcptCounts[rcpt]++
+ }
+ // Create channels with buffer sizes necessary to fit all
+ // statuses for a single recipient to avoid deadlocks.
+ for rcpt, count := range rcptCounts {
+ status.statusMap[rcpt] = make(chan error, count)
+ }
+ for _, rcpt := range c.recipients {
+ status.status = append(status.status, status.statusMap[rcpt])
+ }
+
+ return status
+}
+
+type statusCollector struct {
+ // Contains map from recipient to list of channels that are used for that
+ // recipient.
+ statusMap map[string]chan error
+
+ // Contains channels from statusMap, in the same
+ // order as Conn.recipients.
+ status []chan error
+}
+
+// fillRemaining sets status for all recipients SetStatus was not called for before.
+func (s *statusCollector) fillRemaining(err error) {
+ // Amount of times certain recipient was specified is indicated by the channel
+ // buffer size, so once we fill it, we can be confident that we sent
+ // at least as much statuses as needed. Extra statuses will be ignored anyway.
+chLoop:
+ for _, ch := range s.statusMap {
+ for {
+ select {
+ case ch <- err:
+ default:
+ continue chLoop
+ }
+ }
+ }
+}
+
+func (s *statusCollector) SetStatus(rcptTo string, err error) {
+ ch := s.statusMap[rcptTo]
+ if ch == nil {
+ panic("SetStatus is called for recipient that was not specified before")
+ }
+
+ select {
+ case ch <- err:
+ default:
+ // There enough buffer space to fit all statuses at once, if this is
+ // not the case - backend is doing something wrong.
+ panic("SetStatus is called more times than particular recipient was specified")
+ }
+}
+
+func (c *Conn) handleDataLMTP() {
+ r := newDataReader(c)
+ status := c.createStatusCollector()
+
+ done := make(chan bool, 1)
+
+ lmtpSession, ok := c.Session().(LMTPSession)
+ if !ok {
+ // Fallback to using a single status for all recipients.
+ err := c.Session().Data(r)
+ io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed
+ for _, rcpt := range c.recipients {
+ status.SetStatus(rcpt, err)
+ }
+ done <- true
+ } else {
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ status.fillRemaining(&SMTPError{
+ Code: 421,
+ EnhancedCode: EnhancedCode{4, 0, 0},
+ Message: "Internal server error",
+ })
+
+ stack := debug.Stack()
+ c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.conn.RemoteAddr(), err, stack)
+ done <- false
+ }
+ }()
+
+ status.fillRemaining(lmtpSession.LMTPData(r, status))
+ io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed
+ done <- true
+ }()
+ }
+
+ for i, rcpt := range c.recipients {
+ code, enchCode, msg := toSMTPStatus(<-status.status[i])
+ c.writeResponse(code, enchCode, "<"+rcpt+"> "+msg)
+ }
+
+ // If done gets false, the panic occured in LMTPData and the connection
+ // should be closed.
+ if !<-done {
+ c.Close()
+ }
+}
+
+func toSMTPStatus(err error) (code int, enchCode EnhancedCode, msg string) {
+ if err != nil {
+ if smtperr, ok := err.(*SMTPError); ok {
+ return smtperr.Code, smtperr.EnhancedCode, smtperr.Message
+ } else {
+ return 554, EnhancedCode{5, 0, 0}, "Error: transaction failed, blame it on the weather: " + err.Error()
+ }
+ }
+
+ return 250, EnhancedCode{2, 0, 0}, "OK: queued"
+}
+
+func (c *Conn) Reject() {
+ c.writeResponse(421, EnhancedCode{4, 4, 5}, "Too busy. Try again later.")
+ c.Close()
+}
+
+func (c *Conn) greet() {
+ c.writeResponse(220, NoEnhancedCode, fmt.Sprintf("%v ESMTP Service Ready", c.server.Domain))
+}
+
+func (c *Conn) writeResponse(code int, enhCode EnhancedCode, text ...string) {
+ // TODO: error handling
+ if c.server.WriteTimeout != 0 {
+ c.conn.SetWriteDeadline(time.Now().Add(c.server.WriteTimeout))
+ }
+
+ // All responses must include an enhanced code, if it is missing - use
+ // a generic code X.0.0.
+ if enhCode == EnhancedCodeNotSet {
+ cat := code / 100
+ switch cat {
+ case 2, 4, 5:
+ enhCode = EnhancedCode{cat, 0, 0}
+ default:
+ enhCode = NoEnhancedCode
+ }
+ }
+
+ for i := 0; i < len(text)-1; i++ {
+ c.text.PrintfLine("%d-%v", code, text[i])
+ }
+ if enhCode == NoEnhancedCode {
+ c.text.PrintfLine("%d %v", code, text[len(text)-1])
+ } else {
+ c.text.PrintfLine("%d %v.%v.%v %v", code, enhCode[0], enhCode[1], enhCode[2], text[len(text)-1])
+ }
+}
+
+// Reads a line of input
+func (c *Conn) readLine() (string, error) {
+ if c.server.ReadTimeout != 0 {
+ if err := c.conn.SetReadDeadline(time.Now().Add(c.server.ReadTimeout)); err != nil {
+ return "", err
+ }
+ }
+
+ return c.text.ReadLine()
+}
+
+func (c *Conn) reset() {
+ c.locker.Lock()
+ defer c.locker.Unlock()
+
+ if c.bdatPipe != nil {
+ c.bdatPipe.CloseWithError(ErrDataReset)
+ c.bdatPipe = nil
+ }
+ c.bdatStatus = nil
+ c.bytesReceived = 0
+
+ if c.session != nil {
+ c.session.Reset()
+ }
+
+ c.fromReceived = false
+ c.recipients = nil
+}