mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-02-25 03:12:07 +01:00
WSL Updates for New Architecture (#1756)
This adapts most of the WSL code to follow the new architecture that ssh uses. --------- Co-authored-by: sawka <mike@commandline.dev>
This commit is contained in:
parent
b7dca41b9c
commit
ff5f26709c
cmd/server
pkg
blockcontroller
genconn
remote/conncontroller
service/clientservice
shellexec
util
wshrpc/wshserver
wshutil
wsl
wslconn
@ -35,7 +35,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshserver"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshserver"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ func beforeSendActivityUpdate(ctx context.Context) {
|
|||||||
activity.Blocks, _ = wstore.DBGetBlockViewCounts(ctx)
|
activity.Blocks, _ = wstore.DBGetBlockViewCounts(ctx)
|
||||||
activity.NumWindows, _ = wstore.DBGetCount[*waveobj.Window](ctx)
|
activity.NumWindows, _ = wstore.DBGetCount[*waveobj.Window](ctx)
|
||||||
activity.NumSSHConn = conncontroller.GetNumSSHHasConnected()
|
activity.NumSSHConn = conncontroller.GetNumSSHHasConnected()
|
||||||
activity.NumWSLConn = wsl.GetNumWSLHasConnected()
|
activity.NumWSLConn = wslconn.GetNumWSLHasConnected()
|
||||||
activity.NumWSNamed, activity.NumWS, _ = wstore.DBGetWSCounts(ctx)
|
activity.NumWSNamed, activity.NumWS, _ = wstore.DBGetWSCounts(ctx)
|
||||||
err := telemetry.UpdateActivity(ctx, activity)
|
err := telemetry.UpdateActivity(ctx, activity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -33,7 +33,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -369,7 +369,7 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
|
|||||||
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
wslConn := wsl.GetWslConn(credentialCtx, wslName, false)
|
wslConn := wslconn.GetWslConn(credentialCtx, wslName, false)
|
||||||
connStatus := wslConn.DeriveConnStatus()
|
connStatus := wslConn.DeriveConnStatus()
|
||||||
if connStatus.Status != conncontroller.Status_Connected {
|
if connStatus.Status != conncontroller.Status_Connected {
|
||||||
return nil, fmt.Errorf("not connected, cannot start shellproc")
|
return nil, fmt.Errorf("not connected, cannot start shellproc")
|
||||||
@ -377,10 +377,14 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
|
|||||||
|
|
||||||
// create jwt
|
// create jwt
|
||||||
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
|
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
|
||||||
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName())
|
sockName := wslConn.GetDomainSocketName()
|
||||||
|
rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}
|
||||||
|
jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error making jwt token: %w", err)
|
return nil, fmt.Errorf("error making jwt token: %w", err)
|
||||||
}
|
}
|
||||||
|
swapToken.SockName = sockName
|
||||||
|
swapToken.RpcContext = &rpcContext
|
||||||
swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||||
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||||
}
|
}
|
||||||
@ -747,7 +751,7 @@ func CheckConnStatus(blockId string) error {
|
|||||||
}
|
}
|
||||||
if strings.HasPrefix(connName, "wsl://") {
|
if strings.HasPrefix(connName, "wsl://") {
|
||||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||||
conn := wsl.GetWslConn(context.Background(), distroName, false)
|
conn := wslconn.GetWslConn(context.Background(), distroName, false)
|
||||||
connStatus := conn.DeriveConnStatus()
|
connStatus := conn.DeriveConnStatus()
|
||||||
if connStatus.Status != conncontroller.Status_Connected {
|
if connStatus.Status != conncontroller.Status_Connected {
|
||||||
return fmt.Errorf("not connected: %s", connStatus.Status)
|
return fmt.Errorf("not connected: %s", connStatus.Status)
|
||||||
|
@ -1,25 +1,24 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
// Copyright 2025, Command Line Inc.
|
// Copyright 2025, Command Line Inc.
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
package genconn
|
package genconn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ubuntu/gowsl"
|
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ ShellClient = (*WSLShellClient)(nil)
|
var _ ShellClient = (*WSLShellClient)(nil)
|
||||||
|
|
||||||
type WSLShellClient struct {
|
type WSLShellClient struct {
|
||||||
distro *gowsl.Distro
|
distro *wsl.Distro
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeWSLShellClient(distro *gowsl.Distro) *WSLShellClient {
|
func MakeWSLShellClient(distro *wsl.Distro) *WSLShellClient {
|
||||||
return &WSLShellClient{distro: distro}
|
return &WSLShellClient{distro: distro}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,8 +27,8 @@ func (c *WSLShellClient) MakeProcessController(cmdSpec CommandSpec) (ShellProces
|
|||||||
}
|
}
|
||||||
|
|
||||||
type WSLProcessController struct {
|
type WSLProcessController struct {
|
||||||
distro *gowsl.Distro
|
distro *wsl.Distro
|
||||||
cmd *gowsl.Cmd
|
cmd *wsl.WslCmd
|
||||||
lock *sync.Mutex
|
lock *sync.Mutex
|
||||||
once *sync.Once
|
once *sync.Once
|
||||||
stdinPiped bool
|
stdinPiped bool
|
||||||
@ -40,13 +39,13 @@ type WSLProcessController struct {
|
|||||||
cmdSpec CommandSpec
|
cmdSpec CommandSpec
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeWSLProcessController(distro *gowsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) {
|
func MakeWSLProcessController(distro *wsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) {
|
||||||
fullCmd, err := BuildShellCommand(cmdSpec)
|
fullCmd, err := BuildShellCommand(cmdSpec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to build shell command: %w", err)
|
return nil, fmt.Errorf("failed to build shell command: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := distro.Command(nil, fullCmd)
|
cmd := distro.WslCommand(context.Background(), fullCmd)
|
||||||
if cmd == nil {
|
if cmd == nil {
|
||||||
return nil, fmt.Errorf("failed to create WSL command")
|
return nil, fmt.Errorf("failed to create WSL command")
|
||||||
}
|
}
|
||||||
@ -87,9 +86,14 @@ func (w *WSLProcessController) Kill() {
|
|||||||
w.lock.Lock()
|
w.lock.Lock()
|
||||||
defer w.lock.Unlock()
|
defer w.lock.Unlock()
|
||||||
|
|
||||||
if w.cmd != nil && w.cmd.Process != nil {
|
if w.cmd == nil {
|
||||||
w.cmd.Process.Kill()
|
return
|
||||||
}
|
}
|
||||||
|
process := w.cmd.GetProcess()
|
||||||
|
if process == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
process.Kill()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WSLProcessController) StdinPipe() (io.WriteCloser, error) {
|
func (w *WSLProcessController) StdinPipe() (io.WriteCloser, error) {
|
||||||
|
@ -308,12 +308,13 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", "", fmt.Errorf("unable to start conn controller command: %w", err)
|
return false, "", "", fmt.Errorf("unable to start conn controller command: %w", err)
|
||||||
}
|
}
|
||||||
linesChan := wshutil.StreamToLinesChan(pipeRead)
|
linesChan := utilfn.StreamToLinesChan(pipeRead)
|
||||||
versionLine, err := wshutil.ReadLineWithTimeout(linesChan, 2*time.Second)
|
versionLine, err := utilfn.ReadLineWithTimeout(linesChan, 2*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sshSession.Close()
|
sshSession.Close()
|
||||||
return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
|
return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
|
||||||
}
|
}
|
||||||
|
conn.Infof(ctx, "actual connnserverversion: %q\n", versionLine)
|
||||||
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
|
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
|
||||||
isUpToDate, clientVersion, osArchStr, err := IsWshVersionUpToDate(ctx, versionLine)
|
isUpToDate, clientVersion, osArchStr, err := IsWshVersionUpToDate(ctx, versionLine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -326,11 +327,10 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
|
|||||||
}
|
}
|
||||||
conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
|
conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
|
||||||
if !isUpToDate {
|
if !isUpToDate {
|
||||||
|
|
||||||
sshSession.Close()
|
sshSession.Close()
|
||||||
return true, clientVersion, osArchStr, nil
|
return true, clientVersion, osArchStr, nil
|
||||||
}
|
}
|
||||||
jwtLine, err := wshutil.ReadLineWithTimeout(linesChan, 3*time.Second)
|
jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sshSession.Close()
|
sshSession.Close()
|
||||||
return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err)
|
return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err)
|
||||||
@ -401,12 +401,6 @@ type WshInstallOpts struct {
|
|||||||
NoUserPrompt bool
|
NoUserPrompt bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type WshInstallSkipError struct{}
|
|
||||||
|
|
||||||
func (wise *WshInstallSkipError) Error() string {
|
|
||||||
return "skipping wsh installation"
|
|
||||||
}
|
|
||||||
|
|
||||||
var queryTextTemplate = strings.TrimSpace(`
|
var queryTextTemplate = strings.TrimSpace(`
|
||||||
Wave requires Wave Shell Extensions to be
|
Wave requires Wave Shell Extensions to be
|
||||||
installed on %q
|
installed on %q
|
||||||
@ -555,7 +549,7 @@ func (conn *SSHConn) Connect(ctx context.Context, connFlags *wshrpc.ConnKeywords
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
if !connectAllowed {
|
if !connectAllowed {
|
||||||
conn.Infof(ctx, "cannot connect to when status is %q\n", conn.GetStatus())
|
conn.Infof(ctx, "cannot connect to %q when status is %q\n", conn.GetName(), conn.GetStatus())
|
||||||
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||||
}
|
}
|
||||||
conn.Infof(ctx, "trying to connect to %q...\n", conn.GetName())
|
conn.Infof(ctx, "trying to connect to %q...\n", conn.GetName())
|
||||||
@ -754,7 +748,12 @@ func (conn *SSHConn) connectInternal(ctx context.Context, connFlags *wshrpc.Conn
|
|||||||
conn.WithLock(func() {
|
conn.WithLock(func() {
|
||||||
conn.Client = client
|
conn.Client = client
|
||||||
})
|
})
|
||||||
go conn.waitForDisconnect()
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
panichandler.PanicHandler("conncontroller:waitForDisconnect", recover())
|
||||||
|
}()
|
||||||
|
conn.waitForDisconnect()
|
||||||
|
}()
|
||||||
fmtAddr := knownhosts.Normalize(fmt.Sprintf("%s@%s", client.User(), client.RemoteAddr().String()))
|
fmtAddr := knownhosts.Normalize(fmt.Sprintf("%s@%s", client.User(), client.RemoteAddr().String()))
|
||||||
conn.Infof(ctx, "normalized knownhosts address: %s\n", fmtAddr)
|
conn.Infof(ctx, "normalized knownhosts address: %s\n", fmtAddr)
|
||||||
clientDisplayName := fmt.Sprintf("%s (%s)", conn.GetName(), fmtAddr)
|
clientDisplayName := fmt.Sprintf("%s (%s)", conn.GetName(), fmtAddr)
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wcore"
|
"github.com/wavetermdev/waveterm/pkg/wcore"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ func (cs *ClientService) GetTab(tabId string) (*waveobj.Tab, error) {
|
|||||||
|
|
||||||
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
||||||
sshStatuses := conncontroller.GetAllConnStatus()
|
sshStatuses := conncontroller.GetAllConnStatus()
|
||||||
wslStatuses := wsl.GetAllConnStatus()
|
wslStatuses := wslconn.GetAllConnStatus()
|
||||||
return append(sshStatuses, wslStatuses...), nil
|
return append(sshStatuses, wslStatuses...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/creack/pty"
|
"github.com/creack/pty"
|
||||||
"github.com/wavetermdev/waveterm/pkg/blocklogger"
|
"github.com/wavetermdev/waveterm/pkg/blocklogger"
|
||||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||||
"github.com/wavetermdev/waveterm/pkg/remote"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
||||||
"github.com/wavetermdev/waveterm/pkg/util/pamparse"
|
"github.com/wavetermdev/waveterm/pkg/util/pamparse"
|
||||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||||
@ -30,7 +29,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultGracefulKillWait = 400 * time.Millisecond
|
const DefaultGracefulKillWait = 400 * time.Millisecond
|
||||||
@ -151,85 +150,100 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
|
|||||||
return pp.Write([]byte(s))
|
return pp.Write([]byte(s))
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
|
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wslconn.WslConn) (*ShellProc, error) {
|
||||||
utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second)
|
|
||||||
defer cancelFn()
|
|
||||||
client := conn.GetClient()
|
client := conn.GetClient()
|
||||||
shellPath := cmdOpts.ShellPath
|
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)")
|
||||||
|
connRoute := wshutil.MakeConnectionRouteId(conn.GetName())
|
||||||
|
rpcClient := wshclient.GetBareRpcClient()
|
||||||
|
remoteInfo, err := wshclient.RemoteGetInfoCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to obtain client info: %w", err)
|
||||||
|
}
|
||||||
|
log.Printf("client info collected: %+#v", remoteInfo)
|
||||||
|
var shellPath string
|
||||||
|
if cmdOpts.ShellPath != "" {
|
||||||
|
conn.Infof(ctx, "using shell path from command opts: %s\n", cmdOpts.ShellPath)
|
||||||
|
shellPath = cmdOpts.ShellPath
|
||||||
|
}
|
||||||
|
configShellPath := conn.GetConfigShellPath()
|
||||||
|
if shellPath == "" && configShellPath != "" {
|
||||||
|
conn.Infof(ctx, "using shell path from config (conn:shellpath): %s\n", configShellPath)
|
||||||
|
shellPath = configShellPath
|
||||||
|
}
|
||||||
|
if shellPath == "" && remoteInfo.Shell != "" {
|
||||||
|
conn.Infof(ctx, "using shell path detected on remote machine: %s\n", remoteInfo.Shell)
|
||||||
|
shellPath = remoteInfo.Shell
|
||||||
|
}
|
||||||
if shellPath == "" {
|
if shellPath == "" {
|
||||||
remoteShellPath, err := wsl.DetectShell(utilCtx, client)
|
conn.Infof(ctx, "no shell path detected, using default (/bin/bash)\n")
|
||||||
if err != nil {
|
shellPath = "/bin/bash"
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
shellPath = remoteShellPath
|
|
||||||
}
|
}
|
||||||
var shellOpts []string
|
var shellOpts []string
|
||||||
log.Printf("detected shell: %s", shellPath)
|
var cmdCombined string
|
||||||
|
log.Printf("detected shell %q for conn %q\n", shellPath, conn.GetName())
|
||||||
|
|
||||||
err := wsl.InstallClientRcFiles(utilCtx, client)
|
err = wshclient.RemoteInstallRcFilesCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error installing rc files: %v", err)
|
log.Printf("error installing rc files: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
shellOpts = append(shellOpts, cmdOpts.ShellOpts...)
|
||||||
homeDir := wsl.GetHomeDir(utilCtx, client)
|
shellType := shellutil.GetShellTypeFromShellPath(shellPath)
|
||||||
shellOpts = append(shellOpts, "~", "-d", client.Name())
|
conn.Infof(ctx, "detected shell type: %s\n", shellType)
|
||||||
|
|
||||||
var subShellOpts []string
|
|
||||||
|
|
||||||
if cmdStr == "" {
|
if cmdStr == "" {
|
||||||
/* transform command in order to inject environment vars */
|
/* transform command in order to inject environment vars */
|
||||||
if isBashShell(shellPath) {
|
if shellType == shellutil.ShellType_bash {
|
||||||
log.Printf("recognized as bash shell")
|
|
||||||
// add --rcfile
|
// add --rcfile
|
||||||
// cant set -l or -i with --rcfile
|
// cant set -l or -i with --rcfile
|
||||||
subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
|
bashPath := fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir)
|
||||||
} else if isFishShell(shellPath) {
|
shellOpts = append(shellOpts, "--rcfile", bashPath)
|
||||||
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
|
} else if shellType == shellutil.ShellType_fish {
|
||||||
subShellOpts = append(subShellOpts, "-C", carg)
|
if cmdOpts.Login {
|
||||||
} else if wsl.IsPowershell(shellPath) {
|
shellOpts = append(shellOpts, "-l")
|
||||||
|
}
|
||||||
|
// source the wave.fish file
|
||||||
|
waveFishPath := fmt.Sprintf("~/.waveterm/%s/wave.fish", shellutil.FishIntegrationDir)
|
||||||
|
carg := fmt.Sprintf(`"source %s"`, waveFishPath)
|
||||||
|
shellOpts = append(shellOpts, "-C", carg)
|
||||||
|
} else if shellType == shellutil.ShellType_pwsh {
|
||||||
|
pwshPath := fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir)
|
||||||
// powershell is weird about quoted path executables and requires an ampersand first
|
// powershell is weird about quoted path executables and requires an ampersand first
|
||||||
shellPath = "& " + shellPath
|
shellPath = "& " + shellPath
|
||||||
subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", fmt.Sprintf("%s/.waveterm/%s/wavepwsh.ps1", homeDir, shellutil.PwshIntegrationDir))
|
shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", pwshPath)
|
||||||
} else {
|
} else {
|
||||||
if cmdOpts.Login {
|
if cmdOpts.Login {
|
||||||
subShellOpts = append(subShellOpts, "-l")
|
shellOpts = append(shellOpts, "-l")
|
||||||
}
|
}
|
||||||
if cmdOpts.Interactive {
|
if cmdOpts.Interactive {
|
||||||
subShellOpts = append(subShellOpts, "-i")
|
shellOpts = append(shellOpts, "-i")
|
||||||
}
|
}
|
||||||
// can't set environment vars this way
|
// zdotdir setting moved to after session is created
|
||||||
// will try to do later if possible
|
|
||||||
}
|
}
|
||||||
|
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
||||||
} else {
|
} else {
|
||||||
|
// TODO check quoting of cmdStr
|
||||||
shellPath = cmdStr
|
shellPath = cmdStr
|
||||||
if cmdOpts.Login {
|
shellOpts = append(shellOpts, "-c", cmdStr)
|
||||||
subShellOpts = append(subShellOpts, "-l")
|
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
||||||
}
|
}
|
||||||
if cmdOpts.Interactive {
|
conn.Infof(ctx, "starting shell, using command: %s\n", cmdCombined)
|
||||||
subShellOpts = append(subShellOpts, "-i")
|
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)\n")
|
||||||
}
|
|
||||||
subShellOpts = append(subShellOpts, "-c", cmdStr)
|
if shellType == shellutil.ShellType_zsh {
|
||||||
|
zshDir := fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir)
|
||||||
|
conn.Infof(ctx, "setting ZDOTDIR to %s\n", zshDir)
|
||||||
|
cmdCombined = fmt.Sprintf(`ZDOTDIR=%s %s`, zshDir, cmdCombined)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
|
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("no jwt token provided to connection")
|
return nil, fmt.Errorf("no jwt token provided to connection")
|
||||||
}
|
}
|
||||||
if remote.IsPowershell(shellPath) {
|
cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
|
||||||
shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken))
|
|
||||||
} else {
|
|
||||||
shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken))
|
|
||||||
}
|
|
||||||
|
|
||||||
if isZshShell(shellPath) {
|
log.Printf("full combined command: %s", cmdCombined)
|
||||||
shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR=%s/.waveterm/%s`, homeDir, shellutil.ZshIntegrationDir))
|
ecmd := exec.Command("wsl.exe", "~", "-d", client.Name(), "--", "sh", "-c", cmdCombined)
|
||||||
}
|
|
||||||
shellOpts = append(shellOpts, shellPath)
|
|
||||||
shellOpts = append(shellOpts, subShellOpts...)
|
|
||||||
log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " "))
|
|
||||||
|
|
||||||
ecmd := exec.Command("wsl.exe", shellOpts...)
|
|
||||||
if termSize.Rows == 0 || termSize.Cols == 0 {
|
if termSize.Rows == 0 || termSize.Cols == 0 {
|
||||||
termSize.Rows = shellutil.DefaultTermRows
|
termSize.Rows = shellutil.DefaultTermRows
|
||||||
termSize.Cols = shellutil.DefaultTermCols
|
termSize.Cols = shellutil.DefaultTermCols
|
||||||
@ -237,6 +251,7 @@ func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr st
|
|||||||
if termSize.Rows <= 0 || termSize.Cols <= 0 {
|
if termSize.Rows <= 0 || termSize.Cols <= 0 {
|
||||||
return nil, fmt.Errorf("invalid term size: %v", termSize)
|
return nil, fmt.Errorf("invalid term size: %v", termSize)
|
||||||
}
|
}
|
||||||
|
shellutil.AddTokenSwapEntry(cmdOpts.SwapToken)
|
||||||
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
|
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -8,6 +8,9 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PacketParser struct {
|
type PacketParser struct {
|
||||||
@ -15,11 +18,38 @@ type PacketParser struct {
|
|||||||
Ch chan []byte
|
Ch chan []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ParseWithLinesChan(input chan utilfn.LineOutput, packetCh chan []byte, rawCh chan []byte) {
|
||||||
|
defer close(packetCh)
|
||||||
|
defer close(rawCh)
|
||||||
|
for {
|
||||||
|
// note this line doesn't have a trailing newline
|
||||||
|
line, ok := <-input
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if line.Error != nil {
|
||||||
|
log.Printf("ParseWithLinesChan: error reading line: %v", line.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(line.Line) <= 1 {
|
||||||
|
// just a blank line
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix([]byte(line.Line), []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix([]byte(line.Line), []byte{'}'}) {
|
||||||
|
// strip off the leading "##"
|
||||||
|
packetCh <- []byte(line.Line[3:len(line.Line)])
|
||||||
|
} else {
|
||||||
|
rawCh <- []byte(line.Line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error {
|
func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error {
|
||||||
bufReader := bufio.NewReader(input)
|
bufReader := bufio.NewReader(input)
|
||||||
defer close(packetCh)
|
defer close(packetCh)
|
||||||
defer close(rawCh)
|
defer close(rawCh)
|
||||||
for {
|
for {
|
||||||
|
// note this line does have a trailing newline
|
||||||
line, err := bufReader.ReadBytes('\n')
|
line, err := bufReader.ReadBytes('\n')
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
return nil
|
return nil
|
||||||
|
@ -7,7 +7,6 @@ import "regexp"
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`)
|
safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`)
|
||||||
psSafePattern = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
|
|
||||||
envVarNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
|
envVarNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -73,10 +72,6 @@ func HardQuotePowerShell(s string) string {
|
|||||||
return "\"\""
|
return "\"\""
|
||||||
}
|
}
|
||||||
|
|
||||||
if psSafePattern.MatchString(s) {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, 0, len(s)+5)
|
buf := make([]byte, 0, len(s)+5)
|
||||||
buf = append(buf, '"')
|
buf = append(buf, '"')
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ wsh completion fish | source
|
|||||||
$env:PATH = {{.WSHBINDIR_PWSH}} + "{{.PATHSEP}}" + $env:PATH
|
$env:PATH = {{.WSHBINDIR_PWSH}} + "{{.PATHSEP}}" + $env:PATH
|
||||||
|
|
||||||
# Source dynamic script from wsh token
|
# Source dynamic script from wsh token
|
||||||
$waveterm_swaptoken_output = wsh token $env:WAVETERM_SWAPTOKEN pwsh 2>$null
|
$waveterm_swaptoken_output = wsh token $env:WAVETERM_SWAPTOKEN pwsh 2>$null | Out-String
|
||||||
if ($waveterm_swaptoken_output -and $waveterm_swaptoken_output -ne "") {
|
if ($waveterm_swaptoken_output -and $waveterm_swaptoken_output -ne "") {
|
||||||
Invoke-Expression $waveterm_swaptoken_output
|
Invoke-Expression $waveterm_swaptoken_output
|
||||||
}
|
}
|
||||||
|
85
pkg/util/utilfn/streamtolines.go
Normal file
85
pkg/util/utilfn/streamtolines.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
// Copyright 2025, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package utilfn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LineOutput struct {
|
||||||
|
Line string
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
type lineBuf struct {
|
||||||
|
buf []byte
|
||||||
|
inLongLine bool
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxLineLength = 128 * 1024
|
||||||
|
|
||||||
|
func ReadLineWithTimeout(ch chan LineOutput, timeout time.Duration) (string, error) {
|
||||||
|
select {
|
||||||
|
case output := <-ch:
|
||||||
|
if output.Error != nil {
|
||||||
|
return "", output.Error
|
||||||
|
}
|
||||||
|
return output.Line, nil
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return "", context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]byte)) {
|
||||||
|
for len(readBuf) > 0 {
|
||||||
|
nlIdx := bytes.IndexByte(readBuf, '\n')
|
||||||
|
if nlIdx == -1 {
|
||||||
|
if lineBuf.inLongLine || len(lineBuf.buf)+len(readBuf) > maxLineLength {
|
||||||
|
lineBuf.buf = nil
|
||||||
|
lineBuf.inLongLine = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lineBuf.buf = append(lineBuf.buf, readBuf...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !lineBuf.inLongLine && len(lineBuf.buf)+nlIdx <= maxLineLength {
|
||||||
|
line := append(lineBuf.buf, readBuf[:nlIdx]...)
|
||||||
|
lineFn(line)
|
||||||
|
}
|
||||||
|
lineBuf.buf = nil
|
||||||
|
lineBuf.inLongLine = false
|
||||||
|
readBuf = readBuf[nlIdx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func StreamToLines(input io.Reader, lineFn func([]byte)) error {
|
||||||
|
var lineBuf lineBuf
|
||||||
|
readBuf := make([]byte, 16*1024)
|
||||||
|
for {
|
||||||
|
n, err := input.Read(readBuf)
|
||||||
|
streamToLines_processBuf(&lineBuf, readBuf[:n], lineFn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// starts a goroutine to drive the channel
|
||||||
|
// line output does not include the trailing newline
|
||||||
|
func StreamToLinesChan(input io.Reader) chan LineOutput {
|
||||||
|
ch := make(chan LineOutput)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
err := StreamToLines(input, func(line []byte) {
|
||||||
|
ch <- LineOutput{Line: string(line)}
|
||||||
|
})
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
ch <- LineOutput{Error: err}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return ch
|
||||||
|
}
|
@ -37,6 +37,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -609,7 +610,7 @@ func (ws *WshServer) ConnStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
||||||
rtn := wsl.GetAllConnStatus()
|
rtn := wslconn.GetAllConnStatus()
|
||||||
return rtn, nil
|
return rtn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -633,7 +634,7 @@ func (ws *WshServer) ConnEnsureCommand(ctx context.Context, data wshrpc.ConnExtD
|
|||||||
ctx = termCtxWithLogBlockId(ctx, data.LogBlockId)
|
ctx = termCtxWithLogBlockId(ctx, data.LogBlockId)
|
||||||
if strings.HasPrefix(data.ConnName, "wsl://") {
|
if strings.HasPrefix(data.ConnName, "wsl://") {
|
||||||
distroName := strings.TrimPrefix(data.ConnName, "wsl://")
|
distroName := strings.TrimPrefix(data.ConnName, "wsl://")
|
||||||
return wsl.EnsureConnection(ctx, distroName)
|
return wslconn.EnsureConnection(ctx, distroName)
|
||||||
}
|
}
|
||||||
return conncontroller.EnsureConnection(ctx, data.ConnName)
|
return conncontroller.EnsureConnection(ctx, data.ConnName)
|
||||||
}
|
}
|
||||||
@ -641,7 +642,7 @@ func (ws *WshServer) ConnEnsureCommand(ctx context.Context, data wshrpc.ConnExtD
|
|||||||
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
|
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
|
||||||
if strings.HasPrefix(connName, "wsl://") {
|
if strings.HasPrefix(connName, "wsl://") {
|
||||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||||
conn := wsl.GetWslConn(ctx, distroName, false)
|
conn := wslconn.GetWslConn(ctx, distroName, false)
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return fmt.Errorf("distro not found: %s", connName)
|
return fmt.Errorf("distro not found: %s", connName)
|
||||||
}
|
}
|
||||||
@ -664,7 +665,7 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connRequest wshrpc.
|
|||||||
connName := connRequest.Host
|
connName := connRequest.Host
|
||||||
if strings.HasPrefix(connName, "wsl://") {
|
if strings.HasPrefix(connName, "wsl://") {
|
||||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||||
conn := wsl.GetWslConn(ctx, distroName, false)
|
conn := wslconn.GetWslConn(ctx, distroName, false)
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return fmt.Errorf("connection not found: %s", connName)
|
return fmt.Errorf("connection not found: %s", connName)
|
||||||
}
|
}
|
||||||
@ -687,11 +688,11 @@ func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, data wshrpc.Co
|
|||||||
connName := data.ConnName
|
connName := data.ConnName
|
||||||
if strings.HasPrefix(connName, "wsl://") {
|
if strings.HasPrefix(connName, "wsl://") {
|
||||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||||
conn := wsl.GetWslConn(ctx, distroName, false)
|
conn := wslconn.GetWslConn(ctx, distroName, false)
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return fmt.Errorf("connection not found: %s", connName)
|
return fmt.Errorf("connection not found: %s", connName)
|
||||||
}
|
}
|
||||||
return conn.CheckAndInstallWsh(ctx, connName, &wsl.WshInstallOpts{Force: true, NoUserPrompt: true})
|
return conn.InstallWsh(ctx, "")
|
||||||
}
|
}
|
||||||
connOpts, err := remote.ParseOpts(connName)
|
connOpts, err := remote.ParseOpts(connName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,11 +4,10 @@
|
|||||||
package wshutil
|
package wshutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"time"
|
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||||
)
|
)
|
||||||
|
|
||||||
// special I/O wrappers for wshrpc
|
// special I/O wrappers for wshrpc
|
||||||
@ -16,81 +15,8 @@ import (
|
|||||||
// * stream (json lines)
|
// * stream (json lines)
|
||||||
// * websocket (json packets)
|
// * websocket (json packets)
|
||||||
|
|
||||||
type lineBuf struct {
|
|
||||||
buf []byte
|
|
||||||
inLongLine bool
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxLineLength = 128 * 1024
|
|
||||||
|
|
||||||
func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]byte)) {
|
|
||||||
for len(readBuf) > 0 {
|
|
||||||
nlIdx := bytes.IndexByte(readBuf, '\n')
|
|
||||||
if nlIdx == -1 {
|
|
||||||
if lineBuf.inLongLine || len(lineBuf.buf)+len(readBuf) > maxLineLength {
|
|
||||||
lineBuf.buf = nil
|
|
||||||
lineBuf.inLongLine = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lineBuf.buf = append(lineBuf.buf, readBuf...)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !lineBuf.inLongLine && len(lineBuf.buf)+nlIdx <= maxLineLength {
|
|
||||||
line := append(lineBuf.buf, readBuf[:nlIdx]...)
|
|
||||||
lineFn(line)
|
|
||||||
}
|
|
||||||
lineBuf.buf = nil
|
|
||||||
lineBuf.inLongLine = false
|
|
||||||
readBuf = readBuf[nlIdx+1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func StreamToLines(input io.Reader, lineFn func([]byte)) error {
|
|
||||||
var lineBuf lineBuf
|
|
||||||
readBuf := make([]byte, 16*1024)
|
|
||||||
for {
|
|
||||||
n, err := input.Read(readBuf)
|
|
||||||
streamToLines_processBuf(&lineBuf, readBuf[:n], lineFn)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type LineOutput struct {
|
|
||||||
Line string
|
|
||||||
Error error
|
|
||||||
}
|
|
||||||
|
|
||||||
// starts a goroutine to drive the channel
|
|
||||||
func StreamToLinesChan(input io.Reader) chan LineOutput {
|
|
||||||
ch := make(chan LineOutput)
|
|
||||||
go func() {
|
|
||||||
defer close(ch)
|
|
||||||
err := StreamToLines(input, func(line []byte) {
|
|
||||||
ch <- LineOutput{Line: string(line)}
|
|
||||||
})
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
ch <- LineOutput{Error: err}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadLineWithTimeout(ch chan LineOutput, timeout time.Duration) (string, error) {
|
|
||||||
select {
|
|
||||||
case output := <-ch:
|
|
||||||
if output.Error != nil {
|
|
||||||
return "", output.Error
|
|
||||||
}
|
|
||||||
return output.Line, nil
|
|
||||||
case <-time.After(timeout):
|
|
||||||
return "", context.DeadlineExceeded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
|
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
|
||||||
return StreamToLines(input, func(line []byte) {
|
return utilfn.StreamToLines(input, func(line []byte) {
|
||||||
output <- line
|
output <- line
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||||
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
|
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
|
||||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
@ -418,10 +419,10 @@ type WriteFlusher interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// blocking, returns if there is an error, or on EOF of input
|
// blocking, returns if there is an error, or on EOF of input
|
||||||
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
func HandleStdIOClient(logName string, input chan utilfn.LineOutput, output io.Writer) {
|
||||||
proxy := MakeRpcMultiProxy()
|
proxy := MakeRpcMultiProxy()
|
||||||
rawCh := make(chan []byte, DefaultInputChSize)
|
rawCh := make(chan []byte, DefaultInputChSize)
|
||||||
go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
|
go packetparser.ParseWithLinesChan(input, proxy.FromRemoteRawCh, rawCh)
|
||||||
doneCh := make(chan struct{})
|
doneCh := make(chan struct{})
|
||||||
var doneOnce sync.Once
|
var doneOnce sync.Once
|
||||||
closeDoneCh := func() {
|
closeDoneCh := func() {
|
||||||
@ -455,6 +456,9 @@ func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
|||||||
}()
|
}()
|
||||||
defer closeDoneCh()
|
defer closeDoneCh()
|
||||||
for msg := range rawCh {
|
for msg := range rawCh {
|
||||||
|
if !bytes.HasSuffix(msg, []byte{'\n'}) {
|
||||||
|
msg = append(msg, '\n')
|
||||||
|
}
|
||||||
log.Printf("[%s:stdout] %s", logName, msg)
|
log.Printf("[%s:stdout] %s", logName, msg)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -13,6 +13,10 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type WslName struct {
|
||||||
|
Distro string `json:"distro"`
|
||||||
|
}
|
||||||
|
|
||||||
func RegisteredDistros(ctx context.Context) (distros []Distro, err error) {
|
func RegisteredDistros(ctx context.Context) (distros []Distro, err error) {
|
||||||
return nil, fmt.Errorf("RegisteredDistros not implemented on this system")
|
return nil, fmt.Errorf("RegisteredDistros not implemented on this system")
|
||||||
}
|
}
|
||||||
|
@ -1,289 +0,0 @@
|
|||||||
// Copyright 2025, Command Line Inc.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
package wsl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"text/template"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
|
||||||
)
|
|
||||||
|
|
||||||
func DetectShell(ctx context.Context, client *Distro) (string, error) {
|
|
||||||
wshPath := GetWshPath(ctx, client)
|
|
||||||
|
|
||||||
cmd := client.WslCommand(ctx, wshPath+" shell")
|
|
||||||
log.Printf("shell detecting using command: %s shell", wshPath)
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
|
|
||||||
return "/bin/bash", nil
|
|
||||||
}
|
|
||||||
log.Printf("detecting shell: %s", out)
|
|
||||||
|
|
||||||
// quoting breaks this particular case
|
|
||||||
return strings.TrimSpace(string(out)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetWshVersion(ctx context.Context, client *Distro) (string, error) {
|
|
||||||
wshPath := GetWshPath(ctx, client)
|
|
||||||
|
|
||||||
cmd := client.WslCommand(ctx, wshPath+" version")
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.TrimSpace(string(out)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetWshPath(ctx context.Context, client *Distro) string {
|
|
||||||
defaultPath := wavebase.RemoteFullWshBinPath
|
|
||||||
|
|
||||||
cmd := client.WslCommand(ctx, "which wsh")
|
|
||||||
out, whichErr := cmd.Output()
|
|
||||||
if whichErr == nil {
|
|
||||||
return strings.TrimSpace(string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = client.WslCommand(ctx, "where.exe wsh")
|
|
||||||
out, whereErr := cmd.Output()
|
|
||||||
if whereErr == nil {
|
|
||||||
return strings.TrimSpace(string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
// no custom install, use default path
|
|
||||||
return defaultPath
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasBashInstalled(ctx context.Context, client *Distro) (bool, error) {
|
|
||||||
cmd := client.WslCommand(ctx, "which bash")
|
|
||||||
out, whichErr := cmd.Output()
|
|
||||||
if whichErr == nil && len(out) != 0 {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = client.WslCommand(ctx, "where.exe bash")
|
|
||||||
out, whereErr := cmd.Output()
|
|
||||||
if whereErr == nil && len(out) != 0 {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// note: we could also check in /bin/bash explicitly
|
|
||||||
// just in case that wasn't added to the path. but if
|
|
||||||
// that's true, we will most likely have worse
|
|
||||||
// problems going forward
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetClientOs(ctx context.Context, client *Distro) (string, error) {
|
|
||||||
cmd := client.WslCommand(ctx, "uname -s")
|
|
||||||
out, unixErr := cmd.CombinedOutput()
|
|
||||||
if unixErr == nil {
|
|
||||||
formatted := strings.ToLower(string(out))
|
|
||||||
formatted = strings.TrimSpace(formatted)
|
|
||||||
return formatted, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = client.WslCommand(ctx, "echo %OS%")
|
|
||||||
out, cmdErr := cmd.Output()
|
|
||||||
if cmdErr == nil && strings.TrimSpace(string(out)) != "%OS%" {
|
|
||||||
formatted := strings.ToLower(string(out))
|
|
||||||
formatted = strings.TrimSpace(formatted)
|
|
||||||
return strings.Split(formatted, "_")[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = client.WslCommand(ctx, "echo $env:OS")
|
|
||||||
out, psErr := cmd.Output()
|
|
||||||
if psErr == nil && strings.TrimSpace(string(out)) != "$env:OS" {
|
|
||||||
formatted := strings.ToLower(string(out))
|
|
||||||
formatted = strings.TrimSpace(formatted)
|
|
||||||
return strings.Split(formatted, "_")[0], nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetClientArch(ctx context.Context, client *Distro) (string, error) {
|
|
||||||
cmd := client.WslCommand(ctx, "uname -m")
|
|
||||||
out, unixErr := cmd.Output()
|
|
||||||
if unixErr == nil {
|
|
||||||
return utilfn.FilterValidArch(string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = client.WslCommand(ctx, "echo %PROCESSOR_ARCHITECTURE%")
|
|
||||||
out, cmdErr := cmd.CombinedOutput()
|
|
||||||
if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" {
|
|
||||||
return utilfn.FilterValidArch(string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = client.WslCommand(ctx, "echo $env:PROCESSOR_ARCHITECTURE")
|
|
||||||
out, psErr := cmd.CombinedOutput()
|
|
||||||
if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" {
|
|
||||||
return utilfn.FilterValidArch(string(out))
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
type CancellableCmd struct {
|
|
||||||
Cmd *WslCmd
|
|
||||||
Cancel func()
|
|
||||||
}
|
|
||||||
|
|
||||||
var installTemplatesRawBash = map[string]string{
|
|
||||||
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
|
|
||||||
"cat": `bash -c 'cat > {{.tempPath}}'`,
|
|
||||||
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
|
|
||||||
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
|
|
||||||
}
|
|
||||||
|
|
||||||
var installTemplatesRawDefault = map[string]string{
|
|
||||||
"mkdir": `mkdir -p {{.installDir}}`,
|
|
||||||
"cat": `cat > {{.tempPath}}`,
|
|
||||||
"mv": `mv {{.tempPath}} {{.installPath}}`,
|
|
||||||
"chmod": `chmod a+x {{.installPath}}`,
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeCancellableCommand(ctx context.Context, client *Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
|
|
||||||
cmdContext, cmdCancel := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
cmdStr := &bytes.Buffer{}
|
|
||||||
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
|
|
||||||
if err != nil {
|
|
||||||
cmdCancel()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cmdTemplate.Execute(cmdStr, words)
|
|
||||||
|
|
||||||
cmd := client.WslCommand(cmdContext, cmdStr.String())
|
|
||||||
return &CancellableCmd{cmd, cmdCancel}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CpHostToRemote(ctx context.Context, client *Distro, sourcePath string, destPath string) error {
|
|
||||||
// warning: does not work on windows remote yet
|
|
||||||
bashInstalled, err := hasBashInstalled(ctx, client)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var selectedTemplatesRaw map[string]string
|
|
||||||
if bashInstalled {
|
|
||||||
selectedTemplatesRaw = installTemplatesRawBash
|
|
||||||
} else {
|
|
||||||
log.Printf("bash is not installed on remote. attempting with default shell")
|
|
||||||
selectedTemplatesRaw = installTemplatesRawDefault
|
|
||||||
}
|
|
||||||
|
|
||||||
// I need to use toSlash here to force unix keybindings
|
|
||||||
// this means we can't guarantee it will work on a remote windows machine
|
|
||||||
var installWords = map[string]string{
|
|
||||||
"installDir": filepath.ToSlash(filepath.Dir(destPath)),
|
|
||||||
"tempPath": destPath + ".temp",
|
|
||||||
"installPath": destPath,
|
|
||||||
}
|
|
||||||
|
|
||||||
installStepCmds := make(map[string]*CancellableCmd)
|
|
||||||
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
|
|
||||||
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
installStepCmds[cmdName] = cancellableCmd
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = installStepCmds["mkdir"].Cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// the cat part of this is complicated since it requires stdin
|
|
||||||
catCmd := installStepCmds["cat"].Cmd
|
|
||||||
catStdin, err := catCmd.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = catCmd.Start()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
input, err := os.Open(sourcePath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
panichandler.PanicHandler("wslutil:cpHostToRemote:catStdin", recover())
|
|
||||||
}()
|
|
||||||
io.Copy(catStdin, input)
|
|
||||||
installStepCmds["cat"].Cancel()
|
|
||||||
|
|
||||||
// backup just in case something weird happens
|
|
||||||
// could cause potential race condition, but very
|
|
||||||
// unlikely
|
|
||||||
time.Sleep(time.Second * 1)
|
|
||||||
process := catCmd.GetProcess()
|
|
||||||
if process != nil {
|
|
||||||
process.Kill()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
catErr := catCmd.Wait()
|
|
||||||
if catErr != nil && !errors.Is(catErr, context.Canceled) {
|
|
||||||
return catErr
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = installStepCmds["mv"].Cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = installStepCmds["chmod"].Cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InstallClientRcFiles(ctx context.Context, client *Distro) error {
|
|
||||||
path := GetWshPath(ctx, client)
|
|
||||||
log.Printf("path to wsh searched is: %s", path)
|
|
||||||
|
|
||||||
cmd := client.WslCommand(ctx, path+" rcfiles")
|
|
||||||
_, err := cmd.Output()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetHomeDir(ctx context.Context, client *Distro) string {
|
|
||||||
// note: also works for powershell
|
|
||||||
cmd := client.WslCommand(ctx, `echo "$HOME"`)
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err == nil {
|
|
||||||
return strings.TrimSpace(string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = client.WslCommand(ctx, `echo %userprofile%`)
|
|
||||||
out, err = cmd.Output()
|
|
||||||
if err == nil {
|
|
||||||
return strings.TrimSpace(string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
return "~"
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsPowershell(shellPath string) bool {
|
|
||||||
// get the base path, and then check contains
|
|
||||||
shellBase := filepath.Base(shellPath)
|
|
||||||
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
|
|
||||||
}
|
|
@ -18,6 +18,10 @@ import (
|
|||||||
var RegisteredDistros = gowsl.RegisteredDistros
|
var RegisteredDistros = gowsl.RegisteredDistros
|
||||||
var DefaultDistro = gowsl.DefaultDistro
|
var DefaultDistro = gowsl.DefaultDistro
|
||||||
|
|
||||||
|
type WslName struct {
|
||||||
|
Distro string `json:"distro"`
|
||||||
|
}
|
||||||
|
|
||||||
type Distro struct {
|
type Distro struct {
|
||||||
gowsl.Distro
|
gowsl.Distro
|
||||||
}
|
}
|
||||||
|
533
pkg/wsl/wsl.go
533
pkg/wsl/wsl.go
@ -1,533 +0,0 @@
|
|||||||
// Copyright 2025, Command Line Inc.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
package wsl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/telemetry"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/userinput"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
Status_Init = "init"
|
|
||||||
Status_Connecting = "connecting"
|
|
||||||
Status_Connected = "connected"
|
|
||||||
Status_Disconnected = "disconnected"
|
|
||||||
Status_Error = "error"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DefaultConnectionTimeout = 60 * time.Second
|
|
||||||
|
|
||||||
var globalLock = &sync.Mutex{}
|
|
||||||
var clientControllerMap = make(map[string]*WslConn)
|
|
||||||
var activeConnCounter = &atomic.Int32{}
|
|
||||||
|
|
||||||
type WslConn struct {
|
|
||||||
Lock *sync.Mutex
|
|
||||||
Status string
|
|
||||||
Name WslName
|
|
||||||
Client *Distro
|
|
||||||
SockName string
|
|
||||||
DomainSockListener net.Listener
|
|
||||||
ConnController *WslCmd
|
|
||||||
Error string
|
|
||||||
HasWaiter *atomic.Bool
|
|
||||||
LastConnectTime int64
|
|
||||||
ActiveConnNum int
|
|
||||||
cancelFn func()
|
|
||||||
}
|
|
||||||
|
|
||||||
type WslName struct {
|
|
||||||
Distro string `json:"distro"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAllConnStatus() []wshrpc.ConnStatus {
|
|
||||||
globalLock.Lock()
|
|
||||||
defer globalLock.Unlock()
|
|
||||||
|
|
||||||
var connStatuses []wshrpc.ConnStatus
|
|
||||||
for _, conn := range clientControllerMap {
|
|
||||||
connStatuses = append(connStatuses, conn.DeriveConnStatus())
|
|
||||||
}
|
|
||||||
return connStatuses
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetNumWSLHasConnected() int {
|
|
||||||
globalLock.Lock()
|
|
||||||
defer globalLock.Unlock()
|
|
||||||
|
|
||||||
var connectedCount int
|
|
||||||
for _, conn := range clientControllerMap {
|
|
||||||
if conn.LastConnectTime > 0 {
|
|
||||||
connectedCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return connectedCount
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
|
|
||||||
conn.Lock.Lock()
|
|
||||||
defer conn.Lock.Unlock()
|
|
||||||
return wshrpc.ConnStatus{
|
|
||||||
Status: conn.Status,
|
|
||||||
Connected: conn.Status == Status_Connected,
|
|
||||||
WshEnabled: true, // always use wsh for wsl connections (temporary)
|
|
||||||
Connection: conn.GetName(),
|
|
||||||
HasConnected: (conn.LastConnectTime > 0),
|
|
||||||
ActiveConnNum: conn.ActiveConnNum,
|
|
||||||
Error: conn.Error,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) FireConnChangeEvent() {
|
|
||||||
status := conn.DeriveConnStatus()
|
|
||||||
event := wps.WaveEvent{
|
|
||||||
Event: wps.Event_ConnChange,
|
|
||||||
Scopes: []string{
|
|
||||||
fmt.Sprintf("connection:%s", conn.GetName()),
|
|
||||||
},
|
|
||||||
Data: status,
|
|
||||||
}
|
|
||||||
log.Printf("sending event: %+#v", event)
|
|
||||||
wps.Broker.Publish(event)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) Close() error {
|
|
||||||
defer conn.FireConnChangeEvent()
|
|
||||||
conn.WithLock(func() {
|
|
||||||
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
|
|
||||||
// if status is init, disconnected, or error don't change it
|
|
||||||
conn.Status = Status_Disconnected
|
|
||||||
}
|
|
||||||
conn.close_nolock()
|
|
||||||
})
|
|
||||||
// we must wait for the waiter to complete
|
|
||||||
startTime := time.Now()
|
|
||||||
for conn.HasWaiter.Load() {
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
if time.Since(startTime) > 2*time.Second {
|
|
||||||
return fmt.Errorf("timeout waiting for waiter to complete")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) close_nolock() {
|
|
||||||
// does not set status (that should happen at another level)
|
|
||||||
if conn.DomainSockListener != nil {
|
|
||||||
conn.DomainSockListener.Close()
|
|
||||||
conn.DomainSockListener = nil
|
|
||||||
}
|
|
||||||
if conn.ConnController != nil {
|
|
||||||
conn.cancelFn() // this suspends the conn controller
|
|
||||||
conn.ConnController = nil
|
|
||||||
}
|
|
||||||
if conn.Client != nil {
|
|
||||||
// conn.Client.Close() is not relevant here
|
|
||||||
// we do not want to completely close the wsl in case
|
|
||||||
// other applications are using it
|
|
||||||
conn.Client = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) GetDomainSocketName() string {
|
|
||||||
conn.Lock.Lock()
|
|
||||||
defer conn.Lock.Unlock()
|
|
||||||
return conn.SockName
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) GetStatus() string {
|
|
||||||
conn.Lock.Lock()
|
|
||||||
defer conn.Lock.Unlock()
|
|
||||||
return conn.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) GetName() string {
|
|
||||||
// no lock required because opts is immutable
|
|
||||||
return "wsl://" + conn.Name.Distro
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This function is does not set a listener for WslConn
|
|
||||||
* It is still required in order to set SockName
|
|
||||||
**/
|
|
||||||
func (conn *WslConn) OpenDomainSocketListener() error {
|
|
||||||
var allowed bool
|
|
||||||
conn.WithLock(func() {
|
|
||||||
if conn.Status != Status_Connecting {
|
|
||||||
allowed = false
|
|
||||||
} else {
|
|
||||||
allowed = true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if !allowed {
|
|
||||||
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
|
|
||||||
}
|
|
||||||
conn.WithLock(func() {
|
|
||||||
conn.SockName = wavebase.RemoteFullDomainSocketPath
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) StartConnServer() error {
|
|
||||||
utilCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancelFn()
|
|
||||||
var allowed bool
|
|
||||||
conn.WithLock(func() {
|
|
||||||
if conn.Status != Status_Connecting {
|
|
||||||
allowed = false
|
|
||||||
} else {
|
|
||||||
allowed = true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if !allowed {
|
|
||||||
return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
|
|
||||||
}
|
|
||||||
client := conn.GetClient()
|
|
||||||
wshPath := GetWshPath(utilCtx, client)
|
|
||||||
rpcCtx := wshrpc.RpcContext{
|
|
||||||
ClientType: wshrpc.ClientType_ConnServer,
|
|
||||||
Conn: conn.GetName(),
|
|
||||||
}
|
|
||||||
sockName := conn.GetDomainSocketName()
|
|
||||||
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
|
||||||
}
|
|
||||||
shellPath, err := DetectShell(utilCtx, client)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var cmdStr string
|
|
||||||
if IsPowershell(shellPath) {
|
|
||||||
cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
|
||||||
} else {
|
|
||||||
cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
|
||||||
}
|
|
||||||
log.Printf("starting conn controller: %s\n", cmdStr)
|
|
||||||
connServerCtx, cancelFn := context.WithCancel(context.Background())
|
|
||||||
conn.WithLock(func() {
|
|
||||||
if conn.cancelFn != nil {
|
|
||||||
conn.cancelFn()
|
|
||||||
}
|
|
||||||
conn.cancelFn = cancelFn
|
|
||||||
})
|
|
||||||
cmd := client.WslCommand(connServerCtx, cmdStr)
|
|
||||||
pipeRead, pipeWrite := io.Pipe()
|
|
||||||
inputPipeRead, inputPipeWrite := io.Pipe()
|
|
||||||
cmd.SetStdout(pipeWrite)
|
|
||||||
cmd.SetStderr(pipeWrite)
|
|
||||||
cmd.SetStdin(inputPipeRead)
|
|
||||||
err = cmd.Start()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to start conn controller: %w", err)
|
|
||||||
}
|
|
||||||
conn.WithLock(func() {
|
|
||||||
conn.ConnController = cmd
|
|
||||||
})
|
|
||||||
// service the I/O
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
panichandler.PanicHandler("wsl:StartConnServer:wait", recover())
|
|
||||||
}()
|
|
||||||
// wait for termination, clear the controller
|
|
||||||
defer conn.WithLock(func() {
|
|
||||||
conn.ConnController = nil
|
|
||||||
})
|
|
||||||
waitErr := cmd.Wait()
|
|
||||||
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
panichandler.PanicHandler("wsl:StartConnServer:handleStdIOClient", recover())
|
|
||||||
}()
|
|
||||||
logName := fmt.Sprintf("conncontroller:%s", conn.GetName())
|
|
||||||
wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite)
|
|
||||||
}()
|
|
||||||
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancelFn()
|
|
||||||
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("timeout waiting for connserver to register")
|
|
||||||
}
|
|
||||||
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type WshInstallOpts struct {
|
|
||||||
Force bool
|
|
||||||
NoUserPrompt bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error {
|
|
||||||
if opts == nil {
|
|
||||||
opts = &WshInstallOpts{}
|
|
||||||
}
|
|
||||||
client := conn.GetClient()
|
|
||||||
if client == nil {
|
|
||||||
return fmt.Errorf("client is nil")
|
|
||||||
}
|
|
||||||
// check that correct wsh extensions are installed
|
|
||||||
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
|
||||||
clientVersion, err := GetWshVersion(ctx, client)
|
|
||||||
if err == nil && clientVersion == expectedVersion && !opts.Force {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var queryText string
|
|
||||||
var title string
|
|
||||||
if opts.Force {
|
|
||||||
queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName)
|
|
||||||
title = "Install Wave Shell Extensions"
|
|
||||||
} else if err != nil {
|
|
||||||
queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+
|
|
||||||
"installed on `%s` \n"+
|
|
||||||
"to ensure a seamless experience. \n\n"+
|
|
||||||
"Would you like to install them?", clientDisplayName)
|
|
||||||
title = "Install Wave Shell Extensions"
|
|
||||||
} else {
|
|
||||||
// don't ask for upgrading the version
|
|
||||||
opts.NoUserPrompt = true
|
|
||||||
}
|
|
||||||
if !opts.NoUserPrompt {
|
|
||||||
request := &userinput.UserInputRequest{
|
|
||||||
ResponseType: "confirm",
|
|
||||||
QueryText: queryText,
|
|
||||||
Title: title,
|
|
||||||
Markdown: true,
|
|
||||||
CheckBoxMsg: "Don't show me this again",
|
|
||||||
}
|
|
||||||
response, err := userinput.GetUserInput(ctx, request)
|
|
||||||
if err != nil || !response.Confirm {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if response.CheckboxStat {
|
|
||||||
meta := waveobj.MetaMapType{
|
|
||||||
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
|
|
||||||
}
|
|
||||||
err := wconfig.SetBaseConfigValue(meta)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.Printf("attempting to install wsh to `%s`", clientDisplayName)
|
|
||||||
clientOs, err := GetClientOs(ctx, client)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
clientArch, err := GetClientArch(ctx, client)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// attempt to install extension
|
|
||||||
wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = CpHostToRemote(ctx, client, wshLocalPath, wavebase.RemoteFullWshBinPath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Printf("successfully installed wsh on %s\n", conn.GetName())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) GetClient() *Distro {
|
|
||||||
conn.Lock.Lock()
|
|
||||||
defer conn.Lock.Unlock()
|
|
||||||
return conn.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) Reconnect(ctx context.Context) error {
|
|
||||||
err := conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return conn.Connect(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
|
|
||||||
for {
|
|
||||||
status := conn.DeriveConnStatus()
|
|
||||||
if status.Status == Status_Connected {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if status.Status == Status_Connecting {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return fmt.Errorf("context timeout")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if status.Status == Status_Init || status.Status == Status_Disconnected {
|
|
||||||
return fmt.Errorf("disconnected")
|
|
||||||
}
|
|
||||||
if status.Status == Status_Error {
|
|
||||||
return fmt.Errorf("error: %v", status.Error)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("unknown status: %q", status.Status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// does not return an error since that error is stored inside of WslConn
|
|
||||||
func (conn *WslConn) Connect(ctx context.Context) error {
|
|
||||||
var connectAllowed bool
|
|
||||||
conn.WithLock(func() {
|
|
||||||
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
|
|
||||||
connectAllowed = false
|
|
||||||
} else {
|
|
||||||
conn.Status = Status_Connecting
|
|
||||||
conn.Error = ""
|
|
||||||
connectAllowed = true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
log.Printf("Connect %s\n", conn.GetName())
|
|
||||||
if !connectAllowed {
|
|
||||||
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
|
||||||
}
|
|
||||||
conn.FireConnChangeEvent()
|
|
||||||
err := conn.connectInternal(ctx)
|
|
||||||
conn.WithLock(func() {
|
|
||||||
if err != nil {
|
|
||||||
conn.Status = Status_Error
|
|
||||||
conn.Error = err.Error()
|
|
||||||
conn.close_nolock()
|
|
||||||
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
|
|
||||||
Conn: map[string]int{"wsl:connecterror": 1},
|
|
||||||
}, "wsl-connconnect")
|
|
||||||
} else {
|
|
||||||
conn.Status = Status_Connected
|
|
||||||
conn.LastConnectTime = time.Now().UnixMilli()
|
|
||||||
if conn.ActiveConnNum == 0 {
|
|
||||||
conn.ActiveConnNum = int(activeConnCounter.Add(1))
|
|
||||||
}
|
|
||||||
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
|
|
||||||
Conn: map[string]int{"wsl:connect": 1},
|
|
||||||
}, "wsl-connconnect")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
conn.FireConnChangeEvent()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) WithLock(fn func()) {
|
|
||||||
conn.Lock.Lock()
|
|
||||||
defer conn.Lock.Unlock()
|
|
||||||
fn()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) connectInternal(ctx context.Context) error {
|
|
||||||
client, err := GetDistro(ctx, conn.Name)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn.WithLock(func() {
|
|
||||||
conn.Client = client
|
|
||||||
})
|
|
||||||
err = conn.OpenDomainSocketListener()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
config := wconfig.GetWatcher().GetFullConfig()
|
|
||||||
wshAsk := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true)
|
|
||||||
installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !wshAsk})
|
|
||||||
if installErr != nil {
|
|
||||||
return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
|
|
||||||
}
|
|
||||||
csErr := conn.StartConnServer()
|
|
||||||
if csErr != nil {
|
|
||||||
return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
|
|
||||||
}
|
|
||||||
conn.HasWaiter.Store(true)
|
|
||||||
go conn.waitForDisconnect()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *WslConn) waitForDisconnect() {
|
|
||||||
defer conn.FireConnChangeEvent()
|
|
||||||
defer conn.HasWaiter.Store(false)
|
|
||||||
err := conn.ConnController.Wait()
|
|
||||||
conn.WithLock(func() {
|
|
||||||
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
|
|
||||||
// so we just set the status to "disconnected" here (not error)
|
|
||||||
// don't overwrite any existing error (or error status)
|
|
||||||
if err != nil && conn.Error == "" {
|
|
||||||
conn.Error = err.Error()
|
|
||||||
}
|
|
||||||
if conn.Status != Status_Error {
|
|
||||||
conn.Status = Status_Disconnected
|
|
||||||
}
|
|
||||||
conn.close_nolock()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func getConnInternal(name string) *WslConn {
|
|
||||||
globalLock.Lock()
|
|
||||||
defer globalLock.Unlock()
|
|
||||||
connName := WslName{Distro: name}
|
|
||||||
rtn := clientControllerMap[name]
|
|
||||||
if rtn == nil {
|
|
||||||
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, cancelFn: nil}
|
|
||||||
clientControllerMap[name] = rtn
|
|
||||||
}
|
|
||||||
return rtn
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
|
|
||||||
conn := getConnInternal(name)
|
|
||||||
if conn.Client == nil && shouldConnect {
|
|
||||||
conn.Connect(ctx)
|
|
||||||
}
|
|
||||||
return conn
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convenience function for ensuring a connection is established
|
|
||||||
func EnsureConnection(ctx context.Context, connName string) error {
|
|
||||||
if connName == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
conn := GetWslConn(ctx, connName, false)
|
|
||||||
if conn == nil {
|
|
||||||
return fmt.Errorf("connection not found: %s", connName)
|
|
||||||
}
|
|
||||||
connStatus := conn.DeriveConnStatus()
|
|
||||||
switch connStatus.Status {
|
|
||||||
case Status_Connected:
|
|
||||||
return nil
|
|
||||||
case Status_Connecting:
|
|
||||||
return conn.WaitForConnect(ctx)
|
|
||||||
case Status_Init, Status_Disconnected:
|
|
||||||
return conn.Connect(ctx)
|
|
||||||
case Status_Error:
|
|
||||||
return fmt.Errorf("connection error: %s", connStatus.Error)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown connection status %q", connStatus.Status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func DisconnectClient(connName string) error {
|
|
||||||
conn := getConnInternal(connName)
|
|
||||||
if conn == nil {
|
|
||||||
return fmt.Errorf("client %q not found", connName)
|
|
||||||
}
|
|
||||||
err := conn.Close()
|
|
||||||
return err
|
|
||||||
}
|
|
226
pkg/wslconn/wsl-util.go
Normal file
226
pkg/wslconn/wsl-util.go
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
// Copyright 2025, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wslconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/blocklogger"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/genconn"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||||
|
)
|
||||||
|
|
||||||
|
func hasBashInstalled(ctx context.Context, client *wsl.Distro) (bool, error) {
|
||||||
|
cmd := client.WslCommand(ctx, "which bash")
|
||||||
|
out, whichErr := cmd.Output()
|
||||||
|
if whichErr == nil && len(out) != 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = client.WslCommand(ctx, "where.exe bash")
|
||||||
|
out, whereErr := cmd.Output()
|
||||||
|
if whereErr == nil && len(out) != 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: we could also check in /bin/bash explicitly
|
||||||
|
// just in case that wasn't added to the path. but if
|
||||||
|
// that's true, we will most likely have worse
|
||||||
|
// problems going forward
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeOs(os string) string {
|
||||||
|
os = strings.ToLower(strings.TrimSpace(os))
|
||||||
|
return os
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeArch(arch string) string {
|
||||||
|
arch = strings.ToLower(strings.TrimSpace(arch))
|
||||||
|
switch arch {
|
||||||
|
case "x86_64", "amd64":
|
||||||
|
arch = "x64"
|
||||||
|
case "arm64", "aarch64":
|
||||||
|
arch = "arm64"
|
||||||
|
}
|
||||||
|
return arch
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns (os, arch, error)
|
||||||
|
// guaranteed to return a supported platform
|
||||||
|
func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string, string, error) {
|
||||||
|
blocklogger.Infof(ctx, "[conndebug] running `uname -sm` to detect client platform\n")
|
||||||
|
stdout, stderr, err := genconn.RunSimpleCommand(ctx, shell, genconn.CommandSpec{
|
||||||
|
Cmd: "uname -sm",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("error running uname -sm: %w, stderr: %s", err, stderr)
|
||||||
|
}
|
||||||
|
// Parse and normalize output
|
||||||
|
parts := strings.Fields(strings.ToLower(strings.TrimSpace(stdout)))
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", "", fmt.Errorf("unexpected output from uname: %s", stdout)
|
||||||
|
}
|
||||||
|
os, arch := normalizeOs(parts[0]), normalizeArch(parts[1])
|
||||||
|
if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return os, arch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClientPlatformFromOsArchStr(ctx context.Context, osArchStr string) (string, string, error) {
|
||||||
|
parts := strings.Fields(strings.TrimSpace(osArchStr))
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", "", fmt.Errorf("unexpected output from uname: %s", osArchStr)
|
||||||
|
}
|
||||||
|
os, arch := normalizeOs(parts[0]), normalizeArch(parts[1])
|
||||||
|
if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return os, arch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type CancellableCmd struct {
|
||||||
|
Cmd *wsl.WslCmd
|
||||||
|
Cancel func()
|
||||||
|
}
|
||||||
|
|
||||||
|
var installTemplatesRawBash = map[string]string{
|
||||||
|
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
|
||||||
|
"cat": `bash -c 'cat > {{.tempPath}}'`,
|
||||||
|
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
|
||||||
|
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var installTemplatesRawDefault = map[string]string{
|
||||||
|
"mkdir": `mkdir -p {{.installDir}}`,
|
||||||
|
"cat": `cat > {{.tempPath}}`,
|
||||||
|
"mv": `mv {{.tempPath}} {{.installPath}}`,
|
||||||
|
"chmod": `chmod a+x {{.installPath}}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeCancellableCommand(ctx context.Context, client *wsl.Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
|
||||||
|
cmdContext, cmdCancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
cmdStr := &bytes.Buffer{}
|
||||||
|
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
|
||||||
|
if err != nil {
|
||||||
|
cmdCancel()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cmdTemplate.Execute(cmdStr, words)
|
||||||
|
|
||||||
|
cmd := client.WslCommand(cmdContext, cmdStr.String())
|
||||||
|
return &CancellableCmd{cmd, cmdCancel}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CpWshToRemote(ctx context.Context, client *wsl.Distro, clientOs string, clientArch string) error {
|
||||||
|
wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// warning: does not work on windows remote yet
|
||||||
|
bashInstalled, err := hasBashInstalled(ctx, client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var selectedTemplatesRaw map[string]string
|
||||||
|
if bashInstalled {
|
||||||
|
selectedTemplatesRaw = installTemplatesRawBash
|
||||||
|
} else {
|
||||||
|
log.Printf("bash is not installed on remote. attempting with default shell")
|
||||||
|
selectedTemplatesRaw = installTemplatesRawDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
// I need to use toSlash here to force unix keybindings
|
||||||
|
// this means we can't guarantee it will work on a remote windows machine
|
||||||
|
var installWords = map[string]string{
|
||||||
|
"installDir": filepath.ToSlash(filepath.Dir(wavebase.RemoteFullWshBinPath)),
|
||||||
|
"tempPath": wavebase.RemoteFullWshBinPath + ".temp",
|
||||||
|
"installPath": wavebase.RemoteFullWshBinPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
blocklogger.Infof(ctx, "[conndebug] copying %q to remote server %q\n", wshLocalPath, wavebase.RemoteFullWshBinPath)
|
||||||
|
installStepCmds := make(map[string]*CancellableCmd)
|
||||||
|
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
|
||||||
|
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
installStepCmds[cmdName] = cancellableCmd
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = installStepCmds["mkdir"].Cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// the cat part of this is complicated since it requires stdin
|
||||||
|
catCmd := installStepCmds["cat"].Cmd
|
||||||
|
catStdin, err := catCmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = catCmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
input, err := os.Open(wshLocalPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot open local file %s to send to host: %v", wshLocalPath, err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
panichandler.PanicHandler("wslutil:cpHostToRemote:catStdin", recover())
|
||||||
|
}()
|
||||||
|
io.Copy(catStdin, input)
|
||||||
|
installStepCmds["cat"].Cancel()
|
||||||
|
|
||||||
|
// backup just in case something weird happens
|
||||||
|
// could cause potential race condition, but very
|
||||||
|
// unlikely
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
|
process := catCmd.GetProcess()
|
||||||
|
if process != nil {
|
||||||
|
process.Kill()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
catErr := catCmd.Wait()
|
||||||
|
if catErr != nil && !errors.Is(catErr, context.Canceled) {
|
||||||
|
return catErr
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = installStepCmds["mv"].Cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = installStepCmds["chmod"].Cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsPowershell(shellPath string) bool {
|
||||||
|
// get the base path, and then check contains
|
||||||
|
shellBase := filepath.Base(shellPath)
|
||||||
|
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
|
||||||
|
}
|
787
pkg/wslconn/wslconn.go
Normal file
787
pkg/wslconn/wslconn.go
Normal file
@ -0,0 +1,787 @@
|
|||||||
|
// Copyright 2025, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wslconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/blocklogger"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/genconn"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/telemetry"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/userinput"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Status_Init = "init"
|
||||||
|
Status_Connecting = "connecting"
|
||||||
|
Status_Connected = "connected"
|
||||||
|
Status_Disconnected = "disconnected"
|
||||||
|
Status_Error = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DefaultConnectionTimeout = 60 * time.Second
|
||||||
|
|
||||||
|
var globalLock = &sync.Mutex{}
|
||||||
|
var clientControllerMap = make(map[string]*WslConn)
|
||||||
|
var activeConnCounter = &atomic.Int32{}
|
||||||
|
|
||||||
|
type WslConn struct {
|
||||||
|
Lock *sync.Mutex
|
||||||
|
Status string
|
||||||
|
WshEnabled *atomic.Bool
|
||||||
|
Name wsl.WslName
|
||||||
|
Client *wsl.Distro
|
||||||
|
DomainSockName string // if "", then no domain socket
|
||||||
|
DomainSockListener net.Listener
|
||||||
|
ConnController *wsl.WslCmd
|
||||||
|
Error string
|
||||||
|
WshError string
|
||||||
|
NoWshReason string
|
||||||
|
WshVersion string
|
||||||
|
HasWaiter *atomic.Bool
|
||||||
|
LastConnectTime int64
|
||||||
|
ActiveConnNum int
|
||||||
|
cancelFn func()
|
||||||
|
}
|
||||||
|
|
||||||
|
var ConnServerCmdTemplate = strings.TrimSpace(
|
||||||
|
strings.Join([]string{
|
||||||
|
"%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm);",
|
||||||
|
"exec %s connserver --router",
|
||||||
|
}, "\n"))
|
||||||
|
|
||||||
|
func GetAllConnStatus() []wshrpc.ConnStatus {
|
||||||
|
globalLock.Lock()
|
||||||
|
defer globalLock.Unlock()
|
||||||
|
|
||||||
|
var connStatuses []wshrpc.ConnStatus
|
||||||
|
for _, conn := range clientControllerMap {
|
||||||
|
connStatuses = append(connStatuses, conn.DeriveConnStatus())
|
||||||
|
}
|
||||||
|
return connStatuses
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetNumWSLHasConnected() int {
|
||||||
|
globalLock.Lock()
|
||||||
|
defer globalLock.Unlock()
|
||||||
|
|
||||||
|
var connectedCount int
|
||||||
|
for _, conn := range clientControllerMap {
|
||||||
|
if conn.LastConnectTime > 0 {
|
||||||
|
connectedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connectedCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
return wshrpc.ConnStatus{
|
||||||
|
Status: conn.Status,
|
||||||
|
Connected: conn.Status == Status_Connected,
|
||||||
|
WshEnabled: true, // always use wsh for wsl connections (temporary)
|
||||||
|
Connection: conn.GetName(),
|
||||||
|
HasConnected: (conn.LastConnectTime > 0),
|
||||||
|
ActiveConnNum: conn.ActiveConnNum,
|
||||||
|
Error: conn.Error,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) Infof(ctx context.Context, format string, args ...any) {
|
||||||
|
log.Print(fmt.Sprintf("[conn:%s] ", conn.GetName()) + fmt.Sprintf(format, args...))
|
||||||
|
blocklogger.Infof(ctx, "[conndebug] "+format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) FireConnChangeEvent() {
|
||||||
|
status := conn.DeriveConnStatus()
|
||||||
|
event := wps.WaveEvent{
|
||||||
|
Event: wps.Event_ConnChange,
|
||||||
|
Scopes: []string{
|
||||||
|
fmt.Sprintf("connection:%s", conn.GetName()),
|
||||||
|
},
|
||||||
|
Data: status,
|
||||||
|
}
|
||||||
|
log.Printf("sending event: %+#v", event)
|
||||||
|
wps.Broker.Publish(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) Close() error {
|
||||||
|
defer conn.FireConnChangeEvent()
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
|
||||||
|
// if status is init, disconnected, or error don't change it
|
||||||
|
conn.Status = Status_Disconnected
|
||||||
|
}
|
||||||
|
conn.close_nolock()
|
||||||
|
})
|
||||||
|
// we must wait for the waiter to complete
|
||||||
|
startTime := time.Now()
|
||||||
|
for conn.HasWaiter.Load() {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
if time.Since(startTime) > 2*time.Second {
|
||||||
|
return fmt.Errorf("timeout waiting for waiter to complete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) close_nolock() {
|
||||||
|
// does not set status (that should happen at another level)
|
||||||
|
if conn.DomainSockListener != nil {
|
||||||
|
conn.DomainSockListener.Close()
|
||||||
|
conn.DomainSockListener = nil
|
||||||
|
conn.DomainSockName = ""
|
||||||
|
}
|
||||||
|
if conn.ConnController != nil {
|
||||||
|
conn.cancelFn() // this suspends the conn controller
|
||||||
|
conn.ConnController = nil
|
||||||
|
}
|
||||||
|
if conn.Client != nil {
|
||||||
|
// conn.Client.Close() is not relevant here
|
||||||
|
// we do not want to completely close the wsl in case
|
||||||
|
// other applications are using it
|
||||||
|
conn.Client = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) GetDomainSocketName() string {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
return conn.DomainSockName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) GetStatus() string {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
return conn.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) GetName() string {
|
||||||
|
// no lock required because opts is immutable
|
||||||
|
return "wsl://" + conn.Name.Distro
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This function is does not set a listener for WslConn
|
||||||
|
* It is still required in order to set SockName
|
||||||
|
**/
|
||||||
|
func (conn *WslConn) OpenDomainSocketListener(ctx context.Context) error {
|
||||||
|
conn.Infof(ctx, "running OpenDomainSocketListener...\n")
|
||||||
|
allowed := WithLockRtn(conn, func() bool {
|
||||||
|
return conn.Status == Status_Connecting
|
||||||
|
})
|
||||||
|
if !allowed {
|
||||||
|
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
listener, err := client.ListenUnix(sockName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to request connection domain socket: %v", err)
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
conn.Infof(ctx, "setting domain socket to %s\n", wavebase.RemoteFullDomainSocketPath)
|
||||||
|
conn.WithLock(func() {
|
||||||
|
conn.DomainSockName = wavebase.RemoteFullDomainSocketPath
|
||||||
|
//conn.DomainSockListener = listener
|
||||||
|
})
|
||||||
|
conn.Infof(ctx, "successfully connected domain socket\n")
|
||||||
|
/*
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
panichandler.PanicHandler("wslconn:OpenDomainSocketListener", recover())
|
||||||
|
}()
|
||||||
|
defer conn.WithLock(func() {
|
||||||
|
conn.DomainSockListener = nil
|
||||||
|
conn.DomainSockName = ""
|
||||||
|
})
|
||||||
|
wshutil.RunWshRpcOverListener(listener)
|
||||||
|
}()
|
||||||
|
*/
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) getWshPath() string {
|
||||||
|
config, ok := conn.getConnectionConfig()
|
||||||
|
if ok && config.ConnWshPath != "" {
|
||||||
|
return config.ConnWshPath
|
||||||
|
}
|
||||||
|
return wavebase.RemoteFullWshBinPath
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) GetConfigShellPath() string {
|
||||||
|
config, ok := conn.getConnectionConfig()
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return config.ConnShellPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns (needsInstall, clientVersion, osArchStr, error)
|
||||||
|
// if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr
|
||||||
|
// if clientVersion is set, then no osArchStr will be returned
|
||||||
|
func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (bool, string, string, error) {
|
||||||
|
conn.Infof(ctx, "running StartConnServer...\n")
|
||||||
|
allowed := WithLockRtn(conn, func() bool {
|
||||||
|
return conn.Status == Status_Connecting
|
||||||
|
})
|
||||||
|
if !allowed {
|
||||||
|
return false, "", "", fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||||
|
}
|
||||||
|
client := conn.GetClient()
|
||||||
|
wshPath := conn.getWshPath()
|
||||||
|
rpcCtx := wshrpc.RpcContext{
|
||||||
|
ClientType: wshrpc.ClientType_ConnServer,
|
||||||
|
Conn: conn.GetName(),
|
||||||
|
}
|
||||||
|
sockName := conn.GetDomainSocketName()
|
||||||
|
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", "", fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
||||||
|
}
|
||||||
|
conn.Infof(ctx, "WSL-NEWSESSION (StartConnServer)\n")
|
||||||
|
connServerCtx, cancelFn := context.WithCancel(context.Background())
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if conn.cancelFn != nil {
|
||||||
|
conn.cancelFn()
|
||||||
|
}
|
||||||
|
conn.cancelFn = cancelFn
|
||||||
|
})
|
||||||
|
cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath)
|
||||||
|
shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr))
|
||||||
|
cmd := client.WslCommand(connServerCtx, shWrappedCmdStr)
|
||||||
|
pipeRead, pipeWrite := io.Pipe()
|
||||||
|
inputPipeRead, inputPipeWrite := io.Pipe()
|
||||||
|
cmd.SetStdout(pipeWrite)
|
||||||
|
cmd.SetStderr(pipeWrite)
|
||||||
|
cmd.SetStdin(inputPipeRead)
|
||||||
|
log.Printf("starting conn controller: %q\n", cmdStr)
|
||||||
|
blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr)
|
||||||
|
err = cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return false, "", "", fmt.Errorf("unable to start conn controller cmd: %w", err)
|
||||||
|
}
|
||||||
|
linesChan := utilfn.StreamToLinesChan(pipeRead)
|
||||||
|
versionLine, err := utilfn.ReadLineWithTimeout(linesChan, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
cancelFn()
|
||||||
|
return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
|
||||||
|
}
|
||||||
|
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
|
||||||
|
isUpToDate, clientVersion, osArchStr, err := conncontroller.IsWshVersionUpToDate(ctx, versionLine)
|
||||||
|
if err != nil {
|
||||||
|
cancelFn()
|
||||||
|
return false, "", "", fmt.Errorf("error checking wsh version: %w", err)
|
||||||
|
}
|
||||||
|
if isUpToDate && !afterUpdate && os.Getenv(wavebase.WaveWshForceUpdateVarName) != "" {
|
||||||
|
isUpToDate = false
|
||||||
|
conn.Infof(ctx, "%s set, forcing wsh update\n", wavebase.WaveWshForceUpdateVarName)
|
||||||
|
}
|
||||||
|
conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
|
||||||
|
if !isUpToDate {
|
||||||
|
cancelFn()
|
||||||
|
return true, clientVersion, osArchStr, nil
|
||||||
|
}
|
||||||
|
jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
cancelFn()
|
||||||
|
return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err)
|
||||||
|
}
|
||||||
|
conn.Infof(ctx, "got jwt status line: %s\n", jwtLine)
|
||||||
|
if strings.TrimSpace(jwtLine) == wavebase.NeedJwtConst {
|
||||||
|
// write the jwt
|
||||||
|
conn.Infof(ctx, "writing jwt token to connserver\n")
|
||||||
|
_, err = fmt.Fprintf(inputPipeWrite, "%s\n", jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
cancelFn()
|
||||||
|
return false, clientVersion, "", fmt.Errorf("failed to write JWT token: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.WithLock(func() {
|
||||||
|
conn.ConnController = cmd
|
||||||
|
})
|
||||||
|
// service the I/O
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
panichandler.PanicHandler("wslconn:cmd.Wait", recover())
|
||||||
|
}()
|
||||||
|
// wait for termination, clear the controller
|
||||||
|
var waitErr error
|
||||||
|
defer conn.WithLock(func() {
|
||||||
|
if conn.ConnController != nil {
|
||||||
|
conn.WshEnabled.Store(false)
|
||||||
|
conn.NoWshReason = "connserver terminated"
|
||||||
|
if waitErr != nil {
|
||||||
|
conn.WshError = fmt.Sprintf("connserver terminated unexpectedly with error: %v", waitErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.ConnController = nil
|
||||||
|
})
|
||||||
|
waitErr = cmd.Wait()
|
||||||
|
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
panichandler.PanicHandler("wsl:StartConnServer:handleStdIOClient", recover())
|
||||||
|
}()
|
||||||
|
logName := fmt.Sprintf("wslconn:%s", conn.GetName())
|
||||||
|
wshutil.HandleStdIOClient(logName, linesChan, inputPipeWrite)
|
||||||
|
}()
|
||||||
|
conn.Infof(ctx, "connserver started, waiting for route to be registered\n")
|
||||||
|
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
|
||||||
|
if err != nil {
|
||||||
|
return false, clientVersion, "", fmt.Errorf("timeout waiting for connserver to register")
|
||||||
|
}
|
||||||
|
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
|
||||||
|
conn.Infof(ctx, "connserver is registered and ready\n")
|
||||||
|
return false, clientVersion, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type WshInstallOpts struct {
|
||||||
|
Force bool
|
||||||
|
NoUserPrompt bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var queryTextTemplate = strings.TrimSpace(`
|
||||||
|
Wave requires Wave Shell Extensions to be
|
||||||
|
installed on %q
|
||||||
|
to ensure a seamless experience.
|
||||||
|
|
||||||
|
Would you like to install them?
|
||||||
|
`)
|
||||||
|
|
||||||
|
func (conn *WslConn) UpdateWsh(ctx context.Context, clientDisplayName string, remoteInfo *wshrpc.RemoteInfo) error {
|
||||||
|
conn.Infof(ctx, "attempting to update wsh for connection %s (os:%s arch:%s version:%s)\n",
|
||||||
|
conn.GetName(), remoteInfo.ClientOs, remoteInfo.ClientArch, remoteInfo.ClientVersion)
|
||||||
|
client := conn.GetClient()
|
||||||
|
if client == nil {
|
||||||
|
return fmt.Errorf("cannot update wsh: ssh client is not connected")
|
||||||
|
}
|
||||||
|
err := CpWshToRemote(ctx, client, remoteInfo.ClientOs, remoteInfo.ClientArch)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error installing wsh to remote: %w", err)
|
||||||
|
}
|
||||||
|
conn.Infof(ctx, "successfully updated wsh on %s\n", conn.GetName())
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns (allowed, error)
|
||||||
|
func (conn *WslConn) getPermissionToInstallWsh(ctx context.Context, clientDisplayName string) (bool, error) {
|
||||||
|
conn.Infof(ctx, "running getPermissionToInstallWsh...\n")
|
||||||
|
queryText := fmt.Sprintf(queryTextTemplate, clientDisplayName)
|
||||||
|
title := "Install Wave Shell Extensions"
|
||||||
|
request := &userinput.UserInputRequest{
|
||||||
|
ResponseType: "confirm",
|
||||||
|
QueryText: queryText,
|
||||||
|
Title: title,
|
||||||
|
Markdown: true,
|
||||||
|
CheckBoxMsg: "Automatically install for all connections",
|
||||||
|
OkLabel: "Install wsh",
|
||||||
|
CancelLabel: "No wsh",
|
||||||
|
}
|
||||||
|
conn.Infof(ctx, "requesting user confirmation...\n")
|
||||||
|
response, err := userinput.GetUserInput(ctx, request)
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "error getting user input: %v\n", err)
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
conn.Infof(ctx, "user response to allowing wsh: %v\n", response.Confirm)
|
||||||
|
meta := make(map[string]any)
|
||||||
|
meta["conn:wshenabled"] = response.Confirm
|
||||||
|
conn.Infof(ctx, "writing conn:wshenabled=%v to connections.json\n", response.Confirm)
|
||||||
|
err = wconfig.SetConnectionsConfigValue(conn.GetName(), meta)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("warning: error writing to connections file: %v", err)
|
||||||
|
}
|
||||||
|
if !response.Confirm {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if response.CheckboxStat {
|
||||||
|
conn.Infof(ctx, "writing conn:askbeforewshinstall=false to settings.json\n")
|
||||||
|
meta := waveobj.MetaMapType{
|
||||||
|
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
|
||||||
|
}
|
||||||
|
setConfigErr := wconfig.SetBaseConfigValue(meta)
|
||||||
|
if setConfigErr != nil {
|
||||||
|
// this is not a critical error, just log and continue
|
||||||
|
log.Printf("warning: error writing to base config file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) InstallWsh(ctx context.Context, osArchStr string) error {
|
||||||
|
conn.Infof(ctx, "running installWsh...\n")
|
||||||
|
client := conn.GetClient()
|
||||||
|
if client == nil {
|
||||||
|
conn.Infof(ctx, "ERROR ssh client is not connected, cannot install\n")
|
||||||
|
return fmt.Errorf("ssh client is not connected, cannot install")
|
||||||
|
}
|
||||||
|
var clientOs, clientArch string
|
||||||
|
var err error
|
||||||
|
if osArchStr != "" {
|
||||||
|
clientOs, clientArch, err = GetClientPlatformFromOsArchStr(ctx, osArchStr)
|
||||||
|
} else {
|
||||||
|
clientOs, clientArch, err = GetClientPlatform(ctx, genconn.MakeWSLShellClient(client))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "ERROR detecting client platform: %v\n", err)
|
||||||
|
}
|
||||||
|
conn.Infof(ctx, "detected remote platform os:%s arch:%s\n", clientOs, clientArch)
|
||||||
|
err = CpWshToRemote(ctx, client, clientOs, clientArch)
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "ERROR copying wsh binary to remote: %v\n", err)
|
||||||
|
return fmt.Errorf("error copying wsh binary to remote: %w", err)
|
||||||
|
}
|
||||||
|
conn.Infof(ctx, "successfully installed wsh\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) GetClient() *wsl.Distro {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
return conn.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) Reconnect(ctx context.Context) error {
|
||||||
|
err := conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return conn.Connect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
|
||||||
|
for {
|
||||||
|
status := conn.DeriveConnStatus()
|
||||||
|
if status.Status == Status_Connected {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if status.Status == Status_Connecting {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("context timeout")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if status.Status == Status_Init || status.Status == Status_Disconnected {
|
||||||
|
return fmt.Errorf("disconnected")
|
||||||
|
}
|
||||||
|
if status.Status == Status_Error {
|
||||||
|
return fmt.Errorf("error: %v", status.Error)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unknown status: %q", status.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// does not return an error since that error is stored inside of WslConn
|
||||||
|
func (conn *WslConn) Connect(ctx context.Context) error {
|
||||||
|
var connectAllowed bool
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
|
||||||
|
connectAllowed = false
|
||||||
|
} else {
|
||||||
|
conn.Status = Status_Connecting
|
||||||
|
conn.Error = ""
|
||||||
|
connectAllowed = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
log.Printf("Connect %s\n", conn.GetName())
|
||||||
|
if !connectAllowed {
|
||||||
|
conn.Infof(ctx, "cannot connect to %q when status is %q\n", conn.GetName(), conn.GetStatus())
|
||||||
|
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||||
|
}
|
||||||
|
conn.FireConnChangeEvent()
|
||||||
|
err := conn.connectInternal(ctx)
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "ERROR %v\n\n", err)
|
||||||
|
conn.Status = Status_Error
|
||||||
|
conn.Error = err.Error()
|
||||||
|
conn.close_nolock()
|
||||||
|
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
|
||||||
|
Conn: map[string]int{"wsl:connecterror": 1},
|
||||||
|
}, "wsl-connconnect")
|
||||||
|
} else {
|
||||||
|
conn.Infof(ctx, "successfully connected (wsh:%v)\n\n", conn.WshEnabled.Load())
|
||||||
|
conn.Status = Status_Connected
|
||||||
|
conn.LastConnectTime = time.Now().UnixMilli()
|
||||||
|
if conn.ActiveConnNum == 0 {
|
||||||
|
conn.ActiveConnNum = int(activeConnCounter.Add(1))
|
||||||
|
}
|
||||||
|
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
|
||||||
|
Conn: map[string]int{"wsl:connect": 1},
|
||||||
|
}, "wsl-connconnect")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
conn.FireConnChangeEvent()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) WithLock(fn func()) {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithLockRtn[T any](conn *WslConn, fn func() T) T {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
return fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns (enable-wsh, ask-before-install)
|
||||||
|
func (conn *WslConn) getConnWshSettings() (bool, bool) {
|
||||||
|
config := wconfig.GetWatcher().GetFullConfig()
|
||||||
|
enableWsh := config.Settings.ConnWshEnabled
|
||||||
|
askBeforeInstall := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true)
|
||||||
|
connSettings, ok := conn.getConnectionConfig()
|
||||||
|
if ok {
|
||||||
|
if connSettings.ConnWshEnabled != nil {
|
||||||
|
enableWsh = *connSettings.ConnWshEnabled
|
||||||
|
}
|
||||||
|
// if the connection object exists, and conn:askbeforewshinstall is not set, the user must have allowed it
|
||||||
|
// TODO: in v0.12+ this should be removed. we'll explicitly write a "false" into the connection object on successful connection
|
||||||
|
if connSettings.ConnAskBeforeWshInstall == nil {
|
||||||
|
askBeforeInstall = false
|
||||||
|
} else {
|
||||||
|
askBeforeInstall = *connSettings.ConnAskBeforeWshInstall
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return enableWsh, askBeforeInstall
|
||||||
|
}
|
||||||
|
|
||||||
|
type WshCheckResult struct {
|
||||||
|
WshEnabled bool
|
||||||
|
ClientVersion string
|
||||||
|
NoWshReason string
|
||||||
|
WshError error
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns (wsh-enabled, clientVersion, text-reason, wshError)
|
||||||
|
func (conn *WslConn) tryEnableWsh(ctx context.Context, clientDisplayName string) WshCheckResult {
|
||||||
|
conn.Infof(ctx, "running tryEnableWsh...\n")
|
||||||
|
enableWsh, askBeforeInstall := conn.getConnWshSettings()
|
||||||
|
conn.Infof(ctx, "wsh settings enable:%v ask:%v\n", enableWsh, askBeforeInstall)
|
||||||
|
if !enableWsh {
|
||||||
|
return WshCheckResult{NoWshReason: "conn:wshenabled set to false"}
|
||||||
|
}
|
||||||
|
if askBeforeInstall {
|
||||||
|
allowInstall, err := conn.getPermissionToInstallWsh(ctx, clientDisplayName)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error getting permission to install wsh: %v\n", err)
|
||||||
|
return WshCheckResult{NoWshReason: "error getting user permission to install", WshError: err}
|
||||||
|
}
|
||||||
|
if !allowInstall {
|
||||||
|
return WshCheckResult{NoWshReason: "user selected not to install wsh extensions"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := conn.OpenDomainSocketListener(ctx)
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "ERROR opening domain socket listener: %v\n", err)
|
||||||
|
err = fmt.Errorf("error opening domain socket listener: %w", err)
|
||||||
|
return WshCheckResult{NoWshReason: "error opening domain socket", WshError: err}
|
||||||
|
}
|
||||||
|
needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false)
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "ERROR starting conn server: %v\n", err)
|
||||||
|
err = fmt.Errorf("error starting conn server: %w", err)
|
||||||
|
return WshCheckResult{NoWshReason: "error starting connserver", WshError: err}
|
||||||
|
}
|
||||||
|
if needsInstall {
|
||||||
|
conn.Infof(ctx, "connserver needs to be (re)installed\n")
|
||||||
|
err = conn.InstallWsh(ctx, osArchStr)
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "ERROR installing wsh: %v\n", err)
|
||||||
|
err = fmt.Errorf("error installing wsh: %w", err)
|
||||||
|
return WshCheckResult{NoWshReason: "error installing wsh/connserver", WshError: err}
|
||||||
|
}
|
||||||
|
needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true)
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err)
|
||||||
|
err = fmt.Errorf("error starting conn server (after install): %w", err)
|
||||||
|
return WshCheckResult{NoWshReason: "error starting connserver", WshError: err}
|
||||||
|
}
|
||||||
|
if needsInstall {
|
||||||
|
conn.Infof(ctx, "conn server not installed correctly (after install)\n")
|
||||||
|
err = fmt.Errorf("conn server not installed correctly (after install)")
|
||||||
|
return WshCheckResult{NoWshReason: "connserver not installed properly", WshError: err}
|
||||||
|
}
|
||||||
|
return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion}
|
||||||
|
} else {
|
||||||
|
return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) getConnectionConfig() (wshrpc.ConnKeywords, bool) {
|
||||||
|
config := wconfig.GetWatcher().GetFullConfig()
|
||||||
|
connSettings, ok := config.Connections[conn.GetName()]
|
||||||
|
if !ok {
|
||||||
|
return wshrpc.ConnKeywords{}, false
|
||||||
|
}
|
||||||
|
return connSettings, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) persistWshInstalled(ctx context.Context, result WshCheckResult) {
|
||||||
|
conn.WshEnabled.Store(result.WshEnabled)
|
||||||
|
conn.SetWshError(result.WshError)
|
||||||
|
conn.WithLock(func() {
|
||||||
|
conn.NoWshReason = result.NoWshReason
|
||||||
|
conn.WshVersion = result.ClientVersion
|
||||||
|
})
|
||||||
|
connConfig, ok := conn.getConnectionConfig()
|
||||||
|
if ok && connConfig.ConnWshEnabled != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
meta := make(map[string]any)
|
||||||
|
meta["conn:wshenabled"] = result.WshEnabled
|
||||||
|
err := wconfig.SetConnectionsConfigValue(conn.GetName(), meta)
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "WARN could not write conn:wshenabled=%v to connections.json: %v\n", result.WshEnabled, err)
|
||||||
|
log.Printf("warning: error writing to connections file: %v", err)
|
||||||
|
}
|
||||||
|
// doesn't return an error since none of this is required for connection to work
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) connectInternal(ctx context.Context) error {
|
||||||
|
conn.Infof(ctx, "connectInternal %s\n", conn.GetName())
|
||||||
|
client, err := wsl.GetDistro(ctx, conn.Name)
|
||||||
|
if err != nil {
|
||||||
|
conn.Infof(ctx, "ERROR GetDistro: %s\n", err)
|
||||||
|
log.Printf("error: failed to get distro %s: %s\n", conn.GetName(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn.WithLock(func() {
|
||||||
|
conn.Client = client
|
||||||
|
})
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
panichandler.PanicHandler("wsl-waitForDisconnect", recover())
|
||||||
|
}()
|
||||||
|
conn.waitForDisconnect()
|
||||||
|
}()
|
||||||
|
wshResult := conn.tryEnableWsh(ctx, conn.GetName())
|
||||||
|
if !wshResult.WshEnabled {
|
||||||
|
if wshResult.WshError != nil {
|
||||||
|
conn.Infof(ctx, "ERROR enabling wsh: %v\n", wshResult.WshError)
|
||||||
|
conn.Infof(ctx, "will connect with wsh disabled\n")
|
||||||
|
} else {
|
||||||
|
conn.Infof(ctx, "wsh not enabled: %s\n", wshResult.NoWshReason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.persistWshInstalled(ctx, wshResult)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) waitForDisconnect() {
|
||||||
|
log.Printf("wait for disconnect in %+#v", conn)
|
||||||
|
defer conn.FireConnChangeEvent()
|
||||||
|
defer conn.HasWaiter.Store(false)
|
||||||
|
err := conn.ConnController.Wait()
|
||||||
|
conn.WithLock(func() {
|
||||||
|
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
|
||||||
|
// so we just set the status to "disconnected" here (not error)
|
||||||
|
// don't overwrite any existing error (or error status)
|
||||||
|
if err != nil && conn.Error == "" {
|
||||||
|
conn.Error = err.Error()
|
||||||
|
}
|
||||||
|
if conn.Status != Status_Error {
|
||||||
|
conn.Status = Status_Disconnected
|
||||||
|
}
|
||||||
|
conn.close_nolock()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) SetWshError(err error) {
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if err == nil {
|
||||||
|
conn.WshError = ""
|
||||||
|
} else {
|
||||||
|
conn.WshError = err.Error()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) ClearWshError() {
|
||||||
|
conn.WithLock(func() {
|
||||||
|
conn.WshError = ""
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConnInternal(name string) *WslConn {
|
||||||
|
globalLock.Lock()
|
||||||
|
defer globalLock.Unlock()
|
||||||
|
connName := wsl.WslName{Distro: name}
|
||||||
|
rtn := clientControllerMap[name]
|
||||||
|
if rtn == nil {
|
||||||
|
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, WshEnabled: &atomic.Bool{}, HasWaiter: &atomic.Bool{}, cancelFn: nil}
|
||||||
|
clientControllerMap[name] = rtn
|
||||||
|
}
|
||||||
|
return rtn
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
|
||||||
|
conn := getConnInternal(name)
|
||||||
|
if conn.Client == nil && shouldConnect {
|
||||||
|
conn.Connect(ctx)
|
||||||
|
}
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience function for ensuring a connection is established
|
||||||
|
func EnsureConnection(ctx context.Context, connName string) error {
|
||||||
|
if connName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
conn := GetWslConn(ctx, connName, false)
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("connection not found: %s", connName)
|
||||||
|
}
|
||||||
|
connStatus := conn.DeriveConnStatus()
|
||||||
|
switch connStatus.Status {
|
||||||
|
case Status_Connected:
|
||||||
|
return nil
|
||||||
|
case Status_Connecting:
|
||||||
|
return conn.WaitForConnect(ctx)
|
||||||
|
case Status_Init, Status_Disconnected:
|
||||||
|
return conn.Connect(ctx)
|
||||||
|
case Status_Error:
|
||||||
|
return fmt.Errorf("connection error: %s", connStatus.Error)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown connection status %q", connStatus.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DisconnectClient(connName string) error {
|
||||||
|
conn := getConnInternal(connName)
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("client %q not found", connName)
|
||||||
|
}
|
||||||
|
err := conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user