aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/gorilla/websocket/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gorilla/websocket/conn.go')
-rw-r--r--vendor/github.com/gorilla/websocket/conn.go189
1 files changed, 127 insertions, 62 deletions
diff --git a/vendor/github.com/gorilla/websocket/conn.go b/vendor/github.com/gorilla/websocket/conn.go
index d2a21c1..331eebc 100644
--- a/vendor/github.com/gorilla/websocket/conn.go
+++ b/vendor/github.com/gorilla/websocket/conn.go
@@ -13,6 +13,7 @@ import (
"math/rand"
"net"
"strconv"
+ "strings"
"sync"
"time"
"unicode/utf8"
@@ -244,8 +245,8 @@ type Conn struct {
subprotocol string
// Write fields
- mu chan bool // used as mutex to protect write to conn
- writeBuf []byte // frame is constructed in this buffer.
+ mu chan struct{} // used as mutex to protect write to conn
+ writeBuf []byte // frame is constructed in this buffer.
writePool BufferPool
writeBufSize int
writeDeadline time.Time
@@ -260,10 +261,12 @@ type Conn struct {
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
// Read fields
- reader io.ReadCloser // the current reader returned to the application
- readErr error
- br *bufio.Reader
- readRemaining int64 // bytes remaining in current frame.
+ reader io.ReadCloser // the current reader returned to the application
+ readErr error
+ br *bufio.Reader
+ // bytes remaining in current frame.
+ // set setReadRemaining to safely update this value and prevent overflow
+ readRemaining int64
readFinal bool // true the current message has more frames.
readLength int64 // Message size.
readLimit int64 // Maximum message size.
@@ -300,8 +303,8 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
writeBuf = make([]byte, writeBufferSize)
}
- mu := make(chan bool, 1)
- mu <- true
+ mu := make(chan struct{}, 1)
+ mu <- struct{}{}
c := &Conn{
isServer: isServer,
br: br,
@@ -320,6 +323,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
return c
}
+// setReadRemaining tracks the number of bytes remaining on the connection. If n
+// overflows, an ErrReadLimit is returned.
+func (c *Conn) setReadRemaining(n int64) error {
+ if n < 0 {
+ return ErrReadLimit
+ }
+
+ c.readRemaining = n
+ return nil
+}
+
// Subprotocol returns the negotiated protocol for the connection.
func (c *Conn) Subprotocol() string {
return c.subprotocol
@@ -364,7 +378,7 @@ func (c *Conn) read(n int) ([]byte, error) {
func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
<-c.mu
- defer func() { c.mu <- true }()
+ defer func() { c.mu <- struct{}{} }()
c.writeErrMu.Lock()
err := c.writeErr
@@ -388,6 +402,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return nil
}
+func (c *Conn) writeBufs(bufs ...[]byte) error {
+ b := net.Buffers(bufs)
+ _, err := b.WriteTo(c.conn)
+ return err
+}
+
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
@@ -416,7 +436,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
maskBytes(key, 0, buf[6:])
}
- d := time.Hour * 1000
+ d := 1000 * time.Hour
if !deadline.IsZero() {
d = deadline.Sub(time.Now())
if d < 0 {
@@ -431,7 +451,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
case <-timer.C:
return errWriteTimeout
}
- defer func() { c.mu <- true }()
+ defer func() { c.mu <- struct{}{} }()
c.writeErrMu.Lock()
err := c.writeErr
@@ -451,7 +471,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err
}
-func (c *Conn) prepWrite(messageType int) error {
+// beginMessage prepares a connection and message writer for a new message.
+func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
@@ -471,6 +492,10 @@ func (c *Conn) prepWrite(messageType int) error {
return err
}
+ mw.c = c
+ mw.frameType = messageType
+ mw.pos = maxFrameHeaderSize
+
if c.writeBuf == nil {
wpd, ok := c.writePool.Get().(writePoolData)
if ok {
@@ -491,16 +516,11 @@ func (c *Conn) prepWrite(messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
- if err := c.prepWrite(messageType); err != nil {
+ var mw messageWriter
+ if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err
}
-
- mw := &messageWriter{
- c: c,
- frameType: messageType,
- pos: maxFrameHeaderSize,
- }
- c.writer = mw
+ c.writer = &mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w := c.newCompressionWriter(c.writer, c.compressionLevel)
mw.compress = true
@@ -517,10 +537,16 @@ type messageWriter struct {
err error
}
-func (w *messageWriter) fatal(err error) error {
+func (w *messageWriter) endMessage(err error) error {
if w.err != nil {
- w.err = err
- w.c.writer = nil
+ return err
+ }
+ c := w.c
+ w.err = err
+ c.writer = nil
+ if c.writePool != nil {
+ c.writePool.Put(writePoolData{buf: c.writeBuf})
+ c.writeBuf = nil
}
return err
}
@@ -534,7 +560,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
// Check for invalid control frames.
if isControl(w.frameType) &&
(!final || length > maxControlFramePayloadSize) {
- return w.fatal(errInvalidControlFrame)
+ return w.endMessage(errInvalidControlFrame)
}
b0 := byte(w.frameType)
@@ -579,7 +605,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
if len(extra) > 0 {
- return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
+ return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
}
}
@@ -600,15 +626,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
c.isWriting = false
if err != nil {
- return w.fatal(err)
+ return w.endMessage(err)
}
if final {
- c.writer = nil
- if c.writePool != nil {
- c.writePool.Put(writePoolData{buf: c.writeBuf})
- c.writeBuf = nil
- }
+ w.endMessage(errWriteClosed)
return nil
}
@@ -706,11 +728,7 @@ func (w *messageWriter) Close() error {
if w.err != nil {
return w.err
}
- if err := w.flushFrame(true, nil); err != nil {
- return err
- }
- w.err = errWriteClosed
- return nil
+ return w.flushFrame(true, nil)
}
// WritePreparedMessage writes prepared message into connection.
@@ -742,10 +760,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
// Fast path with no allocations and single frame.
- if err := c.prepWrite(messageType); err != nil {
+ var mw messageWriter
+ if err := c.beginMessage(&mw, messageType); err != nil {
return err
}
- mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
n := copy(c.writeBuf[mw.pos:], data)
mw.pos += n
data = data[n:]
@@ -783,50 +801,82 @@ func (c *Conn) advanceFrame() (int, error) {
}
// 2. Read and parse first two bytes of frame header.
+ // To aid debugging, collect and report all errors in the first two bytes
+ // of the header.
+
+ var errors []string
p, err := c.read(2)
if err != nil {
return noFrame, err
}
- final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
+ final := p[0]&finalBit != 0
+ rsv1 := p[0]&rsv1Bit != 0
+ rsv2 := p[0]&rsv2Bit != 0
+ rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
- c.readRemaining = int64(p[1] & 0x7f)
+ c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false
- if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
- c.readDecompress = true
- p[0] &^= rsv1Bit
+ if rsv1 {
+ if c.newDecompressionReader != nil {
+ c.readDecompress = true
+ } else {
+ errors = append(errors, "RSV1 set")
+ }
+ }
+
+ if rsv2 {
+ errors = append(errors, "RSV2 set")
}
- if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
- return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
+ if rsv3 {
+ errors = append(errors, "RSV3 set")
}
switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
- return noFrame, c.handleProtocolError("control frame length > 125")
+ errors = append(errors, "len > 125 for control")
}
if !final {
- return noFrame, c.handleProtocolError("control frame not final")
+ errors = append(errors, "FIN not set on control")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
- return noFrame, c.handleProtocolError("message start before final message frame")
+ errors = append(errors, "data before FIN")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
- return noFrame, c.handleProtocolError("continuation after final message frame")
+ errors = append(errors, "continuation after FIN")
}
c.readFinal = final
default:
- return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
+ errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
+ }
+
+ if mask != c.isServer {
+ errors = append(errors, "bad MASK")
}
- // 3. Read and parse frame length.
+ if len(errors) > 0 {
+ return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
+ }
+
+ // 3. Read and parse frame length as per
+ // https://tools.ietf.org/html/rfc6455#section-5.2
+ //
+ // The length of the "Payload data", in bytes: if 0-125, that is the payload
+ // length.
+ // - If 126, the following 2 bytes interpreted as a 16-bit unsigned
+ // integer are the payload length.
+ // - If 127, the following 8 bytes interpreted as
+ // a 64-bit unsigned integer (the most significant bit MUST be 0) are the
+ // payload length. Multibyte length quantities are expressed in network byte
+ // order.
switch c.readRemaining {
case 126:
@@ -834,21 +884,23 @@ func (c *Conn) advanceFrame() (int, error) {
if err != nil {
return noFrame, err
}
- c.readRemaining = int64(binary.BigEndian.Uint16(p))
+
+ if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
+ return noFrame, err
+ }
case 127:
p, err := c.read(8)
if err != nil {
return noFrame, err
}
- c.readRemaining = int64(binary.BigEndian.Uint64(p))
+
+ if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
+ return noFrame, err
+ }
}
// 4. Handle frame masking.
- if mask != c.isServer {
- return noFrame, c.handleProtocolError("incorrect mask flag")
- }
-
if mask {
c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey))
@@ -863,6 +915,12 @@ func (c *Conn) advanceFrame() (int, error) {
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
c.readLength += c.readRemaining
+ // Don't allow readLength to overflow in the presence of a large readRemaining
+ // counter.
+ if c.readLength < 0 {
+ return noFrame, ErrReadLimit
+ }
+
if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit
@@ -876,7 +934,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte
if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining))
- c.readRemaining = 0
+ c.setReadRemaining(0)
if err != nil {
return noFrame, err
}
@@ -902,7 +960,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
- return noFrame, c.handleProtocolError("invalid close code")
+ return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
}
closeText = string(payload[2:])
if !utf8.ValidString(closeText) {
@@ -919,7 +977,11 @@ func (c *Conn) advanceFrame() (int, error) {
}
func (c *Conn) handleProtocolError(message string) error {
- c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
+ data := FormatCloseMessage(CloseProtocolError, message)
+ if len(data) > maxControlFramePayloadSize {
+ data = data[:maxControlFramePayloadSize]
+ }
+ c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}
@@ -949,6 +1011,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readErr = hideTempErr(err)
break
}
+
if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c}
c.reader = c.messageReader
@@ -989,7 +1052,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
if c.isServer {
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
}
- c.readRemaining -= int64(n)
+ rem := c.readRemaining
+ rem -= int64(n)
+ c.setReadRemaining(rem)
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}
@@ -1041,7 +1106,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
-// SetReadLimit sets the maximum size for a message read from the peer. If a
+// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
// message exceeds the limit, the connection sends a close message to the peer
// and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) {