aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDave Cheney <dave@cheney.net>2011-10-20 15:44:45 -0400
committerAdam Langley <agl@golang.org>2011-10-20 15:44:45 -0400
commit792a55f5db30c7280b2910a9621ea78ec6bd2c1c (patch)
tree6c7e23e458ee590bec498a83cc0ad7357ed3e84d
parente8a426aebe4968d5a27068e5aed2970a4c38f686 (diff)
downloadgo-792a55f5db30c7280b2910a9621ea78ec6bd2c1c.tar.gz
go-792a55f5db30c7280b2910a9621ea78ec6bd2c1c.zip
exp/ssh: add experimental ssh client
Requires CL 5285044 client.go: * add Dial, ClientConn, ClientChan, ClientConfig and Cmd. doc.go: * add Client documentation. server.go: * adjust for readVersion change. transport.go: * return an os.Error not a bool from readVersion. R=rsc, agl, n13m3y3r CC=golang-dev https://golang.org/cl/5162047
-rw-r--r--src/pkg/exp/ssh/Makefile1
-rw-r--r--src/pkg/exp/ssh/client.go628
-rw-r--r--src/pkg/exp/ssh/doc.go24
-rw-r--r--src/pkg/exp/ssh/server.go6
-rw-r--r--src/pkg/exp/ssh/transport.go17
-rw-r--r--src/pkg/exp/ssh/transport_test.go10
6 files changed, 668 insertions, 18 deletions
diff --git a/src/pkg/exp/ssh/Makefile b/src/pkg/exp/ssh/Makefile
index 1a100e9b69..1084d029db 100644
--- a/src/pkg/exp/ssh/Makefile
+++ b/src/pkg/exp/ssh/Makefile
@@ -7,6 +7,7 @@ include ../../../Make.inc
TARG=exp/ssh
GOFILES=\
channel.go\
+ client.go\
common.go\
messages.go\
transport.go\
diff --git a/src/pkg/exp/ssh/client.go b/src/pkg/exp/ssh/client.go
new file mode 100644
index 0000000000..edb95eccc6
--- /dev/null
+++ b/src/pkg/exp/ssh/client.go
@@ -0,0 +1,628 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "big"
+ "crypto"
+ "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "os"
+ "net"
+ "sync"
+)
+
+// clientVersion is the fixed identification string that the client will use.
+var clientVersion = []byte("SSH-2.0-Go\r\n")
+
+// ClientConn represents the client side of an SSH connection.
+type ClientConn struct {
+ *transport
+ config *ClientConfig
+ chanlist
+}
+
+// Client returns a new SSH client connection using c as the underlying transport.
+func Client(c net.Conn, config *ClientConfig) (*ClientConn, os.Error) {
+ conn := &ClientConn{
+ transport: newTransport(c, config.rand()),
+ config: config,
+ chanlist: chanlist{
+ Mutex: new(sync.Mutex),
+ chans: make(map[uint32]*ClientChan),
+ },
+ }
+ if err := conn.handshake(); err != nil {
+ conn.Close()
+ return nil, err
+ }
+ if err := conn.authenticate(); err != nil {
+ conn.Close()
+ return nil, err
+ }
+ go conn.mainLoop()
+ return conn, nil
+}
+
+// handshake performs the client side key exchange. See RFC 4253 Section 7.
+func (c *ClientConn) handshake() os.Error {
+ var magics handshakeMagics
+
+ if _, err := c.Write(clientVersion); err != nil {
+ return err
+ }
+ if err := c.Flush(); err != nil {
+ return err
+ }
+ magics.clientVersion = clientVersion[:len(clientVersion)-2]
+
+ // read remote server version
+ version, err := readVersion(c)
+ if err != nil {
+ return err
+ }
+ magics.serverVersion = version
+ clientKexInit := kexInitMsg{
+ KexAlgos: supportedKexAlgos,
+ ServerHostKeyAlgos: supportedHostKeyAlgos,
+ CiphersClientServer: supportedCiphers,
+ CiphersServerClient: supportedCiphers,
+ MACsClientServer: supportedMACs,
+ MACsServerClient: supportedMACs,
+ CompressionClientServer: supportedCompressions,
+ CompressionServerClient: supportedCompressions,
+ }
+ kexInitPacket := marshal(msgKexInit, clientKexInit)
+ magics.clientKexInit = kexInitPacket
+
+ if err := c.writePacket(kexInitPacket); err != nil {
+ return err
+ }
+ packet, err := c.readPacket()
+ if err != nil {
+ return err
+ }
+
+ magics.serverKexInit = packet
+
+ var serverKexInit kexInitMsg
+ if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
+ return err
+ }
+
+ kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
+ if !ok {
+ return os.NewError("ssh: no common algorithms")
+ }
+
+ if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
+ // The server sent a Kex message for the wrong algorithm,
+ // which we have to ignore.
+ if _, err := c.readPacket(); err != nil {
+ return err
+ }
+ }
+
+ var H, K []byte
+ var hashFunc crypto.Hash
+ switch kexAlgo {
+ case kexAlgoDH14SHA1:
+ hashFunc = crypto.SHA1
+ dhGroup14Once.Do(initDHGroup14)
+ H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
+ default:
+ fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
+ }
+ if err != nil {
+ return err
+ }
+
+ if err = c.writePacket([]byte{msgNewKeys}); err != nil {
+ return err
+ }
+ if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
+ return err
+ }
+ if packet, err = c.readPacket(); err != nil {
+ return err
+ }
+ if packet[0] != msgNewKeys {
+ return UnexpectedMessageError{msgNewKeys, packet[0]}
+ }
+ return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc)
+}
+
+// authenticate authenticates with the remote server. See RFC 4252.
+// Only "password" authentication is supported.
+func (c *ClientConn) authenticate() os.Error {
+ if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
+ return err
+ }
+ packet, err := c.readPacket()
+ if err != nil {
+ return err
+ }
+
+ var serviceAccept serviceAcceptMsg
+ if err = unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil {
+ return err
+ }
+
+ // TODO(dfc) support proper authentication method negotation
+ method := "none"
+ if c.config.Password != "" {
+ method = "password"
+ }
+ if err := c.sendUserAuthReq(method); err != nil {
+ return err
+ }
+
+ if packet, err = c.readPacket(); err != nil {
+ return err
+ }
+
+ if packet[0] != msgUserAuthSuccess {
+ return UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+ }
+ return nil
+}
+
+func (c *ClientConn) sendUserAuthReq(method string) os.Error {
+ length := stringLength([]byte(c.config.Password)) + 1
+ payload := make([]byte, length)
+ // always false for password auth, see RFC 4252 Section 8.
+ payload[0] = 0
+ marshalString(payload[1:], []byte(c.config.Password))
+
+ return c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
+ User: c.config.User,
+ Service: serviceSSH,
+ Method: method,
+ Payload: payload,
+ }))
+}
+
+// kexDH performs Diffie-Hellman key agreement on a ClientConn. The
+// returned values are given the same names as in RFC 4253, section 8.
+func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, os.Error) {
+ x, err := rand.Int(c.config.rand(), group.p)
+ if err != nil {
+ return nil, nil, err
+ }
+ X := new(big.Int).Exp(group.g, x, group.p)
+ kexDHInit := kexDHInitMsg{
+ X: X,
+ }
+ if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
+ return nil, nil, err
+ }
+
+ packet, err := c.readPacket()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ var kexDHReply = new(kexDHReplyMsg)
+ if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
+ return nil, nil, err
+ }
+
+ if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
+ return nil, nil, os.NewError("server DH parameter out of bounds")
+ }
+
+ kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
+ h := hashFunc.New()
+ writeString(h, magics.clientVersion)
+ writeString(h, magics.serverVersion)
+ writeString(h, magics.clientKexInit)
+ writeString(h, magics.serverKexInit)
+ writeString(h, kexDHReply.HostKey)
+ writeInt(h, X)
+ writeInt(h, kexDHReply.Y)
+ K := make([]byte, intLength(kInt))
+ marshalInt(K, kInt)
+ h.Write(K)
+
+ H := h.Sum()
+
+ return H, K, nil
+}
+
+// OpenChan opens a new client channel. The most common session type is "session".
+// The full set of valid session types are listed in RFC 4250 4.9.1.
+func (c *ClientConn) OpenChan(typ string) (*ClientChan, os.Error) {
+ ch, id := c.newChan(c.transport)
+ if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
+ ChanType: typ,
+ PeersId: id,
+ PeersWindow: 8192,
+ MaxPacketSize: 16384,
+ })); err != nil {
+ // remove channel reference
+ c.chanlist.remove(id)
+ return nil, err
+ }
+ // wait for response
+ switch msg := (<-ch.msg).(type) {
+ case *channelOpenConfirmMsg:
+ ch.peersId = msg.MyId
+ case *channelOpenFailureMsg:
+ c.chanlist.remove(id)
+ return nil, os.NewError(msg.Message)
+ default:
+ c.chanlist.remove(id)
+ return nil, os.NewError("Unexpected packet")
+ }
+ return ch, nil
+}
+
+// mainloop reads incoming messages and routes channel messages
+// to their respective ClientChans.
+func (c *ClientConn) mainLoop() {
+ for {
+ packet, err := c.readPacket()
+ if err != nil {
+ // TODO(dfc) signal the underlying close to all channels
+ c.Close()
+ return
+ }
+ switch msg := decode(packet).(type) {
+ case *channelOpenMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelOpenConfirmMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelOpenFailureMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelCloseMsg:
+ ch := c.getChan(msg.PeersId)
+ close(ch.stdinWriter.win)
+ close(ch.stdoutReader.data)
+ close(ch.stderrReader.dataExt)
+ c.chanlist.remove(msg.PeersId)
+ case *channelEOFMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelRequestSuccessMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelRequestFailureMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *channelRequestMsg:
+ c.getChan(msg.PeersId).msg <- msg
+ case *windowAdjustMsg:
+ c.getChan(msg.PeersId).stdinWriter.win <- int(msg.AdditionalBytes)
+ case *channelData:
+ c.getChan(msg.PeersId).stdoutReader.data <- msg.Payload
+ case *channelExtendedData:
+ // TODO(dfc) should this send be non blocking. RFC 4254 5.2 suggests
+ // ext data consumes window size, does that need to be handled as well ?
+ c.getChan(msg.PeersId).stderrReader.dataExt <- msg.Data
+ default:
+ fmt.Printf("mainLoop: unhandled %#v\n", msg)
+ }
+ }
+}
+
+// Dial connects to the given network address using net.Dial and
+// then initiates a SSH handshake, returning the resulting client connection.
+func Dial(network, addr string, config *ClientConfig) (*ClientConn, os.Error) {
+ conn, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ return Client(conn, config)
+}
+
+// A ClientConfig structure is used to configure a ClientConn. After one has
+// been passed to an SSH function it must not be modified.
+type ClientConfig struct {
+ // Rand provides the source of entropy for key exchange. If Rand is
+ // nil, the cryptographic random reader in package crypto/rand will
+ // be used.
+ Rand io.Reader
+
+ // The username to authenticate.
+ User string
+
+ // Used for "password" method authentication.
+ Password string
+}
+
+func (c *ClientConfig) rand() io.Reader {
+ if c.Rand == nil {
+ return rand.Reader
+ }
+ return c.Rand
+}
+
+// A ClientChan represents a single RFC 4254 channel that is multiplexed
+// over a single SSH connection.
+type ClientChan struct {
+ packetWriter
+ *stdinWriter // used by Exec and Shell
+ *stdoutReader // used by Exec and Shell
+ *stderrReader // used by Exec and Shell
+ id, peersId uint32
+ msg chan interface{} // incoming messages
+}
+
+func newClientChan(t *transport, id uint32) *ClientChan {
+ // TODO(DFC) allocating stdin/out/err on ClientChan creation is
+ // wasteful, but ClientConn.mainLoop() needs a way of finding
+ // those channels before Exec/Shell is called because the remote
+ // may send window adjustments at any time.
+ return &ClientChan{
+ packetWriter: t,
+ stdinWriter: &stdinWriter{
+ packetWriter: t,
+ id: id,
+ win: make(chan int, 16),
+ },
+ stdoutReader: &stdoutReader{
+ packetWriter: t,
+ id: id,
+ win: 8192,
+ data: make(chan []byte, 16),
+ },
+ stderrReader: &stderrReader{
+ dataExt: make(chan string, 16),
+ },
+ id: id,
+ msg: make(chan interface{}, 16),
+ }
+}
+
+// Close closes the channel. This does not close the underlying connection.
+func (c *ClientChan) Close() os.Error {
+ return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
+ PeersId: c.id,
+ }))
+}
+
+// Setenv sets an environment variable that will be applied to any
+// command executed by Shell or Exec.
+func (c *ClientChan) Setenv(name, value string) os.Error {
+ namLen := stringLength([]byte(name))
+ valLen := stringLength([]byte(value))
+ payload := make([]byte, namLen+valLen)
+ marshalString(payload[:namLen], []byte(name))
+ marshalString(payload[namLen:], []byte(value))
+
+ return c.sendChanReq(channelRequestMsg{
+ PeersId: c.id,
+ Request: "env",
+ WantReply: true,
+ RequestSpecificData: payload,
+ })
+}
+
+func (c *ClientChan) sendChanReq(req channelRequestMsg) os.Error {
+ if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
+ return err
+ }
+ for {
+ switch msg := (<-c.msg).(type) {
+ case *channelRequestSuccessMsg:
+ return nil
+ case *channelRequestFailureMsg:
+ return os.NewError(req.Request)
+ default:
+ return fmt.Errorf("%#v", msg)
+ }
+ }
+ panic("unreachable")
+}
+
+// An empty mode list (a string of 1 character, opcode 0), see RFC 4254 Section 8.
+var emptyModeList = []byte{0, 0, 0, 1, 0}
+
+// RequstPty requests a pty to be allocated on the remote side of this channel.
+func (c *ClientChan) RequestPty(term string, h, w int) os.Error {
+ buf := make([]byte, 4+len(term)+16+len(emptyModeList))
+ b := marshalString(buf, []byte(term))
+ binary.BigEndian.PutUint32(b, uint32(h))
+ binary.BigEndian.PutUint32(b[4:], uint32(w))
+ binary.BigEndian.PutUint32(b[8:], uint32(h*8))
+ binary.BigEndian.PutUint32(b[12:], uint32(w*8))
+ copy(b[16:], emptyModeList)
+
+ return c.sendChanReq(channelRequestMsg{
+ PeersId: c.id,
+ Request: "pty-req",
+ WantReply: true,
+ RequestSpecificData: buf,
+ })
+}
+
+// Exec runs cmd on the remote host.
+// Typically, the remote server passes cmd to the shell for interpretation.
+func (c *ClientChan) Exec(cmd string) (*Cmd, os.Error) {
+ cmdLen := stringLength([]byte(cmd))
+ payload := make([]byte, cmdLen)
+ marshalString(payload, []byte(cmd))
+ err := c.sendChanReq(channelRequestMsg{
+ PeersId: c.id,
+ Request: "exec",
+ WantReply: true,
+ RequestSpecificData: payload,
+ })
+ return &Cmd{
+ c.stdinWriter,
+ c.stdoutReader,
+ c.stderrReader,
+ }, err
+}
+
+// Shell starts a login shell on the remote host.
+func (c *ClientChan) Shell() (*Cmd, os.Error) {
+ err := c.sendChanReq(channelRequestMsg{
+ PeersId: c.id,
+ Request: "shell",
+ WantReply: true,
+ })
+ return &Cmd{
+ c.stdinWriter,
+ c.stdoutReader,
+ c.stderrReader,
+ }, err
+
+}
+
+// Thread safe channel list.
+type chanlist struct {
+ *sync.Mutex
+ // TODO(dfc) should could be converted to a slice
+ chans map[uint32]*ClientChan
+}
+
+// Allocate a new ClientChan with the next avail local id.
+func (c *chanlist) newChan(t *transport) (*ClientChan, uint32) {
+ c.Lock()
+ defer c.Unlock()
+
+ for i := uint32(0); i < 1<<31; i++ {
+ if _, ok := c.chans[i]; !ok {
+ ch := newClientChan(t, i)
+ c.chans[i] = ch
+ return ch, uint32(i)
+ }
+ }
+ panic("unable to find free channel")
+}
+
+func (c *chanlist) getChan(id uint32) *ClientChan {
+ c.Lock()
+ defer c.Unlock()
+ return c.chans[id]
+}
+
+func (c *chanlist) remove(id uint32) {
+ c.Lock()
+ defer c.Unlock()
+ delete(c.chans, id)
+}
+
+// A Cmd represents a connection to a remote command or shell
+// Closing Cmd.Stdin will be observed by the remote process.
+type Cmd struct {
+ // Writes to Stdin are made available to the command's standard input.
+ // Closing Stdin causes the command to observe an EOF on its standard input.
+ Stdin io.WriteCloser
+
+ // Reads from Stdout consume the command's standard output.
+ // There is a fixed amount of buffering of the command's standard output.
+ // Failing to read from Stdout will eventually cause the command to block
+ // when writing to its standard output. Closing Stdout unblocks any
+ // such writes and makes them return errors.
+ Stdout io.ReadCloser
+
+ // Reads from Stderr consume the command's standard error.
+ // The SSH protocol assumes it can always send standard error;
+ // the command will never block writing to its standard error.
+ // However, failure to read from Stderr will eventually cause the
+ // SSH protocol to jam, so it is important to arrange for reading
+ // from Stderr, even if by
+ // go io.Copy(ioutil.Discard, cmd.Stderr)
+ Stderr io.Reader
+}
+
+// A stdinWriter represents the stdin of a remote process.
+type stdinWriter struct {
+ win chan int // receives window adjustments
+ id uint32
+ rwin int // current rwin size
+ packetWriter // for sending channelDataMsg
+}
+
+// Write writes data to the remote process's standard input.
+func (w *stdinWriter) Write(data []byte) (n int, err os.Error) {
+ for {
+ if w.rwin == 0 {
+ win, ok := <-w.win
+ if !ok {
+ return 0, os.EOF
+ }
+ w.rwin += win
+ continue
+ }
+ n = len(data)
+ packet := make([]byte, 0, 9+n)
+ packet = append(packet, msgChannelData,
+ byte(w.id)>>24, byte(w.id)>>16, byte(w.id)>>8, byte(w.id),
+ byte(n)>>24, byte(n)>>16, byte(n)>>8, byte(n))
+ err = w.writePacket(append(packet, data...))
+ w.rwin -= n
+ return
+ }
+ panic("unreachable")
+}
+
+func (w *stdinWriter) Close() os.Error {
+ return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.id}))
+}
+
+// A stdoutReader represents the stdout of a remote process.
+type stdoutReader struct {
+ // TODO(dfc) a fixed size channel may not be the right data structure.
+ // If writes to this channel block, they will block mainLoop, making
+ // it unable to receive new messages from the remote side.
+ data chan []byte // receives data from remote
+ id uint32
+ win int // current win size
+ packetWriter // for sending windowAdjustMsg
+ buf []byte
+}
+
+// Read reads data from the remote process's standard output.
+func (r *stdoutReader) Read(data []byte) (int, os.Error) {
+ var ok bool
+ for {
+ if len(r.buf) > 0 {
+ n := copy(data, r.buf)
+ r.buf = r.buf[n:]
+ r.win += n
+ msg := windowAdjustMsg{
+ PeersId: r.id,
+ AdditionalBytes: uint32(n),
+ }
+ err := r.writePacket(marshal(msgChannelWindowAdjust, msg))
+ return n, err
+ }
+ r.buf, ok = <-r.data
+ if !ok {
+ return 0, os.EOF
+ }
+ r.win -= len(r.buf)
+ }
+ panic("unreachable")
+}
+
+func (r *stdoutReader) Close() os.Error {
+ return r.writePacket(marshal(msgChannelEOF, channelEOFMsg{r.id}))
+}
+
+// A stderrReader represents the stderr of a remote process.
+type stderrReader struct {
+ dataExt chan string // receives dataExt from remote
+ buf []byte // buffer current dataExt
+}
+
+// Read reads a line of data from the remote process's stderr.
+func (r *stderrReader) Read(data []byte) (int, os.Error) {
+ for {
+ if len(r.buf) > 0 {
+ n := copy(data, r.buf)
+ r.buf = r.buf[n:]
+ return n, nil
+ }
+ buf, ok := <-r.dataExt
+ if !ok {
+ return 0, os.EOF
+ }
+ r.buf = []byte(buf)
+ }
+ panic("unreachable")
+}
diff --git a/src/pkg/exp/ssh/doc.go b/src/pkg/exp/ssh/doc.go
index 54a7ba9fda..a2ec3faca7 100644
--- a/src/pkg/exp/ssh/doc.go
+++ b/src/pkg/exp/ssh/doc.go
@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
/*
-Package ssh implements an SSH server.
+Package ssh implements an SSH client and server.
SSH is a transport security protocol, an authentication protocol and a
family of application protocols. The most typical application level
@@ -75,5 +75,27 @@ present a simple terminal interface.
}
return
}()
+
+An SSH client is represented with a ClientConn. Currently only the "password"
+authentication method is supported.
+
+ config := &ClientConfig{
+ User: "username",
+ Password: "123456",
+ }
+ client, err := Dial("yourserver.com:22", config)
+
+Each ClientConn can support multiple channels, represented by ClientChan. Each
+channel should be of a type specified in rfc4250, 4.9.1.
+
+ ch, err := client.OpenChan("session")
+
+Once the ClientChan is opened, you can execute a single command on the remote side
+using the Exec method.
+
+ cmd, err := ch.Exec("/usr/bin/whoami")
+ reader := bufio.NewReader(cmd.Stdin)
+ line, _, _ := reader.ReadLine()
+ fmt.Println(line)
*/
package ssh
diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go
index 410cafc44c..b5a5e017d3 100644
--- a/src/pkg/exp/ssh/server.go
+++ b/src/pkg/exp/ssh/server.go
@@ -267,9 +267,9 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
}
magics.serverVersion = serverVersion[:len(serverVersion)-2]
- version, ok := readVersion(s.transport)
- if !ok {
- return os.NewError("failed to read version string from client")
+ version, err := readVersion(s.transport)
+ if err != nil {
+ return err
}
magics.clientVersion = version
diff --git a/src/pkg/exp/ssh/transport.go b/src/pkg/exp/ssh/transport.go
index 5994004d86..97eaf975d1 100644
--- a/src/pkg/exp/ssh/transport.go
+++ b/src/pkg/exp/ssh/transport.go
@@ -332,16 +332,15 @@ func (t truncatingMAC) Size() int {
const maxVersionStringBytes = 1024
// Read version string as specified by RFC 4253, section 4.2.
-func readVersion(r io.Reader) (versionString []byte, ok bool) {
- versionString = make([]byte, 0, 64)
- seenCR := false
-
+func readVersion(r io.Reader) ([]byte, os.Error) {
+ versionString := make([]byte, 0, 64)
+ var ok, seenCR bool
var buf [1]byte
forEachByte:
for len(versionString) < maxVersionStringBytes {
_, err := io.ReadFull(r, buf[:])
if err != nil {
- return
+ return nil, err
}
b := buf[0]
@@ -360,10 +359,10 @@ forEachByte:
versionString = append(versionString, b)
}
- if ok {
- // We need to remove the CR from versionString
- versionString = versionString[:len(versionString)-1]
+ if !ok {
+ return nil, os.NewError("failed to read version string")
}
- return
+ // We need to remove the CR from versionString
+ return versionString[:len(versionString)-1], nil
}
diff --git a/src/pkg/exp/ssh/transport_test.go b/src/pkg/exp/ssh/transport_test.go
index 9a610a7803..b2e2a7fc92 100644
--- a/src/pkg/exp/ssh/transport_test.go
+++ b/src/pkg/exp/ssh/transport_test.go
@@ -12,9 +12,9 @@ import (
func TestReadVersion(t *testing.T) {
buf := []byte(serverVersion)
- result, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
- if !ok {
- t.Error("readVersion didn't read version correctly")
+ result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
+ if err != nil {
+ t.Errorf("readVersion didn't read version correctly: %s", err)
}
if !bytes.Equal(buf[:len(buf)-2], result) {
t.Error("version read did not match expected")
@@ -23,7 +23,7 @@ func TestReadVersion(t *testing.T) {
func TestReadVersionTooLong(t *testing.T) {
buf := make([]byte, maxVersionStringBytes+1)
- if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok {
+ if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
t.Errorf("readVersion consumed %d bytes without error", len(buf))
}
}
@@ -31,7 +31,7 @@ func TestReadVersionTooLong(t *testing.T) {
func TestReadVersionWithoutCRLF(t *testing.T) {
buf := []byte(serverVersion)
buf = buf[:len(buf)-1]
- if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok {
+ if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
t.Error("readVersion did not notice \\n was missing")
}
}