waveterm/pkg/wslconn/wslconn.go
Sylvie Crowe ad3113e46e
Wshless Connections in Wsl (#1769)
This makes it possible to disable wsh for WSL connections. While this is
not recommended, it brings the code closer to the SSH connection
implementation and will make it easier to consolidate the two in the
future.
2025-01-17 18:28:46 -08:00

794 lines
24 KiB
Go

// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wslconn
import (
"context"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/wavetermdev/waveterm/pkg/blocklogger"
"github.com/wavetermdev/waveterm/pkg/genconn"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/telemetry"
"github.com/wavetermdev/waveterm/pkg/userinput"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
)
const (
Status_Init = "init"
Status_Connecting = "connecting"
Status_Connected = "connected"
Status_Disconnected = "disconnected"
Status_Error = "error"
)
const DefaultConnectionTimeout = 60 * time.Second
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[string]*WslConn)
var activeConnCounter = &atomic.Int32{}
type WslConn struct {
Lock *sync.Mutex
Status string
WshEnabled *atomic.Bool
Name wsl.WslName
Client *wsl.Distro
DomainSockName string // if "", then no domain socket
DomainSockListener net.Listener
ConnController *wsl.WslCmd
Error string
WshError string
NoWshReason string
WshVersion string
HasWaiter *atomic.Bool
LastConnectTime int64
ActiveConnNum int
cancelFn func()
}
var ConnServerCmdTemplate = strings.TrimSpace(
strings.Join([]string{
"%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm);",
"exec %s connserver --router",
}, "\n"))
func GetAllConnStatus() []wshrpc.ConnStatus {
globalLock.Lock()
defer globalLock.Unlock()
var connStatuses []wshrpc.ConnStatus
for _, conn := range clientControllerMap {
connStatuses = append(connStatuses, conn.DeriveConnStatus())
}
return connStatuses
}
func GetNumWSLHasConnected() int {
globalLock.Lock()
defer globalLock.Unlock()
var connectedCount int
for _, conn := range clientControllerMap {
if conn.LastConnectTime > 0 {
connectedCount++
}
}
return connectedCount
}
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return wshrpc.ConnStatus{
Status: conn.Status,
Connected: conn.Status == Status_Connected,
WshEnabled: conn.WshEnabled.Load(),
Connection: conn.GetName(),
HasConnected: (conn.LastConnectTime > 0),
ActiveConnNum: conn.ActiveConnNum,
Error: conn.Error,
WshError: conn.WshError,
NoWshReason: conn.NoWshReason,
WshVersion: conn.WshVersion,
}
}
func (conn *WslConn) Infof(ctx context.Context, format string, args ...any) {
log.Print(fmt.Sprintf("[conn:%s] ", conn.GetName()) + fmt.Sprintf(format, args...))
blocklogger.Infof(ctx, "[conndebug] "+format, args...)
}
func (conn *WslConn) FireConnChangeEvent() {
status := conn.DeriveConnStatus()
event := wps.WaveEvent{
Event: wps.Event_ConnChange,
Scopes: []string{
fmt.Sprintf("connection:%s", conn.GetName()),
},
Data: status,
}
log.Printf("sending event: %+#v", event)
wps.Broker.Publish(event)
}
func (conn *WslConn) Close() error {
defer conn.FireConnChangeEvent()
conn.WithLock(func() {
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
// if status is init, disconnected, or error don't change it
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
// we must wait for the waiter to complete
startTime := time.Now()
for conn.HasWaiter.Load() {
time.Sleep(10 * time.Millisecond)
if time.Since(startTime) > 2*time.Second {
return fmt.Errorf("timeout waiting for waiter to complete")
}
}
return nil
}
func (conn *WslConn) close_nolock() {
// does not set status (that should happen at another level)
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
conn.DomainSockListener = nil
conn.DomainSockName = ""
}
if conn.ConnController != nil {
conn.cancelFn() // this suspends the conn controller
conn.ConnController = nil
}
if conn.Client != nil {
// conn.Client.Close() is not relevant here
// we do not want to completely close the wsl in case
// other applications are using it
conn.Client = nil
}
}
func (conn *WslConn) GetDomainSocketName() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.DomainSockName
}
func (conn *WslConn) GetStatus() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Status
}
func (conn *WslConn) GetName() string {
// no lock required because opts is immutable
return "wsl://" + conn.Name.Distro
}
/**
* This function is does not set a listener for WslConn
* It is still required in order to set SockName
**/
func (conn *WslConn) OpenDomainSocketListener(ctx context.Context) error {
conn.Infof(ctx, "running OpenDomainSocketListener...\n")
allowed := WithLockRtn(conn, func() bool {
return conn.Status == Status_Connecting
})
if !allowed {
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
}
/*
listener, err := client.ListenUnix(sockName)
if err != nil {
return fmt.Errorf("unable to request connection domain socket: %v", err)
}
*/
conn.Infof(ctx, "setting domain socket to %s\n", wavebase.RemoteFullDomainSocketPath)
conn.WithLock(func() {
conn.DomainSockName = wavebase.RemoteFullDomainSocketPath
//conn.DomainSockListener = listener
})
conn.Infof(ctx, "successfully connected domain socket\n")
/*
go func() {
defer func() {
panichandler.PanicHandler("wslconn:OpenDomainSocketListener", recover())
}()
defer conn.WithLock(func() {
conn.DomainSockListener = nil
conn.DomainSockName = ""
})
wshutil.RunWshRpcOverListener(listener)
}()
*/
return nil
}
func (conn *WslConn) getWshPath() string {
config, ok := conn.getConnectionConfig()
if ok && config.ConnWshPath != "" {
return config.ConnWshPath
}
return wavebase.RemoteFullWshBinPath
}
func (conn *WslConn) GetConfigShellPath() string {
config, ok := conn.getConnectionConfig()
if !ok {
return ""
}
return config.ConnShellPath
}
// returns (needsInstall, clientVersion, osArchStr, error)
// if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr
// if clientVersion is set, then no osArchStr will be returned
func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (bool, string, string, error) {
conn.Infof(ctx, "running StartConnServer...\n")
allowed := WithLockRtn(conn, func() bool {
return conn.Status == Status_Connecting
})
if !allowed {
return false, "", "", fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
}
client := conn.GetClient()
wshPath := conn.getWshPath()
rpcCtx := wshrpc.RpcContext{
ClientType: wshrpc.ClientType_ConnServer,
Conn: conn.GetName(),
}
sockName := conn.GetDomainSocketName()
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
if err != nil {
return false, "", "", fmt.Errorf("unable to create jwt token for conn controller: %w", err)
}
conn.Infof(ctx, "WSL-NEWSESSION (StartConnServer)\n")
connServerCtx, cancelFn := context.WithCancel(context.Background())
conn.WithLock(func() {
if conn.cancelFn != nil {
conn.cancelFn()
}
conn.cancelFn = cancelFn
})
cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath)
shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr))
cmd := client.WslCommand(connServerCtx, shWrappedCmdStr)
pipeRead, pipeWrite := io.Pipe()
inputPipeRead, inputPipeWrite := io.Pipe()
cmd.SetStdout(pipeWrite)
cmd.SetStderr(pipeWrite)
cmd.SetStdin(inputPipeRead)
log.Printf("starting conn controller: %q\n", cmdStr)
blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr)
err = cmd.Start()
if err != nil {
return false, "", "", fmt.Errorf("unable to start conn controller cmd: %w", err)
}
linesChan := utilfn.StreamToLinesChan(pipeRead)
versionLine, err := utilfn.ReadLineWithTimeout(linesChan, 2*time.Second)
if err != nil {
cancelFn()
return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
}
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
isUpToDate, clientVersion, osArchStr, err := conncontroller.IsWshVersionUpToDate(ctx, versionLine)
if err != nil {
cancelFn()
return false, "", "", fmt.Errorf("error checking wsh version: %w", err)
}
if isUpToDate && !afterUpdate && os.Getenv(wavebase.WaveWshForceUpdateVarName) != "" {
isUpToDate = false
conn.Infof(ctx, "%s set, forcing wsh update\n", wavebase.WaveWshForceUpdateVarName)
}
conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
if !isUpToDate {
cancelFn()
return true, clientVersion, osArchStr, nil
}
jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second)
if err != nil {
cancelFn()
return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err)
}
conn.Infof(ctx, "got jwt status line: %s\n", jwtLine)
if strings.TrimSpace(jwtLine) == wavebase.NeedJwtConst {
// write the jwt
conn.Infof(ctx, "writing jwt token to connserver\n")
_, err = fmt.Fprintf(inputPipeWrite, "%s\n", jwtToken)
if err != nil {
cancelFn()
return false, clientVersion, "", fmt.Errorf("failed to write JWT token: %w", err)
}
}
conn.WithLock(func() {
conn.ConnController = cmd
})
// service the I/O
go func() {
defer func() {
panichandler.PanicHandler("wslconn:cmd.Wait", recover())
}()
// wait for termination, clear the controller
var waitErr error
defer conn.WithLock(func() {
if conn.ConnController != nil {
conn.WshEnabled.Store(false)
conn.NoWshReason = "connserver terminated"
if waitErr != nil {
conn.WshError = fmt.Sprintf("connserver terminated unexpectedly with error: %v", waitErr)
}
}
conn.ConnController = nil
})
waitErr = cmd.Wait()
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
}()
go func() {
defer func() {
panichandler.PanicHandler("wsl:StartConnServer:handleStdIOClient", recover())
}()
logName := fmt.Sprintf("wslconn:%s", conn.GetName())
wshutil.HandleStdIOClient(logName, linesChan, inputPipeWrite)
}()
conn.Infof(ctx, "connserver started, waiting for route to be registered\n")
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
if err != nil {
return false, clientVersion, "", fmt.Errorf("timeout waiting for connserver to register")
}
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
conn.Infof(ctx, "connserver is registered and ready\n")
return false, clientVersion, "", nil
}
type WshInstallOpts struct {
Force bool
NoUserPrompt bool
}
var queryTextTemplate = strings.TrimSpace(`
Wave requires Wave Shell Extensions to be
installed on %q
to ensure a seamless experience.
Would you like to install them?
`)
func (conn *WslConn) UpdateWsh(ctx context.Context, clientDisplayName string, remoteInfo *wshrpc.RemoteInfo) error {
conn.Infof(ctx, "attempting to update wsh for connection %s (os:%s arch:%s version:%s)\n",
conn.GetName(), remoteInfo.ClientOs, remoteInfo.ClientArch, remoteInfo.ClientVersion)
client := conn.GetClient()
if client == nil {
return fmt.Errorf("cannot update wsh: ssh client is not connected")
}
err := CpWshToRemote(ctx, client, remoteInfo.ClientOs, remoteInfo.ClientArch)
if err != nil {
return fmt.Errorf("error installing wsh to remote: %w", err)
}
conn.Infof(ctx, "successfully updated wsh on %s\n", conn.GetName())
return nil
}
// returns (allowed, error)
func (conn *WslConn) getPermissionToInstallWsh(ctx context.Context, clientDisplayName string) (bool, error) {
conn.Infof(ctx, "running getPermissionToInstallWsh...\n")
queryText := fmt.Sprintf(queryTextTemplate, clientDisplayName)
title := "Install Wave Shell Extensions"
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
Markdown: true,
CheckBoxMsg: "Automatically install for all connections",
OkLabel: "Install wsh",
CancelLabel: "No wsh",
}
conn.Infof(ctx, "requesting user confirmation...\n")
response, err := userinput.GetUserInput(ctx, request)
if err != nil {
conn.Infof(ctx, "error getting user input: %v\n", err)
return false, err
}
conn.Infof(ctx, "user response to allowing wsh: %v\n", response.Confirm)
meta := make(map[string]any)
meta["conn:wshenabled"] = response.Confirm
conn.Infof(ctx, "writing conn:wshenabled=%v to connections.json\n", response.Confirm)
err = wconfig.SetConnectionsConfigValue(conn.GetName(), meta)
if err != nil {
log.Printf("warning: error writing to connections file: %v", err)
}
if !response.Confirm {
return false, nil
}
if response.CheckboxStat {
conn.Infof(ctx, "writing conn:askbeforewshinstall=false to settings.json\n")
meta := waveobj.MetaMapType{
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
}
setConfigErr := wconfig.SetBaseConfigValue(meta)
if setConfigErr != nil {
// this is not a critical error, just log and continue
log.Printf("warning: error writing to base config file: %v", err)
}
}
return true, nil
}
func (conn *WslConn) InstallWsh(ctx context.Context, osArchStr string) error {
conn.Infof(ctx, "running installWsh...\n")
client := conn.GetClient()
if client == nil {
conn.Infof(ctx, "ERROR ssh client is not connected, cannot install\n")
return fmt.Errorf("ssh client is not connected, cannot install")
}
var clientOs, clientArch string
var err error
if osArchStr != "" {
clientOs, clientArch, err = GetClientPlatformFromOsArchStr(ctx, osArchStr)
} else {
clientOs, clientArch, err = GetClientPlatform(ctx, genconn.MakeWSLShellClient(client))
}
if err != nil {
conn.Infof(ctx, "ERROR detecting client platform: %v\n", err)
}
conn.Infof(ctx, "detected remote platform os:%s arch:%s\n", clientOs, clientArch)
err = CpWshToRemote(ctx, client, clientOs, clientArch)
if err != nil {
conn.Infof(ctx, "ERROR copying wsh binary to remote: %v\n", err)
return fmt.Errorf("error copying wsh binary to remote: %w", err)
}
conn.Infof(ctx, "successfully installed wsh\n")
return nil
}
func (conn *WslConn) GetClient() *wsl.Distro {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Client
}
func (conn *WslConn) Reconnect(ctx context.Context) error {
err := conn.Close()
if err != nil {
return err
}
return conn.Connect(ctx)
}
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
for {
status := conn.DeriveConnStatus()
if status.Status == Status_Connected {
return nil
}
if status.Status == Status_Connecting {
select {
case <-ctx.Done():
return fmt.Errorf("context timeout")
case <-time.After(100 * time.Millisecond):
continue
}
}
if status.Status == Status_Init || status.Status == Status_Disconnected {
return fmt.Errorf("disconnected")
}
if status.Status == Status_Error {
return fmt.Errorf("error: %v", status.Error)
}
return fmt.Errorf("unknown status: %q", status.Status)
}
}
// does not return an error since that error is stored inside of WslConn
func (conn *WslConn) Connect(ctx context.Context) error {
var connectAllowed bool
conn.WithLock(func() {
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
connectAllowed = false
} else {
conn.Status = Status_Connecting
conn.Error = ""
connectAllowed = true
}
})
log.Printf("Connect %s\n", conn.GetName())
if !connectAllowed {
conn.Infof(ctx, "cannot connect to %q when status is %q\n", conn.GetName(), conn.GetStatus())
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
}
conn.FireConnChangeEvent()
err := conn.connectInternal(ctx)
conn.WithLock(func() {
if err != nil {
conn.Infof(ctx, "ERROR %v\n\n", err)
conn.Status = Status_Error
conn.Error = err.Error()
conn.close_nolock()
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
Conn: map[string]int{"wsl:connecterror": 1},
}, "wsl-connconnect")
} else {
conn.Infof(ctx, "successfully connected (wsh:%v)\n\n", conn.WshEnabled.Load())
conn.Status = Status_Connected
conn.LastConnectTime = time.Now().UnixMilli()
if conn.ActiveConnNum == 0 {
conn.ActiveConnNum = int(activeConnCounter.Add(1))
}
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{
Conn: map[string]int{"wsl:connect": 1},
}, "wsl-connconnect")
}
})
conn.FireConnChangeEvent()
return err
}
func (conn *WslConn) WithLock(fn func()) {
conn.Lock.Lock()
defer conn.Lock.Unlock()
fn()
}
func WithLockRtn[T any](conn *WslConn, fn func() T) T {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return fn()
}
// returns (enable-wsh, ask-before-install)
func (conn *WslConn) getConnWshSettings() (bool, bool) {
config := wconfig.GetWatcher().GetFullConfig()
enableWsh := config.Settings.ConnWshEnabled
askBeforeInstall := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true)
connSettings, ok := conn.getConnectionConfig()
if ok {
if connSettings.ConnWshEnabled != nil {
enableWsh = *connSettings.ConnWshEnabled
}
// if the connection object exists, and conn:askbeforewshinstall is not set, the user must have allowed it
// TODO: in v0.12+ this should be removed. we'll explicitly write a "false" into the connection object on successful connection
if connSettings.ConnAskBeforeWshInstall == nil {
askBeforeInstall = false
} else {
askBeforeInstall = *connSettings.ConnAskBeforeWshInstall
}
}
return enableWsh, askBeforeInstall
}
type WshCheckResult struct {
WshEnabled bool
ClientVersion string
NoWshReason string
WshError error
}
// returns (wsh-enabled, clientVersion, text-reason, wshError)
func (conn *WslConn) tryEnableWsh(ctx context.Context, clientDisplayName string) WshCheckResult {
conn.Infof(ctx, "running tryEnableWsh...\n")
enableWsh, askBeforeInstall := conn.getConnWshSettings()
conn.Infof(ctx, "wsh settings enable:%v ask:%v\n", enableWsh, askBeforeInstall)
if !enableWsh {
return WshCheckResult{NoWshReason: "conn:wshenabled set to false"}
}
if askBeforeInstall {
allowInstall, err := conn.getPermissionToInstallWsh(ctx, clientDisplayName)
if err != nil {
log.Printf("error getting permission to install wsh: %v\n", err)
return WshCheckResult{NoWshReason: "error getting user permission to install", WshError: err}
}
if !allowInstall {
return WshCheckResult{NoWshReason: "user selected not to install wsh extensions"}
}
}
err := conn.OpenDomainSocketListener(ctx)
if err != nil {
conn.Infof(ctx, "ERROR opening domain socket listener: %v\n", err)
err = fmt.Errorf("error opening domain socket listener: %w", err)
return WshCheckResult{NoWshReason: "error opening domain socket", WshError: err}
}
needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false)
if err != nil {
conn.Infof(ctx, "ERROR starting conn server: %v\n", err)
err = fmt.Errorf("error starting conn server: %w", err)
return WshCheckResult{NoWshReason: "error starting connserver", WshError: err}
}
if needsInstall {
conn.Infof(ctx, "connserver needs to be (re)installed\n")
err = conn.InstallWsh(ctx, osArchStr)
if err != nil {
conn.Infof(ctx, "ERROR installing wsh: %v\n", err)
err = fmt.Errorf("error installing wsh: %w", err)
return WshCheckResult{NoWshReason: "error installing wsh/connserver", WshError: err}
}
needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true)
if err != nil {
conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err)
err = fmt.Errorf("error starting conn server (after install): %w", err)
return WshCheckResult{NoWshReason: "error starting connserver", WshError: err}
}
if needsInstall {
conn.Infof(ctx, "conn server not installed correctly (after install)\n")
err = fmt.Errorf("conn server not installed correctly (after install)")
return WshCheckResult{NoWshReason: "connserver not installed properly", WshError: err}
}
return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion}
} else {
return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion}
}
}
func (conn *WslConn) getConnectionConfig() (wshrpc.ConnKeywords, bool) {
config := wconfig.GetWatcher().GetFullConfig()
connSettings, ok := config.Connections[conn.GetName()]
if !ok {
return wshrpc.ConnKeywords{}, false
}
return connSettings, true
}
func (conn *WslConn) persistWshInstalled(ctx context.Context, result WshCheckResult) {
conn.WshEnabled.Store(result.WshEnabled)
conn.SetWshError(result.WshError)
conn.WithLock(func() {
conn.NoWshReason = result.NoWshReason
conn.WshVersion = result.ClientVersion
})
connConfig, ok := conn.getConnectionConfig()
if ok && connConfig.ConnWshEnabled != nil {
return
}
meta := make(map[string]any)
meta["conn:wshenabled"] = result.WshEnabled
err := wconfig.SetConnectionsConfigValue(conn.GetName(), meta)
if err != nil {
conn.Infof(ctx, "WARN could not write conn:wshenabled=%v to connections.json: %v\n", result.WshEnabled, err)
log.Printf("warning: error writing to connections file: %v", err)
}
// doesn't return an error since none of this is required for connection to work
}
func (conn *WslConn) connectInternal(ctx context.Context) error {
conn.Infof(ctx, "connectInternal %s\n", conn.GetName())
client, err := wsl.GetDistro(ctx, conn.Name)
if err != nil {
conn.Infof(ctx, "ERROR GetDistro: %s\n", err)
log.Printf("error: failed to get distro %s: %s\n", conn.GetName(), err)
return err
}
conn.WithLock(func() {
conn.Client = client
})
go func() {
defer func() {
panichandler.PanicHandler("wsl-waitForDisconnect", recover())
}()
conn.waitForDisconnect()
}()
wshResult := conn.tryEnableWsh(ctx, conn.GetName())
if !wshResult.WshEnabled {
if wshResult.WshError != nil {
conn.Infof(ctx, "ERROR enabling wsh: %v\n", wshResult.WshError)
conn.Infof(ctx, "will connect with wsh disabled\n")
} else {
conn.Infof(ctx, "wsh not enabled: %s\n", wshResult.NoWshReason)
}
}
conn.persistWshInstalled(ctx, wshResult)
return nil
}
func (conn *WslConn) waitForDisconnect() {
log.Printf("wait for disconnect in %+#v", conn)
defer conn.FireConnChangeEvent()
defer conn.HasWaiter.Store(false)
if conn.ConnController == nil {
return
}
err := conn.ConnController.Wait()
conn.WithLock(func() {
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
// so we just set the status to "disconnected" here (not error)
// don't overwrite any existing error (or error status)
if err != nil && conn.Error == "" {
conn.Error = err.Error()
}
if conn.Status != Status_Error {
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
}
func (conn *WslConn) SetWshError(err error) {
conn.WithLock(func() {
if err == nil {
conn.WshError = ""
} else {
conn.WshError = err.Error()
}
})
}
func (conn *WslConn) ClearWshError() {
conn.WithLock(func() {
conn.WshError = ""
})
}
func getConnInternal(name string) *WslConn {
globalLock.Lock()
defer globalLock.Unlock()
connName := wsl.WslName{Distro: name}
rtn := clientControllerMap[name]
if rtn == nil {
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, WshEnabled: &atomic.Bool{}, HasWaiter: &atomic.Bool{}, cancelFn: nil}
clientControllerMap[name] = rtn
}
return rtn
}
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
conn := getConnInternal(name)
if conn.Client == nil && shouldConnect {
conn.Connect(ctx)
}
return conn
}
// Convenience function for ensuring a connection is established
func EnsureConnection(ctx context.Context, connName string) error {
if connName == "" {
return nil
}
conn := GetWslConn(ctx, connName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
connStatus := conn.DeriveConnStatus()
switch connStatus.Status {
case Status_Connected:
return nil
case Status_Connecting:
return conn.WaitForConnect(ctx)
case Status_Init, Status_Disconnected:
return conn.Connect(ctx)
case Status_Error:
return fmt.Errorf("connection error: %s", connStatus.Error)
default:
return fmt.Errorf("unknown connection status %q", connStatus.Status)
}
}
func DisconnectClient(connName string) error {
conn := getConnInternal(connName)
if conn == nil {
return fmt.Errorf("client %q not found", connName)
}
err := conn.Close()
return err
}