mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-30 23:01:30 +01:00
WSL Updates for New Architecture (#1756)
This adapts most of the WSL code to follow the new architecture that ssh uses. --------- Co-authored-by: sawka <mike@commandline.dev>
This commit is contained in:
parent
b7dca41b9c
commit
ff5f26709c
@ -35,7 +35,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshserver"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||
)
|
||||
|
||||
@ -145,7 +145,7 @@ func beforeSendActivityUpdate(ctx context.Context) {
|
||||
activity.Blocks, _ = wstore.DBGetBlockViewCounts(ctx)
|
||||
activity.NumWindows, _ = wstore.DBGetCount[*waveobj.Window](ctx)
|
||||
activity.NumSSHConn = conncontroller.GetNumSSHHasConnected()
|
||||
activity.NumWSLConn = wsl.GetNumWSLHasConnected()
|
||||
activity.NumWSLConn = wslconn.GetNumWSLHasConnected()
|
||||
activity.NumWSNamed, activity.NumWS, _ = wstore.DBGetWSCounts(ctx)
|
||||
err := telemetry.UpdateActivity(ctx, activity)
|
||||
if err != nil {
|
||||
|
@ -33,7 +33,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||
)
|
||||
|
||||
@ -369,7 +369,7 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
|
||||
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
wslConn := wsl.GetWslConn(credentialCtx, wslName, false)
|
||||
wslConn := wslconn.GetWslConn(credentialCtx, wslName, false)
|
||||
connStatus := wslConn.DeriveConnStatus()
|
||||
if connStatus.Status != conncontroller.Status_Connected {
|
||||
return nil, fmt.Errorf("not connected, cannot start shellproc")
|
||||
@ -377,10 +377,14 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
|
||||
|
||||
// create jwt
|
||||
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
|
||||
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName())
|
||||
sockName := wslConn.GetDomainSocketName()
|
||||
rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}
|
||||
jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making jwt token: %w", err)
|
||||
}
|
||||
swapToken.SockName = sockName
|
||||
swapToken.RpcContext = &rpcContext
|
||||
swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||
}
|
||||
@ -747,7 +751,7 @@ func CheckConnStatus(blockId string) error {
|
||||
}
|
||||
if strings.HasPrefix(connName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||
conn := wsl.GetWslConn(context.Background(), distroName, false)
|
||||
conn := wslconn.GetWslConn(context.Background(), distroName, false)
|
||||
connStatus := conn.DeriveConnStatus()
|
||||
if connStatus.Status != conncontroller.Status_Connected {
|
||||
return fmt.Errorf("not connected: %s", connStatus.Status)
|
||||
|
@ -1,25 +1,24 @@
|
||||
//go:build windows
|
||||
|
||||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package genconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/ubuntu/gowsl"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
)
|
||||
|
||||
var _ ShellClient = (*WSLShellClient)(nil)
|
||||
|
||||
type WSLShellClient struct {
|
||||
distro *gowsl.Distro
|
||||
distro *wsl.Distro
|
||||
}
|
||||
|
||||
func MakeWSLShellClient(distro *gowsl.Distro) *WSLShellClient {
|
||||
func MakeWSLShellClient(distro *wsl.Distro) *WSLShellClient {
|
||||
return &WSLShellClient{distro: distro}
|
||||
}
|
||||
|
||||
@ -28,8 +27,8 @@ func (c *WSLShellClient) MakeProcessController(cmdSpec CommandSpec) (ShellProces
|
||||
}
|
||||
|
||||
type WSLProcessController struct {
|
||||
distro *gowsl.Distro
|
||||
cmd *gowsl.Cmd
|
||||
distro *wsl.Distro
|
||||
cmd *wsl.WslCmd
|
||||
lock *sync.Mutex
|
||||
once *sync.Once
|
||||
stdinPiped bool
|
||||
@ -40,13 +39,13 @@ type WSLProcessController struct {
|
||||
cmdSpec CommandSpec
|
||||
}
|
||||
|
||||
func MakeWSLProcessController(distro *gowsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) {
|
||||
func MakeWSLProcessController(distro *wsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) {
|
||||
fullCmd, err := BuildShellCommand(cmdSpec)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build shell command: %w", err)
|
||||
}
|
||||
|
||||
cmd := distro.Command(nil, fullCmd)
|
||||
cmd := distro.WslCommand(context.Background(), fullCmd)
|
||||
if cmd == nil {
|
||||
return nil, fmt.Errorf("failed to create WSL command")
|
||||
}
|
||||
@ -87,9 +86,14 @@ func (w *WSLProcessController) Kill() {
|
||||
w.lock.Lock()
|
||||
defer w.lock.Unlock()
|
||||
|
||||
if w.cmd != nil && w.cmd.Process != nil {
|
||||
w.cmd.Process.Kill()
|
||||
if w.cmd == nil {
|
||||
return
|
||||
}
|
||||
process := w.cmd.GetProcess()
|
||||
if process == nil {
|
||||
return
|
||||
}
|
||||
process.Kill()
|
||||
}
|
||||
|
||||
func (w *WSLProcessController) StdinPipe() (io.WriteCloser, error) {
|
||||
|
@ -308,12 +308,13 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
|
||||
if err != nil {
|
||||
return false, "", "", fmt.Errorf("unable to start conn controller command: %w", err)
|
||||
}
|
||||
linesChan := wshutil.StreamToLinesChan(pipeRead)
|
||||
versionLine, err := wshutil.ReadLineWithTimeout(linesChan, 2*time.Second)
|
||||
linesChan := utilfn.StreamToLinesChan(pipeRead)
|
||||
versionLine, err := utilfn.ReadLineWithTimeout(linesChan, 2*time.Second)
|
||||
if err != nil {
|
||||
sshSession.Close()
|
||||
return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
|
||||
}
|
||||
conn.Infof(ctx, "actual connnserverversion: %q\n", versionLine)
|
||||
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
|
||||
isUpToDate, clientVersion, osArchStr, err := IsWshVersionUpToDate(ctx, versionLine)
|
||||
if err != nil {
|
||||
@ -326,11 +327,10 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
|
||||
}
|
||||
conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
|
||||
if !isUpToDate {
|
||||
|
||||
sshSession.Close()
|
||||
return true, clientVersion, osArchStr, nil
|
||||
}
|
||||
jwtLine, err := wshutil.ReadLineWithTimeout(linesChan, 3*time.Second)
|
||||
jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second)
|
||||
if err != nil {
|
||||
sshSession.Close()
|
||||
return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err)
|
||||
@ -401,12 +401,6 @@ type WshInstallOpts struct {
|
||||
NoUserPrompt bool
|
||||
}
|
||||
|
||||
type WshInstallSkipError struct{}
|
||||
|
||||
func (wise *WshInstallSkipError) Error() string {
|
||||
return "skipping wsh installation"
|
||||
}
|
||||
|
||||
var queryTextTemplate = strings.TrimSpace(`
|
||||
Wave requires Wave Shell Extensions to be
|
||||
installed on %q
|
||||
@ -555,7 +549,7 @@ func (conn *SSHConn) Connect(ctx context.Context, connFlags *wshrpc.ConnKeywords
|
||||
}
|
||||
})
|
||||
if !connectAllowed {
|
||||
conn.Infof(ctx, "cannot connect to when status is %q\n", conn.GetStatus())
|
||||
conn.Infof(ctx, "cannot connect to %q when status is %q\n", conn.GetName(), conn.GetStatus())
|
||||
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
conn.Infof(ctx, "trying to connect to %q...\n", conn.GetName())
|
||||
@ -754,7 +748,12 @@ func (conn *SSHConn) connectInternal(ctx context.Context, connFlags *wshrpc.Conn
|
||||
conn.WithLock(func() {
|
||||
conn.Client = client
|
||||
})
|
||||
go conn.waitForDisconnect()
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("conncontroller:waitForDisconnect", recover())
|
||||
}()
|
||||
conn.waitForDisconnect()
|
||||
}()
|
||||
fmtAddr := knownhosts.Normalize(fmt.Sprintf("%s@%s", client.User(), client.RemoteAddr().String()))
|
||||
conn.Infof(ctx, "normalized knownhosts address: %s\n", fmtAddr)
|
||||
clientDisplayName := fmt.Sprintf("%s (%s)", conn.GetName(), fmtAddr)
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
||||
"github.com/wavetermdev/waveterm/pkg/wcore"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||
)
|
||||
|
||||
@ -42,7 +42,7 @@ func (cs *ClientService) GetTab(tabId string) (*waveobj.Tab, error) {
|
||||
|
||||
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
||||
sshStatuses := conncontroller.GetAllConnStatus()
|
||||
wslStatuses := wsl.GetAllConnStatus()
|
||||
wslStatuses := wslconn.GetAllConnStatus()
|
||||
return append(sshStatuses, wslStatuses...), nil
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,6 @@ import (
|
||||
"github.com/creack/pty"
|
||||
"github.com/wavetermdev/waveterm/pkg/blocklogger"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/pamparse"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||
@ -30,7 +29,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||
)
|
||||
|
||||
const DefaultGracefulKillWait = 400 * time.Millisecond
|
||||
@ -151,85 +150,100 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
|
||||
return pp.Write([]byte(s))
|
||||
}
|
||||
|
||||
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
|
||||
utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancelFn()
|
||||
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wslconn.WslConn) (*ShellProc, error) {
|
||||
client := conn.GetClient()
|
||||
shellPath := cmdOpts.ShellPath
|
||||
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)")
|
||||
connRoute := wshutil.MakeConnectionRouteId(conn.GetName())
|
||||
rpcClient := wshclient.GetBareRpcClient()
|
||||
remoteInfo, err := wshclient.RemoteGetInfoCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to obtain client info: %w", err)
|
||||
}
|
||||
log.Printf("client info collected: %+#v", remoteInfo)
|
||||
var shellPath string
|
||||
if cmdOpts.ShellPath != "" {
|
||||
conn.Infof(ctx, "using shell path from command opts: %s\n", cmdOpts.ShellPath)
|
||||
shellPath = cmdOpts.ShellPath
|
||||
}
|
||||
configShellPath := conn.GetConfigShellPath()
|
||||
if shellPath == "" && configShellPath != "" {
|
||||
conn.Infof(ctx, "using shell path from config (conn:shellpath): %s\n", configShellPath)
|
||||
shellPath = configShellPath
|
||||
}
|
||||
if shellPath == "" && remoteInfo.Shell != "" {
|
||||
conn.Infof(ctx, "using shell path detected on remote machine: %s\n", remoteInfo.Shell)
|
||||
shellPath = remoteInfo.Shell
|
||||
}
|
||||
if shellPath == "" {
|
||||
remoteShellPath, err := wsl.DetectShell(utilCtx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shellPath = remoteShellPath
|
||||
conn.Infof(ctx, "no shell path detected, using default (/bin/bash)\n")
|
||||
shellPath = "/bin/bash"
|
||||
}
|
||||
var shellOpts []string
|
||||
log.Printf("detected shell: %s", shellPath)
|
||||
var cmdCombined string
|
||||
log.Printf("detected shell %q for conn %q\n", shellPath, conn.GetName())
|
||||
|
||||
err := wsl.InstallClientRcFiles(utilCtx, client)
|
||||
err = wshclient.RemoteInstallRcFilesCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000})
|
||||
if err != nil {
|
||||
log.Printf("error installing rc files: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
homeDir := wsl.GetHomeDir(utilCtx, client)
|
||||
shellOpts = append(shellOpts, "~", "-d", client.Name())
|
||||
|
||||
var subShellOpts []string
|
||||
shellOpts = append(shellOpts, cmdOpts.ShellOpts...)
|
||||
shellType := shellutil.GetShellTypeFromShellPath(shellPath)
|
||||
conn.Infof(ctx, "detected shell type: %s\n", shellType)
|
||||
|
||||
if cmdStr == "" {
|
||||
/* transform command in order to inject environment vars */
|
||||
if isBashShell(shellPath) {
|
||||
log.Printf("recognized as bash shell")
|
||||
if shellType == shellutil.ShellType_bash {
|
||||
// add --rcfile
|
||||
// cant set -l or -i with --rcfile
|
||||
subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
|
||||
} else if isFishShell(shellPath) {
|
||||
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
|
||||
subShellOpts = append(subShellOpts, "-C", carg)
|
||||
} else if wsl.IsPowershell(shellPath) {
|
||||
bashPath := fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir)
|
||||
shellOpts = append(shellOpts, "--rcfile", bashPath)
|
||||
} else if shellType == shellutil.ShellType_fish {
|
||||
if cmdOpts.Login {
|
||||
shellOpts = append(shellOpts, "-l")
|
||||
}
|
||||
// source the wave.fish file
|
||||
waveFishPath := fmt.Sprintf("~/.waveterm/%s/wave.fish", shellutil.FishIntegrationDir)
|
||||
carg := fmt.Sprintf(`"source %s"`, waveFishPath)
|
||||
shellOpts = append(shellOpts, "-C", carg)
|
||||
} else if shellType == shellutil.ShellType_pwsh {
|
||||
pwshPath := fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir)
|
||||
// powershell is weird about quoted path executables and requires an ampersand first
|
||||
shellPath = "& " + shellPath
|
||||
subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", fmt.Sprintf("%s/.waveterm/%s/wavepwsh.ps1", homeDir, shellutil.PwshIntegrationDir))
|
||||
shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", pwshPath)
|
||||
} else {
|
||||
if cmdOpts.Login {
|
||||
subShellOpts = append(subShellOpts, "-l")
|
||||
shellOpts = append(shellOpts, "-l")
|
||||
}
|
||||
if cmdOpts.Interactive {
|
||||
subShellOpts = append(subShellOpts, "-i")
|
||||
shellOpts = append(shellOpts, "-i")
|
||||
}
|
||||
// can't set environment vars this way
|
||||
// will try to do later if possible
|
||||
// zdotdir setting moved to after session is created
|
||||
}
|
||||
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
||||
} else {
|
||||
// TODO check quoting of cmdStr
|
||||
shellPath = cmdStr
|
||||
if cmdOpts.Login {
|
||||
subShellOpts = append(subShellOpts, "-l")
|
||||
}
|
||||
if cmdOpts.Interactive {
|
||||
subShellOpts = append(subShellOpts, "-i")
|
||||
}
|
||||
subShellOpts = append(subShellOpts, "-c", cmdStr)
|
||||
shellOpts = append(shellOpts, "-c", cmdStr)
|
||||
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
||||
}
|
||||
conn.Infof(ctx, "starting shell, using command: %s\n", cmdCombined)
|
||||
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)\n")
|
||||
|
||||
if shellType == shellutil.ShellType_zsh {
|
||||
zshDir := fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir)
|
||||
conn.Infof(ctx, "setting ZDOTDIR to %s\n", zshDir)
|
||||
cmdCombined = fmt.Sprintf(`ZDOTDIR=%s %s`, zshDir, cmdCombined)
|
||||
}
|
||||
|
||||
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no jwt token provided to connection")
|
||||
}
|
||||
if remote.IsPowershell(shellPath) {
|
||||
shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken))
|
||||
} else {
|
||||
shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken))
|
||||
}
|
||||
cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
|
||||
|
||||
if isZshShell(shellPath) {
|
||||
shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR=%s/.waveterm/%s`, homeDir, shellutil.ZshIntegrationDir))
|
||||
}
|
||||
shellOpts = append(shellOpts, shellPath)
|
||||
shellOpts = append(shellOpts, subShellOpts...)
|
||||
log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " "))
|
||||
|
||||
ecmd := exec.Command("wsl.exe", shellOpts...)
|
||||
log.Printf("full combined command: %s", cmdCombined)
|
||||
ecmd := exec.Command("wsl.exe", "~", "-d", client.Name(), "--", "sh", "-c", cmdCombined)
|
||||
if termSize.Rows == 0 || termSize.Cols == 0 {
|
||||
termSize.Rows = shellutil.DefaultTermRows
|
||||
termSize.Cols = shellutil.DefaultTermCols
|
||||
@ -237,6 +251,7 @@ func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr st
|
||||
if termSize.Rows <= 0 || termSize.Cols <= 0 {
|
||||
return nil, fmt.Errorf("invalid term size: %v", termSize)
|
||||
}
|
||||
shellutil.AddTokenSwapEntry(cmdOpts.SwapToken)
|
||||
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -8,6 +8,9 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
)
|
||||
|
||||
type PacketParser struct {
|
||||
@ -15,11 +18,38 @@ type PacketParser struct {
|
||||
Ch chan []byte
|
||||
}
|
||||
|
||||
func ParseWithLinesChan(input chan utilfn.LineOutput, packetCh chan []byte, rawCh chan []byte) {
|
||||
defer close(packetCh)
|
||||
defer close(rawCh)
|
||||
for {
|
||||
// note this line doesn't have a trailing newline
|
||||
line, ok := <-input
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if line.Error != nil {
|
||||
log.Printf("ParseWithLinesChan: error reading line: %v", line.Error)
|
||||
return
|
||||
}
|
||||
if len(line.Line) <= 1 {
|
||||
// just a blank line
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix([]byte(line.Line), []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix([]byte(line.Line), []byte{'}'}) {
|
||||
// strip off the leading "##"
|
||||
packetCh <- []byte(line.Line[3:len(line.Line)])
|
||||
} else {
|
||||
rawCh <- []byte(line.Line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error {
|
||||
bufReader := bufio.NewReader(input)
|
||||
defer close(packetCh)
|
||||
defer close(rawCh)
|
||||
for {
|
||||
// note this line does have a trailing newline
|
||||
line, err := bufReader.ReadBytes('\n')
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
|
@ -7,7 +7,6 @@ import "regexp"
|
||||
|
||||
var (
|
||||
safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`)
|
||||
psSafePattern = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
|
||||
envVarNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
|
||||
)
|
||||
|
||||
@ -73,10 +72,6 @@ func HardQuotePowerShell(s string) string {
|
||||
return "\"\""
|
||||
}
|
||||
|
||||
if psSafePattern.MatchString(s) {
|
||||
return s
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, len(s)+5)
|
||||
buf = append(buf, '"')
|
||||
|
||||
|
@ -153,7 +153,7 @@ wsh completion fish | source
|
||||
$env:PATH = {{.WSHBINDIR_PWSH}} + "{{.PATHSEP}}" + $env:PATH
|
||||
|
||||
# Source dynamic script from wsh token
|
||||
$waveterm_swaptoken_output = wsh token $env:WAVETERM_SWAPTOKEN pwsh 2>$null
|
||||
$waveterm_swaptoken_output = wsh token $env:WAVETERM_SWAPTOKEN pwsh 2>$null | Out-String
|
||||
if ($waveterm_swaptoken_output -and $waveterm_swaptoken_output -ne "") {
|
||||
Invoke-Expression $waveterm_swaptoken_output
|
||||
}
|
||||
|
85
pkg/util/utilfn/streamtolines.go
Normal file
85
pkg/util/utilfn/streamtolines.go
Normal file
@ -0,0 +1,85 @@
|
||||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package utilfn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LineOutput struct {
|
||||
Line string
|
||||
Error error
|
||||
}
|
||||
|
||||
type lineBuf struct {
|
||||
buf []byte
|
||||
inLongLine bool
|
||||
}
|
||||
|
||||
const maxLineLength = 128 * 1024
|
||||
|
||||
func ReadLineWithTimeout(ch chan LineOutput, timeout time.Duration) (string, error) {
|
||||
select {
|
||||
case output := <-ch:
|
||||
if output.Error != nil {
|
||||
return "", output.Error
|
||||
}
|
||||
return output.Line, nil
|
||||
case <-time.After(timeout):
|
||||
return "", context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]byte)) {
|
||||
for len(readBuf) > 0 {
|
||||
nlIdx := bytes.IndexByte(readBuf, '\n')
|
||||
if nlIdx == -1 {
|
||||
if lineBuf.inLongLine || len(lineBuf.buf)+len(readBuf) > maxLineLength {
|
||||
lineBuf.buf = nil
|
||||
lineBuf.inLongLine = true
|
||||
return
|
||||
}
|
||||
lineBuf.buf = append(lineBuf.buf, readBuf...)
|
||||
return
|
||||
}
|
||||
if !lineBuf.inLongLine && len(lineBuf.buf)+nlIdx <= maxLineLength {
|
||||
line := append(lineBuf.buf, readBuf[:nlIdx]...)
|
||||
lineFn(line)
|
||||
}
|
||||
lineBuf.buf = nil
|
||||
lineBuf.inLongLine = false
|
||||
readBuf = readBuf[nlIdx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
func StreamToLines(input io.Reader, lineFn func([]byte)) error {
|
||||
var lineBuf lineBuf
|
||||
readBuf := make([]byte, 16*1024)
|
||||
for {
|
||||
n, err := input.Read(readBuf)
|
||||
streamToLines_processBuf(&lineBuf, readBuf[:n], lineFn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// starts a goroutine to drive the channel
|
||||
// line output does not include the trailing newline
|
||||
func StreamToLinesChan(input io.Reader) chan LineOutput {
|
||||
ch := make(chan LineOutput)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamToLines(input, func(line []byte) {
|
||||
ch <- LineOutput{Line: string(line)}
|
||||
})
|
||||
if err != nil && err != io.EOF {
|
||||
ch <- LineOutput{Error: err}
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
@ -37,6 +37,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
"github.com/wavetermdev/waveterm/pkg/wslconn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||
)
|
||||
|
||||
@ -609,7 +610,7 @@ func (ws *WshServer) ConnStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus
|
||||
}
|
||||
|
||||
func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
||||
rtn := wsl.GetAllConnStatus()
|
||||
rtn := wslconn.GetAllConnStatus()
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
@ -633,7 +634,7 @@ func (ws *WshServer) ConnEnsureCommand(ctx context.Context, data wshrpc.ConnExtD
|
||||
ctx = termCtxWithLogBlockId(ctx, data.LogBlockId)
|
||||
if strings.HasPrefix(data.ConnName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(data.ConnName, "wsl://")
|
||||
return wsl.EnsureConnection(ctx, distroName)
|
||||
return wslconn.EnsureConnection(ctx, distroName)
|
||||
}
|
||||
return conncontroller.EnsureConnection(ctx, data.ConnName)
|
||||
}
|
||||
@ -641,7 +642,7 @@ func (ws *WshServer) ConnEnsureCommand(ctx context.Context, data wshrpc.ConnExtD
|
||||
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
|
||||
if strings.HasPrefix(connName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||
conn := wsl.GetWslConn(ctx, distroName, false)
|
||||
conn := wslconn.GetWslConn(ctx, distroName, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("distro not found: %s", connName)
|
||||
}
|
||||
@ -664,7 +665,7 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connRequest wshrpc.
|
||||
connName := connRequest.Host
|
||||
if strings.HasPrefix(connName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||
conn := wsl.GetWslConn(ctx, distroName, false)
|
||||
conn := wslconn.GetWslConn(ctx, distroName, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("connection not found: %s", connName)
|
||||
}
|
||||
@ -687,11 +688,11 @@ func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, data wshrpc.Co
|
||||
connName := data.ConnName
|
||||
if strings.HasPrefix(connName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||
conn := wsl.GetWslConn(ctx, distroName, false)
|
||||
conn := wslconn.GetWslConn(ctx, distroName, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("connection not found: %s", connName)
|
||||
}
|
||||
return conn.CheckAndInstallWsh(ctx, connName, &wsl.WshInstallOpts{Force: true, NoUserPrompt: true})
|
||||
return conn.InstallWsh(ctx, "")
|
||||
}
|
||||
connOpts, err := remote.ParseOpts(connName)
|
||||
if err != nil {
|
||||
|
@ -4,11 +4,10 @@
|
||||
package wshutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
)
|
||||
|
||||
// special I/O wrappers for wshrpc
|
||||
@ -16,81 +15,8 @@ import (
|
||||
// * stream (json lines)
|
||||
// * websocket (json packets)
|
||||
|
||||
type lineBuf struct {
|
||||
buf []byte
|
||||
inLongLine bool
|
||||
}
|
||||
|
||||
const maxLineLength = 128 * 1024
|
||||
|
||||
func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]byte)) {
|
||||
for len(readBuf) > 0 {
|
||||
nlIdx := bytes.IndexByte(readBuf, '\n')
|
||||
if nlIdx == -1 {
|
||||
if lineBuf.inLongLine || len(lineBuf.buf)+len(readBuf) > maxLineLength {
|
||||
lineBuf.buf = nil
|
||||
lineBuf.inLongLine = true
|
||||
return
|
||||
}
|
||||
lineBuf.buf = append(lineBuf.buf, readBuf...)
|
||||
return
|
||||
}
|
||||
if !lineBuf.inLongLine && len(lineBuf.buf)+nlIdx <= maxLineLength {
|
||||
line := append(lineBuf.buf, readBuf[:nlIdx]...)
|
||||
lineFn(line)
|
||||
}
|
||||
lineBuf.buf = nil
|
||||
lineBuf.inLongLine = false
|
||||
readBuf = readBuf[nlIdx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
func StreamToLines(input io.Reader, lineFn func([]byte)) error {
|
||||
var lineBuf lineBuf
|
||||
readBuf := make([]byte, 16*1024)
|
||||
for {
|
||||
n, err := input.Read(readBuf)
|
||||
streamToLines_processBuf(&lineBuf, readBuf[:n], lineFn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type LineOutput struct {
|
||||
Line string
|
||||
Error error
|
||||
}
|
||||
|
||||
// starts a goroutine to drive the channel
|
||||
func StreamToLinesChan(input io.Reader) chan LineOutput {
|
||||
ch := make(chan LineOutput)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamToLines(input, func(line []byte) {
|
||||
ch <- LineOutput{Line: string(line)}
|
||||
})
|
||||
if err != nil && err != io.EOF {
|
||||
ch <- LineOutput{Error: err}
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func ReadLineWithTimeout(ch chan LineOutput, timeout time.Duration) (string, error) {
|
||||
select {
|
||||
case output := <-ch:
|
||||
if output.Error != nil {
|
||||
return "", output.Error
|
||||
}
|
||||
return output.Line, nil
|
||||
case <-time.After(timeout):
|
||||
return "", context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
|
||||
return StreamToLines(input, func(line []byte) {
|
||||
return utilfn.StreamToLines(input, func(line []byte) {
|
||||
output <- line
|
||||
})
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"golang.org/x/term"
|
||||
@ -418,10 +419,10 @@ type WriteFlusher interface {
|
||||
}
|
||||
|
||||
// blocking, returns if there is an error, or on EOF of input
|
||||
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
||||
func HandleStdIOClient(logName string, input chan utilfn.LineOutput, output io.Writer) {
|
||||
proxy := MakeRpcMultiProxy()
|
||||
rawCh := make(chan []byte, DefaultInputChSize)
|
||||
go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
|
||||
go packetparser.ParseWithLinesChan(input, proxy.FromRemoteRawCh, rawCh)
|
||||
doneCh := make(chan struct{})
|
||||
var doneOnce sync.Once
|
||||
closeDoneCh := func() {
|
||||
@ -455,6 +456,9 @@ func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
||||
}()
|
||||
defer closeDoneCh()
|
||||
for msg := range rawCh {
|
||||
if !bytes.HasSuffix(msg, []byte{'\n'}) {
|
||||
msg = append(msg, '\n')
|
||||
}
|
||||
log.Printf("[%s:stdout] %s", logName, msg)
|
||||
}
|
||||
}()
|
||||
|
@ -13,6 +13,10 @@ import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
type WslName struct {
|
||||
Distro string `json:"distro"`
|
||||
}
|
||||
|
||||
func RegisteredDistros(ctx context.Context) (distros []Distro, err error) {
|
||||
return nil, fmt.Errorf("RegisteredDistros not implemented on this system")
|
||||
}
|
||||
|
@ -1,289 +0,0 @@
|
||||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wsl
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
)
|
||||
|
||||
func DetectShell(ctx context.Context, client *Distro) (string, error) {
|
||||
wshPath := GetWshPath(ctx, client)
|
||||
|
||||
cmd := client.WslCommand(ctx, wshPath+" shell")
|
||||
log.Printf("shell detecting using command: %s shell", wshPath)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
|
||||
return "/bin/bash", nil
|
||||
}
|
||||
log.Printf("detecting shell: %s", out)
|
||||
|
||||
// quoting breaks this particular case
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
func GetWshVersion(ctx context.Context, client *Distro) (string, error) {
|
||||
wshPath := GetWshPath(ctx, client)
|
||||
|
||||
cmd := client.WslCommand(ctx, wshPath+" version")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
func GetWshPath(ctx context.Context, client *Distro) string {
|
||||
defaultPath := wavebase.RemoteFullWshBinPath
|
||||
|
||||
cmd := client.WslCommand(ctx, "which wsh")
|
||||
out, whichErr := cmd.Output()
|
||||
if whichErr == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "where.exe wsh")
|
||||
out, whereErr := cmd.Output()
|
||||
if whereErr == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// no custom install, use default path
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
func hasBashInstalled(ctx context.Context, client *Distro) (bool, error) {
|
||||
cmd := client.WslCommand(ctx, "which bash")
|
||||
out, whichErr := cmd.Output()
|
||||
if whichErr == nil && len(out) != 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "where.exe bash")
|
||||
out, whereErr := cmd.Output()
|
||||
if whereErr == nil && len(out) != 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// note: we could also check in /bin/bash explicitly
|
||||
// just in case that wasn't added to the path. but if
|
||||
// that's true, we will most likely have worse
|
||||
// problems going forward
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func GetClientOs(ctx context.Context, client *Distro) (string, error) {
|
||||
cmd := client.WslCommand(ctx, "uname -s")
|
||||
out, unixErr := cmd.CombinedOutput()
|
||||
if unixErr == nil {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
return formatted, nil
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "echo %OS%")
|
||||
out, cmdErr := cmd.Output()
|
||||
if cmdErr == nil && strings.TrimSpace(string(out)) != "%OS%" {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
return strings.Split(formatted, "_")[0], nil
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "echo $env:OS")
|
||||
out, psErr := cmd.Output()
|
||||
if psErr == nil && strings.TrimSpace(string(out)) != "$env:OS" {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
return strings.Split(formatted, "_")[0], nil
|
||||
}
|
||||
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||
}
|
||||
|
||||
func GetClientArch(ctx context.Context, client *Distro) (string, error) {
|
||||
cmd := client.WslCommand(ctx, "uname -m")
|
||||
out, unixErr := cmd.Output()
|
||||
if unixErr == nil {
|
||||
return utilfn.FilterValidArch(string(out))
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "echo %PROCESSOR_ARCHITECTURE%")
|
||||
out, cmdErr := cmd.CombinedOutput()
|
||||
if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" {
|
||||
return utilfn.FilterValidArch(string(out))
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "echo $env:PROCESSOR_ARCHITECTURE")
|
||||
out, psErr := cmd.CombinedOutput()
|
||||
if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" {
|
||||
return utilfn.FilterValidArch(string(out))
|
||||
}
|
||||
return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||
}
|
||||
|
||||
type CancellableCmd struct {
|
||||
Cmd *WslCmd
|
||||
Cancel func()
|
||||
}
|
||||
|
||||
var installTemplatesRawBash = map[string]string{
|
||||
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
|
||||
"cat": `bash -c 'cat > {{.tempPath}}'`,
|
||||
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
|
||||
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
|
||||
}
|
||||
|
||||
var installTemplatesRawDefault = map[string]string{
|
||||
"mkdir": `mkdir -p {{.installDir}}`,
|
||||
"cat": `cat > {{.tempPath}}`,
|
||||
"mv": `mv {{.tempPath}} {{.installPath}}`,
|
||||
"chmod": `chmod a+x {{.installPath}}`,
|
||||
}
|
||||
|
||||
func makeCancellableCommand(ctx context.Context, client *Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
|
||||
cmdContext, cmdCancel := context.WithCancel(ctx)
|
||||
|
||||
cmdStr := &bytes.Buffer{}
|
||||
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
|
||||
if err != nil {
|
||||
cmdCancel()
|
||||
return nil, err
|
||||
}
|
||||
cmdTemplate.Execute(cmdStr, words)
|
||||
|
||||
cmd := client.WslCommand(cmdContext, cmdStr.String())
|
||||
return &CancellableCmd{cmd, cmdCancel}, nil
|
||||
}
|
||||
|
||||
func CpHostToRemote(ctx context.Context, client *Distro, sourcePath string, destPath string) error {
|
||||
// warning: does not work on windows remote yet
|
||||
bashInstalled, err := hasBashInstalled(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var selectedTemplatesRaw map[string]string
|
||||
if bashInstalled {
|
||||
selectedTemplatesRaw = installTemplatesRawBash
|
||||
} else {
|
||||
log.Printf("bash is not installed on remote. attempting with default shell")
|
||||
selectedTemplatesRaw = installTemplatesRawDefault
|
||||
}
|
||||
|
||||
// I need to use toSlash here to force unix keybindings
|
||||
// this means we can't guarantee it will work on a remote windows machine
|
||||
var installWords = map[string]string{
|
||||
"installDir": filepath.ToSlash(filepath.Dir(destPath)),
|
||||
"tempPath": destPath + ".temp",
|
||||
"installPath": destPath,
|
||||
}
|
||||
|
||||
installStepCmds := make(map[string]*CancellableCmd)
|
||||
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
|
||||
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
installStepCmds[cmdName] = cancellableCmd
|
||||
}
|
||||
|
||||
_, err = installStepCmds["mkdir"].Cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// the cat part of this is complicated since it requires stdin
|
||||
catCmd := installStepCmds["cat"].Cmd
|
||||
catStdin, err := catCmd.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = catCmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input, err := os.Open(sourcePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("wslutil:cpHostToRemote:catStdin", recover())
|
||||
}()
|
||||
io.Copy(catStdin, input)
|
||||
installStepCmds["cat"].Cancel()
|
||||
|
||||
// backup just in case something weird happens
|
||||
// could cause potential race condition, but very
|
||||
// unlikely
|
||||
time.Sleep(time.Second * 1)
|
||||
process := catCmd.GetProcess()
|
||||
if process != nil {
|
||||
process.Kill()
|
||||
}
|
||||
}()
|
||||
catErr := catCmd.Wait()
|
||||
if catErr != nil && !errors.Is(catErr, context.Canceled) {
|
||||
return catErr
|
||||
}
|
||||
|
||||
_, err = installStepCmds["mv"].Cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = installStepCmds["chmod"].Cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func InstallClientRcFiles(ctx context.Context, client *Distro) error {
|
||||
path := GetWshPath(ctx, client)
|
||||
log.Printf("path to wsh searched is: %s", path)
|
||||
|
||||
cmd := client.WslCommand(ctx, path+" rcfiles")
|
||||
_, err := cmd.Output()
|
||||
return err
|
||||
}
|
||||
|
||||
func GetHomeDir(ctx context.Context, client *Distro) string {
|
||||
// note: also works for powershell
|
||||
cmd := client.WslCommand(ctx, `echo "$HOME"`)
|
||||
out, err := cmd.Output()
|
||||
if err == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, `echo %userprofile%`)
|
||||
out, err = cmd.Output()
|
||||
if err == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
return "~"
|
||||
}
|
||||
|
||||
func IsPowershell(shellPath string) bool {
|
||||
// get the base path, and then check contains
|
||||
shellBase := filepath.Base(shellPath)
|
||||
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
|
||||
}
|
@ -18,6 +18,10 @@ import (
|
||||
var RegisteredDistros = gowsl.RegisteredDistros
|
||||
var DefaultDistro = gowsl.DefaultDistro
|
||||
|
||||
type WslName struct {
|
||||
Distro string `json:"distro"`
|
||||
}
|
||||
|
||||
type Distro struct {
|
||||
gowsl.Distro
|
||||
}
|
||||
|
533
pkg/wsl/wsl.go
533
pkg/wsl/wsl.go
@ -1,533 +0,0 @@
|
||||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wsl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/telemetry"
|
||||
"github.com/wavetermdev/waveterm/pkg/userinput"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
const (
|
||||
Status_Init = "init"
|
||||
Status_Connecting = "connecting"
|
||||
Status_Connected = "connected"
|
||||
Status_Disconnected = "disconnected"
|
||||
Status_Error = "error"
|
||||
)
|
||||
|
||||
const DefaultConnectionTimeout = 60 * time.Second
|
||||
|
||||
var globalLock = &sync.Mutex{}
|
||||
var clientControllerMap = make(map[string]*WslConn)
|
||||
var activeConnCounter = &atomic.Int32{}
|
||||
|
||||
type WslConn struct {
|
||||
Lock *sync.Mutex
|
||||
Status string
|
||||
Name WslName
|
||||
Client *Distro
|
||||
SockName string
|
||||
DomainSockListener net.Listener
|
||||
ConnController *WslCmd
|
||||
Error string
|
||||
HasWaiter *atomic.Bool
|
||||
LastConnectTime int64
|
||||
ActiveConnNum int
|
||||
cancelFn func()
|
||||
}
|
||||
|
||||
type WslName struct {
|
||||
Distro string `json:"distro"`
|
||||
}
|
||||
|
||||
func GetAllConnStatus() []wshrpc.ConnStatus {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
var connStatuses []wshrpc.ConnStatus
|
||||
for _, conn := range clientControllerMap {
|
||||
connStatuses = append(connStatuses, conn.DeriveConnStatus())
|
||||
}
|
||||
return connStatuses
|
||||
}
|
||||
|
||||
func GetNumWSLHasConnected() int {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
var connectedCount int
|
||||
for _, conn := range clientControllerMap {
|
||||
if conn.LastConnectTime > 0 {
|
||||
connectedCount++
|
||||
}
|
||||
}
|
||||
return connectedCount
|
||||
}
|
||||
|
||||
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return wshrpc.ConnStatus{
|
||||
Status: conn.Status,
|
||||
Connected: conn.Status == Status_Connected,
|
||||
WshEnabled: true, // always use wsh for wsl connections (temporary)
|
||||
Connection: conn.GetName(),
|
||||
HasConnected: (conn.LastConnectTime > 0),
|
||||
ActiveConnNum: conn.ActiveConnNum,
|
||||
Error: conn.Error,
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *WslConn) FireConnChangeEvent() {
|
||||
status := conn.DeriveConnStatus()
|
||||
event := wps.WaveEvent{
|
||||
Event: wps.Event_ConnChange,
|
||||
Scopes: []string{
|
||||
fmt.Sprintf("connection:%s", conn.GetName()),
|
||||
},
|
||||
Data: status,
|
||||
}
|
||||
log.Printf("sending event: %+#v", event)
|
||||
wps.Broker.Publish(event)
|
||||
}
|
||||
|
||||
func (conn *WslConn) Close() error {
|
||||
defer conn.FireConnChangeEvent()
|
||||
conn.WithLock(func() {
|
||||
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
|
||||
// if status is init, disconnected, or error don't change it
|
||||
conn.Status = Status_Disconnected
|
||||
}
|
||||
conn.close_nolock()
|
||||
})
|
||||
// we must wait for the waiter to complete
|
||||
startTime := time.Now()
|
||||
for conn.HasWaiter.Load() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if time.Since(startTime) > 2*time.Second {
|
||||
return fmt.Errorf("timeout waiting for waiter to complete")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) close_nolock() {
|
||||
// does not set status (that should happen at another level)
|
||||
if conn.DomainSockListener != nil {
|
||||
conn.DomainSockListener.Close()
|
||||
conn.DomainSockListener = nil
|
||||
}
|
||||
if conn.ConnController != nil {
|
||||
conn.cancelFn() // this suspends the conn controller
|
||||
conn.ConnController = nil
|
||||
}
|
||||
if conn.Client != nil {
|
||||
// conn.Client.Close() is not relevant here
|
||||
// we do not want to completely close the wsl in case
|
||||
// other applications are using it
|
||||
conn.Client = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetDomainSocketName() string {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return conn.SockName
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetStatus() string {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return conn.Status
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetName() string {
|
||||
// no lock required because opts is immutable
|
||||
return "wsl://" + conn.Name.Distro
|
||||
}
|
||||
|
||||
/**
|
||||
* This function is does not set a listener for WslConn
|
||||
* It is still required in order to set SockName
|
||||
**/
|
||||
func (conn *WslConn) OpenDomainSocketListener() error {
|
||||
var allowed bool
|
||||
conn.WithLock(func() {
|
||||
if conn.Status != Status_Connecting {
|
||||
allowed = false
|
||||
} else {
|
||||
allowed = true
|
||||
}
|
||||
})
|
||||
if !allowed {
|
||||
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
conn.WithLock(func() {
|
||||
conn.SockName = wavebase.RemoteFullDomainSocketPath
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) StartConnServer() error {
|
||||
utilCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
var allowed bool
|
||||
conn.WithLock(func() {
|
||||
if conn.Status != Status_Connecting {
|
||||
allowed = false
|
||||
} else {
|
||||
allowed = true
|
||||
}
|
||||
})
|
||||
if !allowed {
|
||||
return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
client := conn.GetClient()
|
||||
wshPath := GetWshPath(utilCtx, client)
|
||||
rpcCtx := wshrpc.RpcContext{
|
||||
ClientType: wshrpc.ClientType_ConnServer,
|
||||
Conn: conn.GetName(),
|
||||
}
|
||||
sockName := conn.GetDomainSocketName()
|
||||
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
||||
}
|
||||
shellPath, err := DetectShell(utilCtx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var cmdStr string
|
||||
if IsPowershell(shellPath) {
|
||||
cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
||||
} else {
|
||||
cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
||||
}
|
||||
log.Printf("starting conn controller: %s\n", cmdStr)
|
||||
connServerCtx, cancelFn := context.WithCancel(context.Background())
|
||||
conn.WithLock(func() {
|
||||
if conn.cancelFn != nil {
|
||||
conn.cancelFn()
|
||||
}
|
||||
conn.cancelFn = cancelFn
|
||||
})
|
||||
cmd := client.WslCommand(connServerCtx, cmdStr)
|
||||
pipeRead, pipeWrite := io.Pipe()
|
||||
inputPipeRead, inputPipeWrite := io.Pipe()
|
||||
cmd.SetStdout(pipeWrite)
|
||||
cmd.SetStderr(pipeWrite)
|
||||
cmd.SetStdin(inputPipeRead)
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to start conn controller: %w", err)
|
||||
}
|
||||
conn.WithLock(func() {
|
||||
conn.ConnController = cmd
|
||||
})
|
||||
// service the I/O
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("wsl:StartConnServer:wait", recover())
|
||||
}()
|
||||
// wait for termination, clear the controller
|
||||
defer conn.WithLock(func() {
|
||||
conn.ConnController = nil
|
||||
})
|
||||
waitErr := cmd.Wait()
|
||||
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
|
||||
}()
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("wsl:StartConnServer:handleStdIOClient", recover())
|
||||
}()
|
||||
logName := fmt.Sprintf("conncontroller:%s", conn.GetName())
|
||||
wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite)
|
||||
}()
|
||||
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
|
||||
if err != nil {
|
||||
return fmt.Errorf("timeout waiting for connserver to register")
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
|
||||
return nil
|
||||
}
|
||||
|
||||
type WshInstallOpts struct {
|
||||
Force bool
|
||||
NoUserPrompt bool
|
||||
}
|
||||
|
||||
func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error {
|
||||
if opts == nil {
|
||||
opts = &WshInstallOpts{}
|
||||
}
|
||||
client := conn.GetClient()
|
||||
if client == nil {
|
||||
return fmt.Errorf("client is nil")
|
||||
}
|
||||
// check that correct wsh extensions are installed
|
||||
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
||||
clientVersion, err := GetWshVersion(ctx, client)
|
||||
if err == nil && clientVersion == expectedVersion && !opts.Force {
|
||||
return nil
|
||||
}
|
||||
var queryText string
|
||||
var title string
|
||||
if opts.Force {
|
||||
queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName)
|
||||
title = "Install Wave Shell Extensions"
|
||||
} else if err != nil {
|
||||
queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+
|
||||
"installed on `%s` \n"+
|
||||
"to ensure a seamless experience. \n\n"+
|
||||
"Would you like to install them?", clientDisplayName)
|
||||
title = "Install Wave Shell Extensions"
|
||||
} else {
|
||||
// don't ask for upgrading the version
|
||||
opts.NoUserPrompt = true
|
||||
}
|
||||
if !opts.NoUserPrompt {
|
||||
request := &userinput.UserInputRequest{
|
||||
ResponseType: "confirm",
|
||||
QueryText: queryText,
|
||||
Title: title,
|
||||
Markdown: true,
|
||||
CheckBoxMsg: "Don't show me this again",
|
||||
}
|
||||
response, err := userinput.GetUserInput(ctx, request)
|
||||
if err != nil || !response.Confirm {
|
||||
return err
|
||||
}
|
||||
if response.CheckboxStat {
|
||||
meta := waveobj.MetaMapType{
|
||||
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
|
||||
}
|
||||
err := wconfig.SetBaseConfigValue(meta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Printf("attempting to install wsh to `%s`", clientDisplayName)
|
||||
clientOs, err := GetClientOs(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientArch, err := GetClientArch(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// attempt to install extension
|
||||
wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = CpHostToRemote(ctx, client, wshLocalPath, wavebase.RemoteFullWshBinPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("successfully installed wsh on %s\n", conn.GetName())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetClient() *Distro {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return conn.Client
|
||||
}
|
||||
|
||||
func (conn *WslConn) Reconnect(ctx context.Context) error {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Connect(ctx)
|
||||
}
|
||||
|
||||
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
|
||||
for {
|
||||
status := conn.DeriveConnStatus()
|
||||
if status.Status == Status_Connected {
|
||||
return nil
|
||||
}
|
||||
if status.Status == Status_Connecting {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context timeout")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
if status.Status == Status_Init || status.Status == Status_Disconnected {
|
||||
return fmt.Errorf("disconnected")
|
||||
}
|
||||
if status.Status == Status_Error {
|
||||
return fmt.Errorf("error: %v", status.Error)
|
||||
}
|
||||
return fmt.Errorf("unknown status: %q", status.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// does not return an error since that error is stored inside of WslConn
|
||||
func (conn *WslConn) Connect(ctx context.Context) error {
|
||||
var connectAllowed bool
|
||||
conn.WithLock(func() {
|
||||
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
|
||||
connectAllowed = false
|
||||
} else {
|
||||
conn.Status = Status_Connecting
|
||||
conn.Error = ""
|
||||
connectAllowed = true
|
||||
}
|
||||
})
|
||||
log.Printf("Connect %s\n", conn.GetName())
|
||||
if !connectAllowed {
|
||||
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
conn.FireConnChangeEvent()
|
||||
err := conn.connectInternal(ctx)
|
||||
conn.WithLock(func() {
|
||||
if err != nil {
|
||||
conn.Status = Status_Error
|
||||
conn.Error = err.Error()
|
||||
conn.close_nolock()
|
||||
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
|
||||
Conn: map[string]int{"wsl:connecterror": 1},
|
||||
}, "wsl-connconnect")
|
||||
} else {
|
||||
conn.Status = Status_Connected
|
||||
conn.LastConnectTime = time.Now().UnixMilli()
|
||||
if conn.ActiveConnNum == 0 {
|
||||
conn.ActiveConnNum = int(activeConnCounter.Add(1))
|
||||
}
|
||||
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
|
||||
Conn: map[string]int{"wsl:connect": 1},
|
||||
}, "wsl-connconnect")
|
||||
}
|
||||
})
|
||||
conn.FireConnChangeEvent()
|
||||
return err
|
||||
}
|
||||
|
||||
func (conn *WslConn) WithLock(fn func()) {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
fn()
|
||||
}
|
||||
|
||||
func (conn *WslConn) connectInternal(ctx context.Context) error {
|
||||
client, err := GetDistro(ctx, conn.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.WithLock(func() {
|
||||
conn.Client = client
|
||||
})
|
||||
err = conn.OpenDomainSocketListener()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := wconfig.GetWatcher().GetFullConfig()
|
||||
wshAsk := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true)
|
||||
installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !wshAsk})
|
||||
if installErr != nil {
|
||||
return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
|
||||
}
|
||||
csErr := conn.StartConnServer()
|
||||
if csErr != nil {
|
||||
return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
|
||||
}
|
||||
conn.HasWaiter.Store(true)
|
||||
go conn.waitForDisconnect()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) waitForDisconnect() {
|
||||
defer conn.FireConnChangeEvent()
|
||||
defer conn.HasWaiter.Store(false)
|
||||
err := conn.ConnController.Wait()
|
||||
conn.WithLock(func() {
|
||||
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
|
||||
// so we just set the status to "disconnected" here (not error)
|
||||
// don't overwrite any existing error (or error status)
|
||||
if err != nil && conn.Error == "" {
|
||||
conn.Error = err.Error()
|
||||
}
|
||||
if conn.Status != Status_Error {
|
||||
conn.Status = Status_Disconnected
|
||||
}
|
||||
conn.close_nolock()
|
||||
})
|
||||
}
|
||||
|
||||
func getConnInternal(name string) *WslConn {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
connName := WslName{Distro: name}
|
||||
rtn := clientControllerMap[name]
|
||||
if rtn == nil {
|
||||
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, cancelFn: nil}
|
||||
clientControllerMap[name] = rtn
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
|
||||
conn := getConnInternal(name)
|
||||
if conn.Client == nil && shouldConnect {
|
||||
conn.Connect(ctx)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
// Convenience function for ensuring a connection is established
|
||||
func EnsureConnection(ctx context.Context, connName string) error {
|
||||
if connName == "" {
|
||||
return nil
|
||||
}
|
||||
conn := GetWslConn(ctx, connName, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("connection not found: %s", connName)
|
||||
}
|
||||
connStatus := conn.DeriveConnStatus()
|
||||
switch connStatus.Status {
|
||||
case Status_Connected:
|
||||
return nil
|
||||
case Status_Connecting:
|
||||
return conn.WaitForConnect(ctx)
|
||||
case Status_Init, Status_Disconnected:
|
||||
return conn.Connect(ctx)
|
||||
case Status_Error:
|
||||
return fmt.Errorf("connection error: %s", connStatus.Error)
|
||||
default:
|
||||
return fmt.Errorf("unknown connection status %q", connStatus.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func DisconnectClient(connName string) error {
|
||||
conn := getConnInternal(connName)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("client %q not found", connName)
|
||||
}
|
||||
err := conn.Close()
|
||||
return err
|
||||
}
|
226
pkg/wslconn/wsl-util.go
Normal file
226
pkg/wslconn/wsl-util.go
Normal file
@ -0,0 +1,226 @@
|
||||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wslconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/blocklogger"
|
||||
"github.com/wavetermdev/waveterm/pkg/genconn"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
)
|
||||
|
||||
func hasBashInstalled(ctx context.Context, client *wsl.Distro) (bool, error) {
|
||||
cmd := client.WslCommand(ctx, "which bash")
|
||||
out, whichErr := cmd.Output()
|
||||
if whichErr == nil && len(out) != 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "where.exe bash")
|
||||
out, whereErr := cmd.Output()
|
||||
if whereErr == nil && len(out) != 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// note: we could also check in /bin/bash explicitly
|
||||
// just in case that wasn't added to the path. but if
|
||||
// that's true, we will most likely have worse
|
||||
// problems going forward
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func normalizeOs(os string) string {
|
||||
os = strings.ToLower(strings.TrimSpace(os))
|
||||
return os
|
||||
}
|
||||
|
||||
func normalizeArch(arch string) string {
|
||||
arch = strings.ToLower(strings.TrimSpace(arch))
|
||||
switch arch {
|
||||
case "x86_64", "amd64":
|
||||
arch = "x64"
|
||||
case "arm64", "aarch64":
|
||||
arch = "arm64"
|
||||
}
|
||||
return arch
|
||||
}
|
||||
|
||||
// returns (os, arch, error)
|
||||
// guaranteed to return a supported platform
|
||||
func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string, string, error) {
|
||||
blocklogger.Infof(ctx, "[conndebug] running `uname -sm` to detect client platform\n")
|
||||
stdout, stderr, err := genconn.RunSimpleCommand(ctx, shell, genconn.CommandSpec{
|
||||
Cmd: "uname -sm",
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error running uname -sm: %w, stderr: %s", err, stderr)
|
||||
}
|
||||
// Parse and normalize output
|
||||
parts := strings.Fields(strings.ToLower(strings.TrimSpace(stdout)))
|
||||
if len(parts) != 2 {
|
||||
return "", "", fmt.Errorf("unexpected output from uname: %s", stdout)
|
||||
}
|
||||
os, arch := normalizeOs(parts[0]), normalizeArch(parts[1])
|
||||
if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return os, arch, nil
|
||||
}
|
||||
|
||||
func GetClientPlatformFromOsArchStr(ctx context.Context, osArchStr string) (string, string, error) {
|
||||
parts := strings.Fields(strings.TrimSpace(osArchStr))
|
||||
if len(parts) != 2 {
|
||||
return "", "", fmt.Errorf("unexpected output from uname: %s", osArchStr)
|
||||
}
|
||||
os, arch := normalizeOs(parts[0]), normalizeArch(parts[1])
|
||||
if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return os, arch, nil
|
||||
}
|
||||
|
||||
type CancellableCmd struct {
|
||||
Cmd *wsl.WslCmd
|
||||
Cancel func()
|
||||
}
|
||||
|
||||
var installTemplatesRawBash = map[string]string{
|
||||
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
|
||||
"cat": `bash -c 'cat > {{.tempPath}}'`,
|
||||
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
|
||||
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
|
||||
}
|
||||
|
||||
var installTemplatesRawDefault = map[string]string{
|
||||
"mkdir": `mkdir -p {{.installDir}}`,
|
||||
"cat": `cat > {{.tempPath}}`,
|
||||
"mv": `mv {{.tempPath}} {{.installPath}}`,
|
||||
"chmod": `chmod a+x {{.installPath}}`,
|
||||
}
|
||||
|
||||
func makeCancellableCommand(ctx context.Context, client *wsl.Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
|
||||
cmdContext, cmdCancel := context.WithCancel(ctx)
|
||||
|
||||
cmdStr := &bytes.Buffer{}
|
||||
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
|
||||
if err != nil {
|
||||
cmdCancel()
|
||||
return nil, err
|
||||
}
|
||||
cmdTemplate.Execute(cmdStr, words)
|
||||
|
||||
cmd := client.WslCommand(cmdContext, cmdStr.String())
|
||||
return &CancellableCmd{cmd, cmdCancel}, nil
|
||||
}
|
||||
|
||||
func CpWshToRemote(ctx context.Context, client *wsl.Distro, clientOs string, clientArch string) error {
|
||||
wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// warning: does not work on windows remote yet
|
||||
bashInstalled, err := hasBashInstalled(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var selectedTemplatesRaw map[string]string
|
||||
if bashInstalled {
|
||||
selectedTemplatesRaw = installTemplatesRawBash
|
||||
} else {
|
||||
log.Printf("bash is not installed on remote. attempting with default shell")
|
||||
selectedTemplatesRaw = installTemplatesRawDefault
|
||||
}
|
||||
|
||||
// I need to use toSlash here to force unix keybindings
|
||||
// this means we can't guarantee it will work on a remote windows machine
|
||||
var installWords = map[string]string{
|
||||
"installDir": filepath.ToSlash(filepath.Dir(wavebase.RemoteFullWshBinPath)),
|
||||
"tempPath": wavebase.RemoteFullWshBinPath + ".temp",
|
||||
"installPath": wavebase.RemoteFullWshBinPath,
|
||||
}
|
||||
|
||||
blocklogger.Infof(ctx, "[conndebug] copying %q to remote server %q\n", wshLocalPath, wavebase.RemoteFullWshBinPath)
|
||||
installStepCmds := make(map[string]*CancellableCmd)
|
||||
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
|
||||
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
installStepCmds[cmdName] = cancellableCmd
|
||||
}
|
||||
|
||||
_, err = installStepCmds["mkdir"].Cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// the cat part of this is complicated since it requires stdin
|
||||
catCmd := installStepCmds["cat"].Cmd
|
||||
catStdin, err := catCmd.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = catCmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input, err := os.Open(wshLocalPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open local file %s to send to host: %v", wshLocalPath, err)
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("wslutil:cpHostToRemote:catStdin", recover())
|
||||
}()
|
||||
io.Copy(catStdin, input)
|
||||
installStepCmds["cat"].Cancel()
|
||||
|
||||
// backup just in case something weird happens
|
||||
// could cause potential race condition, but very
|
||||
// unlikely
|
||||
time.Sleep(time.Second * 1)
|
||||
process := catCmd.GetProcess()
|
||||
if process != nil {
|
||||
process.Kill()
|
||||
}
|
||||
}()
|
||||
catErr := catCmd.Wait()
|
||||
if catErr != nil && !errors.Is(catErr, context.Canceled) {
|
||||
return catErr
|
||||
}
|
||||
|
||||
_, err = installStepCmds["mv"].Cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = installStepCmds["chmod"].Cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsPowershell(shellPath string) bool {
|
||||
// get the base path, and then check contains
|
||||
shellBase := filepath.Base(shellPath)
|
||||
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
|
||||
}
|
787
pkg/wslconn/wslconn.go
Normal file
787
pkg/wslconn/wslconn.go
Normal file
@ -0,0 +1,787 @@
|
||||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wslconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/blocklogger"
|
||||
"github.com/wavetermdev/waveterm/pkg/genconn"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
||||
"github.com/wavetermdev/waveterm/pkg/telemetry"
|
||||
"github.com/wavetermdev/waveterm/pkg/userinput"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
)
|
||||
|
||||
const (
|
||||
Status_Init = "init"
|
||||
Status_Connecting = "connecting"
|
||||
Status_Connected = "connected"
|
||||
Status_Disconnected = "disconnected"
|
||||
Status_Error = "error"
|
||||
)
|
||||
|
||||
const DefaultConnectionTimeout = 60 * time.Second
|
||||
|
||||
var globalLock = &sync.Mutex{}
|
||||
var clientControllerMap = make(map[string]*WslConn)
|
||||
var activeConnCounter = &atomic.Int32{}
|
||||
|
||||
type WslConn struct {
|
||||
Lock *sync.Mutex
|
||||
Status string
|
||||
WshEnabled *atomic.Bool
|
||||
Name wsl.WslName
|
||||
Client *wsl.Distro
|
||||
DomainSockName string // if "", then no domain socket
|
||||
DomainSockListener net.Listener
|
||||
ConnController *wsl.WslCmd
|
||||
Error string
|
||||
WshError string
|
||||
NoWshReason string
|
||||
WshVersion string
|
||||
HasWaiter *atomic.Bool
|
||||
LastConnectTime int64
|
||||
ActiveConnNum int
|
||||
cancelFn func()
|
||||
}
|
||||
|
||||
var ConnServerCmdTemplate = strings.TrimSpace(
|
||||
strings.Join([]string{
|
||||
"%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm);",
|
||||
"exec %s connserver --router",
|
||||
}, "\n"))
|
||||
|
||||
func GetAllConnStatus() []wshrpc.ConnStatus {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
var connStatuses []wshrpc.ConnStatus
|
||||
for _, conn := range clientControllerMap {
|
||||
connStatuses = append(connStatuses, conn.DeriveConnStatus())
|
||||
}
|
||||
return connStatuses
|
||||
}
|
||||
|
||||
func GetNumWSLHasConnected() int {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
var connectedCount int
|
||||
for _, conn := range clientControllerMap {
|
||||
if conn.LastConnectTime > 0 {
|
||||
connectedCount++
|
||||
}
|
||||
}
|
||||
return connectedCount
|
||||
}
|
||||
|
||||
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return wshrpc.ConnStatus{
|
||||
Status: conn.Status,
|
||||
Connected: conn.Status == Status_Connected,
|
||||
WshEnabled: true, // always use wsh for wsl connections (temporary)
|
||||
Connection: conn.GetName(),
|
||||
HasConnected: (conn.LastConnectTime > 0),
|
||||
ActiveConnNum: conn.ActiveConnNum,
|
||||
Error: conn.Error,
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *WslConn) Infof(ctx context.Context, format string, args ...any) {
|
||||
log.Print(fmt.Sprintf("[conn:%s] ", conn.GetName()) + fmt.Sprintf(format, args...))
|
||||
blocklogger.Infof(ctx, "[conndebug] "+format, args...)
|
||||
}
|
||||
|
||||
func (conn *WslConn) FireConnChangeEvent() {
|
||||
status := conn.DeriveConnStatus()
|
||||
event := wps.WaveEvent{
|
||||
Event: wps.Event_ConnChange,
|
||||
Scopes: []string{
|
||||
fmt.Sprintf("connection:%s", conn.GetName()),
|
||||
},
|
||||
Data: status,
|
||||
}
|
||||
log.Printf("sending event: %+#v", event)
|
||||
wps.Broker.Publish(event)
|
||||
}
|
||||
|
||||
func (conn *WslConn) Close() error {
|
||||
defer conn.FireConnChangeEvent()
|
||||
conn.WithLock(func() {
|
||||
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
|
||||
// if status is init, disconnected, or error don't change it
|
||||
conn.Status = Status_Disconnected
|
||||
}
|
||||
conn.close_nolock()
|
||||
})
|
||||
// we must wait for the waiter to complete
|
||||
startTime := time.Now()
|
||||
for conn.HasWaiter.Load() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if time.Since(startTime) > 2*time.Second {
|
||||
return fmt.Errorf("timeout waiting for waiter to complete")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) close_nolock() {
|
||||
// does not set status (that should happen at another level)
|
||||
if conn.DomainSockListener != nil {
|
||||
conn.DomainSockListener.Close()
|
||||
conn.DomainSockListener = nil
|
||||
conn.DomainSockName = ""
|
||||
}
|
||||
if conn.ConnController != nil {
|
||||
conn.cancelFn() // this suspends the conn controller
|
||||
conn.ConnController = nil
|
||||
}
|
||||
if conn.Client != nil {
|
||||
// conn.Client.Close() is not relevant here
|
||||
// we do not want to completely close the wsl in case
|
||||
// other applications are using it
|
||||
conn.Client = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetDomainSocketName() string {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return conn.DomainSockName
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetStatus() string {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return conn.Status
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetName() string {
|
||||
// no lock required because opts is immutable
|
||||
return "wsl://" + conn.Name.Distro
|
||||
}
|
||||
|
||||
/**
|
||||
* This function is does not set a listener for WslConn
|
||||
* It is still required in order to set SockName
|
||||
**/
|
||||
func (conn *WslConn) OpenDomainSocketListener(ctx context.Context) error {
|
||||
conn.Infof(ctx, "running OpenDomainSocketListener...\n")
|
||||
allowed := WithLockRtn(conn, func() bool {
|
||||
return conn.Status == Status_Connecting
|
||||
})
|
||||
if !allowed {
|
||||
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
/*
|
||||
listener, err := client.ListenUnix(sockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to request connection domain socket: %v", err)
|
||||
}
|
||||
*/
|
||||
conn.Infof(ctx, "setting domain socket to %s\n", wavebase.RemoteFullDomainSocketPath)
|
||||
conn.WithLock(func() {
|
||||
conn.DomainSockName = wavebase.RemoteFullDomainSocketPath
|
||||
//conn.DomainSockListener = listener
|
||||
})
|
||||
conn.Infof(ctx, "successfully connected domain socket\n")
|
||||
/*
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("wslconn:OpenDomainSocketListener", recover())
|
||||
}()
|
||||
defer conn.WithLock(func() {
|
||||
conn.DomainSockListener = nil
|
||||
conn.DomainSockName = ""
|
||||
})
|
||||
wshutil.RunWshRpcOverListener(listener)
|
||||
}()
|
||||
*/
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) getWshPath() string {
|
||||
config, ok := conn.getConnectionConfig()
|
||||
if ok && config.ConnWshPath != "" {
|
||||
return config.ConnWshPath
|
||||
}
|
||||
return wavebase.RemoteFullWshBinPath
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetConfigShellPath() string {
|
||||
config, ok := conn.getConnectionConfig()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return config.ConnShellPath
|
||||
}
|
||||
|
||||
// returns (needsInstall, clientVersion, osArchStr, error)
|
||||
// if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr
|
||||
// if clientVersion is set, then no osArchStr will be returned
|
||||
func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (bool, string, string, error) {
|
||||
conn.Infof(ctx, "running StartConnServer...\n")
|
||||
allowed := WithLockRtn(conn, func() bool {
|
||||
return conn.Status == Status_Connecting
|
||||
})
|
||||
if !allowed {
|
||||
return false, "", "", fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
client := conn.GetClient()
|
||||
wshPath := conn.getWshPath()
|
||||
rpcCtx := wshrpc.RpcContext{
|
||||
ClientType: wshrpc.ClientType_ConnServer,
|
||||
Conn: conn.GetName(),
|
||||
}
|
||||
sockName := conn.GetDomainSocketName()
|
||||
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
|
||||
if err != nil {
|
||||
return false, "", "", fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
||||
}
|
||||
conn.Infof(ctx, "WSL-NEWSESSION (StartConnServer)\n")
|
||||
connServerCtx, cancelFn := context.WithCancel(context.Background())
|
||||
conn.WithLock(func() {
|
||||
if conn.cancelFn != nil {
|
||||
conn.cancelFn()
|
||||
}
|
||||
conn.cancelFn = cancelFn
|
||||
})
|
||||
cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath)
|
||||
shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr))
|
||||
cmd := client.WslCommand(connServerCtx, shWrappedCmdStr)
|
||||
pipeRead, pipeWrite := io.Pipe()
|
||||
inputPipeRead, inputPipeWrite := io.Pipe()
|
||||
cmd.SetStdout(pipeWrite)
|
||||
cmd.SetStderr(pipeWrite)
|
||||
cmd.SetStdin(inputPipeRead)
|
||||
log.Printf("starting conn controller: %q\n", cmdStr)
|
||||
blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr)
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return false, "", "", fmt.Errorf("unable to start conn controller cmd: %w", err)
|
||||
}
|
||||
linesChan := utilfn.StreamToLinesChan(pipeRead)
|
||||
versionLine, err := utilfn.ReadLineWithTimeout(linesChan, 2*time.Second)
|
||||
if err != nil {
|
||||
cancelFn()
|
||||
return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
|
||||
}
|
||||
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
|
||||
isUpToDate, clientVersion, osArchStr, err := conncontroller.IsWshVersionUpToDate(ctx, versionLine)
|
||||
if err != nil {
|
||||
cancelFn()
|
||||
return false, "", "", fmt.Errorf("error checking wsh version: %w", err)
|
||||
}
|
||||
if isUpToDate && !afterUpdate && os.Getenv(wavebase.WaveWshForceUpdateVarName) != "" {
|
||||
isUpToDate = false
|
||||
conn.Infof(ctx, "%s set, forcing wsh update\n", wavebase.WaveWshForceUpdateVarName)
|
||||
}
|
||||
conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
|
||||
if !isUpToDate {
|
||||
cancelFn()
|
||||
return true, clientVersion, osArchStr, nil
|
||||
}
|
||||
jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second)
|
||||
if err != nil {
|
||||
cancelFn()
|
||||
return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err)
|
||||
}
|
||||
conn.Infof(ctx, "got jwt status line: %s\n", jwtLine)
|
||||
if strings.TrimSpace(jwtLine) == wavebase.NeedJwtConst {
|
||||
// write the jwt
|
||||
conn.Infof(ctx, "writing jwt token to connserver\n")
|
||||
_, err = fmt.Fprintf(inputPipeWrite, "%s\n", jwtToken)
|
||||
if err != nil {
|
||||
cancelFn()
|
||||
return false, clientVersion, "", fmt.Errorf("failed to write JWT token: %w", err)
|
||||
}
|
||||
}
|
||||
conn.WithLock(func() {
|
||||
conn.ConnController = cmd
|
||||
})
|
||||
// service the I/O
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("wslconn:cmd.Wait", recover())
|
||||
}()
|
||||
// wait for termination, clear the controller
|
||||
var waitErr error
|
||||
defer conn.WithLock(func() {
|
||||
if conn.ConnController != nil {
|
||||
conn.WshEnabled.Store(false)
|
||||
conn.NoWshReason = "connserver terminated"
|
||||
if waitErr != nil {
|
||||
conn.WshError = fmt.Sprintf("connserver terminated unexpectedly with error: %v", waitErr)
|
||||
}
|
||||
}
|
||||
conn.ConnController = nil
|
||||
})
|
||||
waitErr = cmd.Wait()
|
||||
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
|
||||
}()
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("wsl:StartConnServer:handleStdIOClient", recover())
|
||||
}()
|
||||
logName := fmt.Sprintf("wslconn:%s", conn.GetName())
|
||||
wshutil.HandleStdIOClient(logName, linesChan, inputPipeWrite)
|
||||
}()
|
||||
conn.Infof(ctx, "connserver started, waiting for route to be registered\n")
|
||||
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
|
||||
if err != nil {
|
||||
return false, clientVersion, "", fmt.Errorf("timeout waiting for connserver to register")
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
|
||||
conn.Infof(ctx, "connserver is registered and ready\n")
|
||||
return false, clientVersion, "", nil
|
||||
}
|
||||
|
||||
type WshInstallOpts struct {
|
||||
Force bool
|
||||
NoUserPrompt bool
|
||||
}
|
||||
|
||||
var queryTextTemplate = strings.TrimSpace(`
|
||||
Wave requires Wave Shell Extensions to be
|
||||
installed on %q
|
||||
to ensure a seamless experience.
|
||||
|
||||
Would you like to install them?
|
||||
`)
|
||||
|
||||
func (conn *WslConn) UpdateWsh(ctx context.Context, clientDisplayName string, remoteInfo *wshrpc.RemoteInfo) error {
|
||||
conn.Infof(ctx, "attempting to update wsh for connection %s (os:%s arch:%s version:%s)\n",
|
||||
conn.GetName(), remoteInfo.ClientOs, remoteInfo.ClientArch, remoteInfo.ClientVersion)
|
||||
client := conn.GetClient()
|
||||
if client == nil {
|
||||
return fmt.Errorf("cannot update wsh: ssh client is not connected")
|
||||
}
|
||||
err := CpWshToRemote(ctx, client, remoteInfo.ClientOs, remoteInfo.ClientArch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error installing wsh to remote: %w", err)
|
||||
}
|
||||
conn.Infof(ctx, "successfully updated wsh on %s\n", conn.GetName())
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// returns (allowed, error)
|
||||
func (conn *WslConn) getPermissionToInstallWsh(ctx context.Context, clientDisplayName string) (bool, error) {
|
||||
conn.Infof(ctx, "running getPermissionToInstallWsh...\n")
|
||||
queryText := fmt.Sprintf(queryTextTemplate, clientDisplayName)
|
||||
title := "Install Wave Shell Extensions"
|
||||
request := &userinput.UserInputRequest{
|
||||
ResponseType: "confirm",
|
||||
QueryText: queryText,
|
||||
Title: title,
|
||||
Markdown: true,
|
||||
CheckBoxMsg: "Automatically install for all connections",
|
||||
OkLabel: "Install wsh",
|
||||
CancelLabel: "No wsh",
|
||||
}
|
||||
conn.Infof(ctx, "requesting user confirmation...\n")
|
||||
response, err := userinput.GetUserInput(ctx, request)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "error getting user input: %v\n", err)
|
||||
return false, err
|
||||
}
|
||||
conn.Infof(ctx, "user response to allowing wsh: %v\n", response.Confirm)
|
||||
meta := make(map[string]any)
|
||||
meta["conn:wshenabled"] = response.Confirm
|
||||
conn.Infof(ctx, "writing conn:wshenabled=%v to connections.json\n", response.Confirm)
|
||||
err = wconfig.SetConnectionsConfigValue(conn.GetName(), meta)
|
||||
if err != nil {
|
||||
log.Printf("warning: error writing to connections file: %v", err)
|
||||
}
|
||||
if !response.Confirm {
|
||||
return false, nil
|
||||
}
|
||||
if response.CheckboxStat {
|
||||
conn.Infof(ctx, "writing conn:askbeforewshinstall=false to settings.json\n")
|
||||
meta := waveobj.MetaMapType{
|
||||
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
|
||||
}
|
||||
setConfigErr := wconfig.SetBaseConfigValue(meta)
|
||||
if setConfigErr != nil {
|
||||
// this is not a critical error, just log and continue
|
||||
log.Printf("warning: error writing to base config file: %v", err)
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) InstallWsh(ctx context.Context, osArchStr string) error {
|
||||
conn.Infof(ctx, "running installWsh...\n")
|
||||
client := conn.GetClient()
|
||||
if client == nil {
|
||||
conn.Infof(ctx, "ERROR ssh client is not connected, cannot install\n")
|
||||
return fmt.Errorf("ssh client is not connected, cannot install")
|
||||
}
|
||||
var clientOs, clientArch string
|
||||
var err error
|
||||
if osArchStr != "" {
|
||||
clientOs, clientArch, err = GetClientPlatformFromOsArchStr(ctx, osArchStr)
|
||||
} else {
|
||||
clientOs, clientArch, err = GetClientPlatform(ctx, genconn.MakeWSLShellClient(client))
|
||||
}
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR detecting client platform: %v\n", err)
|
||||
}
|
||||
conn.Infof(ctx, "detected remote platform os:%s arch:%s\n", clientOs, clientArch)
|
||||
err = CpWshToRemote(ctx, client, clientOs, clientArch)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR copying wsh binary to remote: %v\n", err)
|
||||
return fmt.Errorf("error copying wsh binary to remote: %w", err)
|
||||
}
|
||||
conn.Infof(ctx, "successfully installed wsh\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetClient() *wsl.Distro {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return conn.Client
|
||||
}
|
||||
|
||||
func (conn *WslConn) Reconnect(ctx context.Context) error {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Connect(ctx)
|
||||
}
|
||||
|
||||
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
|
||||
for {
|
||||
status := conn.DeriveConnStatus()
|
||||
if status.Status == Status_Connected {
|
||||
return nil
|
||||
}
|
||||
if status.Status == Status_Connecting {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context timeout")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
if status.Status == Status_Init || status.Status == Status_Disconnected {
|
||||
return fmt.Errorf("disconnected")
|
||||
}
|
||||
if status.Status == Status_Error {
|
||||
return fmt.Errorf("error: %v", status.Error)
|
||||
}
|
||||
return fmt.Errorf("unknown status: %q", status.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// does not return an error since that error is stored inside of WslConn
|
||||
func (conn *WslConn) Connect(ctx context.Context) error {
|
||||
var connectAllowed bool
|
||||
conn.WithLock(func() {
|
||||
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
|
||||
connectAllowed = false
|
||||
} else {
|
||||
conn.Status = Status_Connecting
|
||||
conn.Error = ""
|
||||
connectAllowed = true
|
||||
}
|
||||
})
|
||||
log.Printf("Connect %s\n", conn.GetName())
|
||||
if !connectAllowed {
|
||||
conn.Infof(ctx, "cannot connect to %q when status is %q\n", conn.GetName(), conn.GetStatus())
|
||||
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
conn.FireConnChangeEvent()
|
||||
err := conn.connectInternal(ctx)
|
||||
conn.WithLock(func() {
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR %v\n\n", err)
|
||||
conn.Status = Status_Error
|
||||
conn.Error = err.Error()
|
||||
conn.close_nolock()
|
||||
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
|
||||
Conn: map[string]int{"wsl:connecterror": 1},
|
||||
}, "wsl-connconnect")
|
||||
} else {
|
||||
conn.Infof(ctx, "successfully connected (wsh:%v)\n\n", conn.WshEnabled.Load())
|
||||
conn.Status = Status_Connected
|
||||
conn.LastConnectTime = time.Now().UnixMilli()
|
||||
if conn.ActiveConnNum == 0 {
|
||||
conn.ActiveConnNum = int(activeConnCounter.Add(1))
|
||||
}
|
||||
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
|
||||
Conn: map[string]int{"wsl:connect": 1},
|
||||
}, "wsl-connconnect")
|
||||
}
|
||||
})
|
||||
conn.FireConnChangeEvent()
|
||||
return err
|
||||
}
|
||||
|
||||
func (conn *WslConn) WithLock(fn func()) {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
fn()
|
||||
}
|
||||
|
||||
func WithLockRtn[T any](conn *WslConn, fn func() T) T {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return fn()
|
||||
}
|
||||
|
||||
// returns (enable-wsh, ask-before-install)
|
||||
func (conn *WslConn) getConnWshSettings() (bool, bool) {
|
||||
config := wconfig.GetWatcher().GetFullConfig()
|
||||
enableWsh := config.Settings.ConnWshEnabled
|
||||
askBeforeInstall := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true)
|
||||
connSettings, ok := conn.getConnectionConfig()
|
||||
if ok {
|
||||
if connSettings.ConnWshEnabled != nil {
|
||||
enableWsh = *connSettings.ConnWshEnabled
|
||||
}
|
||||
// if the connection object exists, and conn:askbeforewshinstall is not set, the user must have allowed it
|
||||
// TODO: in v0.12+ this should be removed. we'll explicitly write a "false" into the connection object on successful connection
|
||||
if connSettings.ConnAskBeforeWshInstall == nil {
|
||||
askBeforeInstall = false
|
||||
} else {
|
||||
askBeforeInstall = *connSettings.ConnAskBeforeWshInstall
|
||||
}
|
||||
}
|
||||
return enableWsh, askBeforeInstall
|
||||
}
|
||||
|
||||
type WshCheckResult struct {
|
||||
WshEnabled bool
|
||||
ClientVersion string
|
||||
NoWshReason string
|
||||
WshError error
|
||||
}
|
||||
|
||||
// returns (wsh-enabled, clientVersion, text-reason, wshError)
|
||||
func (conn *WslConn) tryEnableWsh(ctx context.Context, clientDisplayName string) WshCheckResult {
|
||||
conn.Infof(ctx, "running tryEnableWsh...\n")
|
||||
enableWsh, askBeforeInstall := conn.getConnWshSettings()
|
||||
conn.Infof(ctx, "wsh settings enable:%v ask:%v\n", enableWsh, askBeforeInstall)
|
||||
if !enableWsh {
|
||||
return WshCheckResult{NoWshReason: "conn:wshenabled set to false"}
|
||||
}
|
||||
if askBeforeInstall {
|
||||
allowInstall, err := conn.getPermissionToInstallWsh(ctx, clientDisplayName)
|
||||
if err != nil {
|
||||
log.Printf("error getting permission to install wsh: %v\n", err)
|
||||
return WshCheckResult{NoWshReason: "error getting user permission to install", WshError: err}
|
||||
}
|
||||
if !allowInstall {
|
||||
return WshCheckResult{NoWshReason: "user selected not to install wsh extensions"}
|
||||
}
|
||||
}
|
||||
err := conn.OpenDomainSocketListener(ctx)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR opening domain socket listener: %v\n", err)
|
||||
err = fmt.Errorf("error opening domain socket listener: %w", err)
|
||||
return WshCheckResult{NoWshReason: "error opening domain socket", WshError: err}
|
||||
}
|
||||
needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR starting conn server: %v\n", err)
|
||||
err = fmt.Errorf("error starting conn server: %w", err)
|
||||
return WshCheckResult{NoWshReason: "error starting connserver", WshError: err}
|
||||
}
|
||||
if needsInstall {
|
||||
conn.Infof(ctx, "connserver needs to be (re)installed\n")
|
||||
err = conn.InstallWsh(ctx, osArchStr)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR installing wsh: %v\n", err)
|
||||
err = fmt.Errorf("error installing wsh: %w", err)
|
||||
return WshCheckResult{NoWshReason: "error installing wsh/connserver", WshError: err}
|
||||
}
|
||||
needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err)
|
||||
err = fmt.Errorf("error starting conn server (after install): %w", err)
|
||||
return WshCheckResult{NoWshReason: "error starting connserver", WshError: err}
|
||||
}
|
||||
if needsInstall {
|
||||
conn.Infof(ctx, "conn server not installed correctly (after install)\n")
|
||||
err = fmt.Errorf("conn server not installed correctly (after install)")
|
||||
return WshCheckResult{NoWshReason: "connserver not installed properly", WshError: err}
|
||||
}
|
||||
return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion}
|
||||
} else {
|
||||
return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion}
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *WslConn) getConnectionConfig() (wshrpc.ConnKeywords, bool) {
|
||||
config := wconfig.GetWatcher().GetFullConfig()
|
||||
connSettings, ok := config.Connections[conn.GetName()]
|
||||
if !ok {
|
||||
return wshrpc.ConnKeywords{}, false
|
||||
}
|
||||
return connSettings, true
|
||||
}
|
||||
|
||||
func (conn *WslConn) persistWshInstalled(ctx context.Context, result WshCheckResult) {
|
||||
conn.WshEnabled.Store(result.WshEnabled)
|
||||
conn.SetWshError(result.WshError)
|
||||
conn.WithLock(func() {
|
||||
conn.NoWshReason = result.NoWshReason
|
||||
conn.WshVersion = result.ClientVersion
|
||||
})
|
||||
connConfig, ok := conn.getConnectionConfig()
|
||||
if ok && connConfig.ConnWshEnabled != nil {
|
||||
return
|
||||
}
|
||||
meta := make(map[string]any)
|
||||
meta["conn:wshenabled"] = result.WshEnabled
|
||||
err := wconfig.SetConnectionsConfigValue(conn.GetName(), meta)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "WARN could not write conn:wshenabled=%v to connections.json: %v\n", result.WshEnabled, err)
|
||||
log.Printf("warning: error writing to connections file: %v", err)
|
||||
}
|
||||
// doesn't return an error since none of this is required for connection to work
|
||||
}
|
||||
|
||||
func (conn *WslConn) connectInternal(ctx context.Context) error {
|
||||
conn.Infof(ctx, "connectInternal %s\n", conn.GetName())
|
||||
client, err := wsl.GetDistro(ctx, conn.Name)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR GetDistro: %s\n", err)
|
||||
log.Printf("error: failed to get distro %s: %s\n", conn.GetName(), err)
|
||||
return err
|
||||
}
|
||||
conn.WithLock(func() {
|
||||
conn.Client = client
|
||||
})
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("wsl-waitForDisconnect", recover())
|
||||
}()
|
||||
conn.waitForDisconnect()
|
||||
}()
|
||||
wshResult := conn.tryEnableWsh(ctx, conn.GetName())
|
||||
if !wshResult.WshEnabled {
|
||||
if wshResult.WshError != nil {
|
||||
conn.Infof(ctx, "ERROR enabling wsh: %v\n", wshResult.WshError)
|
||||
conn.Infof(ctx, "will connect with wsh disabled\n")
|
||||
} else {
|
||||
conn.Infof(ctx, "wsh not enabled: %s\n", wshResult.NoWshReason)
|
||||
}
|
||||
}
|
||||
conn.persistWshInstalled(ctx, wshResult)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) waitForDisconnect() {
|
||||
log.Printf("wait for disconnect in %+#v", conn)
|
||||
defer conn.FireConnChangeEvent()
|
||||
defer conn.HasWaiter.Store(false)
|
||||
err := conn.ConnController.Wait()
|
||||
conn.WithLock(func() {
|
||||
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
|
||||
// so we just set the status to "disconnected" here (not error)
|
||||
// don't overwrite any existing error (or error status)
|
||||
if err != nil && conn.Error == "" {
|
||||
conn.Error = err.Error()
|
||||
}
|
||||
if conn.Status != Status_Error {
|
||||
conn.Status = Status_Disconnected
|
||||
}
|
||||
conn.close_nolock()
|
||||
})
|
||||
}
|
||||
|
||||
func (conn *WslConn) SetWshError(err error) {
|
||||
conn.WithLock(func() {
|
||||
if err == nil {
|
||||
conn.WshError = ""
|
||||
} else {
|
||||
conn.WshError = err.Error()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (conn *WslConn) ClearWshError() {
|
||||
conn.WithLock(func() {
|
||||
conn.WshError = ""
|
||||
})
|
||||
}
|
||||
|
||||
func getConnInternal(name string) *WslConn {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
connName := wsl.WslName{Distro: name}
|
||||
rtn := clientControllerMap[name]
|
||||
if rtn == nil {
|
||||
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, WshEnabled: &atomic.Bool{}, HasWaiter: &atomic.Bool{}, cancelFn: nil}
|
||||
clientControllerMap[name] = rtn
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
|
||||
conn := getConnInternal(name)
|
||||
if conn.Client == nil && shouldConnect {
|
||||
conn.Connect(ctx)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
// Convenience function for ensuring a connection is established
|
||||
func EnsureConnection(ctx context.Context, connName string) error {
|
||||
if connName == "" {
|
||||
return nil
|
||||
}
|
||||
conn := GetWslConn(ctx, connName, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("connection not found: %s", connName)
|
||||
}
|
||||
connStatus := conn.DeriveConnStatus()
|
||||
switch connStatus.Status {
|
||||
case Status_Connected:
|
||||
return nil
|
||||
case Status_Connecting:
|
||||
return conn.WaitForConnect(ctx)
|
||||
case Status_Init, Status_Disconnected:
|
||||
return conn.Connect(ctx)
|
||||
case Status_Error:
|
||||
return fmt.Errorf("connection error: %s", connStatus.Error)
|
||||
default:
|
||||
return fmt.Errorf("unknown connection status %q", connStatus.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func DisconnectClient(connName string) error {
|
||||
conn := getConnInternal(connName)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("client %q not found", connName)
|
||||
}
|
||||
err := conn.Close()
|
||||
return err
|
||||
}
|
Loading…
Reference in New Issue
Block a user