// Copyright 2025, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 package wslconn import ( "context" "fmt" "io" "log" "net" "os" "strings" "sync" "sync/atomic" "time" "github.com/wavetermdev/waveterm/pkg/blocklogger" "github.com/wavetermdev/waveterm/pkg/genconn" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/userinput" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wconfig" "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" "github.com/wavetermdev/waveterm/pkg/wsl" ) const ( Status_Init = "init" Status_Connecting = "connecting" Status_Connected = "connected" Status_Disconnected = "disconnected" Status_Error = "error" ) const DefaultConnectionTimeout = 60 * time.Second var globalLock = &sync.Mutex{} var clientControllerMap = make(map[string]*WslConn) var activeConnCounter = &atomic.Int32{} type WslConn struct { Lock *sync.Mutex Status string WshEnabled *atomic.Bool Name wsl.WslName Client *wsl.Distro DomainSockName string // if "", then no domain socket DomainSockListener net.Listener ConnController *wsl.WslCmd Error string WshError string NoWshReason string WshVersion string HasWaiter *atomic.Bool LastConnectTime int64 ActiveConnNum int cancelFn func() } var ConnServerCmdTemplate = strings.TrimSpace( strings.Join([]string{ "%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm);", "exec %s connserver --router", }, "\n")) func GetAllConnStatus() []wshrpc.ConnStatus { globalLock.Lock() defer globalLock.Unlock() var connStatuses []wshrpc.ConnStatus for _, conn := range clientControllerMap { connStatuses = append(connStatuses, conn.DeriveConnStatus()) } return connStatuses } func GetNumWSLHasConnected() int { globalLock.Lock() defer globalLock.Unlock() var connectedCount int for _, conn := range clientControllerMap { if conn.LastConnectTime > 0 { connectedCount++ } } return connectedCount } func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus { conn.Lock.Lock() defer conn.Lock.Unlock() return wshrpc.ConnStatus{ Status: conn.Status, Connected: conn.Status == Status_Connected, WshEnabled: conn.WshEnabled.Load(), Connection: conn.GetName(), HasConnected: (conn.LastConnectTime > 0), ActiveConnNum: conn.ActiveConnNum, Error: conn.Error, WshError: conn.WshError, NoWshReason: conn.NoWshReason, WshVersion: conn.WshVersion, } } func (conn *WslConn) Infof(ctx context.Context, format string, args ...any) { log.Print(fmt.Sprintf("[conn:%s] ", conn.GetName()) + fmt.Sprintf(format, args...)) blocklogger.Infof(ctx, "[conndebug] "+format, args...) } func (conn *WslConn) FireConnChangeEvent() { status := conn.DeriveConnStatus() event := wps.WaveEvent{ Event: wps.Event_ConnChange, Scopes: []string{ fmt.Sprintf("connection:%s", conn.GetName()), }, Data: status, } log.Printf("sending event: %+#v", event) wps.Broker.Publish(event) } func (conn *WslConn) Close() error { defer conn.FireConnChangeEvent() conn.WithLock(func() { if conn.Status == Status_Connected || conn.Status == Status_Connecting { // if status is init, disconnected, or error don't change it conn.Status = Status_Disconnected } conn.close_nolock() }) // we must wait for the waiter to complete startTime := time.Now() for conn.HasWaiter.Load() { time.Sleep(10 * time.Millisecond) if time.Since(startTime) > 2*time.Second { return fmt.Errorf("timeout waiting for waiter to complete") } } return nil } func (conn *WslConn) close_nolock() { // does not set status (that should happen at another level) if conn.DomainSockListener != nil { conn.DomainSockListener.Close() conn.DomainSockListener = nil conn.DomainSockName = "" } if conn.ConnController != nil { conn.cancelFn() // this suspends the conn controller conn.ConnController = nil } if conn.Client != nil { // conn.Client.Close() is not relevant here // we do not want to completely close the wsl in case // other applications are using it conn.Client = nil } } func (conn *WslConn) GetDomainSocketName() string { conn.Lock.Lock() defer conn.Lock.Unlock() return conn.DomainSockName } func (conn *WslConn) GetStatus() string { conn.Lock.Lock() defer conn.Lock.Unlock() return conn.Status } func (conn *WslConn) GetName() string { // no lock required because opts is immutable return "wsl://" + conn.Name.Distro } /** * This function is does not set a listener for WslConn * It is still required in order to set SockName **/ func (conn *WslConn) OpenDomainSocketListener(ctx context.Context) error { conn.Infof(ctx, "running OpenDomainSocketListener...\n") allowed := WithLockRtn(conn, func() bool { return conn.Status == Status_Connecting }) if !allowed { return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus()) } /* listener, err := client.ListenUnix(sockName) if err != nil { return fmt.Errorf("unable to request connection domain socket: %v", err) } */ conn.Infof(ctx, "setting domain socket to %s\n", wavebase.RemoteFullDomainSocketPath) conn.WithLock(func() { conn.DomainSockName = wavebase.RemoteFullDomainSocketPath //conn.DomainSockListener = listener }) conn.Infof(ctx, "successfully connected domain socket\n") /* go func() { defer func() { panichandler.PanicHandler("wslconn:OpenDomainSocketListener", recover()) }() defer conn.WithLock(func() { conn.DomainSockListener = nil conn.DomainSockName = "" }) wshutil.RunWshRpcOverListener(listener) }() */ return nil } func (conn *WslConn) getWshPath() string { config, ok := conn.getConnectionConfig() if ok && config.ConnWshPath != "" { return config.ConnWshPath } return wavebase.RemoteFullWshBinPath } func (conn *WslConn) GetConfigShellPath() string { config, ok := conn.getConnectionConfig() if !ok { return "" } return config.ConnShellPath } // returns (needsInstall, clientVersion, osArchStr, error) // if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr // if clientVersion is set, then no osArchStr will be returned func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (bool, string, string, error) { conn.Infof(ctx, "running StartConnServer...\n") allowed := WithLockRtn(conn, func() bool { return conn.Status == Status_Connecting }) if !allowed { return false, "", "", fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus()) } client := conn.GetClient() wshPath := conn.getWshPath() rpcCtx := wshrpc.RpcContext{ ClientType: wshrpc.ClientType_ConnServer, Conn: conn.GetName(), } sockName := conn.GetDomainSocketName() jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName) if err != nil { return false, "", "", fmt.Errorf("unable to create jwt token for conn controller: %w", err) } conn.Infof(ctx, "WSL-NEWSESSION (StartConnServer)\n") connServerCtx, cancelFn := context.WithCancel(context.Background()) conn.WithLock(func() { if conn.cancelFn != nil { conn.cancelFn() } conn.cancelFn = cancelFn }) cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath) shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr)) cmd := client.WslCommand(connServerCtx, shWrappedCmdStr) pipeRead, pipeWrite := io.Pipe() inputPipeRead, inputPipeWrite := io.Pipe() cmd.SetStdout(pipeWrite) cmd.SetStderr(pipeWrite) cmd.SetStdin(inputPipeRead) log.Printf("starting conn controller: %q\n", cmdStr) blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr) err = cmd.Start() if err != nil { return false, "", "", fmt.Errorf("unable to start conn controller cmd: %w", err) } linesChan := utilfn.StreamToLinesChan(pipeRead) versionLine, err := utilfn.ReadLineWithTimeout(linesChan, 5*time.Second) if err != nil { cancelFn() return false, "", "", fmt.Errorf("error reading wsh version: %w", err) } conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine)) isUpToDate, clientVersion, osArchStr, err := conncontroller.IsWshVersionUpToDate(ctx, versionLine) if err != nil { cancelFn() return false, "", "", fmt.Errorf("error checking wsh version: %w", err) } if isUpToDate && !afterUpdate && os.Getenv(wavebase.WaveWshForceUpdateVarName) != "" { isUpToDate = false conn.Infof(ctx, "%s set, forcing wsh update\n", wavebase.WaveWshForceUpdateVarName) } conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate) if !isUpToDate { cancelFn() return true, clientVersion, osArchStr, nil } jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second) if err != nil { cancelFn() return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err) } conn.Infof(ctx, "got jwt status line: %s\n", jwtLine) if strings.TrimSpace(jwtLine) == wavebase.NeedJwtConst { // write the jwt conn.Infof(ctx, "writing jwt token to connserver\n") _, err = fmt.Fprintf(inputPipeWrite, "%s\n", jwtToken) if err != nil { cancelFn() return false, clientVersion, "", fmt.Errorf("failed to write JWT token: %w", err) } } conn.WithLock(func() { conn.ConnController = cmd }) // service the I/O go func() { defer func() { panichandler.PanicHandler("wslconn:cmd.Wait", recover()) }() // wait for termination, clear the controller var waitErr error defer conn.WithLock(func() { if conn.ConnController != nil { conn.WshEnabled.Store(false) conn.NoWshReason = "connserver terminated" if waitErr != nil { conn.WshError = fmt.Sprintf("connserver terminated unexpectedly with error: %v", waitErr) } } conn.ConnController = nil }) waitErr = cmd.Wait() log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr) }() go func() { defer func() { panichandler.PanicHandler("wsl:StartConnServer:handleStdIOClient", recover()) }() logName := fmt.Sprintf("wslconn:%s", conn.GetName()) wshutil.HandleStdIOClient(logName, linesChan, inputPipeWrite) }() conn.Infof(ctx, "connserver started, waiting for route to be registered\n") regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) defer cancelFn() err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn)) if err != nil { return false, clientVersion, "", fmt.Errorf("timeout waiting for connserver to register") } time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready") conn.Infof(ctx, "connserver is registered and ready\n") return false, clientVersion, "", nil } type WshInstallOpts struct { Force bool NoUserPrompt bool } var queryTextTemplate = strings.TrimSpace(` Wave requires Wave Shell Extensions to be installed on %q to ensure a seamless experience. Would you like to install them? `) func (conn *WslConn) UpdateWsh(ctx context.Context, clientDisplayName string, remoteInfo *wshrpc.RemoteInfo) error { conn.Infof(ctx, "attempting to update wsh for connection %s (os:%s arch:%s version:%s)\n", conn.GetName(), remoteInfo.ClientOs, remoteInfo.ClientArch, remoteInfo.ClientVersion) client := conn.GetClient() if client == nil { return fmt.Errorf("cannot update wsh: ssh client is not connected") } err := CpWshToRemote(ctx, client, remoteInfo.ClientOs, remoteInfo.ClientArch) if err != nil { return fmt.Errorf("error installing wsh to remote: %w", err) } conn.Infof(ctx, "successfully updated wsh on %s\n", conn.GetName()) return nil } // returns (allowed, error) func (conn *WslConn) getPermissionToInstallWsh(ctx context.Context, clientDisplayName string) (bool, error) { conn.Infof(ctx, "running getPermissionToInstallWsh...\n") queryText := fmt.Sprintf(queryTextTemplate, clientDisplayName) title := "Install Wave Shell Extensions" request := &userinput.UserInputRequest{ ResponseType: "confirm", QueryText: queryText, Title: title, Markdown: true, CheckBoxMsg: "Automatically install for all connections", OkLabel: "Install wsh", CancelLabel: "No wsh", } conn.Infof(ctx, "requesting user confirmation...\n") response, err := userinput.GetUserInput(ctx, request) if err != nil { conn.Infof(ctx, "error getting user input: %v\n", err) return false, err } conn.Infof(ctx, "user response to allowing wsh: %v\n", response.Confirm) meta := make(map[string]any) meta["conn:wshenabled"] = response.Confirm conn.Infof(ctx, "writing conn:wshenabled=%v to connections.json\n", response.Confirm) err = wconfig.SetConnectionsConfigValue(conn.GetName(), meta) if err != nil { log.Printf("warning: error writing to connections file: %v", err) } if !response.Confirm { return false, nil } if response.CheckboxStat { conn.Infof(ctx, "writing conn:askbeforewshinstall=false to settings.json\n") meta := waveobj.MetaMapType{ wconfig.ConfigKey_ConnAskBeforeWshInstall: false, } setConfigErr := wconfig.SetBaseConfigValue(meta) if setConfigErr != nil { // this is not a critical error, just log and continue log.Printf("warning: error writing to base config file: %v", err) } } return true, nil } func (conn *WslConn) InstallWsh(ctx context.Context, osArchStr string) error { conn.Infof(ctx, "running installWsh...\n") client := conn.GetClient() if client == nil { conn.Infof(ctx, "ERROR ssh client is not connected, cannot install\n") return fmt.Errorf("ssh client is not connected, cannot install") } var clientOs, clientArch string var err error if osArchStr != "" { clientOs, clientArch, err = GetClientPlatformFromOsArchStr(ctx, osArchStr) } else { clientOs, clientArch, err = GetClientPlatform(ctx, genconn.MakeWSLShellClient(client)) } if err != nil { conn.Infof(ctx, "ERROR detecting client platform: %v\n", err) } conn.Infof(ctx, "detected remote platform os:%s arch:%s\n", clientOs, clientArch) err = CpWshToRemote(ctx, client, clientOs, clientArch) if err != nil { conn.Infof(ctx, "ERROR copying wsh binary to remote: %v\n", err) return fmt.Errorf("error copying wsh binary to remote: %w", err) } conn.Infof(ctx, "successfully installed wsh\n") return nil } func (conn *WslConn) GetClient() *wsl.Distro { conn.Lock.Lock() defer conn.Lock.Unlock() return conn.Client } func (conn *WslConn) Reconnect(ctx context.Context) error { err := conn.Close() if err != nil { return err } return conn.Connect(ctx) } func (conn *WslConn) WaitForConnect(ctx context.Context) error { for { status := conn.DeriveConnStatus() if status.Status == Status_Connected { return nil } if status.Status == Status_Connecting { select { case <-ctx.Done(): return fmt.Errorf("context timeout") case <-time.After(100 * time.Millisecond): continue } } if status.Status == Status_Init || status.Status == Status_Disconnected { return fmt.Errorf("disconnected") } if status.Status == Status_Error { return fmt.Errorf("error: %v", status.Error) } return fmt.Errorf("unknown status: %q", status.Status) } } // does not return an error since that error is stored inside of WslConn func (conn *WslConn) Connect(ctx context.Context) error { var connectAllowed bool conn.WithLock(func() { if conn.Status == Status_Connecting || conn.Status == Status_Connected { connectAllowed = false } else { conn.Status = Status_Connecting conn.Error = "" connectAllowed = true } }) log.Printf("Connect %s\n", conn.GetName()) if !connectAllowed { conn.Infof(ctx, "cannot connect to %q when status is %q\n", conn.GetName(), conn.GetStatus()) return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus()) } conn.FireConnChangeEvent() err := conn.connectInternal(ctx) conn.WithLock(func() { if err != nil { conn.Infof(ctx, "ERROR %v\n\n", err) conn.Status = Status_Error conn.Error = err.Error() conn.close_nolock() telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{ Conn: map[string]int{"wsl:connecterror": 1}, }, "wsl-connconnect") } else { conn.Infof(ctx, "successfully connected (wsh:%v)\n\n", conn.WshEnabled.Load()) conn.Status = Status_Connected conn.LastConnectTime = time.Now().UnixMilli() if conn.ActiveConnNum == 0 { conn.ActiveConnNum = int(activeConnCounter.Add(1)) } telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{ Conn: map[string]int{"wsl:connect": 1}, }, "wsl-connconnect") } }) conn.FireConnChangeEvent() return err } func (conn *WslConn) WithLock(fn func()) { conn.Lock.Lock() defer conn.Lock.Unlock() fn() } func WithLockRtn[T any](conn *WslConn, fn func() T) T { conn.Lock.Lock() defer conn.Lock.Unlock() return fn() } // returns (enable-wsh, ask-before-install) func (conn *WslConn) getConnWshSettings() (bool, bool) { config := wconfig.GetWatcher().GetFullConfig() enableWsh := config.Settings.ConnWshEnabled askBeforeInstall := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true) connSettings, ok := conn.getConnectionConfig() if ok { if connSettings.ConnWshEnabled != nil { enableWsh = *connSettings.ConnWshEnabled } // if the connection object exists, and conn:askbeforewshinstall is not set, the user must have allowed it // TODO: in v0.12+ this should be removed. we'll explicitly write a "false" into the connection object on successful connection if connSettings.ConnAskBeforeWshInstall == nil { askBeforeInstall = false } else { askBeforeInstall = *connSettings.ConnAskBeforeWshInstall } } return enableWsh, askBeforeInstall } type WshCheckResult struct { WshEnabled bool ClientVersion string NoWshReason string WshError error } // returns (wsh-enabled, clientVersion, text-reason, wshError) func (conn *WslConn) tryEnableWsh(ctx context.Context, clientDisplayName string) WshCheckResult { conn.Infof(ctx, "running tryEnableWsh...\n") enableWsh, askBeforeInstall := conn.getConnWshSettings() conn.Infof(ctx, "wsh settings enable:%v ask:%v\n", enableWsh, askBeforeInstall) if !enableWsh { return WshCheckResult{NoWshReason: "conn:wshenabled set to false"} } if askBeforeInstall { allowInstall, err := conn.getPermissionToInstallWsh(ctx, clientDisplayName) if err != nil { log.Printf("error getting permission to install wsh: %v\n", err) return WshCheckResult{NoWshReason: "error getting user permission to install", WshError: err} } if !allowInstall { return WshCheckResult{NoWshReason: "user selected not to install wsh extensions"} } } err := conn.OpenDomainSocketListener(ctx) if err != nil { conn.Infof(ctx, "ERROR opening domain socket listener: %v\n", err) err = fmt.Errorf("error opening domain socket listener: %w", err) return WshCheckResult{NoWshReason: "error opening domain socket", WshError: err} } needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false) if err != nil { conn.Infof(ctx, "ERROR starting conn server: %v\n", err) err = fmt.Errorf("error starting conn server: %w", err) return WshCheckResult{NoWshReason: "error starting connserver", WshError: err} } if needsInstall { conn.Infof(ctx, "connserver needs to be (re)installed\n") err = conn.InstallWsh(ctx, osArchStr) if err != nil { conn.Infof(ctx, "ERROR installing wsh: %v\n", err) err = fmt.Errorf("error installing wsh: %w", err) return WshCheckResult{NoWshReason: "error installing wsh/connserver", WshError: err} } needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true) if err != nil { conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err) err = fmt.Errorf("error starting conn server (after install): %w", err) return WshCheckResult{NoWshReason: "error starting connserver", WshError: err} } if needsInstall { conn.Infof(ctx, "conn server not installed correctly (after install)\n") err = fmt.Errorf("conn server not installed correctly (after install)") return WshCheckResult{NoWshReason: "connserver not installed properly", WshError: err} } return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion} } else { return WshCheckResult{WshEnabled: true, ClientVersion: clientVersion} } } func (conn *WslConn) getConnectionConfig() (wshrpc.ConnKeywords, bool) { config := wconfig.GetWatcher().GetFullConfig() connSettings, ok := config.Connections[conn.GetName()] if !ok { return wshrpc.ConnKeywords{}, false } return connSettings, true } func (conn *WslConn) persistWshInstalled(ctx context.Context, result WshCheckResult) { conn.WshEnabled.Store(result.WshEnabled) conn.SetWshError(result.WshError) conn.WithLock(func() { conn.NoWshReason = result.NoWshReason conn.WshVersion = result.ClientVersion }) connConfig, ok := conn.getConnectionConfig() if ok && connConfig.ConnWshEnabled != nil { return } meta := make(map[string]any) meta["conn:wshenabled"] = result.WshEnabled err := wconfig.SetConnectionsConfigValue(conn.GetName(), meta) if err != nil { conn.Infof(ctx, "WARN could not write conn:wshenabled=%v to connections.json: %v\n", result.WshEnabled, err) log.Printf("warning: error writing to connections file: %v", err) } // doesn't return an error since none of this is required for connection to work } func (conn *WslConn) connectInternal(ctx context.Context) error { conn.Infof(ctx, "connectInternal %s\n", conn.GetName()) client, err := wsl.GetDistro(ctx, conn.Name) if err != nil { conn.Infof(ctx, "ERROR GetDistro: %s\n", err) log.Printf("error: failed to get distro %s: %s\n", conn.GetName(), err) return err } conn.WithLock(func() { conn.Client = client }) go func() { defer func() { panichandler.PanicHandler("wsl-waitForDisconnect", recover()) }() conn.waitForDisconnect() }() wshResult := conn.tryEnableWsh(ctx, conn.GetName()) if !wshResult.WshEnabled { if wshResult.WshError != nil { conn.Infof(ctx, "ERROR enabling wsh: %v\n", wshResult.WshError) conn.Infof(ctx, "will connect with wsh disabled\n") } else { conn.Infof(ctx, "wsh not enabled: %s\n", wshResult.NoWshReason) } } conn.persistWshInstalled(ctx, wshResult) return nil } func (conn *WslConn) waitForDisconnect() { log.Printf("wait for disconnect in %+#v", conn) defer conn.FireConnChangeEvent() defer conn.HasWaiter.Store(false) if conn.ConnController == nil { return } err := conn.ConnController.Wait() conn.WithLock(func() { // disconnects happen for a variety of reasons (like network, etc. and are typically transient) // so we just set the status to "disconnected" here (not error) // don't overwrite any existing error (or error status) if err != nil && conn.Error == "" { conn.Error = err.Error() } if conn.Status != Status_Error { conn.Status = Status_Disconnected } conn.close_nolock() }) } func (conn *WslConn) SetWshError(err error) { conn.WithLock(func() { if err == nil { conn.WshError = "" } else { conn.WshError = err.Error() } }) } func (conn *WslConn) ClearWshError() { conn.WithLock(func() { conn.WshError = "" }) } func getConnInternal(name string) *WslConn { globalLock.Lock() defer globalLock.Unlock() connName := wsl.WslName{Distro: name} rtn := clientControllerMap[name] if rtn == nil { rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, WshEnabled: &atomic.Bool{}, HasWaiter: &atomic.Bool{}, cancelFn: nil} clientControllerMap[name] = rtn } return rtn } func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn { conn := getConnInternal(name) if conn.Client == nil && shouldConnect { conn.Connect(ctx) } return conn } // Convenience function for ensuring a connection is established func EnsureConnection(ctx context.Context, connName string) error { if connName == "" { return nil } conn := GetWslConn(ctx, connName, false) if conn == nil { return fmt.Errorf("connection not found: %s", connName) } connStatus := conn.DeriveConnStatus() switch connStatus.Status { case Status_Connected: return nil case Status_Connecting: return conn.WaitForConnect(ctx) case Status_Init, Status_Disconnected: return conn.Connect(ctx) case Status_Error: return fmt.Errorf("connection error: %s", connStatus.Error) default: return fmt.Errorf("unknown connection status %q", connStatus.Status) } } func DisconnectClient(connName string) error { conn := getConnInternal(connName) if conn == nil { return fmt.Errorf("client %q not found", connName) } err := conn.Close() return err }