mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-22 21:42:49 +01:00
936d4bfb30
This migrates all remaining eventbus events sent over the websocket to use the wps interface. WPS is more flexible for registering events and callbacks and provides support for more reliable unsubscribes and resubscribes.
614 lines
17 KiB
Go
614 lines
17 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package conncontroller
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/kevinburke/ssh_config"
|
|
"github.com/skeema/knownhosts"
|
|
"github.com/wavetermdev/waveterm/pkg/remote"
|
|
"github.com/wavetermdev/waveterm/pkg/userinput"
|
|
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
|
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
|
"github.com/wavetermdev/waveterm/pkg/wps"
|
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
const (
|
|
Status_Init = "init"
|
|
Status_Connecting = "connecting"
|
|
Status_Connected = "connected"
|
|
Status_Disconnected = "disconnected"
|
|
Status_Error = "error"
|
|
)
|
|
|
|
const DefaultConnectionTimeout = 60 * time.Second
|
|
|
|
var globalLock = &sync.Mutex{}
|
|
var clientControllerMap = make(map[remote.SSHOpts]*SSHConn)
|
|
var activeConnCounter = &atomic.Int32{}
|
|
|
|
type SSHConn struct {
|
|
Lock *sync.Mutex
|
|
Status string
|
|
Opts *remote.SSHOpts
|
|
Client *ssh.Client
|
|
SockName string
|
|
DomainSockListener net.Listener
|
|
ConnController *ssh.Session
|
|
Error string
|
|
HasWaiter *atomic.Bool
|
|
LastConnectTime int64
|
|
ActiveConnNum int
|
|
}
|
|
|
|
func GetAllConnStatus() []wshrpc.ConnStatus {
|
|
globalLock.Lock()
|
|
defer globalLock.Unlock()
|
|
|
|
var connStatuses []wshrpc.ConnStatus
|
|
for _, conn := range clientControllerMap {
|
|
connStatuses = append(connStatuses, conn.DeriveConnStatus())
|
|
}
|
|
return connStatuses
|
|
}
|
|
|
|
func (conn *SSHConn) DeriveConnStatus() wshrpc.ConnStatus {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
return wshrpc.ConnStatus{
|
|
Status: conn.Status,
|
|
Connected: conn.Status == Status_Connected,
|
|
Connection: conn.Opts.String(),
|
|
HasConnected: (conn.LastConnectTime > 0),
|
|
ActiveConnNum: conn.ActiveConnNum,
|
|
Error: conn.Error,
|
|
}
|
|
}
|
|
|
|
func (conn *SSHConn) FireConnChangeEvent() {
|
|
status := conn.DeriveConnStatus()
|
|
event := wps.WaveEvent{
|
|
Event: wps.Event_ConnChange,
|
|
Scopes: []string{
|
|
fmt.Sprintf("connection:%s", conn.GetName()),
|
|
},
|
|
Data: status,
|
|
}
|
|
log.Printf("sending event: %+#v", event)
|
|
wps.Broker.Publish(event)
|
|
}
|
|
|
|
func (conn *SSHConn) Close() error {
|
|
defer conn.FireConnChangeEvent()
|
|
conn.WithLock(func() {
|
|
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
|
|
// if status is init, disconnected, or error don't change it
|
|
conn.Status = Status_Disconnected
|
|
}
|
|
conn.close_nolock()
|
|
})
|
|
// we must wait for the waiter to complete
|
|
startTime := time.Now()
|
|
for conn.HasWaiter.Load() {
|
|
time.Sleep(10 * time.Millisecond)
|
|
if time.Since(startTime) > 2*time.Second {
|
|
return fmt.Errorf("timeout waiting for waiter to complete")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (conn *SSHConn) close_nolock() {
|
|
// does not set status (that should happen at another level)
|
|
if conn.DomainSockListener != nil {
|
|
conn.DomainSockListener.Close()
|
|
conn.DomainSockListener = nil
|
|
}
|
|
if conn.ConnController != nil {
|
|
conn.ConnController.Close()
|
|
conn.ConnController = nil
|
|
}
|
|
if conn.Client != nil {
|
|
conn.Client.Close()
|
|
conn.Client = nil
|
|
}
|
|
}
|
|
|
|
func (conn *SSHConn) GetDomainSocketName() string {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
return conn.SockName
|
|
}
|
|
|
|
func (conn *SSHConn) GetStatus() string {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
return conn.Status
|
|
}
|
|
|
|
func (conn *SSHConn) GetName() string {
|
|
// no lock required because opts is immutable
|
|
return conn.Opts.String()
|
|
}
|
|
|
|
func (conn *SSHConn) OpenDomainSocketListener() error {
|
|
var allowed bool
|
|
conn.WithLock(func() {
|
|
if conn.Status != Status_Connecting {
|
|
allowed = false
|
|
} else {
|
|
allowed = true
|
|
}
|
|
})
|
|
if !allowed {
|
|
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
|
|
}
|
|
client := conn.GetClient()
|
|
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
|
|
if err != nil {
|
|
return fmt.Errorf("error generating random string: %w", err)
|
|
}
|
|
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
|
|
log.Printf("remote domain socket %s %q\n", conn.GetName(), conn.GetDomainSocketName())
|
|
listener, err := client.ListenUnix(sockName)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to request connection domain socket: %v", err)
|
|
}
|
|
conn.WithLock(func() {
|
|
conn.SockName = sockName
|
|
conn.DomainSockListener = listener
|
|
})
|
|
go func() {
|
|
defer conn.WithLock(func() {
|
|
conn.DomainSockListener = nil
|
|
conn.SockName = ""
|
|
})
|
|
wshutil.RunWshRpcOverListener(listener)
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (conn *SSHConn) StartConnServer() error {
|
|
var allowed bool
|
|
conn.WithLock(func() {
|
|
if conn.Status != Status_Connecting {
|
|
allowed = false
|
|
} else {
|
|
allowed = true
|
|
}
|
|
})
|
|
if !allowed {
|
|
return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
|
|
}
|
|
client := conn.GetClient()
|
|
wshPath := remote.GetWshPath(client)
|
|
rpcCtx := wshrpc.RpcContext{
|
|
ClientType: wshrpc.ClientType_ConnServer,
|
|
Conn: conn.GetName(),
|
|
}
|
|
sockName := conn.GetDomainSocketName()
|
|
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
|
}
|
|
sshSession, err := client.NewSession()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to create ssh session for conn controller: %w", err)
|
|
}
|
|
pipeRead, pipeWrite := io.Pipe()
|
|
sshSession.Stdout = pipeWrite
|
|
sshSession.Stderr = pipeWrite
|
|
shellPath, err := remote.DetectShell(client)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var cmdStr string
|
|
if remote.IsPowershell(shellPath) {
|
|
cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
|
} else {
|
|
cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
|
}
|
|
log.Printf("starting conn controller: %s\n", cmdStr)
|
|
err = sshSession.Start(cmdStr)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to start conn controller: %w", err)
|
|
}
|
|
conn.WithLock(func() {
|
|
conn.ConnController = sshSession
|
|
})
|
|
// service the I/O
|
|
go func() {
|
|
// wait for termination, clear the controller
|
|
defer conn.WithLock(func() {
|
|
conn.ConnController = nil
|
|
})
|
|
waitErr := sshSession.Wait()
|
|
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
|
|
}()
|
|
go func() {
|
|
readErr := wshutil.StreamToLines(pipeRead, func(line []byte) {
|
|
lineStr := string(line)
|
|
if !strings.HasSuffix(lineStr, "\n") {
|
|
lineStr += "\n"
|
|
}
|
|
log.Printf("[conncontroller:%s:output] %s", conn.GetName(), lineStr)
|
|
})
|
|
if readErr != nil && readErr != io.EOF {
|
|
log.Printf("[conncontroller:%s] error reading output: %v\n", conn.GetName(), readErr)
|
|
}
|
|
}()
|
|
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancelFn()
|
|
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
|
|
if err != nil {
|
|
return fmt.Errorf("timeout waiting for connserver to register")
|
|
}
|
|
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
|
|
return nil
|
|
}
|
|
|
|
type WshInstallOpts struct {
|
|
Force bool
|
|
NoUserPrompt bool
|
|
}
|
|
|
|
func (conn *SSHConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error {
|
|
if opts == nil {
|
|
opts = &WshInstallOpts{}
|
|
}
|
|
client := conn.GetClient()
|
|
if client == nil {
|
|
return fmt.Errorf("client is nil")
|
|
}
|
|
// check that correct wsh extensions are installed
|
|
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
|
clientVersion, err := remote.GetWshVersion(client)
|
|
if err == nil && clientVersion == expectedVersion && !opts.Force {
|
|
return nil
|
|
}
|
|
var queryText string
|
|
var title string
|
|
if opts.Force {
|
|
queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName)
|
|
title = "Install Wave Shell Extensions"
|
|
} else if err != nil {
|
|
queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+
|
|
"installed on `%s` \n"+
|
|
"to ensure a seamless experience. \n\n"+
|
|
"Would you like to install them?", clientDisplayName)
|
|
title = "Install Wave Shell Extensions"
|
|
} else {
|
|
queryText = fmt.Sprintf("Wave requires the Wave Shell Extensions \n"+
|
|
"installed on `%s` \n"+
|
|
"to be updated from %s to %s. \n\n"+
|
|
"Would you like to update?", clientDisplayName, clientVersion, expectedVersion)
|
|
title = "Update Wave Shell Extensions"
|
|
}
|
|
if !opts.NoUserPrompt {
|
|
request := &userinput.UserInputRequest{
|
|
ResponseType: "confirm",
|
|
QueryText: queryText,
|
|
Title: title,
|
|
Markdown: true,
|
|
CheckBoxMsg: "Don't show me this again",
|
|
}
|
|
response, err := userinput.GetUserInput(ctx, request)
|
|
if err != nil || !response.Confirm {
|
|
return err
|
|
}
|
|
}
|
|
log.Printf("attempting to install wsh to `%s`", clientDisplayName)
|
|
clientOs, err := remote.GetClientOs(client)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
clientArch, err := remote.GetClientArch(client)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// attempt to install extension
|
|
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
|
err = remote.CpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Printf("successfully installed wsh on %s\n", conn.GetName())
|
|
return nil
|
|
}
|
|
|
|
func (conn *SSHConn) GetClient() *ssh.Client {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
return conn.Client
|
|
}
|
|
|
|
func (conn *SSHConn) Reconnect(ctx context.Context) error {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return conn.Connect(ctx)
|
|
}
|
|
|
|
func (conn *SSHConn) WaitForConnect(ctx context.Context) error {
|
|
for {
|
|
status := conn.DeriveConnStatus()
|
|
if status.Status == Status_Connected {
|
|
return nil
|
|
}
|
|
if status.Status == Status_Connecting {
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("context timeout")
|
|
case <-time.After(100 * time.Millisecond):
|
|
continue
|
|
}
|
|
}
|
|
if status.Status == Status_Init || status.Status == Status_Disconnected {
|
|
return fmt.Errorf("disconnected")
|
|
}
|
|
if status.Status == Status_Error {
|
|
return fmt.Errorf("error: %v", status.Error)
|
|
}
|
|
return fmt.Errorf("unknown status: %q", status.Status)
|
|
}
|
|
}
|
|
|
|
// does not return an error since that error is stored inside of SSHConn
|
|
func (conn *SSHConn) Connect(ctx context.Context) error {
|
|
var connectAllowed bool
|
|
conn.WithLock(func() {
|
|
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
|
|
connectAllowed = false
|
|
} else {
|
|
conn.Status = Status_Connecting
|
|
conn.Error = ""
|
|
connectAllowed = true
|
|
}
|
|
})
|
|
log.Printf("Connect %s\n", conn.GetName())
|
|
if !connectAllowed {
|
|
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
|
}
|
|
conn.FireConnChangeEvent()
|
|
err := conn.connectInternal(ctx)
|
|
conn.WithLock(func() {
|
|
if err != nil {
|
|
conn.Status = Status_Error
|
|
conn.Error = err.Error()
|
|
conn.close_nolock()
|
|
} else {
|
|
conn.Status = Status_Connected
|
|
conn.LastConnectTime = time.Now().UnixMilli()
|
|
if conn.ActiveConnNum == 0 {
|
|
conn.ActiveConnNum = int(activeConnCounter.Add(1))
|
|
}
|
|
}
|
|
})
|
|
conn.FireConnChangeEvent()
|
|
return err
|
|
}
|
|
|
|
func (conn *SSHConn) WithLock(fn func()) {
|
|
conn.Lock.Lock()
|
|
defer conn.Lock.Unlock()
|
|
fn()
|
|
}
|
|
|
|
func (conn *SSHConn) connectInternal(ctx context.Context) error {
|
|
client, err := remote.ConnectToClient(ctx, conn.Opts) //todo specify or remove opts
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fmtAddr := knownhosts.Normalize(fmt.Sprintf("%s@%s", client.User(), client.RemoteAddr().String()))
|
|
clientDisplayName := fmt.Sprintf("%s (%s)", conn.GetName(), fmtAddr)
|
|
conn.WithLock(func() {
|
|
conn.Client = client
|
|
})
|
|
err = conn.OpenDomainSocketListener()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
installErr := conn.CheckAndInstallWsh(ctx, clientDisplayName, nil)
|
|
if installErr != nil {
|
|
return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
|
|
}
|
|
csErr := conn.StartConnServer()
|
|
if csErr != nil {
|
|
return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
|
|
}
|
|
conn.HasWaiter.Store(true)
|
|
go conn.waitForDisconnect()
|
|
return nil
|
|
}
|
|
|
|
func (conn *SSHConn) waitForDisconnect() {
|
|
defer conn.FireConnChangeEvent()
|
|
defer conn.HasWaiter.Store(false)
|
|
client := conn.GetClient()
|
|
if client == nil {
|
|
return
|
|
}
|
|
err := client.Wait()
|
|
conn.WithLock(func() {
|
|
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
|
|
// so we just set the status to "disconnected" here (not error)
|
|
// don't overwrite any existing error (or error status)
|
|
if err != nil && conn.Error == "" {
|
|
conn.Error = err.Error()
|
|
}
|
|
if conn.Status != Status_Error {
|
|
conn.Status = Status_Disconnected
|
|
}
|
|
conn.close_nolock()
|
|
})
|
|
}
|
|
|
|
func getConnInternal(opts *remote.SSHOpts) *SSHConn {
|
|
globalLock.Lock()
|
|
defer globalLock.Unlock()
|
|
rtn := clientControllerMap[*opts]
|
|
if rtn == nil {
|
|
rtn = &SSHConn{Lock: &sync.Mutex{}, Status: Status_Init, Opts: opts, HasWaiter: &atomic.Bool{}}
|
|
clientControllerMap[*opts] = rtn
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func GetConn(ctx context.Context, opts *remote.SSHOpts, shouldConnect bool) *SSHConn {
|
|
conn := getConnInternal(opts)
|
|
if conn.Client == nil && shouldConnect {
|
|
conn.Connect(ctx)
|
|
}
|
|
return conn
|
|
}
|
|
|
|
// Convenience function for ensuring a connection is established
|
|
func EnsureConnection(ctx context.Context, connName string) error {
|
|
if connName == "" {
|
|
return nil
|
|
}
|
|
connOpts, err := remote.ParseOpts(connName)
|
|
if err != nil {
|
|
return fmt.Errorf("error parsing connection name: %w", err)
|
|
}
|
|
conn := GetConn(ctx, connOpts, false)
|
|
if conn == nil {
|
|
return fmt.Errorf("connection not found: %s", connName)
|
|
}
|
|
connStatus := conn.DeriveConnStatus()
|
|
switch connStatus.Status {
|
|
case Status_Connected:
|
|
return nil
|
|
case Status_Connecting:
|
|
return conn.WaitForConnect(ctx)
|
|
case Status_Init, Status_Disconnected:
|
|
return conn.Connect(ctx)
|
|
case Status_Error:
|
|
return fmt.Errorf("connection error: %s", connStatus.Error)
|
|
default:
|
|
return fmt.Errorf("unknown connection status %q", connStatus.Status)
|
|
}
|
|
}
|
|
|
|
func DisconnectClient(opts *remote.SSHOpts) error {
|
|
conn := getConnInternal(opts)
|
|
if conn == nil {
|
|
return fmt.Errorf("client %q not found", opts.String())
|
|
}
|
|
err := conn.Close()
|
|
return err
|
|
}
|
|
|
|
func resolveSshConfigPatterns(configFiles []string) ([]string, error) {
|
|
// using two separate containers to track order and have O(1) lookups
|
|
// since go does not have an ordered map primitive
|
|
var discoveredPatterns []string
|
|
alreadyUsed := make(map[string]bool)
|
|
alreadyUsed[""] = true // this excludes the empty string from potential alias
|
|
var openedFiles []fs.File
|
|
|
|
defer func() {
|
|
for _, openedFile := range openedFiles {
|
|
openedFile.Close()
|
|
}
|
|
}()
|
|
|
|
var errs []error
|
|
for _, configFile := range configFiles {
|
|
fd, openErr := os.Open(configFile)
|
|
openedFiles = append(openedFiles, fd)
|
|
if fd == nil {
|
|
errs = append(errs, openErr)
|
|
continue
|
|
}
|
|
|
|
cfg, _ := ssh_config.Decode(fd)
|
|
for _, host := range cfg.Hosts {
|
|
// for each host, find the first good alias
|
|
for _, hostPattern := range host.Patterns {
|
|
hostPatternStr := hostPattern.String()
|
|
normalized := remote.NormalizeConfigPattern(hostPatternStr)
|
|
if (!strings.Contains(hostPatternStr, "*") && !strings.Contains(hostPatternStr, "?") && !strings.Contains(hostPatternStr, "!")) || alreadyUsed[normalized] {
|
|
discoveredPatterns = append(discoveredPatterns, normalized)
|
|
alreadyUsed[normalized] = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(errs) == len(configFiles) {
|
|
errs = append([]error{fmt.Errorf("no ssh config files could be opened: ")}, errs...)
|
|
return nil, errors.Join(errs...)
|
|
}
|
|
if len(discoveredPatterns) == 0 {
|
|
return nil, fmt.Errorf("no compatible hostnames found in ssh config files")
|
|
}
|
|
|
|
return discoveredPatterns, nil
|
|
}
|
|
|
|
func GetConnectionsList() ([]string, error) {
|
|
existing := GetAllConnStatus()
|
|
var currentlyRunning []string
|
|
var hasConnected []string
|
|
|
|
// populate all lists
|
|
for _, stat := range existing {
|
|
if stat.Connected {
|
|
currentlyRunning = append(currentlyRunning, stat.Connection)
|
|
}
|
|
|
|
if stat.HasConnected {
|
|
hasConnected = append(hasConnected, stat.Connection)
|
|
}
|
|
}
|
|
fromConfig, err := GetConnectionsFromConfig()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// sort into one final list and remove duplicates
|
|
alreadyUsed := make(map[string]struct{})
|
|
var connList []string
|
|
|
|
for _, subList := range [][]string{currentlyRunning, hasConnected, fromConfig} {
|
|
for _, pattern := range subList {
|
|
if _, used := alreadyUsed[pattern]; !used {
|
|
connList = append(connList, pattern)
|
|
alreadyUsed[pattern] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
|
|
return connList, nil
|
|
}
|
|
|
|
func GetConnectionsFromConfig() ([]string, error) {
|
|
home := wavebase.GetHomeDir()
|
|
localConfig := filepath.Join(home, ".ssh", "config")
|
|
systemConfig := filepath.Join("/etc", "ssh", "config")
|
|
sshConfigFiles := []string{localConfig, systemConfig}
|
|
ssh_config.ReloadConfigs()
|
|
|
|
return resolveSshConfigPatterns(sshConfigFiles)
|
|
}
|