mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-20 21:21:44 +01:00
476 lines
13 KiB
Go
476 lines
13 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package conncontroller
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/kevinburke/ssh_config"
|
|
"github.com/wavetermdev/thenextwave/pkg/remote"
|
|
"github.com/wavetermdev/thenextwave/pkg/userinput"
|
|
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
|
|
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
|
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
|
"github.com/wavetermdev/thenextwave/pkg/wps"
|
|
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
|
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
const (
|
|
Status_Init = "init"
|
|
Status_Connecting = "connecting"
|
|
Status_Connected = "connected"
|
|
Status_Disconnected = "disconnected"
|
|
Status_Error = "error"
|
|
)
|
|
|
|
var globalLock = &sync.Mutex{}
|
|
var clientControllerMap = make(map[remote.SSHOpts]*SSHConn)
|
|
|
|
type SSHConn struct {
|
|
Lock *sync.Mutex
|
|
Status string
|
|
Opts *remote.SSHOpts
|
|
Client *ssh.Client
|
|
SockName string
|
|
DomainSockListener net.Listener
|
|
ConnController *ssh.Session
|
|
Error string
|
|
HasWaiter *atomic.Bool
|
|
}
|
|
|
|
func GetAllConnStatus() []wshrpc.ConnStatus {
|
|
globalLock.Lock()
|
|
defer globalLock.Unlock()
|
|
|
|
var connStatuses []wshrpc.ConnStatus
|
|
for _, conn := range clientControllerMap {
|
|
connStatuses = append(connStatuses, conn.DeriveConnStatus())
|
|
}
|
|
return connStatuses
|
|
}
|
|
|
|
func (conn *SSHConn) DeriveConnStatus() wshrpc.ConnStatus {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
return wshrpc.ConnStatus{
|
|
Status: conn.Status,
|
|
Connected: conn.Status == Status_Connected,
|
|
Connection: conn.Opts.String(),
|
|
Error: conn.Error,
|
|
}
|
|
}
|
|
|
|
func (conn *SSHConn) FireConnChangeEvent() {
|
|
status := conn.DeriveConnStatus()
|
|
event := wshrpc.WaveEvent{
|
|
Event: wshrpc.Event_ConnChange,
|
|
Scopes: []string{
|
|
fmt.Sprintf("connection:%s", conn.GetName()),
|
|
},
|
|
Data: status,
|
|
}
|
|
log.Printf("connstatus change %q => %s\n", conn.GetName(), status.Status)
|
|
wps.Broker.Publish(event)
|
|
}
|
|
|
|
func (conn *SSHConn) Close() error {
|
|
defer conn.FireConnChangeEvent()
|
|
conn.WithLock(func() {
|
|
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
|
|
// if status is init, disconnected, or error don't change it
|
|
conn.Status = Status_Disconnected
|
|
}
|
|
conn.close_nolock()
|
|
})
|
|
// we must wait for the waiter to complete
|
|
startTime := time.Now()
|
|
for conn.HasWaiter.Load() {
|
|
time.Sleep(10 * time.Millisecond)
|
|
if time.Since(startTime) > 2*time.Second {
|
|
return fmt.Errorf("timeout waiting for waiter to complete")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (conn *SSHConn) close_nolock() {
|
|
// does not set status (that should happen at another level)
|
|
if conn.DomainSockListener != nil {
|
|
conn.DomainSockListener.Close()
|
|
conn.DomainSockListener = nil
|
|
}
|
|
if conn.ConnController != nil {
|
|
conn.ConnController.Close()
|
|
conn.ConnController = nil
|
|
}
|
|
if conn.Client != nil {
|
|
conn.Client.Close()
|
|
conn.Client = nil
|
|
}
|
|
}
|
|
|
|
func (conn *SSHConn) GetDomainSocketName() string {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
return conn.SockName
|
|
}
|
|
|
|
func (conn *SSHConn) GetStatus() string {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
return conn.Status
|
|
}
|
|
|
|
func (conn *SSHConn) GetName() string {
|
|
// no lock required because opts is immutable
|
|
return conn.Opts.String()
|
|
}
|
|
|
|
func (conn *SSHConn) OpenDomainSocketListener() error {
|
|
var allowed bool
|
|
conn.WithLock(func() {
|
|
if conn.Status != Status_Connecting {
|
|
allowed = false
|
|
} else {
|
|
allowed = true
|
|
}
|
|
})
|
|
if !allowed {
|
|
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
|
|
}
|
|
client := conn.GetClient()
|
|
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
|
|
if err != nil {
|
|
return fmt.Errorf("error generating random string: %w", err)
|
|
}
|
|
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
|
|
log.Printf("remote domain socket %s %q\n", conn.GetName(), sockName)
|
|
listener, err := client.ListenUnix(sockName)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to request connection domain socket: %v", err)
|
|
}
|
|
conn.WithLock(func() {
|
|
conn.SockName = sockName
|
|
conn.DomainSockListener = listener
|
|
})
|
|
go func() {
|
|
defer conn.WithLock(func() {
|
|
conn.DomainSockListener = nil
|
|
conn.SockName = ""
|
|
})
|
|
wshutil.RunWshRpcOverListener(listener)
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (conn *SSHConn) StartConnServer() error {
|
|
var allowed bool
|
|
conn.WithLock(func() {
|
|
if conn.Status != Status_Connecting {
|
|
allowed = false
|
|
} else {
|
|
allowed = true
|
|
}
|
|
})
|
|
if !allowed {
|
|
return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
|
|
}
|
|
client := conn.GetClient()
|
|
wshPath := remote.GetWshPath(client)
|
|
rpcCtx := wshrpc.RpcContext{
|
|
ClientType: wshrpc.ClientType_ConnServer,
|
|
Conn: conn.GetName(),
|
|
}
|
|
sockName := conn.GetDomainSocketName()
|
|
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
|
}
|
|
sshSession, err := client.NewSession()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to create ssh session for conn controller: %w", err)
|
|
}
|
|
pipeRead, pipeWrite := io.Pipe()
|
|
sshSession.Stdout = pipeWrite
|
|
sshSession.Stderr = pipeWrite
|
|
cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
|
log.Printf("starting conn controller: %s\n", cmdStr)
|
|
err = sshSession.Start(cmdStr)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to start conn controller: %w", err)
|
|
}
|
|
conn.WithLock(func() {
|
|
conn.ConnController = sshSession
|
|
})
|
|
// service the I/O
|
|
go func() {
|
|
// wait for termination, clear the controller
|
|
defer conn.WithLock(func() {
|
|
conn.ConnController = nil
|
|
})
|
|
waitErr := sshSession.Wait()
|
|
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
|
|
}()
|
|
go func() {
|
|
readErr := wshutil.StreamToLines(pipeRead, func(line []byte) {
|
|
lineStr := string(line)
|
|
if !strings.HasSuffix(lineStr, "\n") {
|
|
lineStr += "\n"
|
|
}
|
|
log.Printf("[conncontroller:%s:output] %s", conn.GetName(), lineStr)
|
|
})
|
|
if readErr != nil && readErr != io.EOF {
|
|
log.Printf("[conncontroller:%s] error reading output: %v\n", conn.GetName(), readErr)
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error {
|
|
client := conn.GetClient()
|
|
if client == nil {
|
|
return fmt.Errorf("client is nil")
|
|
}
|
|
// check that correct wsh extensions are installed
|
|
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
|
clientVersion, err := remote.GetWshVersion(client)
|
|
if err == nil && clientVersion == expectedVersion {
|
|
return nil
|
|
}
|
|
// TODO add some progress to SSHConn about install status
|
|
var queryText string
|
|
var title string
|
|
if err != nil {
|
|
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
|
|
title = "Install Wsh Shell Extensions"
|
|
} else {
|
|
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
|
|
title = "Update Wsh Shell Extensions"
|
|
}
|
|
request := &userinput.UserInputRequest{
|
|
ResponseType: "confirm",
|
|
QueryText: queryText,
|
|
Title: title,
|
|
CheckBoxMsg: "Don't show me this again",
|
|
}
|
|
response, err := userinput.GetUserInput(ctx, request)
|
|
if err != nil || !response.Confirm {
|
|
return err
|
|
}
|
|
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
|
|
clientOs, err := remote.GetClientOs(client)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
clientArch, err := remote.GetClientArch(client)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// attempt to install extension
|
|
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
|
err = remote.CpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Printf("successfully installed wsh on %s\n", conn.GetName())
|
|
return nil
|
|
}
|
|
|
|
func (conn *SSHConn) GetClient() *ssh.Client {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
return conn.Client
|
|
}
|
|
|
|
func (conn *SSHConn) Reconnect(ctx context.Context) error {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return conn.Connect(ctx)
|
|
}
|
|
|
|
// does not return an error since that error is stored inside of SSHConn
|
|
func (conn *SSHConn) Connect(ctx context.Context) error {
|
|
var connectAllowed bool
|
|
conn.WithLock(func() {
|
|
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
|
|
connectAllowed = false
|
|
} else {
|
|
conn.Status = Status_Connecting
|
|
conn.Error = ""
|
|
connectAllowed = true
|
|
}
|
|
})
|
|
if !connectAllowed {
|
|
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
|
}
|
|
conn.FireConnChangeEvent()
|
|
err := conn.connectInternal(ctx)
|
|
conn.WithLock(func() {
|
|
if err != nil {
|
|
conn.Status = Status_Error
|
|
conn.Error = err.Error()
|
|
conn.close_nolock()
|
|
} else {
|
|
conn.Status = Status_Connected
|
|
}
|
|
})
|
|
conn.FireConnChangeEvent()
|
|
return err
|
|
}
|
|
|
|
func (conn *SSHConn) WithLock(fn func()) {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
fn()
|
|
}
|
|
|
|
func (conn *SSHConn) connectInternal(ctx context.Context) error {
|
|
client, err := remote.ConnectToClient(ctx, conn.Opts) //todo specify or remove opts
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conn.WithLock(func() {
|
|
conn.Client = client
|
|
})
|
|
err = conn.OpenDomainSocketListener()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
installErr := conn.checkAndInstallWsh(ctx)
|
|
if installErr != nil {
|
|
return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
|
|
}
|
|
csErr := conn.StartConnServer()
|
|
if csErr != nil {
|
|
return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
|
|
}
|
|
conn.HasWaiter.Store(true)
|
|
go conn.waitForDisconnect()
|
|
return nil
|
|
}
|
|
|
|
func (conn *SSHConn) waitForDisconnect() {
|
|
defer conn.FireConnChangeEvent()
|
|
defer conn.HasWaiter.Store(false)
|
|
client := conn.GetClient()
|
|
if client == nil {
|
|
return
|
|
}
|
|
err := client.Wait()
|
|
conn.WithLock(func() {
|
|
if err != nil {
|
|
if conn.Status != Status_Disconnected {
|
|
// don't set the error if our status is disconnected (because this error was caused by an explicit close)
|
|
conn.Status = Status_Error
|
|
conn.Error = err.Error()
|
|
}
|
|
} else {
|
|
// not sure if this is possible, because I think Wait() always returns an error (although that's not in the docs)
|
|
conn.Status = Status_Disconnected
|
|
}
|
|
conn.close_nolock()
|
|
})
|
|
}
|
|
|
|
func getConnInternal(opts *remote.SSHOpts) *SSHConn {
|
|
globalLock.Lock()
|
|
defer globalLock.Unlock()
|
|
rtn := clientControllerMap[*opts]
|
|
if rtn == nil {
|
|
rtn = &SSHConn{Lock: &sync.Mutex{}, Status: Status_Init, Opts: opts, HasWaiter: &atomic.Bool{}}
|
|
clientControllerMap[*opts] = rtn
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func GetConn(ctx context.Context, opts *remote.SSHOpts, shouldConnect bool) *SSHConn {
|
|
conn := getConnInternal(opts)
|
|
if conn.Client == nil && shouldConnect {
|
|
conn.Connect(ctx)
|
|
}
|
|
return conn
|
|
}
|
|
|
|
func DisconnectClient(opts *remote.SSHOpts) error {
|
|
conn := getConnInternal(opts)
|
|
if conn == nil {
|
|
return fmt.Errorf("client %q not found", opts.String())
|
|
}
|
|
err := conn.Close()
|
|
return err
|
|
}
|
|
|
|
func resolveSshConfigPatterns(configFiles []string) ([]string, error) {
|
|
// using two separate containers to track order and have O(1) lookups
|
|
// since go does not have an ordered map primitive
|
|
var discoveredPatterns []string
|
|
alreadyUsed := make(map[string]bool)
|
|
alreadyUsed[""] = true // this excludes the empty string from potential alias
|
|
var openedFiles []fs.File
|
|
|
|
defer func() {
|
|
for _, openedFile := range openedFiles {
|
|
openedFile.Close()
|
|
}
|
|
}()
|
|
|
|
var errs []error
|
|
for _, configFile := range configFiles {
|
|
fd, openErr := os.Open(configFile)
|
|
openedFiles = append(openedFiles, fd)
|
|
if fd == nil {
|
|
errs = append(errs, openErr)
|
|
continue
|
|
}
|
|
|
|
cfg, _ := ssh_config.Decode(fd)
|
|
for _, host := range cfg.Hosts {
|
|
// for each host, find the first good alias
|
|
for _, hostPattern := range host.Patterns {
|
|
hostPatternStr := hostPattern.String()
|
|
if !strings.Contains(hostPatternStr, "*") || alreadyUsed[hostPatternStr] {
|
|
discoveredPatterns = append(discoveredPatterns, hostPatternStr)
|
|
alreadyUsed[hostPatternStr] = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(errs) == len(configFiles) {
|
|
errs = append([]error{fmt.Errorf("no ssh config files could be opened:\n")}, errs...)
|
|
return nil, errors.Join(errs...)
|
|
}
|
|
if len(discoveredPatterns) == 0 {
|
|
return nil, fmt.Errorf("no compatible hostnames found in ssh config files")
|
|
}
|
|
|
|
return discoveredPatterns, nil
|
|
}
|
|
|
|
func GetConnectionsFromConfig() ([]string, error) {
|
|
home := wavebase.GetHomeDir()
|
|
localConfig := filepath.Join(home, ".ssh", "config")
|
|
systemConfig := filepath.Join("/etc", "ssh", "config")
|
|
sshConfigFiles := []string{localConfig, systemConfig}
|
|
ssh_config.ReloadConfigs()
|
|
|
|
return resolveSshConfigPatterns(sshConfigFiles)
|
|
}
|