waveterm/pkg/genconn/ssh-impl.go
Evan Simkowitz b51ff834b2
Happy new year! (#1684)
Update all 2024 references to 2025
2025-01-04 20:56:57 -08:00

146 lines
3.4 KiB
Go

// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package genconn
import (
"fmt"
"io"
"sync"
"golang.org/x/crypto/ssh"
)
var _ ShellClient = (*SSHShellClient)(nil)
type SSHShellClient struct {
client *ssh.Client
}
func MakeSSHShellClient(client *ssh.Client) *SSHShellClient {
return &SSHShellClient{client: client}
}
func (c *SSHShellClient) MakeProcessController(cmdSpec CommandSpec) (ShellProcessController, error) {
return MakeSSHCmdClient(c.client, cmdSpec)
}
// SSHProcessController implements ShellCmd for SSH connections
type SSHProcessController struct {
client *ssh.Client
session *ssh.Session
lock *sync.Mutex
once *sync.Once
stdinPiped bool
stdoutPiped bool
stderrPiped bool
waitErr error
started bool
cmdSpec CommandSpec
}
// MakeSSHCmdClient creates a new instance of SSHCmdClient
func MakeSSHCmdClient(client *ssh.Client, cmdSpec CommandSpec) (*SSHProcessController, error) {
session, err := client.NewSession()
if err != nil {
return nil, fmt.Errorf("failed to create SSH session: %w", err)
}
return &SSHProcessController{
client: client,
lock: &sync.Mutex{},
once: &sync.Once{},
cmdSpec: cmdSpec,
session: session,
}, nil
}
// Start begins execution of the command
func (s *SSHProcessController) Start() error {
s.lock.Lock()
defer s.lock.Unlock()
if s.started {
return fmt.Errorf("command already started")
}
fullCmd, err := BuildShellCommand(s.cmdSpec)
if err != nil {
return fmt.Errorf("failed to build shell command: %w", err)
}
// if stdout/stderr weren't piped, then session.stdout/stderr will be nil
// and the library guarantees that the outputs will be attached to io.Discard
// if stdin hasn't been piped, then session.stdin will be nil
// and the libary guarantees that it will be attached to an empty bytes.Buffer, which will produce an immediate EOF
// tl;dr we don't need to worry about hanging beause of long input or explicitly closing stdin
if err := s.session.Start(fullCmd); err != nil {
return fmt.Errorf("failed to start command: %w", err)
}
s.started = true
return nil
}
// Wait waits for the command to complete
func (s *SSHProcessController) Wait() error {
s.once.Do(func() {
s.waitErr = s.session.Wait()
})
return s.waitErr
}
// Kill terminates the command
func (s *SSHProcessController) Kill() {
s.lock.Lock()
defer s.lock.Unlock()
if s.session != nil {
s.session.Close()
}
}
func (s *SSHProcessController) StdinPipe() (io.WriteCloser, error) {
s.lock.Lock()
defer s.lock.Unlock()
if s.started {
return nil, fmt.Errorf("command already started")
}
if s.stdinPiped {
return nil, fmt.Errorf("stdin already piped")
}
s.stdinPiped = true
return s.session.StdinPipe()
}
func (s *SSHProcessController) StdoutPipe() (io.Reader, error) {
s.lock.Lock()
defer s.lock.Unlock()
if s.started {
return nil, fmt.Errorf("command already started")
}
if s.stdoutPiped {
return nil, fmt.Errorf("stdout already piped")
}
s.stdoutPiped = true
stdout, err := s.session.StdoutPipe()
if err != nil {
return nil, err
}
return stdout, nil
}
func (s *SSHProcessController) StderrPipe() (io.Reader, error) {
s.lock.Lock()
defer s.lock.Unlock()
if s.started {
return nil, fmt.Errorf("command already started")
}
if s.stderrPiped {
return nil, fmt.Errorf("stderr already piped")
}
s.stderrPiped = true
stderr, err := s.session.StderrPipe()
if err != nil {
return nil, err
}
return stderr, nil
}