aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDave Cheney <dave@cheney.net>2011-12-01 08:30:16 -0200
committerGustavo Niemeyer <gustavo@niemeyer.net>2011-12-01 08:30:16 -0200
commitc4d0ac0e2f7a12cf44f4711b47bbc5737c14ce9c (patch)
tree89b7697de08b0a172e36c155861dcfd45a67554f
parentc0a53bbc4ac041e0f547c46bf244196eab3caef9 (diff)
downloadgo-c4d0ac0e2f7a12cf44f4711b47bbc5737c14ce9c.tar.gz
go-c4d0ac0e2f7a12cf44f4711b47bbc5737c14ce9c.zip
exp/ssh: add Std{in,out,err}Pipe methods to Session
R=gustav.paul, cw, agl, rsc, n13m3y3r CC=golang-dev https://golang.org/cl/5433080
-rw-r--r--src/pkg/exp/ssh/session.go87
-rw-r--r--src/pkg/exp/ssh/session_test.go149
2 files changed, 224 insertions, 12 deletions
diff --git a/src/pkg/exp/ssh/session.go b/src/pkg/exp/ssh/session.go
index dab0113f4b..8eea8b287b 100644
--- a/src/pkg/exp/ssh/session.go
+++ b/src/pkg/exp/ssh/session.go
@@ -54,9 +54,10 @@ type Session struct {
*clientChan // the channel backing this session
- started bool // true once a Shell or Run is invoked.
- copyFuncs []func() error
- errch chan error // one send per copyFunc
+ started bool // true once Start, Run or Shell is invoked.
+ closeAfterWait []io.Closer
+ copyFuncs []func() error
+ errch chan error // one send per copyFunc
}
// RFC 4254 Section 6.4.
@@ -231,7 +232,7 @@ func (s *Session) start() error {
return nil
}
-// Wait waits for the remote command to exit.
+// Wait waits for the remote command to exit.
func (s *Session) Wait() error {
if !s.started {
return errors.New("ssh: session not started")
@@ -244,11 +245,12 @@ func (s *Session) Wait() error {
copyError = err
}
}
-
+ for _, fd := range s.closeAfterWait {
+ fd.Close()
+ }
if waitErr != nil {
return waitErr
}
-
return copyError
}
@@ -283,11 +285,15 @@ func (s *Session) stdin() error {
s.Stdin = new(bytes.Buffer)
}
s.copyFuncs = append(s.copyFuncs, func() error {
- _, err := io.Copy(&chanWriter{
+ w := &chanWriter{
packetWriter: s,
peersId: s.peersId,
win: s.win,
- }, s.Stdin)
+ }
+ _, err := io.Copy(w, s.Stdin)
+ if err1 := w.Close(); err == nil {
+ err = err1
+ }
return err
})
return nil
@@ -298,11 +304,12 @@ func (s *Session) stdout() error {
s.Stdout = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
- _, err := io.Copy(s.Stdout, &chanReader{
+ r := &chanReader{
packetWriter: s,
peersId: s.peersId,
data: s.data,
- })
+ }
+ _, err := io.Copy(s.Stdout, r)
return err
})
return nil
@@ -313,16 +320,72 @@ func (s *Session) stderr() error {
s.Stderr = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
- _, err := io.Copy(s.Stderr, &chanReader{
+ r := &chanReader{
packetWriter: s,
peersId: s.peersId,
data: s.dataExt,
- })
+ }
+ _, err := io.Copy(s.Stderr, r)
return err
})
return nil
}
+// StdinPipe returns a pipe that will be connected to the
+// remote command's standard input when the command starts.
+func (s *Session) StdinPipe() (io.WriteCloser, error) {
+ if s.Stdin != nil {
+ return nil, errors.New("ssh: Stdin already set")
+ }
+ if s.started {
+ return nil, errors.New("ssh: StdinPipe after process started")
+ }
+ pr, pw := io.Pipe()
+ s.Stdin = pr
+ s.closeAfterWait = append(s.closeAfterWait, pr)
+ return pw, nil
+}
+
+// StdoutPipe returns a pipe that will be connected to the
+// remote command's standard output when the command starts.
+// There is a fixed amount of buffering that is shared between
+// stdout and stderr streams. If the StdoutPipe reader is
+// not serviced fast enought it may eventually cause the
+// remote command to block.
+func (s *Session) StdoutPipe() (io.ReadCloser, error) {
+ if s.Stdout != nil {
+ return nil, errors.New("ssh: Stdout already set")
+ }
+ if s.started {
+ return nil, errors.New("ssh: StdoutPipe after process started")
+ }
+ pr, pw := io.Pipe()
+ s.Stdout = pw
+ s.closeAfterWait = append(s.closeAfterWait, pw)
+ return pr, nil
+}
+
+// StderrPipe returns a pipe that will be connected to the
+// remote command's standard error when the command starts.
+// There is a fixed amount of buffering that is shared between
+// stdout and stderr streams. If the StderrPipe reader is
+// not serviced fast enought it may eventually cause the
+// remote command to block.
+func (s *Session) StderrPipe() (io.ReadCloser, error) {
+ if s.Stderr != nil {
+ return nil, errors.New("ssh: Stderr already set")
+ }
+ if s.started {
+ return nil, errors.New("ssh: StderrPipe after process started")
+ }
+ pr, pw := io.Pipe()
+ s.Stderr = pw
+ s.closeAfterWait = append(s.closeAfterWait, pw)
+ return pr, nil
+}
+
+// TODO(dfc) add Output and CombinedOutput helpers
+
// NewSession returns a new interactive session on the remote host.
func (c *ClientConn) NewSession() (*Session, error) {
ch := c.newChan(c.transport)
diff --git a/src/pkg/exp/ssh/session_test.go b/src/pkg/exp/ssh/session_test.go
new file mode 100644
index 0000000000..4be7746d17
--- /dev/null
+++ b/src/pkg/exp/ssh/session_test.go
@@ -0,0 +1,149 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+// Session tests.
+
+import (
+ "bytes"
+ "io"
+ "testing"
+)
+
+// dial constructs a new test server and returns a *ClientConn.
+func dial(t *testing.T) *ClientConn {
+ pw := password("tiger")
+ serverConfig.PasswordCallback = func(user, pass string) bool {
+ return user == "testuser" && pass == string(pw)
+ }
+ serverConfig.PubKeyCallback = nil
+
+ l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
+ if err != nil {
+ t.Fatalf("unable to listen: %s", err)
+ }
+ go func() {
+ defer l.Close()
+ conn, err := l.Accept()
+ if err != nil {
+ t.Errorf("Unable to accept: %v", err)
+ return
+ }
+ defer conn.Close()
+ if err := conn.Handshake(); err != nil {
+ t.Errorf("Unable to handshake: %v", err)
+ return
+ }
+ for {
+ ch, err := conn.Accept()
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Errorf("Unable to accept incoming channel request: %v", err)
+ return
+ }
+ if ch.ChannelType() != "session" {
+ ch.Reject(UnknownChannelType, "unknown channel type")
+ continue
+ }
+ ch.Accept()
+ go func() {
+ defer ch.Close()
+ // this string is returned to stdout
+ shell := NewServerShell(ch, "golang")
+ shell.ReadLine()
+ type exitMsg struct {
+ PeersId uint32
+ Request string
+ WantReply bool
+ Status uint32
+ }
+ // TODO(dfc) casting to the concrete type should not be
+ // necessary to send a packet.
+ msg := exitMsg{
+ PeersId: ch.(*channel).theirId,
+ Request: "exit-status",
+ WantReply: false,
+ Status: 0,
+ }
+ ch.(*channel).serverConn.writePacket(marshal(msgChannelRequest, msg))
+ }()
+ }
+ t.Log("done")
+ }()
+
+ config := &ClientConfig{
+ User: "testuser",
+ Auth: []ClientAuth{
+ ClientAuthPassword(pw),
+ },
+ }
+
+ c, err := Dial("tcp", l.Addr().String(), config)
+ if err != nil {
+ t.Fatalf("unable to dial remote side: %s", err)
+ }
+ return c
+}
+
+// Test a simple string is returned to session.Stdout.
+func TestSessionShell(t *testing.T) {
+ conn := dial(t)
+ defer conn.Close()
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("Unable to request new session: %s", err)
+ }
+ defer session.Close()
+ stdout := new(bytes.Buffer)
+ session.Stdout = stdout
+ if err := session.Shell(); err != nil {
+ t.Fatalf("Unable to execute command: %s", err)
+ }
+ if err := session.Wait(); err != nil {
+ t.Fatalf("Remote command did not exit cleanly: %s", err)
+ }
+ actual := stdout.String()
+ if actual != "golang" {
+ t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
+ }
+}
+
+// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
+
+// Test a simple string is returned via StdoutPipe.
+func TestSessionStdoutPipe(t *testing.T) {
+ conn := dial(t)
+ defer conn.Close()
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("Unable to request new session: %s", err)
+ }
+ defer session.Close()
+ stdout, err := session.StdoutPipe()
+ if err != nil {
+ t.Fatalf("Unable to request StdoutPipe(): %v", err)
+ }
+ var buf bytes.Buffer
+ if err := session.Shell(); err != nil {
+ t.Fatalf("Unable to execute command: %s", err)
+ }
+ done := make(chan bool, 1)
+ go func() {
+ if _, err := io.Copy(&buf, stdout); err != nil {
+ t.Errorf("Copy of stdout failed: %v", err)
+ }
+ done <- true
+ }()
+ if err := session.Wait(); err != nil {
+ t.Fatalf("Remote command did not exit cleanly: %s", err)
+ }
+ <-done
+ actual := buf.String()
+ if actual != "golang" {
+ t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
+ }
+}