harness-drone/vendor/code.google.com/p/go.crypto/ssh/channel.go
2015-09-29 18:21:17 -07:00

594 lines
14 KiB
Go

// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
)
// extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
// section 5.2.
type extendedDataTypeCode uint32
const (
// extendedDataStderr is the extended data type that is used for stderr.
extendedDataStderr extendedDataTypeCode = 1
// minPacketLength defines the smallest valid packet
minPacketLength = 9
// channelMaxPacketSize defines the maximum packet size advertised in open messages
channelMaxPacketSize = 1 << 15 // RFC 4253 6.1, minimum 32 KiB
// channelWindowSize defines the window size advertised in open messages
channelWindowSize = 64 * channelMaxPacketSize // Like OpenSSH
)
// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
// SSH connection. Channel.Read can return a ChannelRequest as an error.
type Channel interface {
// Accept accepts the channel creation request.
Accept() error
// Reject rejects the channel creation request. After calling this, no
// other methods on the Channel may be called. If they are then the
// peer is likely to signal a protocol error and drop the connection.
Reject(reason RejectionReason, message string) error
// Read may return a ChannelRequest as an error.
Read(data []byte) (int, error)
Write(data []byte) (int, error)
Close() error
// Stderr returns an io.Writer that writes to this channel with the
// extended data type set to stderr.
Stderr() io.Writer
// AckRequest either sends an ack or nack to the channel request.
AckRequest(ok bool) error
// ChannelType returns the type of the channel, as supplied by the
// client.
ChannelType() string
// ExtraData returns the arbitrary payload for this channel, as supplied
// by the client. This data is specific to the channel type.
ExtraData() []byte
}
// ChannelRequest represents a request sent on a channel, outside of the normal
// stream of bytes. It may result from calling Read on a Channel.
type ChannelRequest struct {
Request string
WantReply bool
Payload []byte
}
func (c ChannelRequest) Error() string {
return "ssh: channel request received"
}
// RejectionReason is an enumeration used when rejecting channel creation
// requests. See RFC 4254, section 5.1.
type RejectionReason uint32
const (
Prohibited RejectionReason = iota + 1
ConnectionFailed
UnknownChannelType
ResourceShortage
)
// String converts the rejection reason to human readable form.
func (r RejectionReason) String() string {
switch r {
case Prohibited:
return "administratively prohibited"
case ConnectionFailed:
return "connect failed"
case UnknownChannelType:
return "unknown channel type"
case ResourceShortage:
return "resource shortage"
}
return fmt.Sprintf("unknown reason %d", int(r))
}
type channel struct {
packetConn // the underlying transport
localId, remoteId uint32
remoteWin window
maxPacket uint32
isClosed uint32 // atomic bool, non zero if true
}
func (c *channel) sendWindowAdj(n int) error {
msg := windowAdjustMsg{
PeersId: c.remoteId,
AdditionalBytes: uint32(n),
}
return c.writePacket(marshal(msgChannelWindowAdjust, msg))
}
// sendEOF sends EOF to the remote side. RFC 4254 Section 5.3
func (c *channel) sendEOF() error {
return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
PeersId: c.remoteId,
}))
}
// sendClose informs the remote side of our intent to close the channel.
func (c *channel) sendClose() error {
return c.packetConn.writePacket(marshal(msgChannelClose, channelCloseMsg{
PeersId: c.remoteId,
}))
}
func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
reject := channelOpenFailureMsg{
PeersId: c.remoteId,
Reason: reason,
Message: message,
Language: "en",
}
return c.writePacket(marshal(msgChannelOpenFailure, reject))
}
func (c *channel) writePacket(b []byte) error {
if c.closed() {
return io.EOF
}
if uint32(len(b)) > c.maxPacket {
return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
}
return c.packetConn.writePacket(b)
}
func (c *channel) closed() bool {
return atomic.LoadUint32(&c.isClosed) > 0
}
func (c *channel) setClosed() bool {
return atomic.CompareAndSwapUint32(&c.isClosed, 0, 1)
}
type serverChan struct {
channel
// immutable once created
chanType string
extraData []byte
serverConn *ServerConn
myWindow uint32
theyClosed bool // indicates the close msg has been received from the remote side
theySentEOF bool
isDead uint32
err error
pendingRequests []ChannelRequest
pendingData []byte
head, length int
// This lock is inferior to serverConn.lock
cond *sync.Cond
}
func (c *serverChan) Accept() error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
confirm := channelOpenConfirmMsg{
PeersId: c.remoteId,
MyId: c.localId,
MyWindow: c.myWindow,
MaxPacketSize: c.maxPacket,
}
return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
}
func (c *serverChan) Reject(reason RejectionReason, message string) error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
return c.sendChannelOpenFailure(reason, message)
}
func (c *serverChan) handlePacket(packet interface{}) {
c.cond.L.Lock()
defer c.cond.L.Unlock()
switch packet := packet.(type) {
case *channelRequestMsg:
req := ChannelRequest{
Request: packet.Request,
WantReply: packet.WantReply,
Payload: packet.RequestSpecificData,
}
c.pendingRequests = append(c.pendingRequests, req)
c.cond.Signal()
case *channelCloseMsg:
c.theyClosed = true
c.cond.Signal()
case *channelEOFMsg:
c.theySentEOF = true
c.cond.Signal()
case *windowAdjustMsg:
if !c.remoteWin.add(packet.AdditionalBytes) {
panic("illegal window update")
}
default:
panic("unknown packet type")
}
}
func (c *serverChan) handleData(data []byte) {
c.cond.L.Lock()
defer c.cond.L.Unlock()
// The other side should never send us more than our window.
if len(data)+c.length > len(c.pendingData) {
// TODO(agl): we should tear down the channel with a protocol
// error.
return
}
c.myWindow -= uint32(len(data))
for i := 0; i < 2; i++ {
tail := c.head + c.length
if tail >= len(c.pendingData) {
tail -= len(c.pendingData)
}
n := copy(c.pendingData[tail:], data)
data = data[n:]
c.length += n
}
c.cond.Signal()
}
func (c *serverChan) Stderr() io.Writer {
return extendedDataChannel{c: c, t: extendedDataStderr}
}
// extendedDataChannel is an io.Writer that writes any data to c as extended
// data of the given type.
type extendedDataChannel struct {
t extendedDataTypeCode
c *serverChan
}
func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
const headerLength = 13 // 1 byte message type, 4 bytes remoteId, 4 bytes extended message type, 4 bytes data length
c := edc.c
for len(data) > 0 {
space := min(c.maxPacket-headerLength, len(data))
if space, err = c.getWindowSpace(space); err != nil {
return 0, err
}
todo := data
if uint32(len(todo)) > space {
todo = todo[:space]
}
packet := make([]byte, headerLength+len(todo))
packet[0] = msgChannelExtendedData
marshalUint32(packet[1:], c.remoteId)
marshalUint32(packet[5:], uint32(edc.t))
marshalUint32(packet[9:], uint32(len(todo)))
copy(packet[13:], todo)
if err = c.writePacket(packet); err != nil {
return
}
n += len(todo)
data = data[len(todo):]
}
return
}
func (c *serverChan) Read(data []byte) (n int, err error) {
n, err, windowAdjustment := c.read(data)
if windowAdjustment > 0 {
packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
PeersId: c.remoteId,
AdditionalBytes: windowAdjustment,
})
err = c.writePacket(packet)
}
return
}
func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
c.cond.L.Lock()
defer c.cond.L.Unlock()
if c.err != nil {
return 0, c.err, 0
}
for {
if c.theySentEOF || c.theyClosed || c.dead() {
return 0, io.EOF, 0
}
if len(c.pendingRequests) > 0 {
req := c.pendingRequests[0]
if len(c.pendingRequests) == 1 {
c.pendingRequests = nil
} else {
oldPendingRequests := c.pendingRequests
c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
copy(c.pendingRequests, oldPendingRequests[1:])
}
return 0, req, 0
}
if c.length > 0 {
tail := min(uint32(c.head+c.length), len(c.pendingData))
n = copy(data, c.pendingData[c.head:tail])
c.head += n
c.length -= n
if c.head == len(c.pendingData) {
c.head = 0
}
windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
if windowAdjustment < uint32(len(c.pendingData)/2) {
windowAdjustment = 0
}
c.myWindow += windowAdjustment
return
}
c.cond.Wait()
}
panic("unreachable")
}
// getWindowSpace takes, at most, max bytes of space from the peer's window. It
// returns the number of bytes actually reserved.
func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
if c.dead() || c.closed() {
return 0, io.EOF
}
return c.remoteWin.reserve(max), nil
}
func (c *serverChan) dead() bool {
return atomic.LoadUint32(&c.isDead) > 0
}
func (c *serverChan) setDead() {
atomic.StoreUint32(&c.isDead, 1)
}
func (c *serverChan) Write(data []byte) (n int, err error) {
const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
for len(data) > 0 {
space := min(c.maxPacket-headerLength, len(data))
if space, err = c.getWindowSpace(space); err != nil {
return 0, err
}
todo := data
if uint32(len(todo)) > space {
todo = todo[:space]
}
packet := make([]byte, headerLength+len(todo))
packet[0] = msgChannelData
marshalUint32(packet[1:], c.remoteId)
marshalUint32(packet[5:], uint32(len(todo)))
copy(packet[9:], todo)
if err = c.writePacket(packet); err != nil {
return
}
n += len(todo)
data = data[len(todo):]
}
return
}
// Close signals the intent to close the channel.
func (c *serverChan) Close() error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
if !c.setClosed() {
return errors.New("ssh: channel already closed")
}
return c.sendClose()
}
func (c *serverChan) AckRequest(ok bool) error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
if !ok {
ack := channelRequestFailureMsg{
PeersId: c.remoteId,
}
return c.writePacket(marshal(msgChannelFailure, ack))
}
ack := channelRequestSuccessMsg{
PeersId: c.remoteId,
}
return c.writePacket(marshal(msgChannelSuccess, ack))
}
func (c *serverChan) ChannelType() string {
return c.chanType
}
func (c *serverChan) ExtraData() []byte {
return c.extraData
}
// A clientChan represents a single RFC 4254 channel multiplexed
// over a SSH connection.
type clientChan struct {
channel
stdin *chanWriter
stdout *chanReader
stderr *chanReader
msg chan interface{}
}
// newClientChan returns a partially constructed *clientChan
// using the local id provided. To be usable clientChan.remoteId
// needs to be assigned once known.
func newClientChan(cc packetConn, id uint32) *clientChan {
c := &clientChan{
channel: channel{
packetConn: cc,
localId: id,
remoteWin: window{Cond: newCond()},
},
msg: make(chan interface{}, 16),
}
c.stdin = &chanWriter{
channel: &c.channel,
}
c.stdout = &chanReader{
channel: &c.channel,
buffer: newBuffer(),
}
c.stderr = &chanReader{
channel: &c.channel,
buffer: newBuffer(),
}
return c
}
// waitForChannelOpenResponse, if successful, fills out
// the remoteId and records any initial window advertisement.
func (c *clientChan) waitForChannelOpenResponse() error {
switch msg := (<-c.msg).(type) {
case *channelOpenConfirmMsg:
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return errors.New("ssh: invalid MaxPacketSize from peer")
}
// fixup remoteId field
c.remoteId = msg.MyId
c.maxPacket = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow)
return nil
case *channelOpenFailureMsg:
return errors.New(safeString(msg.Message))
}
return errors.New("ssh: unexpected packet")
}
// Close signals the intent to close the channel.
func (c *clientChan) Close() error {
if !c.setClosed() {
return errors.New("ssh: channel already closed")
}
c.stdout.eof()
c.stderr.eof()
return c.sendClose()
}
// A chanWriter represents the stdin of a remote process.
type chanWriter struct {
*channel
// indicates the writer has been closed. eof is owned by the
// caller of Write/Close.
eof bool
}
// Write writes data to the remote process's standard input.
func (w *chanWriter) Write(data []byte) (written int, err error) {
const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
for len(data) > 0 {
if w.eof || w.closed() {
err = io.EOF
return
}
// never send more data than maxPacket even if
// there is sufficient window.
n := min(w.maxPacket-headerLength, len(data))
r := w.remoteWin.reserve(n)
n = r
remoteId := w.remoteId
packet := []byte{
msgChannelData,
byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
}
if err = w.writePacket(append(packet, data[:n]...)); err != nil {
break
}
data = data[n:]
written += int(n)
}
return
}
func min(a uint32, b int) uint32 {
if a < uint32(b) {
return a
}
return uint32(b)
}
func (w *chanWriter) Close() error {
w.eof = true
return w.sendEOF()
}
// A chanReader represents stdout or stderr of a remote process.
type chanReader struct {
*channel // the channel backing this reader
*buffer
}
// Read reads data from the remote process's stdout or stderr.
func (r *chanReader) Read(buf []byte) (int, error) {
n, err := r.buffer.Read(buf)
if err != nil {
if err == io.EOF {
return n, err
}
return 0, err
}
err = r.sendWindowAdj(n)
if err == io.EOF && n > 0 {
// sendWindowAdjust can return io.EOF if the remote peer has
// closed the connection, however we want to defer forwarding io.EOF to the
// caller of Read until the buffer has been drained.
err = nil
}
return n, err
}