WSL Updates for New Architecture (#1756)

This adapts most of the WSL code to follow the new architecture that ssh
uses.

---------

Co-authored-by: sawka <mike@commandline.dev>
This commit is contained in:
Sylvie Crowe 2025-01-16 15:54:58 -08:00 committed by GitHub
parent b7dca41b9c
commit ff5f26709c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1256 additions and 994 deletions

View File

@ -35,7 +35,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshserver"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wslconn"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
@ -145,7 +145,7 @@ func beforeSendActivityUpdate(ctx context.Context) {
activity.Blocks, _ = wstore.DBGetBlockViewCounts(ctx)
activity.NumWindows, _ = wstore.DBGetCount[*waveobj.Window](ctx)
activity.NumSSHConn = conncontroller.GetNumSSHHasConnected()
activity.NumWSLConn = wsl.GetNumWSLHasConnected()
activity.NumWSLConn = wslconn.GetNumWSLHasConnected()
activity.NumWSNamed, activity.NumWS, _ = wstore.DBGetWSCounts(ctx)
err := telemetry.UpdateActivity(ctx, activity)
if err != nil {

View File

@ -33,7 +33,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wslconn"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
@ -369,7 +369,7 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFunc()
wslConn := wsl.GetWslConn(credentialCtx, wslName, false)
wslConn := wslconn.GetWslConn(credentialCtx, wslName, false)
connStatus := wslConn.DeriveConnStatus()
if connStatus.Status != conncontroller.Status_Connected {
return nil, fmt.Errorf("not connected, cannot start shellproc")
@ -377,10 +377,14 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
// create jwt
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName())
sockName := wslConn.GetDomainSocketName()
rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}
jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName)
if err != nil {
return nil, fmt.Errorf("error making jwt token: %w", err)
}
swapToken.SockName = sockName
swapToken.RpcContext = &rpcContext
swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
@ -747,7 +751,7 @@ func CheckConnStatus(blockId string) error {
}
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(context.Background(), distroName, false)
conn := wslconn.GetWslConn(context.Background(), distroName, false)
connStatus := conn.DeriveConnStatus()
if connStatus.Status != conncontroller.Status_Connected {
return fmt.Errorf("not connected: %s", connStatus.Status)

View File

@ -1,25 +1,24 @@
//go:build windows
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package genconn
import (
"context"
"fmt"
"io"
"sync"
"github.com/ubuntu/gowsl"
"github.com/wavetermdev/waveterm/pkg/wsl"
)
var _ ShellClient = (*WSLShellClient)(nil)
type WSLShellClient struct {
distro *gowsl.Distro
distro *wsl.Distro
}
func MakeWSLShellClient(distro *gowsl.Distro) *WSLShellClient {
func MakeWSLShellClient(distro *wsl.Distro) *WSLShellClient {
return &WSLShellClient{distro: distro}
}
@ -28,8 +27,8 @@ func (c *WSLShellClient) MakeProcessController(cmdSpec CommandSpec) (ShellProces
}
type WSLProcessController struct {
distro *gowsl.Distro
cmd *gowsl.Cmd
distro *wsl.Distro
cmd *wsl.WslCmd
lock *sync.Mutex
once *sync.Once
stdinPiped bool
@ -40,13 +39,13 @@ type WSLProcessController struct {
cmdSpec CommandSpec
}
func MakeWSLProcessController(distro *gowsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) {
func MakeWSLProcessController(distro *wsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) {
fullCmd, err := BuildShellCommand(cmdSpec)
if err != nil {
return nil, fmt.Errorf("failed to build shell command: %w", err)
}
cmd := distro.Command(nil, fullCmd)
cmd := distro.WslCommand(context.Background(), fullCmd)
if cmd == nil {
return nil, fmt.Errorf("failed to create WSL command")
}
@ -87,9 +86,14 @@ func (w *WSLProcessController) Kill() {
w.lock.Lock()
defer w.lock.Unlock()
if w.cmd != nil && w.cmd.Process != nil {
w.cmd.Process.Kill()
if w.cmd == nil {
return
}
process := w.cmd.GetProcess()
if process == nil {
return
}
process.Kill()
}
func (w *WSLProcessController) StdinPipe() (io.WriteCloser, error) {

View File

@ -308,12 +308,13 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
if err != nil {
return false, "", "", fmt.Errorf("unable to start conn controller command: %w", err)
}
linesChan := wshutil.StreamToLinesChan(pipeRead)
versionLine, err := wshutil.ReadLineWithTimeout(linesChan, 2*time.Second)
linesChan := utilfn.StreamToLinesChan(pipeRead)
versionLine, err := utilfn.ReadLineWithTimeout(linesChan, 2*time.Second)
if err != nil {
sshSession.Close()
return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
}
conn.Infof(ctx, "actual connnserverversion: %q\n", versionLine)
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
isUpToDate, clientVersion, osArchStr, err := IsWshVersionUpToDate(ctx, versionLine)
if err != nil {
@ -326,11 +327,10 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
}
conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
if !isUpToDate {
sshSession.Close()
return true, clientVersion, osArchStr, nil
}
jwtLine, err := wshutil.ReadLineWithTimeout(linesChan, 3*time.Second)
jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second)
if err != nil {
sshSession.Close()
return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err)
@ -401,12 +401,6 @@ type WshInstallOpts struct {
NoUserPrompt bool
}
type WshInstallSkipError struct{}
func (wise *WshInstallSkipError) Error() string {
return "skipping wsh installation"
}
var queryTextTemplate = strings.TrimSpace(`
Wave requires Wave Shell Extensions to be
installed on %q
@ -555,7 +549,7 @@ func (conn *SSHConn) Connect(ctx context.Context, connFlags *wshrpc.ConnKeywords
}
})
if !connectAllowed {
conn.Infof(ctx, "cannot connect to when status is %q\n", conn.GetStatus())
conn.Infof(ctx, "cannot connect to %q when status is %q\n", conn.GetName(), conn.GetStatus())
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
}
conn.Infof(ctx, "trying to connect to %q...\n", conn.GetName())
@ -754,7 +748,12 @@ func (conn *SSHConn) connectInternal(ctx context.Context, connFlags *wshrpc.Conn
conn.WithLock(func() {
conn.Client = client
})
go conn.waitForDisconnect()
go func() {
defer func() {
panichandler.PanicHandler("conncontroller:waitForDisconnect", recover())
}()
conn.waitForDisconnect()
}()
fmtAddr := knownhosts.Normalize(fmt.Sprintf("%s@%s", client.User(), client.RemoteAddr().String()))
conn.Infof(ctx, "normalized knownhosts address: %s\n", fmtAddr)
clientDisplayName := fmt.Sprintf("%s (%s)", conn.GetName(), fmtAddr)

View File

@ -15,7 +15,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wconfig"
"github.com/wavetermdev/waveterm/pkg/wcore"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wslconn"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
@ -42,7 +42,7 @@ func (cs *ClientService) GetTab(tabId string) (*waveobj.Tab, error) {
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
sshStatuses := conncontroller.GetAllConnStatus()
wslStatuses := wsl.GetAllConnStatus()
wslStatuses := wslconn.GetAllConnStatus()
return append(sshStatuses, wslStatuses...), nil
}

View File

@ -21,7 +21,6 @@ import (
"github.com/creack/pty"
"github.com/wavetermdev/waveterm/pkg/blocklogger"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/remote"
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/util/pamparse"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
@ -30,7 +29,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wslconn"
)
const DefaultGracefulKillWait = 400 * time.Millisecond
@ -151,85 +150,100 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s))
}
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second)
defer cancelFn()
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wslconn.WslConn) (*ShellProc, error) {
client := conn.GetClient()
shellPath := cmdOpts.ShellPath
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)")
connRoute := wshutil.MakeConnectionRouteId(conn.GetName())
rpcClient := wshclient.GetBareRpcClient()
remoteInfo, err := wshclient.RemoteGetInfoCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000})
if err != nil {
return nil, fmt.Errorf("unable to obtain client info: %w", err)
}
log.Printf("client info collected: %+#v", remoteInfo)
var shellPath string
if cmdOpts.ShellPath != "" {
conn.Infof(ctx, "using shell path from command opts: %s\n", cmdOpts.ShellPath)
shellPath = cmdOpts.ShellPath
}
configShellPath := conn.GetConfigShellPath()
if shellPath == "" && configShellPath != "" {
conn.Infof(ctx, "using shell path from config (conn:shellpath): %s\n", configShellPath)
shellPath = configShellPath
}
if shellPath == "" && remoteInfo.Shell != "" {
conn.Infof(ctx, "using shell path detected on remote machine: %s\n", remoteInfo.Shell)
shellPath = remoteInfo.Shell
}
if shellPath == "" {
remoteShellPath, err := wsl.DetectShell(utilCtx, client)
if err != nil {
return nil, err
}
shellPath = remoteShellPath
conn.Infof(ctx, "no shell path detected, using default (/bin/bash)\n")
shellPath = "/bin/bash"
}
var shellOpts []string
log.Printf("detected shell: %s", shellPath)
var cmdCombined string
log.Printf("detected shell %q for conn %q\n", shellPath, conn.GetName())
err := wsl.InstallClientRcFiles(utilCtx, client)
err = wshclient.RemoteInstallRcFilesCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000})
if err != nil {
log.Printf("error installing rc files: %v", err)
return nil, err
}
homeDir := wsl.GetHomeDir(utilCtx, client)
shellOpts = append(shellOpts, "~", "-d", client.Name())
var subShellOpts []string
shellOpts = append(shellOpts, cmdOpts.ShellOpts...)
shellType := shellutil.GetShellTypeFromShellPath(shellPath)
conn.Infof(ctx, "detected shell type: %s\n", shellType)
if cmdStr == "" {
/* transform command in order to inject environment vars */
if isBashShell(shellPath) {
log.Printf("recognized as bash shell")
if shellType == shellutil.ShellType_bash {
// add --rcfile
// cant set -l or -i with --rcfile
subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
} else if isFishShell(shellPath) {
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
subShellOpts = append(subShellOpts, "-C", carg)
} else if wsl.IsPowershell(shellPath) {
bashPath := fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir)
shellOpts = append(shellOpts, "--rcfile", bashPath)
} else if shellType == shellutil.ShellType_fish {
if cmdOpts.Login {
shellOpts = append(shellOpts, "-l")
}
// source the wave.fish file
waveFishPath := fmt.Sprintf("~/.waveterm/%s/wave.fish", shellutil.FishIntegrationDir)
carg := fmt.Sprintf(`"source %s"`, waveFishPath)
shellOpts = append(shellOpts, "-C", carg)
} else if shellType == shellutil.ShellType_pwsh {
pwshPath := fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir)
// powershell is weird about quoted path executables and requires an ampersand first
shellPath = "& " + shellPath
subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", fmt.Sprintf("%s/.waveterm/%s/wavepwsh.ps1", homeDir, shellutil.PwshIntegrationDir))
shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", pwshPath)
} else {
if cmdOpts.Login {
subShellOpts = append(subShellOpts, "-l")
shellOpts = append(shellOpts, "-l")
}
if cmdOpts.Interactive {
subShellOpts = append(subShellOpts, "-i")
shellOpts = append(shellOpts, "-i")
}
// can't set environment vars this way
// will try to do later if possible
// zdotdir setting moved to after session is created
}
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
} else {
// TODO check quoting of cmdStr
shellPath = cmdStr
if cmdOpts.Login {
subShellOpts = append(subShellOpts, "-l")
}
if cmdOpts.Interactive {
subShellOpts = append(subShellOpts, "-i")
}
subShellOpts = append(subShellOpts, "-c", cmdStr)
shellOpts = append(shellOpts, "-c", cmdStr)
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
}
conn.Infof(ctx, "starting shell, using command: %s\n", cmdCombined)
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)\n")
if shellType == shellutil.ShellType_zsh {
zshDir := fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir)
conn.Infof(ctx, "setting ZDOTDIR to %s\n", zshDir)
cmdCombined = fmt.Sprintf(`ZDOTDIR=%s %s`, zshDir, cmdCombined)
}
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
if !ok {
return nil, fmt.Errorf("no jwt token provided to connection")
}
if remote.IsPowershell(shellPath) {
shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken))
} else {
shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken))
}
cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
if isZshShell(shellPath) {
shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR=%s/.waveterm/%s`, homeDir, shellutil.ZshIntegrationDir))
}
shellOpts = append(shellOpts, shellPath)
shellOpts = append(shellOpts, subShellOpts...)
log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " "))
ecmd := exec.Command("wsl.exe", shellOpts...)
log.Printf("full combined command: %s", cmdCombined)
ecmd := exec.Command("wsl.exe", "~", "-d", client.Name(), "--", "sh", "-c", cmdCombined)
if termSize.Rows == 0 || termSize.Cols == 0 {
termSize.Rows = shellutil.DefaultTermRows
termSize.Cols = shellutil.DefaultTermCols
@ -237,6 +251,7 @@ func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr st
if termSize.Rows <= 0 || termSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", termSize)
}
shellutil.AddTokenSwapEntry(cmdOpts.SwapToken)
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
if err != nil {
return nil, err

View File

@ -8,6 +8,9 @@ import (
"bytes"
"fmt"
"io"
"log"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
)
type PacketParser struct {
@ -15,11 +18,38 @@ type PacketParser struct {
Ch chan []byte
}
func ParseWithLinesChan(input chan utilfn.LineOutput, packetCh chan []byte, rawCh chan []byte) {
defer close(packetCh)
defer close(rawCh)
for {
// note this line doesn't have a trailing newline
line, ok := <-input
if !ok {
return
}
if line.Error != nil {
log.Printf("ParseWithLinesChan: error reading line: %v", line.Error)
return
}
if len(line.Line) <= 1 {
// just a blank line
continue
}
if bytes.HasPrefix([]byte(line.Line), []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix([]byte(line.Line), []byte{'}'}) {
// strip off the leading "##"
packetCh <- []byte(line.Line[3:len(line.Line)])
} else {
rawCh <- []byte(line.Line)
}
}
}
func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error {
bufReader := bufio.NewReader(input)
defer close(packetCh)
defer close(rawCh)
for {
// note this line does have a trailing newline
line, err := bufReader.ReadBytes('\n')
if err == io.EOF {
return nil

View File

@ -7,7 +7,6 @@ import "regexp"
var (
safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`)
psSafePattern = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
envVarNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
)
@ -73,10 +72,6 @@ func HardQuotePowerShell(s string) string {
return "\"\""
}
if psSafePattern.MatchString(s) {
return s
}
buf := make([]byte, 0, len(s)+5)
buf = append(buf, '"')

View File

@ -153,7 +153,7 @@ wsh completion fish | source
$env:PATH = {{.WSHBINDIR_PWSH}} + "{{.PATHSEP}}" + $env:PATH
# Source dynamic script from wsh token
$waveterm_swaptoken_output = wsh token $env:WAVETERM_SWAPTOKEN pwsh 2>$null
$waveterm_swaptoken_output = wsh token $env:WAVETERM_SWAPTOKEN pwsh 2>$null | Out-String
if ($waveterm_swaptoken_output -and $waveterm_swaptoken_output -ne "") {
Invoke-Expression $waveterm_swaptoken_output
}

View File

@ -0,0 +1,85 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package utilfn
import (
"bytes"
"context"
"io"
"time"
)
type LineOutput struct {
Line string
Error error
}
type lineBuf struct {
buf []byte
inLongLine bool
}
const maxLineLength = 128 * 1024
func ReadLineWithTimeout(ch chan LineOutput, timeout time.Duration) (string, error) {
select {
case output := <-ch:
if output.Error != nil {
return "", output.Error
}
return output.Line, nil
case <-time.After(timeout):
return "", context.DeadlineExceeded
}
}
func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]byte)) {
for len(readBuf) > 0 {
nlIdx := bytes.IndexByte(readBuf, '\n')
if nlIdx == -1 {
if lineBuf.inLongLine || len(lineBuf.buf)+len(readBuf) > maxLineLength {
lineBuf.buf = nil
lineBuf.inLongLine = true
return
}
lineBuf.buf = append(lineBuf.buf, readBuf...)
return
}
if !lineBuf.inLongLine && len(lineBuf.buf)+nlIdx <= maxLineLength {
line := append(lineBuf.buf, readBuf[:nlIdx]...)
lineFn(line)
}
lineBuf.buf = nil
lineBuf.inLongLine = false
readBuf = readBuf[nlIdx+1:]
}
}
func StreamToLines(input io.Reader, lineFn func([]byte)) error {
var lineBuf lineBuf
readBuf := make([]byte, 16*1024)
for {
n, err := input.Read(readBuf)
streamToLines_processBuf(&lineBuf, readBuf[:n], lineFn)
if err != nil {
return err
}
}
}
// starts a goroutine to drive the channel
// line output does not include the trailing newline
func StreamToLinesChan(input io.Reader) chan LineOutput {
ch := make(chan LineOutput)
go func() {
defer close(ch)
err := StreamToLines(input, func(line []byte) {
ch <- LineOutput{Line: string(line)}
})
if err != nil && err != io.EOF {
ch <- LineOutput{Error: err}
}
}()
return ch
}

View File

@ -37,6 +37,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wslconn"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
@ -609,7 +610,7 @@ func (ws *WshServer) ConnStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus
}
func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) {
rtn := wsl.GetAllConnStatus()
rtn := wslconn.GetAllConnStatus()
return rtn, nil
}
@ -633,7 +634,7 @@ func (ws *WshServer) ConnEnsureCommand(ctx context.Context, data wshrpc.ConnExtD
ctx = termCtxWithLogBlockId(ctx, data.LogBlockId)
if strings.HasPrefix(data.ConnName, "wsl://") {
distroName := strings.TrimPrefix(data.ConnName, "wsl://")
return wsl.EnsureConnection(ctx, distroName)
return wslconn.EnsureConnection(ctx, distroName)
}
return conncontroller.EnsureConnection(ctx, data.ConnName)
}
@ -641,7 +642,7 @@ func (ws *WshServer) ConnEnsureCommand(ctx context.Context, data wshrpc.ConnExtD
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(ctx, distroName, false)
conn := wslconn.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("distro not found: %s", connName)
}
@ -664,7 +665,7 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connRequest wshrpc.
connName := connRequest.Host
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(ctx, distroName, false)
conn := wslconn.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
@ -687,11 +688,11 @@ func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, data wshrpc.Co
connName := data.ConnName
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(ctx, distroName, false)
conn := wslconn.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
return conn.CheckAndInstallWsh(ctx, connName, &wsl.WshInstallOpts{Force: true, NoUserPrompt: true})
return conn.InstallWsh(ctx, "")
}
connOpts, err := remote.ParseOpts(connName)
if err != nil {

View File

@ -4,11 +4,10 @@
package wshutil
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
)
// special I/O wrappers for wshrpc
@ -16,81 +15,8 @@ import (
// * stream (json lines)
// * websocket (json packets)
type lineBuf struct {
buf []byte
inLongLine bool
}
const maxLineLength = 128 * 1024
func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]byte)) {
for len(readBuf) > 0 {
nlIdx := bytes.IndexByte(readBuf, '\n')
if nlIdx == -1 {
if lineBuf.inLongLine || len(lineBuf.buf)+len(readBuf) > maxLineLength {
lineBuf.buf = nil
lineBuf.inLongLine = true
return
}
lineBuf.buf = append(lineBuf.buf, readBuf...)
return
}
if !lineBuf.inLongLine && len(lineBuf.buf)+nlIdx <= maxLineLength {
line := append(lineBuf.buf, readBuf[:nlIdx]...)
lineFn(line)
}
lineBuf.buf = nil
lineBuf.inLongLine = false
readBuf = readBuf[nlIdx+1:]
}
}
func StreamToLines(input io.Reader, lineFn func([]byte)) error {
var lineBuf lineBuf
readBuf := make([]byte, 16*1024)
for {
n, err := input.Read(readBuf)
streamToLines_processBuf(&lineBuf, readBuf[:n], lineFn)
if err != nil {
return err
}
}
}
type LineOutput struct {
Line string
Error error
}
// starts a goroutine to drive the channel
func StreamToLinesChan(input io.Reader) chan LineOutput {
ch := make(chan LineOutput)
go func() {
defer close(ch)
err := StreamToLines(input, func(line []byte) {
ch <- LineOutput{Line: string(line)}
})
if err != nil && err != io.EOF {
ch <- LineOutput{Error: err}
}
}()
return ch
}
func ReadLineWithTimeout(ch chan LineOutput, timeout time.Duration) (string, error) {
select {
case output := <-ch:
if output.Error != nil {
return "", output.Error
}
return output.Line, nil
case <-time.After(timeout):
return "", context.DeadlineExceeded
}
}
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
return StreamToLines(input, func(line []byte) {
return utilfn.StreamToLines(input, func(line []byte) {
output <- line
})
}

View File

@ -25,6 +25,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"golang.org/x/term"
@ -418,10 +419,10 @@ type WriteFlusher interface {
}
// blocking, returns if there is an error, or on EOF of input
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
func HandleStdIOClient(logName string, input chan utilfn.LineOutput, output io.Writer) {
proxy := MakeRpcMultiProxy()
rawCh := make(chan []byte, DefaultInputChSize)
go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
go packetparser.ParseWithLinesChan(input, proxy.FromRemoteRawCh, rawCh)
doneCh := make(chan struct{})
var doneOnce sync.Once
closeDoneCh := func() {
@ -455,6 +456,9 @@ func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
}()
defer closeDoneCh()
for msg := range rawCh {
if !bytes.HasSuffix(msg, []byte{'\n'}) {
msg = append(msg, '\n')
}
log.Printf("[%s:stdout] %s", logName, msg)
}
}()

View File

@ -13,6 +13,10 @@ import (
"os/exec"
)
type WslName struct {
Distro string `json:"distro"`
}
func RegisteredDistros(ctx context.Context) (distros []Distro, err error) {
return nil, fmt.Errorf("RegisteredDistros not implemented on this system")
}

View File

@ -1,289 +0,0 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
"text/template"
"time"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
)
func DetectShell(ctx context.Context, client *Distro) (string, error) {
wshPath := GetWshPath(ctx, client)
cmd := client.WslCommand(ctx, wshPath+" shell")
log.Printf("shell detecting using command: %s shell", wshPath)
out, err := cmd.Output()
if err != nil {
log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
return "/bin/bash", nil
}
log.Printf("detecting shell: %s", out)
// quoting breaks this particular case
return strings.TrimSpace(string(out)), nil
}
func GetWshVersion(ctx context.Context, client *Distro) (string, error) {
wshPath := GetWshPath(ctx, client)
cmd := client.WslCommand(ctx, wshPath+" version")
out, err := cmd.Output()
if err != nil {
return "", err
}
return strings.TrimSpace(string(out)), nil
}
func GetWshPath(ctx context.Context, client *Distro) string {
defaultPath := wavebase.RemoteFullWshBinPath
cmd := client.WslCommand(ctx, "which wsh")
out, whichErr := cmd.Output()
if whichErr == nil {
return strings.TrimSpace(string(out))
}
cmd = client.WslCommand(ctx, "where.exe wsh")
out, whereErr := cmd.Output()
if whereErr == nil {
return strings.TrimSpace(string(out))
}
// no custom install, use default path
return defaultPath
}
func hasBashInstalled(ctx context.Context, client *Distro) (bool, error) {
cmd := client.WslCommand(ctx, "which bash")
out, whichErr := cmd.Output()
if whichErr == nil && len(out) != 0 {
return true, nil
}
cmd = client.WslCommand(ctx, "where.exe bash")
out, whereErr := cmd.Output()
if whereErr == nil && len(out) != 0 {
return true, nil
}
// note: we could also check in /bin/bash explicitly
// just in case that wasn't added to the path. but if
// that's true, we will most likely have worse
// problems going forward
return false, nil
}
func GetClientOs(ctx context.Context, client *Distro) (string, error) {
cmd := client.WslCommand(ctx, "uname -s")
out, unixErr := cmd.CombinedOutput()
if unixErr == nil {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return formatted, nil
}
cmd = client.WslCommand(ctx, "echo %OS%")
out, cmdErr := cmd.Output()
if cmdErr == nil && strings.TrimSpace(string(out)) != "%OS%" {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return strings.Split(formatted, "_")[0], nil
}
cmd = client.WslCommand(ctx, "echo $env:OS")
out, psErr := cmd.Output()
if psErr == nil && strings.TrimSpace(string(out)) != "$env:OS" {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return strings.Split(formatted, "_")[0], nil
}
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
}
func GetClientArch(ctx context.Context, client *Distro) (string, error) {
cmd := client.WslCommand(ctx, "uname -m")
out, unixErr := cmd.Output()
if unixErr == nil {
return utilfn.FilterValidArch(string(out))
}
cmd = client.WslCommand(ctx, "echo %PROCESSOR_ARCHITECTURE%")
out, cmdErr := cmd.CombinedOutput()
if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" {
return utilfn.FilterValidArch(string(out))
}
cmd = client.WslCommand(ctx, "echo $env:PROCESSOR_ARCHITECTURE")
out, psErr := cmd.CombinedOutput()
if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" {
return utilfn.FilterValidArch(string(out))
}
return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
}
type CancellableCmd struct {
Cmd *WslCmd
Cancel func()
}
var installTemplatesRawBash = map[string]string{
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
"cat": `bash -c 'cat > {{.tempPath}}'`,
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
}
var installTemplatesRawDefault = map[string]string{
"mkdir": `mkdir -p {{.installDir}}`,
"cat": `cat > {{.tempPath}}`,
"mv": `mv {{.tempPath}} {{.installPath}}`,
"chmod": `chmod a+x {{.installPath}}`,
}
func makeCancellableCommand(ctx context.Context, client *Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
cmdContext, cmdCancel := context.WithCancel(ctx)
cmdStr := &bytes.Buffer{}
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
if err != nil {
cmdCancel()
return nil, err
}
cmdTemplate.Execute(cmdStr, words)
cmd := client.WslCommand(cmdContext, cmdStr.String())
return &CancellableCmd{cmd, cmdCancel}, nil
}
func CpHostToRemote(ctx context.Context, client *Distro, sourcePath string, destPath string) error {
// warning: does not work on windows remote yet
bashInstalled, err := hasBashInstalled(ctx, client)
if err != nil {
return err
}
var selectedTemplatesRaw map[string]string
if bashInstalled {
selectedTemplatesRaw = installTemplatesRawBash
} else {
log.Printf("bash is not installed on remote. attempting with default shell")
selectedTemplatesRaw = installTemplatesRawDefault
}
// I need to use toSlash here to force unix keybindings
// this means we can't guarantee it will work on a remote windows machine
var installWords = map[string]string{
"installDir": filepath.ToSlash(filepath.Dir(destPath)),
"tempPath": destPath + ".temp",
"installPath": destPath,
}
installStepCmds := make(map[string]*CancellableCmd)
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
if err != nil {
return err
}
installStepCmds[cmdName] = cancellableCmd
}
_, err = installStepCmds["mkdir"].Cmd.Output()
if err != nil {
return err
}
// the cat part of this is complicated since it requires stdin
catCmd := installStepCmds["cat"].Cmd
catStdin, err := catCmd.StdinPipe()
if err != nil {
return err
}
err = catCmd.Start()
if err != nil {
return err
}
input, err := os.Open(sourcePath)
if err != nil {
return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
}
go func() {
defer func() {
panichandler.PanicHandler("wslutil:cpHostToRemote:catStdin", recover())
}()
io.Copy(catStdin, input)
installStepCmds["cat"].Cancel()
// backup just in case something weird happens
// could cause potential race condition, but very
// unlikely
time.Sleep(time.Second * 1)
process := catCmd.GetProcess()
if process != nil {
process.Kill()
}
}()
catErr := catCmd.Wait()
if catErr != nil && !errors.Is(catErr, context.Canceled) {
return catErr
}
_, err = installStepCmds["mv"].Cmd.Output()
if err != nil {
return err
}
_, err = installStepCmds["chmod"].Cmd.Output()
if err != nil {
return err
}
return nil
}
func InstallClientRcFiles(ctx context.Context, client *Distro) error {
path := GetWshPath(ctx, client)
log.Printf("path to wsh searched is: %s", path)
cmd := client.WslCommand(ctx, path+" rcfiles")
_, err := cmd.Output()
return err
}
func GetHomeDir(ctx context.Context, client *Distro) string {
// note: also works for powershell
cmd := client.WslCommand(ctx, `echo "$HOME"`)
out, err := cmd.Output()
if err == nil {
return strings.TrimSpace(string(out))
}
cmd = client.WslCommand(ctx, `echo %userprofile%`)
out, err = cmd.Output()
if err == nil {
return strings.TrimSpace(string(out))
}
return "~"
}
func IsPowershell(shellPath string) bool {
// get the base path, and then check contains
shellBase := filepath.Base(shellPath)
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
}

View File

@ -18,6 +18,10 @@ import (
var RegisteredDistros = gowsl.RegisteredDistros
var DefaultDistro = gowsl.DefaultDistro
type WslName struct {
Distro string `json:"distro"`
}
type Distro struct {
gowsl.Distro
}

View File

@ -1,533 +0,0 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"context"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/telemetry"
"github.com/wavetermdev/waveterm/pkg/userinput"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
const (
Status_Init = "init"
Status_Connecting = "connecting"
Status_Connected = "connected"
Status_Disconnected = "disconnected"
Status_Error = "error"
)
const DefaultConnectionTimeout = 60 * time.Second
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[string]*WslConn)
var activeConnCounter = &atomic.Int32{}
type WslConn struct {
Lock *sync.Mutex
Status string
Name WslName
Client *Distro
SockName string
DomainSockListener net.Listener
ConnController *WslCmd
Error string
HasWaiter *atomic.Bool
LastConnectTime int64
ActiveConnNum int
cancelFn func()
}
type WslName struct {
Distro string `json:"distro"`
}
func GetAllConnStatus() []wshrpc.ConnStatus {
globalLock.Lock()
defer globalLock.Unlock()
var connStatuses []wshrpc.ConnStatus
for _, conn := range clientControllerMap {
connStatuses = append(connStatuses, conn.DeriveConnStatus())
}
return connStatuses
}
func GetNumWSLHasConnected() int {
globalLock.Lock()
defer globalLock.Unlock()
var connectedCount int
for _, conn := range clientControllerMap {
if conn.LastConnectTime > 0 {
connectedCount++
}
}
return connectedCount
}
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return wshrpc.ConnStatus{
Status: conn.Status,
Connected: conn.Status == Status_Connected,
WshEnabled: true, // always use wsh for wsl connections (temporary)
Connection: conn.GetName(),
HasConnected: (conn.LastConnectTime > 0),
ActiveConnNum: conn.ActiveConnNum,
Error: conn.Error,
}
}
func (conn *WslConn) FireConnChangeEvent() {
status := conn.DeriveConnStatus()
event := wps.WaveEvent{
Event: wps.Event_ConnChange,
Scopes: []string{
fmt.Sprintf("connection:%s", conn.GetName()),
},
Data: status,
}
log.Printf("sending event: %+#v", event)
wps.Broker.Publish(event)
}
func (conn *WslConn) Close() error {
defer conn.FireConnChangeEvent()
conn.WithLock(func() {
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
// if status is init, disconnected, or error don't change it
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
// we must wait for the waiter to complete
startTime := time.Now()
for conn.HasWaiter.Load() {
time.Sleep(10 * time.Millisecond)
if time.Since(startTime) > 2*time.Second {
return fmt.Errorf("timeout waiting for waiter to complete")
}
}
return nil
}
func (conn *WslConn) close_nolock() {
// does not set status (that should happen at another level)
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
conn.DomainSockListener = nil
}
if conn.ConnController != nil {
conn.cancelFn() // this suspends the conn controller
conn.ConnController = nil
}
if conn.Client != nil {
// conn.Client.Close() is not relevant here
// we do not want to completely close the wsl in case
// other applications are using it
conn.Client = nil
}
}
func (conn *WslConn) GetDomainSocketName() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.SockName
}
func (conn *WslConn) GetStatus() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Status
}
func (conn *WslConn) GetName() string {
// no lock required because opts is immutable
return "wsl://" + conn.Name.Distro
}
/**
* This function is does not set a listener for WslConn
* It is still required in order to set SockName
**/
func (conn *WslConn) OpenDomainSocketListener() error {
var allowed bool
conn.WithLock(func() {
if conn.Status != Status_Connecting {
allowed = false
} else {
allowed = true
}
})
if !allowed {
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
}
conn.WithLock(func() {
conn.SockName = wavebase.RemoteFullDomainSocketPath
})
return nil
}
func (conn *WslConn) StartConnServer() error {
utilCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
var allowed bool
conn.WithLock(func() {
if conn.Status != Status_Connecting {
allowed = false
} else {
allowed = true
}
})
if !allowed {
return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
}
client := conn.GetClient()
wshPath := GetWshPath(utilCtx, client)
rpcCtx := wshrpc.RpcContext{
ClientType: wshrpc.ClientType_ConnServer,
Conn: conn.GetName(),
}
sockName := conn.GetDomainSocketName()
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
if err != nil {
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
}
shellPath, err := DetectShell(utilCtx, client)
if err != nil {
return err
}
var cmdStr string
if IsPowershell(shellPath) {
cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
} else {
cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
}
log.Printf("starting conn controller: %s\n", cmdStr)
connServerCtx, cancelFn := context.WithCancel(context.Background())
conn.WithLock(func() {
if conn.cancelFn != nil {
conn.cancelFn()
}
conn.cancelFn = cancelFn
})
cmd := client.WslCommand(connServerCtx, cmdStr)
pipeRead, pipeWrite := io.Pipe()
inputPipeRead, inputPipeWrite := io.Pipe()
cmd.SetStdout(pipeWrite)
cmd.SetStderr(pipeWrite)
cmd.SetStdin(inputPipeRead)
err = cmd.Start()
if err != nil {
return fmt.Errorf("unable to start conn controller: %w", err)
}
conn.WithLock(func() {
conn.ConnController = cmd
})
// service the I/O
go func() {
defer func() {
panichandler.PanicHandler("wsl:StartConnServer:wait", recover())
}()
// wait for termination, clear the controller
defer conn.WithLock(func() {
conn.ConnController = nil
})
waitErr := cmd.Wait()
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
}()
go func() {
defer func() {
panichandler.PanicHandler("wsl:StartConnServer:handleStdIOClient", recover())
}()
logName := fmt.Sprintf("conncontroller:%s", conn.GetName())
wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite)
}()
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
if err != nil {
return fmt.Errorf("timeout waiting for connserver to register")
}
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
return nil
}
type WshInstallOpts struct {
Force bool
NoUserPrompt bool
}
func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error {
if opts == nil {
opts = &WshInstallOpts{}
}
client := conn.GetClient()
if client == nil {
return fmt.Errorf("client is nil")
}
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := GetWshVersion(ctx, client)
if err == nil && clientVersion == expectedVersion && !opts.Force {
return nil
}
var queryText string
var title string
if opts.Force {
queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName)
title = "Install Wave Shell Extensions"
} else if err != nil {
queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+
"installed on `%s` \n"+
"to ensure a seamless experience. \n\n"+
"Would you like to install them?", clientDisplayName)
title = "Install Wave Shell Extensions"
} else {
// don't ask for upgrading the version
opts.NoUserPrompt = true
}
if !opts.NoUserPrompt {
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
Markdown: true,
CheckBoxMsg: "Don't show me this again",
}
response, err := userinput.GetUserInput(ctx, request)
if err != nil || !response.Confirm {
return err
}
if response.CheckboxStat {
meta := waveobj.MetaMapType{
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
}
err := wconfig.SetBaseConfigValue(meta)
if err != nil {
return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err)
}
}
}
log.Printf("attempting to install wsh to `%s`", clientDisplayName)
clientOs, err := GetClientOs(ctx, client)
if err != nil {
return err
}
clientArch, err := GetClientArch(ctx, client)
if err != nil {
return err
}
// attempt to install extension
wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
if err != nil {
return err
}
err = CpHostToRemote(ctx, client, wshLocalPath, wavebase.RemoteFullWshBinPath)
if err != nil {
return err
}
log.Printf("successfully installed wsh on %s\n", conn.GetName())
return nil
}
func (conn *WslConn) GetClient() *Distro {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Client
}
func (conn *WslConn) Reconnect(ctx context.Context) error {
err := conn.Close()
if err != nil {
return err
}
return conn.Connect(ctx)
}
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
for {
status := conn.DeriveConnStatus()
if status.Status == Status_Connected {
return nil
}
if status.Status == Status_Connecting {
select {
case <-ctx.Done():
return fmt.Errorf("context timeout")
case <-time.After(100 * time.Millisecond):
continue
}
}
if status.Status == Status_Init || status.Status == Status_Disconnected {
return fmt.Errorf("disconnected")
}
if status.Status == Status_Error {
return fmt.Errorf("error: %v", status.Error)
}
return fmt.Errorf("unknown status: %q", status.Status)
}
}
// does not return an error since that error is stored inside of WslConn
func (conn *WslConn) Connect(ctx context.Context) error {
var connectAllowed bool
conn.WithLock(func() {
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
connectAllowed = false
} else {
conn.Status = Status_Connecting
conn.Error = ""
connectAllowed = true
}
})
log.Printf("Connect %s\n", conn.GetName())
if !connectAllowed {
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
}
conn.FireConnChangeEvent()
err := conn.connectInternal(ctx)
conn.WithLock(func() {
if err != nil {
conn.Status = Status_Error
conn.Error = err.Error()
conn.close_nolock()
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
Conn: map[string]int{"wsl:connecterror": 1},
}, "wsl-connconnect")
} else {
conn.Status = Status_Connected
conn.LastConnectTime = time.Now().UnixMilli()
if conn.ActiveConnNum == 0 {
conn.ActiveConnNum = int(activeConnCounter.Add(1))
}
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
Conn: map[string]int{"wsl:connect": 1},
}, "wsl-connconnect")
}
})
conn.FireConnChangeEvent()
return err
}
func (conn *WslConn) WithLock(fn func()) {
conn.Lock.Lock()
defer conn.Lock.Unlock()
fn()
}
func (conn *WslConn) connectInternal(ctx context.Context) error {
client, err := GetDistro(ctx, conn.Name)
if err != nil {
return err
}
conn.WithLock(func() {
conn.Client = client
})
err = conn.OpenDomainSocketListener()
if err != nil {
return err
}
config := wconfig.GetWatcher().GetFullConfig()
wshAsk := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true)
installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !wshAsk})
if installErr != nil {
return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
}
csErr := conn.StartConnServer()
if csErr != nil {
return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
}
conn.HasWaiter.Store(true)
go conn.waitForDisconnect()
return nil
}
func (conn *WslConn) waitForDisconnect() {
defer conn.FireConnChangeEvent()
defer conn.HasWaiter.Store(false)
err := conn.ConnController.Wait()
conn.WithLock(func() {
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
// so we just set the status to "disconnected" here (not error)
// don't overwrite any existing error (or error status)
if err != nil && conn.Error == "" {
conn.Error = err.Error()
}
if conn.Status != Status_Error {
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
}
func getConnInternal(name string) *WslConn {
globalLock.Lock()
defer globalLock.Unlock()
connName := WslName{Distro: name}
rtn := clientControllerMap[name]
if rtn == nil {
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, cancelFn: nil}
clientControllerMap[name] = rtn
}
return rtn
}
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
conn := getConnInternal(name)
if conn.Client == nil && shouldConnect {
conn.Connect(ctx)
}
return conn
}
// Convenience function for ensuring a connection is established
func EnsureConnection(ctx context.Context, connName string) error {
if connName == "" {
return nil
}
conn := GetWslConn(ctx, connName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
connStatus := conn.DeriveConnStatus()
switch connStatus.Status {
case Status_Connected:
return nil
case Status_Connecting:
return conn.WaitForConnect(ctx)
case Status_Init, Status_Disconnected:
return conn.Connect(ctx)
case Status_Error:
return fmt.Errorf("connection error: %s", connStatus.Error)
default:
return fmt.Errorf("unknown connection status %q", connStatus.Status)
}
}
func DisconnectClient(connName string) error {
conn := getConnInternal(connName)
if conn == nil {
return fmt.Errorf("client %q not found", connName)
}
err := conn.Close()
return err
}

226
pkg/wslconn/wsl-util.go Normal file
View File

@ -0,0 +1,226 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wslconn
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
"text/template"
"time"
"github.com/wavetermdev/waveterm/pkg/blocklogger"
"github.com/wavetermdev/waveterm/pkg/genconn"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wsl"
)
func hasBashInstalled(ctx context.Context, client *wsl.Distro) (bool, error) {
cmd := client.WslCommand(ctx, "which bash")
out, whichErr := cmd.Output()
if whichErr == nil && len(out) != 0 {
return true, nil
}
cmd = client.WslCommand(ctx, "where.exe bash")
out, whereErr := cmd.Output()
if whereErr == nil && len(out) != 0 {
return true, nil
}
// note: we could also check in /bin/bash explicitly
// just in case that wasn't added to the path. but if
// that's true, we will most likely have worse
// problems going forward
return false, nil
}
func normalizeOs(os string) string {
os = strings.ToLower(strings.TrimSpace(os))
return os
}
func normalizeArch(arch string) string {
arch = strings.ToLower(strings.TrimSpace(arch))
switch arch {
case "x86_64", "amd64":
arch = "x64"
case "arm64", "aarch64":
arch = "arm64"
}
return arch
}
// returns (os, arch, error)
// guaranteed to return a supported platform
func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string, string, error) {
blocklogger.Infof(ctx, "[conndebug] running `uname -sm` to detect client platform\n")
stdout, stderr, err := genconn.RunSimpleCommand(ctx, shell, genconn.CommandSpec{
Cmd: "uname -sm",
})
if err != nil {
return "", "", fmt.Errorf("error running uname -sm: %w, stderr: %s", err, stderr)
}
// Parse and normalize output
parts := strings.Fields(strings.ToLower(strings.TrimSpace(stdout)))
if len(parts) != 2 {
return "", "", fmt.Errorf("unexpected output from uname: %s", stdout)
}
os, arch := normalizeOs(parts[0]), normalizeArch(parts[1])
if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil {
return "", "", err
}
return os, arch, nil
}
func GetClientPlatformFromOsArchStr(ctx context.Context, osArchStr string) (string, string, error) {
parts := strings.Fields(strings.TrimSpace(osArchStr))
if len(parts) != 2 {
return "", "", fmt.Errorf("unexpected output from uname: %s", osArchStr)
}
os, arch := normalizeOs(parts[0]), normalizeArch(parts[1])
if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil {
return "", "", err
}
return os, arch, nil
}
type CancellableCmd struct {
Cmd *wsl.WslCmd
Cancel func()
}
var installTemplatesRawBash = map[string]string{
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
"cat": `bash -c 'cat > {{.tempPath}}'`,
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
}
var installTemplatesRawDefault = map[string]string{
"mkdir": `mkdir -p {{.installDir}}`,
"cat": `cat > {{.tempPath}}`,
"mv": `mv {{.tempPath}} {{.installPath}}`,
"chmod": `chmod a+x {{.installPath}}`,
}
func makeCancellableCommand(ctx context.Context, client *wsl.Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
cmdContext, cmdCancel := context.WithCancel(ctx)
cmdStr := &bytes.Buffer{}
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
if err != nil {
cmdCancel()
return nil, err
}
cmdTemplate.Execute(cmdStr, words)
cmd := client.WslCommand(cmdContext, cmdStr.String())
return &CancellableCmd{cmd, cmdCancel}, nil
}
func CpWshToRemote(ctx context.Context, client *wsl.Distro, clientOs string, clientArch string) error {
wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
if err != nil {
return err
}
// warning: does not work on windows remote yet
bashInstalled, err := hasBashInstalled(ctx, client)
if err != nil {
return err
}
var selectedTemplatesRaw map[string]string
if bashInstalled {
selectedTemplatesRaw = installTemplatesRawBash
} else {
log.Printf("bash is not installed on remote. attempting with default shell")
selectedTemplatesRaw = installTemplatesRawDefault
}
// I need to use toSlash here to force unix keybindings
// this means we can't guarantee it will work on a remote windows machine
var installWords = map[string]string{
"installDir": filepath.ToSlash(filepath.Dir(wavebase.RemoteFullWshBinPath)),
"tempPath": wavebase.RemoteFullWshBinPath + ".temp",
"installPath": wavebase.RemoteFullWshBinPath,
}
blocklogger.Infof(ctx, "[conndebug] copying %q to remote server %q\n", wshLocalPath, wavebase.RemoteFullWshBinPath)
installStepCmds := make(map[string]*CancellableCmd)
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
if err != nil {
return err
}
installStepCmds[cmdName] = cancellableCmd
}
_, err = installStepCmds["mkdir"].Cmd.Output()
if err != nil {
return err
}
// the cat part of this is complicated since it requires stdin
catCmd := installStepCmds["cat"].Cmd
catStdin, err := catCmd.StdinPipe()
if err != nil {
return err
}
err = catCmd.Start()
if err != nil {
return err
}
input, err := os.Open(wshLocalPath)
if err != nil {
return fmt.Errorf("cannot open local file %s to send to host: %v", wshLocalPath, err)
}
go func() {
defer func() {
panichandler.PanicHandler("wslutil:cpHostToRemote:catStdin", recover())
}()
io.Copy(catStdin, input)
installStepCmds["cat"].Cancel()
// backup just in case something weird happens
// could cause potential race condition, but very
// unlikely
time.Sleep(time.Second * 1)
process := catCmd.GetProcess()
if process != nil {
process.Kill()
}
}()
catErr := catCmd.Wait()
if catErr != nil && !errors.Is(catErr, context.Canceled) {
return catErr
}
_, err = installStepCmds["mv"].Cmd.Output()
if err != nil {
return err
}
_, err = installStepCmds["chmod"].Cmd.Output()
if err != nil {
return err
}
return nil
}
func IsPowershell(shellPath string) bool {
// get the base path, and then check contains
shellBase := filepath.Base(shellPath)
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
}

787
pkg/wslconn/wslconn.go Normal file
View File

@ -0,0 +1,787 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wslconn
import (
"context"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/wavetermdev/waveterm/pkg/blocklogger"
"github.com/wavetermdev/waveterm/pkg/genconn"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/telemetry"
"github.com/wavetermdev/waveterm/pkg/userinput"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
)
const (
Status_Init = "init"
Status_Connecting = "connecting"
Status_Connected = "connected"
Status_Disconnected = "disconnected"
Status_Error = "error"
)
const DefaultConnectionTimeout = 60 * time.Second
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[string]*WslConn)
var activeConnCounter = &atomic.Int32{}
type WslConn struct {
Lock *sync.Mutex
Status string
WshEnabled *atomic.Bool
Name wsl.WslName
Client *wsl.Distro
DomainSockName string // if "", then no domain socket
DomainSockListener net.Listener
ConnController *wsl.WslCmd
Error string
WshError string
NoWshReason string
WshVersion string
HasWaiter *atomic.Bool
LastConnectTime int64
ActiveConnNum int
cancelFn func()
}
var ConnServerCmdTemplate = strings.TrimSpace(
strings.Join([]string{
"%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm);",
"exec %s connserver --router",
}, "\n"))
func GetAllConnStatus() []wshrpc.ConnStatus {
globalLock.Lock()
defer globalLock.Unlock()
var connStatuses []wshrpc.ConnStatus
for _, conn := range clientControllerMap {
connStatuses = append(connStatuses, conn.DeriveConnStatus())
}
return connStatuses
}
func GetNumWSLHasConnected() int {
globalLock.Lock()
defer globalLock.Unlock()
var connectedCount int
for _, conn := range clientControllerMap {
if conn.LastConnectTime > 0 {
connectedCount++
}
}
return connectedCount
}
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return wshrpc.ConnStatus{
Status: conn.Status,
Connected: conn.Status == Status_Connected,
WshEnabled: true, // always use wsh for wsl connections (temporary)
Connection: conn.GetName(),
HasConnected: (conn.LastConnectTime > 0),
ActiveConnNum: conn.ActiveConnNum,
Error: conn.Error,
}
}
func (conn *WslConn) Infof(ctx context.Context, format string, args ...any) {
log.Print(fmt.Sprintf("[conn:%s] ", conn.GetName()) + fmt.Sprintf(format, args...))
blocklogger.Infof(ctx, "[conndebug] "+format, args...)
}
func (conn *WslConn) FireConnChangeEvent() {
status := conn.DeriveConnStatus()
event := wps.WaveEvent{
Event: wps.Event_ConnChange,
Scopes: []string{
fmt.Sprintf("connection:%s", conn.GetName()),
},
Data: status,
}
log.Printf("sending event: %+#v", event)
wps.Broker.Publish(event)
}
func (conn *WslConn) Close() error {
defer conn.FireConnChangeEvent()
conn.WithLock(func() {
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
// if status is init, disconnected, or error don't change it
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
// we must wait for the waiter to complete
startTime := time.Now()
for conn.HasWaiter.Load() {
time.Sleep(10 * time.Millisecond)
if time.Since(startTime) > 2*time.Second {
return fmt.Errorf("timeout waiting for waiter to complete")
}
}
return nil
}
func (conn *WslConn) close_nolock() {
// does not set status (that should happen at another level)
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
conn.DomainSockListener = nil
conn.DomainSockName = ""
}
if conn.ConnController != nil {
conn.cancelFn() // this suspends the conn controller
conn.ConnController = nil
}
if conn.Client != nil {
// conn.Client.Close() is not relevant here
// we do not want to completely close the wsl in case
// other applications are using it
conn.Client = nil
}
}
func (conn *WslConn) GetDomainSocketName() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.DomainSockName
}
func (conn *WslConn) GetStatus() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Status
}
func (conn *WslConn) GetName() string {
// no lock required because opts is immutable
return "wsl://" + conn.Name.Distro
}
/**
* This function is does not set a listener for WslConn
* It is still required in order to set SockName
**/
func (conn *WslConn) OpenDomainSocketListener(ctx context.Context) error {
conn.Infof(ctx, "running OpenDomainSocketListener...\n")
allowed := WithLockRtn(conn, func() bool {
return conn.Status == Status_Connecting
})
if !allowed {
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
}
/*
listener, err := client.ListenUnix(sockName)
if err != nil {
return fmt.Errorf("unable to request connection domain socket: %v", err)
}
*/
conn.Infof(ctx, "setting domain socket to %s\n", wavebase.RemoteFullDomainSocketPath)
conn.WithLock(func() {
conn.DomainSockName = wavebase.RemoteFullDomainSocketPath
//conn.DomainSockListener = listener
})
conn.Infof(ctx, "successfully connected domain socket\n")
/*
go func() {
defer func() {
panichandler.PanicHandler("wslconn:OpenDomainSocketListener", recover())
}()
defer conn.WithLock(func() {
conn.DomainSockListener = nil
conn.DomainSockName = ""
})
wshutil.RunWshRpcOverListener(listener)
}()
*/
return nil
}
func (conn *WslConn) getWshPath() string {
config, ok := conn.getConnectionConfig()
if ok && config.ConnWshPath != "" {
return config.ConnWshPath
}
return wavebase.RemoteFullWshBinPath
}
func (conn *WslConn) GetConfigShellPath() string {
config, ok := conn.getConnectionConfig()
if !ok {
return ""
}
return config.ConnShellPath
}
// returns (needsInstall, clientVersion, osArchStr, error)
// if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr
// if clientVersion is set, then no osArchStr will be returned
func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (bool, string, string, error) {
conn.Infof(ctx, "running StartConnServer...\n")
allowed := WithLockRtn(conn, func() bool {
return conn.Status == Status_Connecting
})
if !allowed {
return false, "", "", fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
}
client := conn.GetClient()
wshPath := conn.getWshPath()
rpcCtx := wshrpc.RpcContext{
ClientType: wshrpc.ClientType_ConnServer,
Conn: conn.GetName(),
}
sockName := conn.GetDomainSocketName()
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
if err != nil {
return false, "", "", fmt.Errorf("unable to create jwt token for conn controller: %w", err)
}
conn.Infof(ctx, "WSL-NEWSESSION (StartConnServer)\n")
connServerCtx, cancelFn := context.WithCancel(context.Background())
conn.WithLock(func() {
if conn.cancelFn != nil {
conn.cancelFn()
}
conn.cancelFn = cancelFn
})
cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath)
shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr))
cmd := client.WslCommand(connServerCtx, shWrappedCmdStr)
pipeRead, pipeWrite := io.Pipe()
inputPipeRead, inputPipeWrite := io.Pipe()
cmd.SetStdout(pipeWrite)
cmd.SetStderr(pipeWrite)
cmd.SetStdin(inputPipeRead)
log.Printf("starting conn controller: %q\n", cmdStr)
blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr)
err = cmd.Start()
if err != nil {
return false, "", "", fmt.Errorf("unable to start conn controller cmd: %w", err)
}
linesChan := utilfn.StreamToLinesChan(pipeRead)
versionLine, err := utilfn.ReadLineWithTimeout(linesChan, 2*time.Second)
if err != nil {
cancelFn()
return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
}
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
isUpToDate, clientVersion, osArchStr, err := conncontroller.IsWshVersionUpToDate(ctx, versionLine)
if err != nil {
cancelFn()
return false, "", "", fmt.Errorf("error checking wsh version: %w", err)
}
if isUpToDate && !afterUpdate && os.Getenv(wavebase.WaveWshForceUpdateVarName) != "" {
isUpToDate = false
conn.Infof(ctx, "%s set, forcing wsh update\n", wavebase.WaveWshForceUpdateVarName)
}
conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
if !isUpToDate {
cancelFn()
return true, clientVersion, osArchStr, nil
}
jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second)
if err != nil {
cancelFn()
return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err)
}
conn.Infof(ctx, "got jwt status line: %s\n", jwtLine)
if strings.TrimSpace(jwtLine) == wavebase.NeedJwtConst {
// write the jwt
conn.Infof(ctx, "writing jwt token to connserver\n")
_, err = fmt.Fprintf(inputPipeWrite, "%s\n", jwtToken)
if err != nil {
cancelFn()
return false, clientVersion, "", fmt.Errorf("failed to write JWT token: %w", err)
}
}
conn.WithLock(func() {
conn.ConnController = cmd
})
// service the I/O
go func() {
defer func() {
panichandler.PanicHandler("wslconn:cmd.Wait", recover())
}()
// wait for termination, clear the controller
var waitErr error
defer conn.WithLock(func() {
if conn.ConnController != nil {
conn.WshEnabled.Store(false)
conn.NoWshReason = "connserver terminated"
if waitErr != nil {
conn.WshError = fmt.Sprintf("connserver terminated unexpectedly with error: %v", waitErr)
}
}
conn.ConnController = nil
})
waitErr = cmd.Wait()
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
}()
go func() {
defer func() {
panichandler.PanicHandler("wsl:StartConnServer:handleStdIOClient", recover())
}()
logName := fmt.Sprintf("wslconn:%s", conn.GetName())
wshutil.HandleStdIOClient(logName, linesChan, inputPipeWrite)
}()
conn.Infof(ctx, "connserver started, waiting for route to be registered\n")
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
if err != nil {
return false, clientVersion, "", fmt.Errorf("timeout waiting for connserver to register")
}
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
conn.Infof(ctx, "connserver is registered and ready\n")
return false, clientVersion, "", nil
}
type WshInstallOpts struct {
Force bool
NoUserPrompt bool
}
var queryTextTemplate = strings.TrimSpace(`
Wave requires Wave Shell Extensions to be
installed on %q
to ensure a seamless experience.
Would you like to install them?
`)
func (conn *WslConn) UpdateWsh(ctx context.Context, clientDisplayName string, remoteInfo *wshrpc.RemoteInfo) error {
conn.Infof(ctx, "attempting to update wsh for connection %s (os:%s arch:%s version:%s)\n",
conn.GetName(), remoteInfo.ClientOs, remoteInfo.ClientArch, remoteInfo.ClientVersion)
client := conn.GetClient()
if client == nil {
return fmt.Errorf("cannot update wsh: ssh client is not connected")
}
err := CpWshToRemote(ctx, client, remoteInfo.ClientOs, remoteInfo.ClientArch)
if err != nil {
return fmt.Errorf("error installing wsh to remote: %w", err)
}
conn.Infof(ctx, "successfully updated wsh on %s\n", conn.GetName())
return nil
}
// returns (allowed, error)
func (conn *WslConn) getPermissionToInstallWsh(ctx context.Context, clientDisplayName string) (bool, error) {
conn.Infof(ctx, "running getPermissionToInstallWsh...\n")
queryText := fmt.Sprintf(queryTextTemplate, clientDisplayName)
title := "Install Wave Shell Extensions"
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
Markdown: true,
CheckBoxMsg: "Automatically install for all connections",
OkLabel: "Install wsh",
CancelLabel: "No wsh",
}
conn.Infof(ctx, "requesting user confirmation...\n")
response, err := userinput.GetUserInput(ctx, request)
if err != nil {
conn.Infof(ctx, "error getting user input: %v\n", err)
return false, err
}
conn.Infof(ctx, "user response to allowing wsh: %v\n", response.Confirm)
meta := make(map[string]any)
meta["conn:wshenabled"] = response.Confirm
conn.Infof(ctx, "writing conn:wshenabled=%v to connections.json\n", response.Confirm)
err = wconfig.SetConnectionsConfigValue(conn.GetName(), meta)
if err != nil {
log.Printf("warning: error writing to connections file: %v", err)
}
if !response.Confirm {
return false, nil
}
if response.CheckboxStat {
conn.Infof(ctx, "writing conn:askbeforewshinstall=false to settings.json\n")
meta := waveobj.MetaMapType{
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
}
setConfigErr := wconfig.SetBaseConfigValue(meta)
if setConfigErr != nil {
// this is not a critical error, just log and continue
log.Printf("warning: error writing to base config file: %v", err)
}
}
return true, nil
}
func (conn *WslConn) InstallWsh(ctx context.Context, osArchStr string) error {
conn.Infof(ctx, "running installWsh...\n")
client := conn.GetClient()
if client == nil {
conn.Infof(ctx, "ERROR ssh client is not connected, cannot install\n")
return fmt.Errorf("ssh client is not connected, cannot install")
}
var clientOs, clientArch string
var err error
if osArchStr != "" {
clientOs, clientArch, err = GetClientPlatformFromOsArchStr(ctx, osArchStr)
} else {
clientOs, clientArch, err = GetClientPlatform(ctx, genconn.MakeWSLShellClient(client))
}
if err != nil {
conn.Infof(ctx, "ERROR detecting client platform: %v\n", err)
}
conn.Infof(ctx, "detected remote platform os:%s arch:%s\n", clientOs, clientArch)
err = CpWshToRemote(ctx, client, clientOs, clientArch)
if err != nil {
conn.Infof(ctx, "ERROR copying wsh binary to remote: %v\n", err)
return fmt.Errorf("error copying wsh binary to remote: %w", err)
}
conn.Infof(ctx, "successfully installed wsh\n")
return nil
}
func (conn *WslConn) GetClient() *wsl.Distro {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Client
}
func (conn *WslConn) Reconnect(ctx context.Context) error {
err := conn.Close()
if err != nil {
return err
}
return conn.Connect(ctx)
}
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
for {
status := conn.DeriveConnStatus()
if status.Status == Status_Connected {
return nil
}
if status.Status == Status_Connecting {
select {
case <-ctx.Done():
return fmt.Errorf("context timeout")
case <-time.After(100 * time.Millisecond):
continue
}
}
if status.Status == Status_Init || status.Status == Status_Disconnected {
return fmt.Errorf("disconnected")
}
if status.Status == Status_Error {
return fmt.Errorf("error: %v", status.Error)
}
return fmt.Errorf("unknown status: %q", status.Status)
}
}
// does not return an error since that error is stored inside of WslConn
func (conn *WslConn) Connect(ctx context.Context) error {
var connectAllowed bool
conn.WithLock(func() {
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
connectAllowed = false
} else {
conn.Status = Status_Connecting
conn.Error = ""
connectAllowed = true
}
})
log.Printf("Connect %s\n", conn.GetName())
if !connectAllowed {
conn.Infof(ctx, "cannot connect to %q when status is %q\n", conn.GetName(), conn.GetStatus())
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
}
conn.FireConnChangeEvent()
err := conn.connectInternal(ctx)
conn.WithLock(func() {
if err != nil {
conn.Infof(ctx, "ERROR %v\n\n", err)
conn.Status = Status_Error
conn.Error = err.Error()
conn.close_nolock()
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
Conn: map[string]int{"wsl:connecterror": 1},
}, "wsl-connconnect")
} else {
conn.Infof(ctx, "successfully connected (wsh:%v)\n\n", conn.WshEnabled.Load())
conn.Status = Status_Connected
conn.LastConnectTime = time.Now().UnixMilli()
if conn.ActiveConnNum == 0 {
conn.ActiveConnNum = int(activeConnCounter.Add(1))
}
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
Conn: map[string]int{"wsl:connect": 1},
}, "wsl-connconnect")
}
})
conn.FireConnChangeEvent()
return err
}
func (conn *WslConn) WithLock(fn func()) {
conn.Lock.Lock()
defer conn.Lock.Unlock()
fn()
}
func WithLockRtn[T any](conn *WslConn, fn func() T) T {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return fn()
}
// returns (enable-wsh, ask-before-install)
func (conn *WslConn) getConnWshSettings() (bool, bool) {
config := wconfig.GetWatcher().GetFullConfig()
enableWsh := config.Settings.ConnWshEnabled
askBeforeInstall := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true)
connSettings, ok := conn.getConnectionConfig()
if ok {
if connSettings.ConnWshEnabled != nil {
enableWsh = *connSettings.ConnWshEnabled
}
// if the connection object exists, and conn:askbeforewshinstall is not set, the user must have allowed it
// TODO: in v0.12+ this should be removed. we'll explicitly write a "false" into the connection object on successful connection
if connSettings.ConnAskBeforeWshInstall == nil {
askBeforeInstall = false
} else {
askBeforeInstall = *connSettings.ConnAskBeforeWshInstall
}
}
return enableWsh, askBeforeInstall
}
type WshCheckResult struct {
WshEnabled bool
ClientVersion string
NoWshReason string
WshError error
}
// returns (wsh-enabled, clientVersion, text-reason, wshError)
func (conn *WslConn) tryEnableWsh(ctx context.Context, clientDisplayName string) WshCheckResult {
conn.Infof(ctx, "running tryEnableWsh...\n")
enableWsh, askBeforeInstall := conn.getConnWshSettings()
conn.Infof(ctx, "wsh settings enable:%v ask:%v\n", enableWsh, askBeforeInstall)
if !enableWsh {
return WshCheckResult{NoWshReason: "conn:wshenabled set to false"}
}
if askBeforeInstall {
allowInstall, err := conn.getPermissionToInstallWsh(ctx, clientDisplayName)
if err != nil {
log.Printf("error getting permission to install wsh: %v\n", err)
return WshCheckResult{NoWshReason: "error getting user permission to install", WshError: err}
}
if !allowInstall {
return WshCheckResult{NoWshReason: "user selected not to install wsh extensions"}
}
}
err := conn.OpenDomainSocketListener(ctx)
if err != nil {
conn.Infof(ctx, "ERROR opening domain socket listener: %v\n", err)
err = fmt.Errorf("error opening domain socket listener: %w", err)
return WshCheckResult{NoWshReason: "error opening domain socket", WshError: err}
}
needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false)
if err != nil {
conn.Infof(ctx, "ERROR starting conn server: %v\n", err)
err = fmt.Errorf("error starting conn server: %w", err)
return WshCheckResult{NoWshReason: "error starting connserver", WshError: err}
}
if needsInstall {
conn.Infof(ctx, "connserver needs to be (re)installed\n")
err = conn.InstallWsh(ctx, osArchStr)
if err != nil {
conn.Infof(ctx, "ERROR installing wsh: %v\n", err)
err = fmt.Errorf("error installing wsh: %w", err)
return WshCheckResult{NoWshReason: "error installing wsh/connserver", WshError: err}
}
needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true)
if err != nil {
conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err)
err = fmt.Errorf("error starting conn server (after install): %w", err)
return WshCheckResult{NoWshReason: "error starting connserver", WshError: err}
}
if needsInstall {
conn.Infof(ctx, "conn server not installed correctly (after install)\n")
err = fmt.Errorf("conn server not installed correctly (after install)")
return WshCheckResult{NoWshReason: "connserver not installed properly", WshError: err}
}
return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion}
} else {
return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion}
}
}
func (conn *WslConn) getConnectionConfig() (wshrpc.ConnKeywords, bool) {
config := wconfig.GetWatcher().GetFullConfig()
connSettings, ok := config.Connections[conn.GetName()]
if !ok {
return wshrpc.ConnKeywords{}, false
}
return connSettings, true
}
func (conn *WslConn) persistWshInstalled(ctx context.Context, result WshCheckResult) {
conn.WshEnabled.Store(result.WshEnabled)
conn.SetWshError(result.WshError)
conn.WithLock(func() {
conn.NoWshReason = result.NoWshReason
conn.WshVersion = result.ClientVersion
})
connConfig, ok := conn.getConnectionConfig()
if ok && connConfig.ConnWshEnabled != nil {
return
}
meta := make(map[string]any)
meta["conn:wshenabled"] = result.WshEnabled
err := wconfig.SetConnectionsConfigValue(conn.GetName(), meta)
if err != nil {
conn.Infof(ctx, "WARN could not write conn:wshenabled=%v to connections.json: %v\n", result.WshEnabled, err)
log.Printf("warning: error writing to connections file: %v", err)
}
// doesn't return an error since none of this is required for connection to work
}
func (conn *WslConn) connectInternal(ctx context.Context) error {
conn.Infof(ctx, "connectInternal %s\n", conn.GetName())
client, err := wsl.GetDistro(ctx, conn.Name)
if err != nil {
conn.Infof(ctx, "ERROR GetDistro: %s\n", err)
log.Printf("error: failed to get distro %s: %s\n", conn.GetName(), err)
return err
}
conn.WithLock(func() {
conn.Client = client
})
go func() {
defer func() {
panichandler.PanicHandler("wsl-waitForDisconnect", recover())
}()
conn.waitForDisconnect()
}()
wshResult := conn.tryEnableWsh(ctx, conn.GetName())
if !wshResult.WshEnabled {
if wshResult.WshError != nil {
conn.Infof(ctx, "ERROR enabling wsh: %v\n", wshResult.WshError)
conn.Infof(ctx, "will connect with wsh disabled\n")
} else {
conn.Infof(ctx, "wsh not enabled: %s\n", wshResult.NoWshReason)
}
}
conn.persistWshInstalled(ctx, wshResult)
return nil
}
func (conn *WslConn) waitForDisconnect() {
log.Printf("wait for disconnect in %+#v", conn)
defer conn.FireConnChangeEvent()
defer conn.HasWaiter.Store(false)
err := conn.ConnController.Wait()
conn.WithLock(func() {
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
// so we just set the status to "disconnected" here (not error)
// don't overwrite any existing error (or error status)
if err != nil && conn.Error == "" {
conn.Error = err.Error()
}
if conn.Status != Status_Error {
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
}
func (conn *WslConn) SetWshError(err error) {
conn.WithLock(func() {
if err == nil {
conn.WshError = ""
} else {
conn.WshError = err.Error()
}
})
}
func (conn *WslConn) ClearWshError() {
conn.WithLock(func() {
conn.WshError = ""
})
}
func getConnInternal(name string) *WslConn {
globalLock.Lock()
defer globalLock.Unlock()
connName := wsl.WslName{Distro: name}
rtn := clientControllerMap[name]
if rtn == nil {
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, WshEnabled: &atomic.Bool{}, HasWaiter: &atomic.Bool{}, cancelFn: nil}
clientControllerMap[name] = rtn
}
return rtn
}
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
conn := getConnInternal(name)
if conn.Client == nil && shouldConnect {
conn.Connect(ctx)
}
return conn
}
// Convenience function for ensuring a connection is established
func EnsureConnection(ctx context.Context, connName string) error {
if connName == "" {
return nil
}
conn := GetWslConn(ctx, connName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
connStatus := conn.DeriveConnStatus()
switch connStatus.Status {
case Status_Connected:
return nil
case Status_Connecting:
return conn.WaitForConnect(ctx)
case Status_Init, Status_Disconnected:
return conn.Connect(ctx)
case Status_Error:
return fmt.Errorf("connection error: %s", connStatus.Error)
default:
return fmt.Errorf("unknown connection status %q", connStatus.Status)
}
}
func DisconnectClient(connName string) error {
conn := getConnInternal(connName)
if conn == nil {
return fmt.Errorf("client %q not found", connName)
}
err := conn.Close()
return err
}