aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Langley <agl@golang.org>2011-09-17 15:57:24 -0400
committerAdam Langley <agl@golang.org>2011-09-17 15:57:24 -0400
commit605e57d8fee696238f3338c415043f16a7743731 (patch)
treeeb90377904544cacd84f2c6a0e3d6de3ca86f5ef
parentb71a805cd5131ff6407a25af540e3dd80fa883c2 (diff)
downloadgo-605e57d8fee696238f3338c415043f16a7743731.tar.gz
go-605e57d8fee696238f3338c415043f16a7743731.zip
exp/ssh: new package.
The typical UNIX method for controlling long running process is to send the process signals. Since this doesn't get you very far, various ad-hoc, remote-control protocols have been used over time by programs like Apache and BIND. Implementing an SSH server means that Go code will have a standard, secure way to do this in the future. R=bradfitz, borman, dave, gustavo, dsymonds, r, adg, rsc, rogpeppe, lvd, kevlar, raul.san CC=golang-dev https://golang.org/cl/4962064
-rw-r--r--src/pkg/exp/ssh/Makefile16
-rw-r--r--src/pkg/exp/ssh/channel.go317
-rw-r--r--src/pkg/exp/ssh/common.go96
-rw-r--r--src/pkg/exp/ssh/doc.go79
-rw-r--r--src/pkg/exp/ssh/messages.go557
-rw-r--r--src/pkg/exp/ssh/messages_test.go125
-rw-r--r--src/pkg/exp/ssh/server.go711
-rw-r--r--src/pkg/exp/ssh/server_shell.go399
-rw-r--r--src/pkg/exp/ssh/server_shell_test.go134
-rw-r--r--src/pkg/exp/ssh/transport.go308
10 files changed, 2742 insertions, 0 deletions
diff --git a/src/pkg/exp/ssh/Makefile b/src/pkg/exp/ssh/Makefile
new file mode 100644
index 0000000000..e8f33b708c
--- /dev/null
+++ b/src/pkg/exp/ssh/Makefile
@@ -0,0 +1,16 @@
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../Make.inc
+
+TARG=exp/ssh
+GOFILES=\
+ common.go\
+ messages.go\
+ server.go\
+ transport.go\
+ channel.go\
+ server_shell.go\
+
+include ../../../Make.pkg
diff --git a/src/pkg/exp/ssh/channel.go b/src/pkg/exp/ssh/channel.go
new file mode 100644
index 0000000000..10f62354f4
--- /dev/null
+++ b/src/pkg/exp/ssh/channel.go
@@ -0,0 +1,317 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "os"
+ "sync"
+)
+
+// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
+// SSH connection.
+type Channel interface {
+ // Accept accepts the channel creation request.
+ Accept() os.Error
+ // Reject rejects the channel creation request. After calling this, no
+ // other methods on the Channel may be called. If they are then the
+ // peer is likely to signal a protocol error and drop the connection.
+ Reject(reason RejectionReason, message string) os.Error
+
+ // Read may return a ChannelRequest as an os.Error.
+ Read(data []byte) (int, os.Error)
+ Write(data []byte) (int, os.Error)
+ Close() os.Error
+
+ // AckRequest either sends an ack or nack to the channel request.
+ AckRequest(ok bool) os.Error
+
+ // ChannelType returns the type of the channel, as supplied by the
+ // client.
+ ChannelType() string
+ // ExtraData returns the arbitary payload for this channel, as supplied
+ // by the client. This data is specific to the channel type.
+ ExtraData() []byte
+}
+
+// ChannelRequest represents a request sent on a channel, outside of the normal
+// stream of bytes. It may result from calling Read on a Channel.
+type ChannelRequest struct {
+ Request string
+ WantReply bool
+ Payload []byte
+}
+
+func (c ChannelRequest) String() string {
+ return "channel request received"
+}
+
+// RejectionReason is an enumeration used when rejecting channel creation
+// requests. See RFC 4254, section 5.1.
+type RejectionReason int
+
+const (
+ Prohibited RejectionReason = iota + 1
+ ConnectionFailed
+ UnknownChannelType
+ ResourceShortage
+)
+
+type channel struct {
+ // immutable once created
+ chanType string
+ extraData []byte
+
+ theyClosed bool
+ theySentEOF bool
+ weClosed bool
+ dead bool
+
+ serverConn *ServerConnection
+ myId, theirId uint32
+ myWindow, theirWindow uint32
+ maxPacketSize uint32
+ err os.Error
+
+ pendingRequests []ChannelRequest
+ pendingData []byte
+ head, length int
+
+ // This lock is inferior to serverConn.lock
+ lock sync.Mutex
+ cond *sync.Cond
+}
+
+func (c *channel) Accept() os.Error {
+ c.serverConn.lock.Lock()
+ defer c.serverConn.lock.Unlock()
+
+ if c.serverConn.err != nil {
+ return c.serverConn.err
+ }
+
+ confirm := channelOpenConfirmMsg{
+ PeersId: c.theirId,
+ MyId: c.myId,
+ MyWindow: c.myWindow,
+ MaxPacketSize: c.maxPacketSize,
+ }
+ return c.serverConn.out.writePacket(marshal(msgChannelOpenConfirm, confirm))
+}
+
+func (c *channel) Reject(reason RejectionReason, message string) os.Error {
+ c.serverConn.lock.Lock()
+ defer c.serverConn.lock.Unlock()
+
+ if c.serverConn.err != nil {
+ return c.serverConn.err
+ }
+
+ reject := channelOpenFailureMsg{
+ PeersId: c.theirId,
+ Reason: uint32(reason),
+ Message: message,
+ Language: "en",
+ }
+ return c.serverConn.out.writePacket(marshal(msgChannelOpenFailure, reject))
+}
+
+func (c *channel) handlePacket(packet interface{}) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ switch packet := packet.(type) {
+ case *channelRequestMsg:
+ req := ChannelRequest{
+ Request: packet.Request,
+ WantReply: packet.WantReply,
+ Payload: packet.RequestSpecificData,
+ }
+
+ c.pendingRequests = append(c.pendingRequests, req)
+ c.cond.Signal()
+ case *channelCloseMsg:
+ c.theyClosed = true
+ c.cond.Signal()
+ case *channelEOFMsg:
+ c.theySentEOF = true
+ c.cond.Signal()
+ default:
+ panic("unknown packet type")
+ }
+}
+
+func (c *channel) handleData(data []byte) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ // The other side should never send us more than our window.
+ if len(data)+c.length > len(c.pendingData) {
+ // TODO(agl): we should tear down the channel with a protocol
+ // error.
+ return
+ }
+
+ c.myWindow -= uint32(len(data))
+ for i := 0; i < 2; i++ {
+ tail := c.head + c.length
+ if tail > len(c.pendingData) {
+ tail -= len(c.pendingData)
+ }
+ n := copy(c.pendingData[tail:], data)
+ data = data[n:]
+ c.length += n
+ }
+
+ c.cond.Signal()
+}
+
+func (c *channel) Read(data []byte) (n int, err os.Error) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.err != nil {
+ return 0, c.err
+ }
+
+ if c.myWindow <= uint32(len(c.pendingData))/2 {
+ packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
+ PeersId: c.theirId,
+ AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow,
+ })
+ if err := c.serverConn.out.writePacket(packet); err != nil {
+ return 0, err
+ }
+ }
+
+ for {
+ if c.theySentEOF || c.theyClosed || c.dead {
+ return 0, os.EOF
+ }
+
+ if len(c.pendingRequests) > 0 {
+ req := c.pendingRequests[0]
+ if len(c.pendingRequests) == 1 {
+ c.pendingRequests = nil
+ } else {
+ oldPendingRequests := c.pendingRequests
+ c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
+ copy(c.pendingRequests, oldPendingRequests[1:])
+ }
+
+ return 0, req
+ }
+
+ if c.length > 0 {
+ tail := c.head + c.length
+ if tail > len(c.pendingData) {
+ tail -= len(c.pendingData)
+ }
+ n = copy(data, c.pendingData[c.head:tail])
+ c.head += n
+ c.length -= n
+ if c.head == len(c.pendingData) {
+ c.head = 0
+ }
+ return
+ }
+
+ c.cond.Wait()
+ }
+
+ panic("unreachable")
+}
+
+func (c *channel) Write(data []byte) (n int, err os.Error) {
+ for len(data) > 0 {
+ c.lock.Lock()
+ if c.dead || c.weClosed {
+ return 0, os.EOF
+ }
+
+ if c.theirWindow == 0 {
+ c.cond.Wait()
+ continue
+ }
+ c.lock.Unlock()
+
+ todo := data
+ if uint32(len(todo)) > c.theirWindow {
+ todo = todo[:c.theirWindow]
+ }
+
+ packet := make([]byte, 1+4+4+len(todo))
+ packet[0] = msgChannelData
+ packet[1] = byte(c.theirId) >> 24
+ packet[2] = byte(c.theirId) >> 16
+ packet[3] = byte(c.theirId) >> 8
+ packet[4] = byte(c.theirId)
+ packet[5] = byte(len(todo)) >> 24
+ packet[6] = byte(len(todo)) >> 16
+ packet[7] = byte(len(todo)) >> 8
+ packet[8] = byte(len(todo))
+ copy(packet[9:], todo)
+
+ c.serverConn.lock.Lock()
+ if err = c.serverConn.out.writePacket(packet); err != nil {
+ c.serverConn.lock.Unlock()
+ return
+ }
+ c.serverConn.lock.Unlock()
+
+ n += len(todo)
+ data = data[len(todo):]
+ }
+
+ return
+}
+
+func (c *channel) Close() os.Error {
+ c.serverConn.lock.Lock()
+ defer c.serverConn.lock.Unlock()
+
+ if c.serverConn.err != nil {
+ return c.serverConn.err
+ }
+
+ if c.weClosed {
+ return os.NewError("ssh: channel already closed")
+ }
+ c.weClosed = true
+
+ closeMsg := channelCloseMsg{
+ PeersId: c.theirId,
+ }
+ return c.serverConn.out.writePacket(marshal(msgChannelClose, closeMsg))
+}
+
+func (c *channel) AckRequest(ok bool) os.Error {
+ c.serverConn.lock.Lock()
+ defer c.serverConn.lock.Unlock()
+
+ if c.serverConn.err != nil {
+ return c.serverConn.err
+ }
+
+ if ok {
+ ack := channelRequestSuccessMsg{
+ PeersId: c.theirId,
+ }
+ return c.serverConn.out.writePacket(marshal(msgChannelSuccess, ack))
+ } else {
+ ack := channelRequestFailureMsg{
+ PeersId: c.theirId,
+ }
+ return c.serverConn.out.writePacket(marshal(msgChannelFailure, ack))
+ }
+ panic("unreachable")
+}
+
+func (c *channel) ChannelType() string {
+ return c.chanType
+}
+
+func (c *channel) ExtraData() []byte {
+ return c.extraData
+}
diff --git a/src/pkg/exp/ssh/common.go b/src/pkg/exp/ssh/common.go
new file mode 100644
index 0000000000..c951d1a753
--- /dev/null
+++ b/src/pkg/exp/ssh/common.go
@@ -0,0 +1,96 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "strconv"
+)
+
+// These are string constants in the SSH protocol.
+const (
+ kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
+ hostAlgoRSA = "ssh-rsa"
+ cipherAES128CTR = "aes128-ctr"
+ macSHA196 = "hmac-sha1-96"
+ compressionNone = "none"
+ serviceUserAuth = "ssh-userauth"
+ serviceSSH = "ssh-connection"
+)
+
+// UnexpectedMessageError results when the SSH message that we received didn't
+// match what we wanted.
+type UnexpectedMessageError struct {
+ expected, got uint8
+}
+
+func (u UnexpectedMessageError) String() string {
+ return "ssh: unexpected message type " + strconv.Itoa(int(u.got)) + " (expected " + strconv.Itoa(int(u.expected)) + ")"
+}
+
+// ParseError results from a malformed SSH message.
+type ParseError struct {
+ msgType uint8
+}
+
+func (p ParseError) String() string {
+ return "ssh: parse error in message type " + strconv.Itoa(int(p.msgType))
+}
+
+func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
+ for _, clientAlgo := range clientAlgos {
+ for _, serverAlgo := range serverAlgos {
+ if clientAlgo == serverAlgo {
+ return clientAlgo, true
+ }
+ }
+ }
+
+ return
+}
+
+func findAgreedAlgorithms(clientToServer, serverToClient *halfConnection, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) {
+ kexAlgo, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
+ if !ok {
+ return
+ }
+
+ hostKeyAlgo, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
+ if !ok {
+ return
+ }
+
+ clientToServer.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
+ if !ok {
+ return
+ }
+
+ serverToClient.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
+ if !ok {
+ return
+ }
+
+ clientToServer.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
+ if !ok {
+ return
+ }
+
+ serverToClient.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
+ if !ok {
+ return
+ }
+
+ clientToServer.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
+ if !ok {
+ return
+ }
+
+ serverToClient.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
+ if !ok {
+ return
+ }
+
+ ok = true
+ return
+}
diff --git a/src/pkg/exp/ssh/doc.go b/src/pkg/exp/ssh/doc.go
new file mode 100644
index 0000000000..8dbdb0777c
--- /dev/null
+++ b/src/pkg/exp/ssh/doc.go
@@ -0,0 +1,79 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package ssh implements an SSH server.
+
+SSH is a transport security protocol, an authentication protocol and a
+family of application protocols. The most typical application level
+protocol is a remote shell and this is specifically implemented. However,
+the multiplexed nature of SSH is exposed to users that wish to support
+others.
+
+An SSH server is represented by a Server, which manages a number of
+ServerConnections and handles authentication.
+
+ var s Server
+ s.PubKeyCallback = pubKeyAuth
+ s.PasswordCallback = passwordAuth
+
+ pemBytes, err := ioutil.ReadFile("id_rsa")
+ if err != nil {
+ panic("Failed to load private key")
+ }
+ err = s.SetRSAPrivateKey(pemBytes)
+ if err != nil {
+ panic("Failed to parse private key")
+ }
+
+Once a Server has been set up, connections can be attached.
+
+ var sConn ServerConnection
+ sConn.Server = &s
+ err = sConn.Handshake(conn)
+ if err != nil {
+ panic("failed to handshake")
+ }
+
+An SSH connection multiplexes several channels, which must be accepted themselves:
+
+
+ for {
+ channel, err := sConn.Accept()
+ if err != nil {
+ panic("error from Accept")
+ }
+
+ ...
+ }
+
+Accept reads from the connection, demultiplexes packets to their corresponding
+channels and returns when a new channel request is seen. Some goroutine must
+always be calling Accept; otherwise no messages will be forwarded to the
+channels.
+
+Channels have a type, depending on the application level protocol intended. In
+the case of a shell, the type is "session" and ServerShell may be used to
+present a simple terminal interface.
+
+ if channel.ChannelType() != "session" {
+ c.Reject(RejectUnknownChannelType, "unknown channel type")
+ return
+ }
+ channel.Accept()
+
+ shell := NewServerShell(channel, "> ")
+ go func() {
+ defer channel.Close()
+ for {
+ line, err := shell.ReadLine()
+ if err != nil {
+ break
+ }
+ println(line)
+ }
+ return
+ }()
+*/
+package ssh
diff --git a/src/pkg/exp/ssh/messages.go b/src/pkg/exp/ssh/messages.go
new file mode 100644
index 0000000000..d375eafae9
--- /dev/null
+++ b/src/pkg/exp/ssh/messages.go
@@ -0,0 +1,557 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "big"
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+)
+
+// These are SSH message type numbers. They are scattered around several
+// documents but many were taken from
+// http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
+const (
+ msgDisconnect = 1
+ msgIgnore = 2
+ msgUnimplemented = 3
+ msgDebug = 4
+ msgServiceRequest = 5
+ msgServiceAccept = 6
+
+ msgKexInit = 20
+ msgNewKeys = 21
+
+ msgKexDHInit = 30
+ msgKexDHReply = 31
+
+ msgUserAuthRequest = 50
+ msgUserAuthFailure = 51
+ msgUserAuthSuccess = 52
+ msgUserAuthBanner = 53
+ msgUserAuthPubKeyOk = 60
+
+ msgGlobalRequest = 80
+ msgRequestSuccess = 81
+ msgRequestFailure = 82
+
+ msgChannelOpen = 90
+ msgChannelOpenConfirm = 91
+ msgChannelOpenFailure = 92
+ msgChannelWindowAdjust = 93
+ msgChannelData = 94
+ msgChannelExtendedData = 95
+ msgChannelEOF = 96
+ msgChannelClose = 97
+ msgChannelRequest = 98
+ msgChannelSuccess = 99
+ msgChannelFailure = 100
+)
+
+// SSH messages:
+//
+// These structures mirror the wire format of the corresponding SSH messages.
+// They are marshaled using reflection with the marshal and unmarshal functions
+// in this file. The only wrinkle is that a final member of type []byte with a
+// tag of "rest" receives the remainder of a packet when unmarshaling.
+
+// See RFC 4253, section 7.1.
+type kexInitMsg struct {
+ Cookie [16]byte
+ KexAlgos []string
+ ServerHostKeyAlgos []string
+ CiphersClientServer []string
+ CiphersServerClient []string
+ MACsClientServer []string
+ MACsServerClient []string
+ CompressionClientServer []string
+ CompressionServerClient []string
+ LanguagesClientServer []string
+ LanguagesServerClient []string
+ FirstKexFollows bool
+ Reserved uint32
+}
+
+// See RFC 4253, section 8.
+type kexDHInitMsg struct {
+ X *big.Int
+}
+
+type kexDHReplyMsg struct {
+ HostKey []byte
+ Y *big.Int
+ Signature []byte
+}
+
+// See RFC 4253, section 10.
+type serviceRequestMsg struct {
+ Service string
+}
+
+// See RFC 4253, section 10.
+type serviceAcceptMsg struct {
+ Service string
+}
+
+// See RFC 4252, section 5.
+type userAuthRequestMsg struct {
+ User string
+ Service string
+ Method string
+ Payload []byte "rest"
+}
+
+// See RFC 4252, section 5.1
+type userAuthFailureMsg struct {
+ Methods []string
+ PartialSuccess bool
+}
+
+// See RFC 4254, section 5.1.
+type channelOpenMsg struct {
+ ChanType string
+ PeersId uint32
+ PeersWindow uint32
+ MaxPacketSize uint32
+ TypeSpecificData []byte "rest"
+}
+
+// See RFC 4254, section 5.1.
+type channelOpenConfirmMsg struct {
+ PeersId uint32
+ MyId uint32
+ MyWindow uint32
+ MaxPacketSize uint32
+ TypeSpecificData []byte "rest"
+}
+
+// See RFC 4254, section 5.1.
+type channelOpenFailureMsg struct {
+ PeersId uint32
+ Reason uint32
+ Message string
+ Language string
+}
+
+type channelRequestMsg struct {
+ PeersId uint32
+ Request string
+ WantReply bool
+ RequestSpecificData []byte "rest"
+}
+
+// See RFC 4254, section 5.4.
+type channelRequestSuccessMsg struct {
+ PeersId uint32
+}
+
+// See RFC 4254, section 5.4.
+type channelRequestFailureMsg struct {
+ PeersId uint32
+}
+
+// See RFC 4254, section 5.3
+type channelCloseMsg struct {
+ PeersId uint32
+}
+
+// See RFC 4254, section 5.3
+type channelEOFMsg struct {
+ PeersId uint32
+}
+
+// See RFC 4254, section 4
+type globalRequestMsg struct {
+ Type string
+ WantReply bool
+}
+
+// See RFC 4254, section 5.2
+type windowAdjustMsg struct {
+ PeersId uint32
+ AdditionalBytes uint32
+}
+
+// See RFC 4252, section 7
+type userAuthPubKeyOkMsg struct {
+ Algo string
+ PubKey string
+}
+
+// unmarshal parses the SSH wire data in packet into out using reflection.
+// expectedType is the expected SSH message type. It either returns nil on
+// success, or a ParseError or UnexpectedMessageError on error.
+func unmarshal(out interface{}, packet []byte, expectedType uint8) os.Error {
+ if len(packet) == 0 {
+ return ParseError{expectedType}
+ }
+ if packet[0] != expectedType {
+ return UnexpectedMessageError{expectedType, packet[0]}
+ }
+ packet = packet[1:]
+
+ v := reflect.ValueOf(out).Elem()
+ structType := v.Type()
+ var ok bool
+ for i := 0; i < v.NumField(); i++ {
+ field := v.Field(i)
+ t := field.Type()
+ switch t.Kind() {
+ case reflect.Bool:
+ if len(packet) < 1 {
+ return ParseError{expectedType}
+ }
+ field.SetBool(packet[0] != 0)
+ packet = packet[1:]
+ case reflect.Array:
+ if t.Elem().Kind() != reflect.Uint8 {
+ panic("array of non-uint8")
+ }
+ if len(packet) < t.Len() {
+ return ParseError{expectedType}
+ }
+ for j := 0; j < t.Len(); j++ {
+ field.Index(j).Set(reflect.ValueOf(packet[j]))
+ }
+ packet = packet[t.Len():]
+ case reflect.Uint32:
+ var u32 uint32
+ if u32, packet, ok = parseUint32(packet); !ok {
+ return ParseError{expectedType}
+ }
+ field.SetUint(uint64(u32))
+ case reflect.String:
+ var s []byte
+ if s, packet, ok = parseString(packet); !ok {
+ return ParseError{expectedType}
+ }
+ field.SetString(string(s))
+ case reflect.Slice:
+ switch t.Elem().Kind() {
+ case reflect.Uint8:
+ if structType.Field(i).Tag == "rest" {
+ field.Set(reflect.ValueOf(packet))
+ packet = nil
+ } else {
+ var s []byte
+ if s, packet, ok = parseString(packet); !ok {
+ return ParseError{expectedType}
+ }
+ field.Set(reflect.ValueOf(s))
+ }
+ case reflect.String:
+ var nl []string
+ if nl, packet, ok = parseNameList(packet); !ok {
+ return ParseError{expectedType}
+ }
+ field.Set(reflect.ValueOf(nl))
+ default:
+ panic("slice of unknown type")
+ }
+ case reflect.Ptr:
+ if t == bigIntType {
+ var n *big.Int
+ if n, packet, ok = parseInt(packet); !ok {
+ return ParseError{expectedType}
+ }
+ field.Set(reflect.ValueOf(n))
+ } else {
+ panic("pointer to unknown type")
+ }
+ default:
+ panic("unknown type")
+ }
+ }
+
+ if len(packet) != 0 {
+ return ParseError{expectedType}
+ }
+
+ return nil
+}
+
+// marshal serializes the message in msg, using the given message type.
+func marshal(msgType uint8, msg interface{}) []byte {
+ var out []byte
+ out = append(out, msgType)
+
+ v := reflect.ValueOf(msg)
+ structType := v.Type()
+ for i := 0; i < v.NumField(); i++ {
+ field := v.Field(i)
+ t := field.Type()
+ switch t.Kind() {
+ case reflect.Bool:
+ var v uint8
+ if field.Bool() {
+ v = 1
+ }
+ out = append(out, v)
+ case reflect.Array:
+ if t.Elem().Kind() != reflect.Uint8 {
+ panic("array of non-uint8")
+ }
+ for j := 0; j < t.Len(); j++ {
+ out = append(out, byte(field.Index(j).Uint()))
+ }
+ case reflect.Uint32:
+ u32 := uint32(field.Uint())
+ out = append(out, byte(u32>>24))
+ out = append(out, byte(u32>>16))
+ out = append(out, byte(u32>>8))
+ out = append(out, byte(u32))
+ case reflect.String:
+ s := field.String()
+ out = append(out, byte(len(s)>>24))
+ out = append(out, byte(len(s)>>16))
+ out = append(out, byte(len(s)>>8))
+ out = append(out, byte(len(s)))
+ out = append(out, []byte(s)...)
+ case reflect.Slice:
+ switch t.Elem().Kind() {
+ case reflect.Uint8:
+ length := field.Len()
+ if structType.Field(i).Tag != "rest" {
+ out = append(out, byte(length>>24))
+ out = append(out, byte(length>>16))
+ out = append(out, byte(length>>8))
+ out = append(out, byte(length))
+ }
+ for j := 0; j < length; j++ {
+ out = append(out, byte(field.Index(j).Uint()))
+ }
+ case reflect.String:
+ var length int
+ for j := 0; j < field.Len(); j++ {
+ if j != 0 {
+ length++ /* comma */
+ }
+ length += len(field.Index(j).String())
+ }
+
+ out = append(out, byte(length>>24))
+ out = append(out, byte(length>>16))
+ out = append(out, byte(length>>8))
+ out = append(out, byte(length))
+ for j := 0; j < field.Len(); j++ {
+ if j != 0 {
+ out = append(out, ',')
+ }
+ out = append(out, []byte(field.Index(j).String())...)
+ }
+ default:
+ panic("slice of unknown type")
+ }
+ case reflect.Ptr:
+ if t == bigIntType {
+ var n *big.Int
+ nValue := reflect.ValueOf(&n)
+ nValue.Elem().Set(field)
+ needed := intLength(n)
+ oldLength := len(out)
+
+ if cap(out)-len(out) < needed {
+ newOut := make([]byte, len(out), 2*(len(out)+needed))
+ copy(newOut, out)
+ out = newOut
+ }
+ out = out[:oldLength+needed]
+ marshalInt(out[oldLength:], n)
+ } else {
+ panic("pointer to unknown type")
+ }
+ }
+ }
+
+ return out
+}
+
+var bigOne = big.NewInt(1)
+
+func parseString(in []byte) (out, rest []byte, ok bool) {
+ if len(in) < 4 {
+ return
+ }
+ length := uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
+ if uint32(len(in)) < 4+length {
+ return
+ }
+ out = in[4 : 4+length]
+ rest = in[4+length:]
+ ok = true
+ return
+}
+
+var comma = []byte{','}
+
+func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
+ contents, rest, ok := parseString(in)
+ if !ok {
+ return
+ }
+ if len(contents) == 0 {
+ return
+ }
+ parts := bytes.Split(contents, comma)
+ out = make([]string, len(parts))
+ for i, part := range parts {
+ out[i] = string(part)
+ }
+ return
+}
+
+func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
+ contents, rest, ok := parseString(in)
+ if !ok {
+ return
+ }
+ out = new(big.Int)
+
+ if len(contents) > 0 && contents[0]&0x80 == 0x80 {
+ // This is a negative number
+ notBytes := make([]byte, len(contents))
+ for i := range notBytes {
+ notBytes[i] = ^contents[i]
+ }
+ out.SetBytes(notBytes)
+ out.Add(out, bigOne)
+ out.Neg(out)
+ } else {
+ // Positive number
+ out.SetBytes(contents)
+ }
+ ok = true
+ return
+}
+
+func parseUint32(in []byte) (out uint32, rest []byte, ok bool) {
+ if len(in) < 4 {
+ return
+ }
+ out = uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
+ rest = in[4:]
+ ok = true
+ return
+}
+
+const maxPacketSize = 36000
+
+func nameListLength(namelist []string) int {
+ length := 4 /* uint32 length prefix */
+ for i, name := range namelist {
+ if i != 0 {
+ length++ /* comma */
+ }
+ length += len(name)
+ }
+ return length
+}
+
+func intLength(n *big.Int) int {
+ length := 4 /* length bytes */
+ if n.Sign() < 0 {
+ nMinus1 := new(big.Int).Neg(n)
+ nMinus1.Sub(nMinus1, bigOne)
+ bitLen := nMinus1.BitLen()
+ if bitLen%8 == 0 {
+ // The number will need 0xff padding
+ length++
+ }
+ length += (bitLen + 7) / 8
+ } else if n.Sign() == 0 {
+ // A zero is the zero length string
+ } else {
+ bitLen := n.BitLen()
+ if bitLen%8 == 0 {
+ // The number will need 0x00 padding
+ length++
+ }
+ length += (bitLen + 7) / 8
+ }
+
+ return length
+}
+
+func marshalInt(to []byte, n *big.Int) []byte {
+ lengthBytes := to
+ to = to[4:]
+ length := 0
+
+ if n.Sign() < 0 {
+ // A negative number has to be converted to two's-complement
+ // form. So we'll subtract 1 and invert. If the
+ // most-significant-bit isn't set then we'll need to pad the
+ // beginning with 0xff in order to keep the number negative.
+ nMinus1 := new(big.Int).Neg(n)
+ nMinus1.Sub(nMinus1, bigOne)
+ bytes := nMinus1.Bytes()
+ for i := range bytes {
+ bytes[i] ^= 0xff
+ }
+ if len(bytes) == 0 || bytes[0]&0x80 == 0 {
+ to[0] = 0xff
+ to = to[1:]
+ length++
+ }
+ nBytes := copy(to, bytes)
+ to = to[nBytes:]
+ length += nBytes
+ } else if n.Sign() == 0 {
+ // A zero is the zero length string
+ } else {
+ bytes := n.Bytes()
+ if len(bytes) > 0 && bytes[0]&0x80 != 0 {
+ // We'll have to pad this with a 0x00 in order to
+ // stop it looking like a negative number.
+ to[0] = 0
+ to = to[1:]
+ length++
+ }
+ nBytes := copy(to, bytes)
+ to = to[nBytes:]
+ length += nBytes
+ }
+
+ lengthBytes[0] = byte(length >> 24)
+ lengthBytes[1] = byte(length >> 16)
+ lengthBytes[2] = byte(length >> 8)
+ lengthBytes[3] = byte(length)
+ return to
+}
+
+func writeInt(w io.Writer, n *big.Int) {
+ length := intLength(n)
+ buf := make([]byte, length)
+ marshalInt(buf, n)
+ w.Write(buf)
+}
+
+func writeString(w io.Writer, s []byte) {
+ var lengthBytes [4]byte
+ lengthBytes[0] = byte(len(s) >> 24)
+ lengthBytes[1] = byte(len(s) >> 16)
+ lengthBytes[2] = byte(len(s) >> 8)
+ lengthBytes[3] = byte(len(s))
+ w.Write(lengthBytes[:])
+ w.Write(s)
+}
+
+func stringLength(s []byte) int {
+ return 4 + len(s)
+}
+
+func marshalString(to []byte, s []byte) []byte {
+ to[0] = byte(len(s) >> 24)
+ to[1] = byte(len(s) >> 16)
+ to[2] = byte(len(s) >> 8)
+ to[3] = byte(len(s))
+ to = to[4:]
+ copy(to, s)
+ return to[len(s):]
+}
+
+var bigIntType = reflect.TypeOf((*big.Int)(nil))
diff --git a/src/pkg/exp/ssh/messages_test.go b/src/pkg/exp/ssh/messages_test.go
new file mode 100644
index 0000000000..629f3d3b14
--- /dev/null
+++ b/src/pkg/exp/ssh/messages_test.go
@@ -0,0 +1,125 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "big"
+ "rand"
+ "reflect"
+ "testing"
+ "testing/quick"
+)
+
+var intLengthTests = []struct {
+ val, length int
+}{
+ {0, 4 + 0},
+ {1, 4 + 1},
+ {127, 4 + 1},
+ {128, 4 + 2},
+ {-1, 4 + 1},
+}
+
+func TestIntLength(t *testing.T) {
+ for _, test := range intLengthTests {
+ v := new(big.Int).SetInt64(int64(test.val))
+ length := intLength(v)
+ if length != test.length {
+ t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
+ }
+ }
+}
+
+var messageTypes = []interface{}{
+ &kexInitMsg{},
+ &kexDHInitMsg{},
+ &serviceRequestMsg{},
+ &serviceAcceptMsg{},
+ &userAuthRequestMsg{},
+ &channelOpenMsg{},
+ &channelOpenConfirmMsg{},
+ &channelRequestMsg{},
+ &channelRequestSuccessMsg{},
+}
+
+func TestMarshalUnmarshal(t *testing.T) {
+ rand := rand.New(rand.NewSource(0))
+ for i, iface := range messageTypes {
+ ty := reflect.ValueOf(iface).Type()
+
+ n := 100
+ if testing.Short() {
+ n = 5
+ }
+ for j := 0; j < n; j++ {
+ v, ok := quick.Value(ty, rand)
+ if !ok {
+ t.Errorf("#%d: failed to create value", i)
+ break
+ }
+
+ m1 := v.Elem().Interface()
+ m2 := iface
+
+ marshaled := marshal(msgIgnore, m1)
+ if err := unmarshal(m2, marshaled, msgIgnore); err != nil {
+ t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err)
+ break
+ }
+
+ if !reflect.DeepEqual(v.Interface(), m2) {
+ t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled)
+ break
+ }
+ }
+ }
+}
+
+func randomBytes(out []byte, rand *rand.Rand) {
+ for i := 0; i < len(out); i++ {
+ out[i] = byte(rand.Int31())
+ }
+}
+
+func randomNameList(rand *rand.Rand) []string {
+ ret := make([]string, rand.Int31()&15)
+ for i := range ret {
+ s := make([]byte, 1+(rand.Int31()&15))
+ for j := range s {
+ s[j] = 'a' + uint8(rand.Int31()&15)
+ }
+ ret[i] = string(s)
+ }
+ return ret
+}
+
+func randomInt(rand *rand.Rand) *big.Int {
+ return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
+}
+
+func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+ ki := &kexInitMsg{}
+ randomBytes(ki.Cookie[:], rand)
+ ki.KexAlgos = randomNameList(rand)
+ ki.ServerHostKeyAlgos = randomNameList(rand)
+ ki.CiphersClientServer = randomNameList(rand)
+ ki.CiphersServerClient = randomNameList(rand)
+ ki.MACsClientServer = randomNameList(rand)
+ ki.MACsServerClient = randomNameList(rand)
+ ki.CompressionClientServer = randomNameList(rand)
+ ki.CompressionServerClient = randomNameList(rand)
+ ki.LanguagesClientServer = randomNameList(rand)
+ ki.LanguagesServerClient = randomNameList(rand)
+ if rand.Int31()&1 == 1 {
+ ki.FirstKexFollows = true
+ }
+ return reflect.ValueOf(ki)
+}
+
+func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+ dhi := &kexDHInitMsg{}
+ dhi.X = randomInt(rand)
+ return reflect.ValueOf(dhi)
+}
diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go
new file mode 100644
index 0000000000..57cd597106
--- /dev/null
+++ b/src/pkg/exp/ssh/server.go
@@ -0,0 +1,711 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "big"
+ "bufio"
+ "bytes"
+ "crypto"
+ "crypto/rand"
+ "crypto/rsa"
+ _ "crypto/sha1"
+ "crypto/x509"
+ "encoding/pem"
+ "net"
+ "os"
+ "sync"
+)
+
+var supportedKexAlgos = []string{kexAlgoDH14SHA1}
+var supportedHostKeyAlgos = []string{hostAlgoRSA}
+var supportedCiphers = []string{cipherAES128CTR}
+var supportedMACs = []string{macSHA196}
+var supportedCompressions = []string{compressionNone}
+
+// Server represents an SSH server. A Server may have several ServerConnections.
+type Server struct {
+ rsa *rsa.PrivateKey
+ rsaSerialized []byte
+
+ // NoClientAuth is true if clients are allowed to connect without
+ // authenticating.
+ NoClientAuth bool
+
+ // PasswordCallback, if non-nil, is called when a user attempts to
+ // authenticate using a password. It may be called concurrently from
+ // several goroutines.
+ PasswordCallback func(user, password string) bool
+
+ // PubKeyCallback, if non-nil, is called when a client attempts public
+ // key authentication. It must return true iff the given public key is
+ // valid for the given user.
+ PubKeyCallback func(user, algo string, pubkey []byte) bool
+}
+
+// SetRSAPrivateKey sets the private key for a Server. A Server must have a
+// private key configured in order to accept connections. The private key must
+// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
+// typically contains such a key.
+func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error {
+ block, _ := pem.Decode(pemBytes)
+ if block == nil {
+ return os.NewError("ssh: no key found")
+ }
+ var err os.Error
+ s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes)
+ if err != nil {
+ return err
+ }
+
+ s.rsaSerialized = marshalRSA(s.rsa)
+ return nil
+}
+
+// marshalRSA serializes an RSA private key according to RFC 4256, section 6.6.
+func marshalRSA(priv *rsa.PrivateKey) []byte {
+ e := new(big.Int).SetInt64(int64(priv.E))
+ length := stringLength([]byte(hostAlgoRSA))
+ length += intLength(e)
+ length += intLength(priv.N)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, []byte(hostAlgoRSA))
+ r = marshalInt(r, e)
+ r = marshalInt(r, priv.N)
+
+ return ret
+}
+
+// parseRSA parses an RSA key according to RFC 4256, section 6.6.
+func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) {
+ algo, in, ok := parseString(in)
+ if !ok || string(algo) != hostAlgoRSA {
+ return nil, false
+ }
+ bigE, in, ok := parseInt(in)
+ if !ok || bigE.BitLen() > 24 {
+ return nil, false
+ }
+ e := bigE.Int64()
+ if e < 3 || e&1 == 0 {
+ return nil, false
+ }
+ N, in, ok := parseInt(in)
+ if !ok || len(in) > 0 {
+ return nil, false
+ }
+ return &rsa.PublicKey{
+ N: N,
+ E: int(e),
+ }, true
+}
+
+func parseRSASig(in []byte) (sig []byte, ok bool) {
+ algo, in, ok := parseString(in)
+ if !ok || string(algo) != hostAlgoRSA {
+ return nil, false
+ }
+ sig, in, ok = parseString(in)
+ if len(in) > 0 {
+ ok = false
+ }
+ return
+}
+
+// cachedPubKey contains the results of querying whether a public key is
+// acceptable for a user. The cache only applies to a single ServerConnection.
+type cachedPubKey struct {
+ user, algo string
+ pubKey []byte
+ result bool
+}
+
+const maxCachedPubKeys = 16
+
+// ServerConnection represents an incomming connection to a Server.
+type ServerConnection struct {
+ Server *Server
+
+ in, out *halfConnection
+
+ channels map[uint32]*channel
+ nextChanId uint32
+
+ // lock protects err and also allows Channels to serialise their writes
+ // to out.
+ lock sync.RWMutex
+ err os.Error
+
+ // cachedPubKeys contains the cache results of tests for public keys.
+ // Since SSH clients will query whether a public key is acceptable
+ // before attempting to authenticate with it, we end up with duplicate
+ // queries for public key validity.
+ cachedPubKeys []cachedPubKey
+}
+
+// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
+type dhGroup struct {
+ g, p *big.Int
+}
+
+// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
+// Oakley Group 14 in RFC 3526.
+var dhGroup14 *dhGroup
+
+var dhGroup14Once sync.Once
+
+func initDHGroup14() {
+ p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
+
+ dhGroup14 = &dhGroup{
+ g: new(big.Int).SetInt64(2),
+ p: p,
+ }
+}
+
+type handshakeMagics struct {
+ clientVersion, serverVersion []byte
+ clientKexInit, serverKexInit []byte
+}
+
+// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
+// returned values are given the same names as in RFC 4253, section 8.
+func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
+ packet, err := s.in.readPacket()
+ if err != nil {
+ return
+ }
+ var kexDHInit kexDHInitMsg
+ if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
+ return
+ }
+
+ if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 {
+ return nil, nil, os.NewError("client DH parameter out of bounds")
+ }
+
+ y, err := rand.Int(rand.Reader, group.p)
+ if err != nil {
+ return
+ }
+
+ Y := new(big.Int).Exp(group.g, y, group.p)
+ kInt := new(big.Int).Exp(kexDHInit.X, y, group.p)
+
+ var serializedHostKey []byte
+ switch hostKeyAlgo {
+ case hostAlgoRSA:
+ serializedHostKey = s.Server.rsaSerialized
+ default:
+ return nil, nil, os.NewError("internal error")
+ }
+
+ h := hashFunc.New()
+ writeString(h, magics.clientVersion)
+ writeString(h, magics.serverVersion)
+ writeString(h, magics.clientKexInit)
+ writeString(h, magics.serverKexInit)
+ writeString(h, serializedHostKey)
+ writeInt(h, kexDHInit.X)
+ writeInt(h, Y)
+ K = make([]byte, intLength(kInt))
+ marshalInt(K, kInt)
+ h.Write(K)
+
+ H = h.Sum()
+
+ h.Reset()
+ h.Write(H)
+ hh := h.Sum()
+
+ var sig []byte
+ switch hostKeyAlgo {
+ case hostAlgoRSA:
+ sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh)
+ if err != nil {
+ return
+ }
+ default:
+ return nil, nil, os.NewError("internal error")
+ }
+
+ serializedSig := serializeRSASignature(sig)
+
+ kexDHReply := kexDHReplyMsg{
+ HostKey: serializedHostKey,
+ Y: Y,
+ Signature: serializedSig,
+ }
+ packet = marshal(msgKexDHReply, kexDHReply)
+
+ err = s.out.writePacket(packet)
+ return
+}
+
+func serializeRSASignature(sig []byte) []byte {
+ length := stringLength([]byte(hostAlgoRSA))
+ length += stringLength(sig)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, []byte(hostAlgoRSA))
+ r = marshalString(r, sig)
+
+ return ret
+}
+
+// serverVersion is the fixed identification string that Server will use.
+var serverVersion = []byte("SSH-2.0-Go\r\n")
+
+// buildDataSignedForAuth returns the data that is signed in order to prove
+// posession of a private key. See RFC 4252, section 7.
+func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
+ user := []byte(req.User)
+ service := []byte(req.Service)
+ method := []byte(req.Method)
+
+ length := stringLength(sessionId)
+ length += 1
+ length += stringLength(user)
+ length += stringLength(service)
+ length += stringLength(method)
+ length += 1
+ length += stringLength(algo)
+ length += stringLength(pubKey)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, sessionId)
+ r[0] = msgUserAuthRequest
+ r = r[1:]
+ r = marshalString(r, user)
+ r = marshalString(r, service)
+ r = marshalString(r, method)
+ r[0] = 1
+ r = r[1:]
+ r = marshalString(r, algo)
+ r = marshalString(r, pubKey)
+ return ret
+}
+
+// Handshake performs an SSH transport and client authentication on the given ServerConnection.
+func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
+ var magics handshakeMagics
+ inBuf := bufio.NewReader(conn)
+
+ _, err := conn.Write(serverVersion)
+ if err != nil {
+ return err
+ }
+
+ magics.serverVersion = serverVersion[:len(serverVersion)-2]
+ serverKexInit := kexInitMsg{
+ KexAlgos: supportedKexAlgos,
+ ServerHostKeyAlgos: supportedHostKeyAlgos,
+ CiphersClientServer: supportedCiphers,
+ CiphersServerClient: supportedCiphers,
+ MACsClientServer: supportedMACs,
+ MACsServerClient: supportedMACs,
+ CompressionClientServer: supportedCompressions,
+ CompressionServerClient: supportedCompressions,
+ }
+ kexInitPacket := marshal(msgKexInit, serverKexInit)
+ magics.serverKexInit = kexInitPacket
+
+ var out halfConnection
+ out.out = conn
+ out.rand = rand.Reader
+ s.out = &out
+ err = out.writePacket(kexInitPacket)
+ if err != nil {
+ return err
+ }
+
+ version, ok := readVersion(inBuf)
+ if !ok {
+ return os.NewError("failed to read version string from client")
+ }
+ magics.clientVersion = version
+
+ var in halfConnection
+ in.in = inBuf
+ s.in = &in
+ packet, err := in.readPacket()
+ if err != nil {
+ return err
+ }
+ magics.clientKexInit = packet
+
+ var clientKexInit kexInitMsg
+ if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil {
+ return err
+ }
+
+ kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(&in, &out, &clientKexInit, &serverKexInit)
+ if !ok {
+ return os.NewError("ssh: no common algorithms")
+ }
+
+ if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
+ // The client sent a Kex message for the wrong algorithm,
+ // which we have to ignore.
+ _, err := in.readPacket()
+ if err != nil {
+ return err
+ }
+ }
+
+ var H, K []byte
+ var hashFunc crypto.Hash
+ switch kexAlgo {
+ case kexAlgoDH14SHA1:
+ hashFunc = crypto.SHA1
+ dhGroup14Once.Do(initDHGroup14)
+ H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
+ default:
+ err = os.NewError("ssh: internal error")
+ }
+
+ if err != nil {
+ return err
+ }
+
+ packet = []byte{msgNewKeys}
+ if err = out.writePacket(packet); err != nil {
+ return err
+ }
+ if err = out.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
+ return err
+ }
+
+ if packet, err = in.readPacket(); err != nil {
+ return err
+ }
+ if packet[0] != msgNewKeys {
+ return UnexpectedMessageError{msgNewKeys, packet[0]}
+ }
+
+ in.setupKeys(clientKeys, K, H, H, hashFunc)
+
+ packet, err = in.readPacket()
+ if err != nil {
+ return err
+ }
+
+ var serviceRequest serviceRequestMsg
+ if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
+ return err
+ }
+ if serviceRequest.Service != serviceUserAuth {
+ return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
+ }
+
+ serviceAccept := serviceAcceptMsg{
+ Service: serviceUserAuth,
+ }
+ packet = marshal(msgServiceAccept, serviceAccept)
+ if err = out.writePacket(packet); err != nil {
+ return err
+ }
+
+ if err = s.authenticate(H); err != nil {
+ return err
+ }
+
+ s.channels = make(map[uint32]*channel)
+ return nil
+}
+
+func isAcceptableAlgo(algo string) bool {
+ return algo == hostAlgoRSA
+}
+
+// testPubKey returns true if the given public key is acceptable for the user.
+func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
+ if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
+ return false
+ }
+
+ for _, c := range s.cachedPubKeys {
+ if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) {
+ return c.result
+ }
+ }
+
+ result := s.Server.PubKeyCallback(user, algo, pubKey)
+ if len(s.cachedPubKeys) < maxCachedPubKeys {
+ c := cachedPubKey{
+ user: user,
+ algo: algo,
+ pubKey: make([]byte, len(pubKey)),
+ result: result,
+ }
+ copy(c.pubKey, pubKey)
+ s.cachedPubKeys = append(s.cachedPubKeys, c)
+ }
+
+ return result
+}
+
+func (s *ServerConnection) authenticate(H []byte) os.Error {
+ var userAuthReq userAuthRequestMsg
+ var err os.Error
+ var packet []byte
+
+userAuthLoop:
+ for {
+ if packet, err = s.in.readPacket(); err != nil {
+ return err
+ }
+ if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
+ return err
+ }
+
+ if userAuthReq.Service != serviceSSH {
+ return os.NewError("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
+ }
+
+ switch userAuthReq.Method {
+ case "none":
+ if s.Server.NoClientAuth {
+ break userAuthLoop
+ }
+ case "password":
+ if s.Server.PasswordCallback == nil {
+ break
+ }
+ payload := userAuthReq.Payload
+ if len(payload) < 1 || payload[0] != 0 {
+ return ParseError{msgUserAuthRequest}
+ }
+ payload = payload[1:]
+ password, payload, ok := parseString(payload)
+ if !ok || len(payload) > 0 {
+ return ParseError{msgUserAuthRequest}
+ }
+
+ if s.Server.PasswordCallback(userAuthReq.User, string(password)) {
+ break userAuthLoop
+ }
+ case "publickey":
+ if s.Server.PubKeyCallback == nil {
+ break
+ }
+ payload := userAuthReq.Payload
+ if len(payload) < 1 {
+ return ParseError{msgUserAuthRequest}
+ }
+ isQuery := payload[0] == 0
+ payload = payload[1:]
+ algoBytes, payload, ok := parseString(payload)
+ if !ok {
+ return ParseError{msgUserAuthRequest}
+ }
+ algo := string(algoBytes)
+
+ pubKey, payload, ok := parseString(payload)
+ if !ok {
+ return ParseError{msgUserAuthRequest}
+ }
+ if isQuery {
+ // The client can query if the given public key
+ // would be ok.
+ if len(payload) > 0 {
+ return ParseError{msgUserAuthRequest}
+ }
+ if s.testPubKey(userAuthReq.User, algo, pubKey) {
+ okMsg := userAuthPubKeyOkMsg{
+ Algo: algo,
+ PubKey: string(pubKey),
+ }
+ if err = s.out.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
+ return err
+ }
+ continue userAuthLoop
+ }
+ } else {
+ sig, payload, ok := parseString(payload)
+ if !ok || len(payload) > 0 {
+ return ParseError{msgUserAuthRequest}
+ }
+ if !isAcceptableAlgo(algo) {
+ break
+ }
+ rsaSig, ok := parseRSASig(sig)
+ if !ok {
+ return ParseError{msgUserAuthRequest}
+ }
+ signedData := buildDataSignedForAuth(H, userAuthReq, algoBytes, pubKey)
+ switch algo {
+ case hostAlgoRSA:
+ hashFunc := crypto.SHA1
+ h := hashFunc.New()
+ h.Write(signedData)
+ digest := h.Sum()
+ rsaKey, ok := parseRSA(pubKey)
+ if !ok {
+ return ParseError{msgUserAuthRequest}
+ }
+ if rsa.VerifyPKCS1v15(rsaKey, hashFunc, digest, rsaSig) != nil {
+ return ParseError{msgUserAuthRequest}
+ }
+ default:
+ return os.NewError("ssh: isAcceptableAlgo incorrect")
+ }
+ if s.testPubKey(userAuthReq.User, algo, pubKey) {
+ break userAuthLoop
+ }
+ }
+ }
+
+ var failureMsg userAuthFailureMsg
+ if s.Server.PasswordCallback != nil {
+ failureMsg.Methods = append(failureMsg.Methods, "password")
+ }
+ if s.Server.PubKeyCallback != nil {
+ failureMsg.Methods = append(failureMsg.Methods, "publickey")
+ }
+
+ if len(failureMsg.Methods) == 0 {
+ return os.NewError("ssh: no authentication methods configured but NoClientAuth is also false")
+ }
+
+ if err = s.out.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
+ return err
+ }
+ }
+
+ packet = []byte{msgUserAuthSuccess}
+ if err = s.out.writePacket(packet); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+const defaultWindowSize = 32768
+
+// Accept reads and processes messages on a ServerConnection. It must be called
+// in order to demultiplex messages to any resulting Channels.
+func (s *ServerConnection) Accept() (Channel, os.Error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+
+ for {
+ packet, err := s.in.readPacket()
+ if err != nil {
+
+ s.lock.Lock()
+ s.err = err
+ s.lock.Unlock()
+
+ for _, c := range s.channels {
+ c.dead = true
+ c.handleData(nil)
+ }
+
+ return nil, err
+ }
+
+ switch packet[0] {
+ case msgChannelOpen:
+ var chanOpen channelOpenMsg
+ if err := unmarshal(&chanOpen, packet, msgChannelOpen); err != nil {
+ return nil, err
+ }
+
+ c := new(channel)
+ c.chanType = chanOpen.ChanType
+ c.theirId = chanOpen.PeersId
+ c.theirWindow = chanOpen.PeersWindow
+ c.maxPacketSize = chanOpen.MaxPacketSize
+ c.extraData = chanOpen.TypeSpecificData
+ c.myWindow = defaultWindowSize
+ c.serverConn = s
+ c.cond = sync.NewCond(&c.lock)
+ c.pendingData = make([]byte, c.myWindow)
+
+ s.lock.Lock()
+ c.myId = s.nextChanId
+ s.nextChanId++
+ s.channels[c.myId] = c
+ s.lock.Unlock()
+ return c, nil
+
+ case msgChannelRequest:
+ var chanRequest channelRequestMsg
+ if err := unmarshal(&chanRequest, packet, msgChannelRequest); err != nil {
+ return nil, err
+ }
+
+ s.lock.Lock()
+ c, ok := s.channels[chanRequest.PeersId]
+ if !ok {
+ continue
+ }
+ c.handlePacket(&chanRequest)
+ s.lock.Unlock()
+
+ case msgChannelData:
+ if len(packet) < 5 {
+ return nil, ParseError{msgChannelData}
+ }
+ chanId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
+
+ s.lock.Lock()
+ c, ok := s.channels[chanId]
+ if !ok {
+ continue
+ }
+ c.handleData(packet[9:])
+ s.lock.Unlock()
+
+ case msgChannelEOF:
+ var eofMsg channelEOFMsg
+ if err := unmarshal(&eofMsg, packet, msgChannelEOF); err != nil {
+ return nil, err
+ }
+
+ s.lock.Lock()
+ c, ok := s.channels[eofMsg.PeersId]
+ if !ok {
+ continue
+ }
+ c.handlePacket(&eofMsg)
+ s.lock.Unlock()
+
+ case msgChannelClose:
+ var closeMsg channelCloseMsg
+ if err := unmarshal(&closeMsg, packet, msgChannelClose); err != nil {
+ return nil, err
+ }
+
+ s.lock.Lock()
+ c, ok := s.channels[closeMsg.PeersId]
+ if !ok {
+ continue
+ }
+ c.handlePacket(&closeMsg)
+ s.lock.Unlock()
+
+ case msgGlobalRequest:
+ var request globalRequestMsg
+ if err := unmarshal(&request, packet, msgGlobalRequest); err != nil {
+ return nil, err
+ }
+
+ if request.WantReply {
+ if err := s.out.writePacket([]byte{msgRequestFailure}); err != nil {
+ return nil, err
+ }
+ }
+
+ default:
+ // Unknown message. Ignore.
+ }
+ }
+
+ panic("unreachable")
+}
diff --git a/src/pkg/exp/ssh/server_shell.go b/src/pkg/exp/ssh/server_shell.go
new file mode 100644
index 0000000000..53a3241f5e
--- /dev/null
+++ b/src/pkg/exp/ssh/server_shell.go
@@ -0,0 +1,399 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "os"
+)
+
+// ServerShell contains the state for running a VT100 terminal that is capable
+// of reading lines of input.
+type ServerShell struct {
+ c Channel
+ prompt string
+
+ // line is the current line being entered.
+ line []byte
+ // pos is the logical position of the cursor in line
+ pos int
+
+ // cursorX contains the current X value of the cursor where the left
+ // edge is 0. cursorY contains the row number where the first row of
+ // the current line is 0.
+ cursorX, cursorY int
+ // maxLine is the greatest value of cursorY so far.
+ maxLine int
+
+ termWidth, termHeight int
+
+ // outBuf contains the terminal data to be sent.
+ outBuf []byte
+ // remainder contains the remainder of any partial key sequences after
+ // a read. It aliases into inBuf.
+ remainder []byte
+ inBuf [256]byte
+}
+
+// NewServerShell runs a VT100 terminal on the given channel. prompt is a
+// string that is written at the start of each input line. For example: "> ".
+func NewServerShell(c Channel, prompt string) *ServerShell {
+ return &ServerShell{
+ c: c,
+ prompt: prompt,
+ termWidth: 80,
+ termHeight: 24,
+ }
+}
+
+const (
+ keyCtrlD = 4
+ keyEnter = '\r'
+ keyEscape = 27
+ keyBackspace = 127
+ keyUnknown = 256 + iota
+ keyUp
+ keyDown
+ keyLeft
+ keyRight
+ keyAltLeft
+ keyAltRight
+)
+
+// bytesToKey tries to parse a key sequence from b. If successful, it returns
+// the key and the remainder of the input. Otherwise it returns -1.
+func bytesToKey(b []byte) (int, []byte) {
+ if len(b) == 0 {
+ return -1, nil
+ }
+
+ if b[0] != keyEscape {
+ return int(b[0]), b[1:]
+ }
+
+ if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
+ switch b[2] {
+ case 'A':
+ return keyUp, b[3:]
+ case 'B':
+ return keyDown, b[3:]
+ case 'C':
+ return keyRight, b[3:]
+ case 'D':
+ return keyLeft, b[3:]
+ }
+ }
+
+ if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
+ switch b[5] {
+ case 'C':
+ return keyAltRight, b[6:]
+ case 'D':
+ return keyAltLeft, b[6:]
+ }
+ }
+
+ // If we get here then we have a key that we don't recognise, or a
+ // partial sequence. It's not clear how one should find the end of a
+ // sequence without knowing them all, but it seems that [a-zA-Z] only
+ // appears at the end of a sequence.
+ for i, c := range b[0:] {
+ if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
+ return keyUnknown, b[i+1:]
+ }
+ }
+
+ return -1, b
+}
+
+// queue appends data to the end of ss.outBuf
+func (ss *ServerShell) queue(data []byte) {
+ if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
+ newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
+ copy(newOutBuf, ss.outBuf)
+ ss.outBuf = newOutBuf
+ }
+
+ oldLen := len(ss.outBuf)
+ ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
+ copy(ss.outBuf[oldLen:], data)
+}
+
+var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
+
+func isPrintable(key int) bool {
+ return key >= 32 && key < 127
+}
+
+// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
+// given, logical position in the text.
+func (ss *ServerShell) moveCursorToPos(pos int) {
+ x := len(ss.prompt) + pos
+ y := x / ss.termWidth
+ x = x % ss.termWidth
+
+ up := 0
+ if y < ss.cursorY {
+ up = ss.cursorY - y
+ }
+
+ down := 0
+ if y > ss.cursorY {
+ down = y - ss.cursorY
+ }
+
+ left := 0
+ if x < ss.cursorX {
+ left = ss.cursorX - x
+ }
+
+ right := 0
+ if x > ss.cursorX {
+ right = x - ss.cursorX
+ }
+
+ movement := make([]byte, 3*(up+down+left+right))
+ m := movement
+ for i := 0; i < up; i++ {
+ m[0] = keyEscape
+ m[1] = '['
+ m[2] = 'A'
+ m = m[3:]
+ }
+ for i := 0; i < down; i++ {
+ m[0] = keyEscape
+ m[1] = '['
+ m[2] = 'B'
+ m = m[3:]
+ }
+ for i := 0; i < left; i++ {
+ m[0] = keyEscape
+ m[1] = '['
+ m[2] = 'D'
+ m = m[3:]
+ }
+ for i := 0; i < right; i++ {
+ m[0] = keyEscape
+ m[1] = '['
+ m[2] = 'C'
+ m = m[3:]
+ }
+
+ ss.cursorX = x
+ ss.cursorY = y
+ ss.queue(movement)
+}
+
+const maxLineLength = 4096
+
+// handleKey processes the given key and, optionally, returns a line of text
+// that the user has entered.
+func (ss *ServerShell) handleKey(key int) (line string, ok bool) {
+ switch key {
+ case keyBackspace:
+ if ss.pos == 0 {
+ return
+ }
+ ss.pos--
+
+ copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
+ ss.line = ss.line[:len(ss.line)-1]
+ ss.writeLine(ss.line[ss.pos:])
+ ss.moveCursorToPos(ss.pos)
+ ss.queue(eraseUnderCursor)
+ case keyAltLeft:
+ // move left by a word.
+ if ss.pos == 0 {
+ return
+ }
+ ss.pos--
+ for ss.pos > 0 {
+ if ss.line[ss.pos] != ' ' {
+ break
+ }
+ ss.pos--
+ }
+ for ss.pos > 0 {
+ if ss.line[ss.pos] == ' ' {
+ ss.pos++
+ break
+ }
+ ss.pos--
+ }
+ ss.moveCursorToPos(ss.pos)
+ case keyAltRight:
+ // move right by a word.
+ for ss.pos < len(ss.line) {
+ if ss.line[ss.pos] == ' ' {
+ break
+ }
+ ss.pos++
+ }
+ for ss.pos < len(ss.line) {
+ if ss.line[ss.pos] != ' ' {
+ break
+ }
+ ss.pos++
+ }
+ ss.moveCursorToPos(ss.pos)
+ case keyLeft:
+ if ss.pos == 0 {
+ return
+ }
+ ss.pos--
+ ss.moveCursorToPos(ss.pos)
+ case keyRight:
+ if ss.pos == len(ss.line) {
+ return
+ }
+ ss.pos++
+ ss.moveCursorToPos(ss.pos)
+ case keyEnter:
+ ss.moveCursorToPos(len(ss.line))
+ ss.queue([]byte("\r\n"))
+ line = string(ss.line)
+ ok = true
+ ss.line = ss.line[:0]
+ ss.pos = 0
+ ss.cursorX = 0
+ ss.cursorY = 0
+ ss.maxLine = 0
+ default:
+ if !isPrintable(key) {
+ return
+ }
+ if len(ss.line) == maxLineLength {
+ return
+ }
+ if len(ss.line) == cap(ss.line) {
+ newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
+ copy(newLine, ss.line)
+ ss.line = newLine
+ }
+ ss.line = ss.line[:len(ss.line)+1]
+ copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
+ ss.line[ss.pos] = byte(key)
+ ss.writeLine(ss.line[ss.pos:])
+ ss.pos++
+ ss.moveCursorToPos(ss.pos)
+ }
+ return
+}
+
+func (ss *ServerShell) writeLine(line []byte) {
+ for len(line) != 0 {
+ if ss.cursorX == ss.termWidth {
+ ss.queue([]byte("\r\n"))
+ ss.cursorX = 0
+ ss.cursorY++
+ if ss.cursorY > ss.maxLine {
+ ss.maxLine = ss.cursorY
+ }
+ }
+
+ remainingOnLine := ss.termWidth - ss.cursorX
+ todo := len(line)
+ if todo > remainingOnLine {
+ todo = remainingOnLine
+ }
+ ss.queue(line[:todo])
+ ss.cursorX += todo
+ line = line[todo:]
+ }
+}
+
+// parsePtyRequest parses the payload of the pty-req message and extracts the
+// dimensions of the terminal. See RFC 4254, section 6.2.
+func parsePtyRequest(s []byte) (width, height int, ok bool) {
+ _, s, ok = parseString(s)
+ if !ok {
+ return
+ }
+ width32, s, ok := parseUint32(s)
+ if !ok {
+ return
+ }
+ height32, _, ok := parseUint32(s)
+ width = int(width32)
+ height = int(height32)
+ if width < 1 {
+ ok = false
+ }
+ if height < 1 {
+ ok = false
+ }
+ return
+}
+
+func (ss *ServerShell) Write(buf []byte) (n int, err os.Error) {
+ return ss.c.Write(buf)
+}
+
+// ReadLine returns a line of input from the terminal.
+func (ss *ServerShell) ReadLine() (line string, err os.Error) {
+ ss.writeLine([]byte(ss.prompt))
+ ss.c.Write(ss.outBuf)
+ ss.outBuf = ss.outBuf[:0]
+
+ for {
+ // ss.remainder is a slice at the beginning of ss.inBuf
+ // containing a partial key sequence
+ readBuf := ss.inBuf[len(ss.remainder):]
+ n, err := ss.c.Read(readBuf)
+ if err == nil {
+ ss.remainder = ss.inBuf[:n+len(ss.remainder)]
+ rest := ss.remainder
+ lineOk := false
+ for !lineOk {
+ var key int
+ key, rest = bytesToKey(rest)
+ if key < 0 {
+ break
+ }
+ if key == keyCtrlD {
+ return "", os.EOF
+ }
+ line, lineOk = ss.handleKey(key)
+ }
+ if len(rest) > 0 {
+ n := copy(ss.inBuf[:], rest)
+ ss.remainder = ss.inBuf[:n]
+ } else {
+ ss.remainder = nil
+ }
+ ss.c.Write(ss.outBuf)
+ ss.outBuf = ss.outBuf[:0]
+ if lineOk {
+ return
+ }
+ continue
+ }
+
+ if req, ok := err.(ChannelRequest); ok {
+ ok := false
+ switch req.Request {
+ case "pty-req":
+ ss.termWidth, ss.termHeight, ok = parsePtyRequest(req.Payload)
+ if !ok {
+ ss.termWidth = 80
+ ss.termHeight = 24
+ }
+ case "shell":
+ ok = true
+ if len(req.Payload) > 0 {
+ // We don't accept any commands, only the default shell.
+ ok = false
+ }
+ case "env":
+ ok = true
+ }
+ if req.WantReply {
+ ss.c.AckRequest(ok)
+ }
+ } else {
+ return "", err
+ }
+ }
+ panic("unreachable")
+}
diff --git a/src/pkg/exp/ssh/server_shell_test.go b/src/pkg/exp/ssh/server_shell_test.go
new file mode 100644
index 0000000000..622cf7cfad
--- /dev/null
+++ b/src/pkg/exp/ssh/server_shell_test.go
@@ -0,0 +1,134 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "testing"
+ "os"
+)
+
+type MockChannel struct {
+ toSend []byte
+ bytesPerRead int
+ received []byte
+}
+
+func (c *MockChannel) Accept() os.Error {
+ return nil
+}
+
+func (c *MockChannel) Reject(RejectionReason, string) os.Error {
+ return nil
+}
+
+func (c *MockChannel) Read(data []byte) (n int, err os.Error) {
+ n = len(data)
+ if n == 0 {
+ return
+ }
+ if n > len(c.toSend) {
+ n = len(c.toSend)
+ }
+ if n == 0 {
+ return 0, os.EOF
+ }
+ if c.bytesPerRead > 0 && n > c.bytesPerRead {
+ n = c.bytesPerRead
+ }
+ copy(data, c.toSend[:n])
+ c.toSend = c.toSend[n:]
+ return
+}
+
+func (c *MockChannel) Write(data []byte) (n int, err os.Error) {
+ c.received = append(c.received, data...)
+ return len(data), nil
+}
+
+func (c *MockChannel) Close() os.Error {
+ return nil
+}
+
+func (c *MockChannel) AckRequest(ok bool) os.Error {
+ return nil
+}
+
+func (c *MockChannel) ChannelType() string {
+ return ""
+}
+
+func (c *MockChannel) ExtraData() []byte {
+ return nil
+}
+
+func TestClose(t *testing.T) {
+ c := &MockChannel{}
+ ss := NewServerShell(c, "> ")
+ line, err := ss.ReadLine()
+ if line != "" {
+ t.Errorf("Expected empty line but got: %s", line)
+ }
+ if err != os.EOF {
+ t.Errorf("Error should have been EOF but got: %s", err)
+ }
+}
+
+var keyPressTests = []struct {
+ in string
+ line string
+ err os.Error
+}{
+ {
+ "",
+ "",
+ os.EOF,
+ },
+ {
+ "\r",
+ "",
+ nil,
+ },
+ {
+ "foo\r",
+ "foo",
+ nil,
+ },
+ {
+ "a\x1b[Cb\r", // right
+ "ab",
+ nil,
+ },
+ {
+ "a\x1b[Db\r", // left
+ "ba",
+ nil,
+ },
+ {
+ "a\177b\r", // backspace
+ "b",
+ nil,
+ },
+}
+
+func TestKeyPresses(t *testing.T) {
+ for i, test := range keyPressTests {
+ for j := 0; j < len(test.in); j++ {
+ c := &MockChannel{
+ toSend: []byte(test.in),
+ bytesPerRead: j,
+ }
+ ss := NewServerShell(c, "> ")
+ line, err := ss.ReadLine()
+ if line != test.line {
+ t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
+ break
+ }
+ if err != test.err {
+ t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
+ break
+ }
+ }
+ }
+}
diff --git a/src/pkg/exp/ssh/transport.go b/src/pkg/exp/ssh/transport.go
new file mode 100644
index 0000000000..919759ff98
--- /dev/null
+++ b/src/pkg/exp/ssh/transport.go
@@ -0,0 +1,308 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "bufio"
+ "crypto"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/hmac"
+ "crypto/subtle"
+ "hash"
+ "io"
+ "net"
+ "os"
+)
+
+// halfConnection represents one direction of an SSH connection. It maintains
+// the cipher state needed to process messages.
+type halfConnection struct {
+ // Only one of these two will be non-nil
+ in *bufio.Reader
+ out net.Conn
+
+ rand io.Reader
+ cipherAlgo string
+ macAlgo string
+ compressionAlgo string
+ paddingMultiple int
+
+ seqNum uint32
+
+ mac hash.Hash
+ cipher cipher.Stream
+}
+
+func (hc *halfConnection) readOnePacket() (packet []byte, err os.Error) {
+ var lengthBytes [5]byte
+
+ _, err = io.ReadFull(hc.in, lengthBytes[:])
+ if err != nil {
+ return
+ }
+
+ if hc.cipher != nil {
+ hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
+ }
+
+ macSize := 0
+ if hc.mac != nil {
+ hc.mac.Reset()
+ var seqNumBytes [4]byte
+ seqNumBytes[0] = byte(hc.seqNum >> 24)
+ seqNumBytes[1] = byte(hc.seqNum >> 16)
+ seqNumBytes[2] = byte(hc.seqNum >> 8)
+ seqNumBytes[3] = byte(hc.seqNum)
+ hc.mac.Write(seqNumBytes[:])
+ hc.mac.Write(lengthBytes[:])
+ macSize = hc.mac.Size()
+ }
+
+ length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
+
+ paddingLength := uint32(lengthBytes[4])
+
+ if length <= paddingLength+1 {
+ return nil, os.NewError("invalid packet length")
+ }
+ if length > maxPacketSize {
+ return nil, os.NewError("packet too large")
+ }
+
+ packet = make([]byte, length-1+uint32(macSize))
+ _, err = io.ReadFull(hc.in, packet)
+ if err != nil {
+ return nil, err
+ }
+ mac := packet[length-1:]
+ if hc.cipher != nil {
+ hc.cipher.XORKeyStream(packet, packet[:length-1])
+ }
+
+ if hc.mac != nil {
+ hc.mac.Write(packet[:length-1])
+ if subtle.ConstantTimeCompare(hc.mac.Sum(), mac) != 1 {
+ return nil, os.NewError("ssh: MAC failure")
+ }
+ }
+
+ hc.seqNum++
+ packet = packet[:length-paddingLength-1]
+ return
+}
+
+func (hc *halfConnection) readPacket() (packet []byte, err os.Error) {
+ for {
+ packet, err := hc.readOnePacket()
+ if err != nil {
+ return nil, err
+ }
+ if packet[0] != msgIgnore && packet[0] != msgDebug {
+ return packet, nil
+ }
+ }
+ panic("unreachable")
+}
+
+func (hc *halfConnection) writePacket(packet []byte) os.Error {
+ paddingMultiple := hc.paddingMultiple
+ if paddingMultiple == 0 {
+ paddingMultiple = 8
+ }
+
+ paddingLength := paddingMultiple - (4+1+len(packet))%paddingMultiple
+ if paddingLength < 4 {
+ paddingLength += paddingMultiple
+ }
+
+ var lengthBytes [5]byte
+ length := len(packet) + 1 + paddingLength
+ lengthBytes[0] = byte(length >> 24)
+ lengthBytes[1] = byte(length >> 16)
+ lengthBytes[2] = byte(length >> 8)
+ lengthBytes[3] = byte(length)
+ lengthBytes[4] = byte(paddingLength)
+
+ var padding [32]byte
+ _, err := io.ReadFull(hc.rand, padding[:paddingLength])
+ if err != nil {
+ return err
+ }
+
+ if hc.mac != nil {
+ hc.mac.Reset()
+ var seqNumBytes [4]byte
+ seqNumBytes[0] = byte(hc.seqNum >> 24)
+ seqNumBytes[1] = byte(hc.seqNum >> 16)
+ seqNumBytes[2] = byte(hc.seqNum >> 8)
+ seqNumBytes[3] = byte(hc.seqNum)
+ hc.mac.Write(seqNumBytes[:])
+ hc.mac.Write(lengthBytes[:])
+ hc.mac.Write(packet)
+ hc.mac.Write(padding[:paddingLength])
+ }
+
+ if hc.cipher != nil {
+ hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
+ hc.cipher.XORKeyStream(packet, packet)
+ hc.cipher.XORKeyStream(padding[:], padding[:paddingLength])
+ }
+
+ _, err = hc.out.Write(lengthBytes[:])
+ if err != nil {
+ return err
+ }
+ _, err = hc.out.Write(packet)
+ if err != nil {
+ return err
+ }
+ _, err = hc.out.Write(padding[:paddingLength])
+ if err != nil {
+ return err
+ }
+
+ if hc.mac != nil {
+ _, err = hc.out.Write(hc.mac.Sum())
+ }
+
+ hc.seqNum++
+
+ return err
+}
+
+const (
+ serverKeys = iota
+ clientKeys
+)
+
+// setupServerKeys sets the cipher and MAC keys from K, H and sessionId, as
+// described in RFC 4253, section 6.4. direction should either be serverKeys
+// (to setup server->client keys) or clientKeys (for client->server keys).
+func (hc *halfConnection) setupKeys(direction int, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error {
+ h := hashFunc.New()
+
+ // We only support these algorithms for now.
+ if hc.cipherAlgo != cipherAES128CTR || hc.macAlgo != macSHA196 {
+ return os.NewError("ssh: setupServerKeys internal error")
+ }
+
+ blockSize := 16
+ keySize := 16
+ macKeySize := 20
+
+ var ivTag, keyTag, macKeyTag byte
+ if direction == serverKeys {
+ ivTag, keyTag, macKeyTag = 'B', 'D', 'F'
+ } else {
+ ivTag, keyTag, macKeyTag = 'A', 'C', 'E'
+ }
+
+ iv := make([]byte, blockSize)
+ key := make([]byte, keySize)
+ macKey := make([]byte, macKeySize)
+ generateKeyMaterial(iv, ivTag, K, H, sessionId, h)
+ generateKeyMaterial(key, keyTag, K, H, sessionId, h)
+ generateKeyMaterial(macKey, macKeyTag, K, H, sessionId, h)
+
+ hc.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
+ aes, err := aes.NewCipher(key)
+ if err != nil {
+ return err
+ }
+ hc.cipher = cipher.NewCTR(aes, iv)
+ hc.paddingMultiple = 16
+ return nil
+}
+
+// generateKeyMaterial fills out with key material generated from tag, K, H
+// and sessionId, as specified in RFC 4253, section 7.2.
+func generateKeyMaterial(out []byte, tag byte, K, H, sessionId []byte, h hash.Hash) {
+ var digestsSoFar []byte
+
+ for len(out) > 0 {
+ h.Reset()
+ h.Write(K)
+ h.Write(H)
+
+ if len(digestsSoFar) == 0 {
+ h.Write([]byte{tag})
+ h.Write(sessionId)
+ } else {
+ h.Write(digestsSoFar)
+ }
+
+ digest := h.Sum()
+ n := copy(out, digest)
+ out = out[n:]
+ if len(out) > 0 {
+ digestsSoFar = append(digestsSoFar, digest...)
+ }
+ }
+}
+
+// truncatingMAC wraps around a hash.Hash and truncates the output digest to
+// a given size.
+type truncatingMAC struct {
+ length int
+ hmac hash.Hash
+}
+
+func (t truncatingMAC) Write(data []byte) (int, os.Error) {
+ return t.hmac.Write(data)
+}
+
+func (t truncatingMAC) Sum() []byte {
+ digest := t.hmac.Sum()
+ return digest[:t.length]
+}
+
+func (t truncatingMAC) Reset() {
+ t.hmac.Reset()
+}
+
+func (t truncatingMAC) Size() int {
+ return t.length
+}
+
+// maxVersionStringBytes is the maximum number of bytes that we'll accept as a
+// version string. In the event that the client is talking a different protocol
+// we need to set a limit otherwise we will keep using more and more memory
+// while searching for the end of the version handshake.
+const maxVersionStringBytes = 1024
+
+func readVersion(r *bufio.Reader) (versionString []byte, ok bool) {
+ versionString = make([]byte, 0, 64)
+ seenCR := false
+
+forEachByte:
+ for len(versionString) < maxVersionStringBytes {
+ b, err := r.ReadByte()
+ if err != nil {
+ return
+ }
+
+ if !seenCR {
+ if b == '\r' {
+ seenCR = true
+ }
+ } else {
+ if b == '\n' {
+ ok = true
+ break forEachByte
+ } else {
+ seenCR = false
+ }
+ }
+ versionString = append(versionString, b)
+ }
+
+ if ok {
+ // We need to remove the CR from versionString
+ versionString = versionString[:len(versionString)-1]
+ }
+
+ return
+}