// Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 package wsl import ( "context" "fmt" "io" "log" "net" "sync" "sync/atomic" "time" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/userinput" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wconfig" "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" ) const ( Status_Init = "init" Status_Connecting = "connecting" Status_Connected = "connected" Status_Disconnected = "disconnected" Status_Error = "error" ) const DefaultConnectionTimeout = 60 * time.Second var globalLock = &sync.Mutex{} var clientControllerMap = make(map[string]*WslConn) var activeConnCounter = &atomic.Int32{} type WslConn struct { Lock *sync.Mutex Status string Name WslName Client *Distro SockName string DomainSockListener net.Listener ConnController *WslCmd Error string HasWaiter *atomic.Bool LastConnectTime int64 ActiveConnNum int Context context.Context cancelFn func() } type WslName struct { Distro string `json:"distro"` } func GetAllConnStatus() []wshrpc.ConnStatus { globalLock.Lock() defer globalLock.Unlock() var connStatuses []wshrpc.ConnStatus for _, conn := range clientControllerMap { connStatuses = append(connStatuses, conn.DeriveConnStatus()) } return connStatuses } func GetNumWSLHasConnected() int { globalLock.Lock() defer globalLock.Unlock() var connectedCount int for _, conn := range clientControllerMap { if conn.LastConnectTime > 0 { connectedCount++ } } return connectedCount } func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus { conn.Lock.Lock() defer conn.Lock.Unlock() return wshrpc.ConnStatus{ Status: conn.Status, Connected: conn.Status == Status_Connected, Connection: conn.GetName(), HasConnected: (conn.LastConnectTime > 0), ActiveConnNum: conn.ActiveConnNum, Error: conn.Error, } } func (conn *WslConn) FireConnChangeEvent() { status := conn.DeriveConnStatus() event := wps.WaveEvent{ Event: wps.Event_ConnChange, Scopes: []string{ fmt.Sprintf("connection:%s", conn.GetName()), }, Data: status, } log.Printf("sending event: %+#v", event) wps.Broker.Publish(event) } func (conn *WslConn) Close() error { defer conn.FireConnChangeEvent() conn.WithLock(func() { if conn.Status == Status_Connected || conn.Status == Status_Connecting { // if status is init, disconnected, or error don't change it conn.Status = Status_Disconnected } conn.close_nolock() }) // we must wait for the waiter to complete startTime := time.Now() for conn.HasWaiter.Load() { time.Sleep(10 * time.Millisecond) if time.Since(startTime) > 2*time.Second { return fmt.Errorf("timeout waiting for waiter to complete") } } return nil } func (conn *WslConn) close_nolock() { // does not set status (that should happen at another level) if conn.DomainSockListener != nil { conn.DomainSockListener.Close() conn.DomainSockListener = nil } if conn.ConnController != nil { conn.cancelFn() // this suspends the conn controller conn.ConnController = nil } if conn.Client != nil { // conn.Client.Close() is not relevant here // we do not want to completely close the wsl in case // other applications are using it conn.Client = nil } } func (conn *WslConn) GetDomainSocketName() string { conn.Lock.Lock() defer conn.Lock.Unlock() return conn.SockName } func (conn *WslConn) GetStatus() string { conn.Lock.Lock() defer conn.Lock.Unlock() return conn.Status } func (conn *WslConn) GetName() string { // no lock required because opts is immutable return "wsl://" + conn.Name.Distro } /** * This function is does not set a listener for WslConn * It is still required in order to set SockName **/ func (conn *WslConn) OpenDomainSocketListener() error { var allowed bool conn.WithLock(func() { if conn.Status != Status_Connecting { allowed = false } else { allowed = true } }) if !allowed { return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus()) } conn.WithLock(func() { conn.SockName = "~/.waveterm/wave-remote.sock" }) return nil } func (conn *WslConn) StartConnServer() error { var allowed bool conn.WithLock(func() { if conn.Status != Status_Connecting { allowed = false } else { allowed = true } }) if !allowed { return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus()) } client := conn.GetClient() wshPath := GetWshPath(conn.Context, client) rpcCtx := wshrpc.RpcContext{ ClientType: wshrpc.ClientType_ConnServer, Conn: conn.GetName(), } sockName := conn.GetDomainSocketName() jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName) if err != nil { return fmt.Errorf("unable to create jwt token for conn controller: %w", err) } shellPath, err := DetectShell(conn.Context, client) if err != nil { return err } var cmdStr string if IsPowershell(shellPath) { cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) } else { cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) } log.Printf("starting conn controller: %s\n", cmdStr) cmd := client.WslCommand(conn.Context, cmdStr) pipeRead, pipeWrite := io.Pipe() inputPipeRead, inputPipeWrite := io.Pipe() cmd.SetStdout(pipeWrite) cmd.SetStderr(pipeWrite) cmd.SetStdin(inputPipeRead) err = cmd.Start() if err != nil { return fmt.Errorf("unable to start conn controller: %w", err) } conn.WithLock(func() { conn.ConnController = cmd }) // service the I/O go func() { // wait for termination, clear the controller defer conn.WithLock(func() { conn.ConnController = nil }) waitErr := cmd.Wait() log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr) }() go func() { logName := fmt.Sprintf("conncontroller:%s", conn.GetName()) wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite) }() regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) defer cancelFn() err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn)) if err != nil { return fmt.Errorf("timeout waiting for connserver to register") } time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready") return nil } type WshInstallOpts struct { Force bool NoUserPrompt bool } func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error { if opts == nil { opts = &WshInstallOpts{} } client := conn.GetClient() if client == nil { return fmt.Errorf("client is nil") } // check that correct wsh extensions are installed expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion) clientVersion, err := GetWshVersion(ctx, client) if err == nil && clientVersion == expectedVersion && !opts.Force { return nil } var queryText string var title string if opts.Force { queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName) title = "Install Wave Shell Extensions" } else if err != nil { queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+ "installed on `%s` \n"+ "to ensure a seamless experience. \n\n"+ "Would you like to install them?", clientDisplayName) title = "Install Wave Shell Extensions" } else { // don't ask for upgrading the version opts.NoUserPrompt = true } if !opts.NoUserPrompt { request := &userinput.UserInputRequest{ ResponseType: "confirm", QueryText: queryText, Title: title, Markdown: true, CheckBoxMsg: "Don't show me this again", } response, err := userinput.GetUserInput(ctx, request) if err != nil || !response.Confirm { return err } if response.CheckboxStat { meta := waveobj.MetaMapType{ wconfig.ConfigKey_ConnAskBeforeWshInstall: false, } err := wconfig.SetBaseConfigValue(meta) if err != nil { return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err) } } } log.Printf("attempting to install wsh to `%s`", clientDisplayName) clientOs, err := GetClientOs(ctx, client) if err != nil { return err } clientArch, err := GetClientArch(ctx, client) if err != nil { return err } // attempt to install extension wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) err = CpHostToRemote(ctx, client, wshLocalPath, "~/.waveterm/bin/wsh") if err != nil { return err } log.Printf("successfully installed wsh on %s\n", conn.GetName()) return nil } func (conn *WslConn) GetClient() *Distro { conn.Lock.Lock() defer conn.Lock.Unlock() return conn.Client } func (conn *WslConn) Reconnect(ctx context.Context) error { err := conn.Close() if err != nil { return err } return conn.Connect(ctx) } func (conn *WslConn) WaitForConnect(ctx context.Context) error { for { status := conn.DeriveConnStatus() if status.Status == Status_Connected { return nil } if status.Status == Status_Connecting { select { case <-ctx.Done(): return fmt.Errorf("context timeout") case <-time.After(100 * time.Millisecond): continue } } if status.Status == Status_Init || status.Status == Status_Disconnected { return fmt.Errorf("disconnected") } if status.Status == Status_Error { return fmt.Errorf("error: %v", status.Error) } return fmt.Errorf("unknown status: %q", status.Status) } } // does not return an error since that error is stored inside of WslConn func (conn *WslConn) Connect(ctx context.Context) error { var connectAllowed bool conn.WithLock(func() { if conn.Status == Status_Connecting || conn.Status == Status_Connected { connectAllowed = false } else { conn.Status = Status_Connecting conn.Error = "" connectAllowed = true } }) log.Printf("Connect %s\n", conn.GetName()) if !connectAllowed { return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus()) } conn.FireConnChangeEvent() err := conn.connectInternal(ctx) conn.WithLock(func() { if err != nil { conn.Status = Status_Error conn.Error = err.Error() conn.close_nolock() telemetry.GoUpdateActivityWrap(telemetry.ActivityUpdate{ Conn: map[string]int{"wsl:connecterror": 1}, }, "wsl-connconnect") } else { conn.Status = Status_Connected conn.LastConnectTime = time.Now().UnixMilli() if conn.ActiveConnNum == 0 { conn.ActiveConnNum = int(activeConnCounter.Add(1)) } telemetry.GoUpdateActivityWrap(telemetry.ActivityUpdate{ Conn: map[string]int{"wsl:connect": 1}, }, "wsl-connconnect") } }) conn.FireConnChangeEvent() return err } func (conn *WslConn) WithLock(fn func()) { conn.Lock.Lock() defer conn.Lock.Unlock() fn() } func (conn *WslConn) connectInternal(ctx context.Context) error { client, err := GetDistro(ctx, conn.Name) if err != nil { return err } conn.WithLock(func() { conn.Client = client }) err = conn.OpenDomainSocketListener() if err != nil { return err } config := wconfig.ReadFullConfig() installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !config.Settings.ConnAskBeforeWshInstall}) if installErr != nil { return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr) } csErr := conn.StartConnServer() if csErr != nil { return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr) } conn.HasWaiter.Store(true) go conn.waitForDisconnect() return nil } func (conn *WslConn) waitForDisconnect() { defer conn.FireConnChangeEvent() defer conn.HasWaiter.Store(false) err := conn.ConnController.Wait() conn.WithLock(func() { // disconnects happen for a variety of reasons (like network, etc. and are typically transient) // so we just set the status to "disconnected" here (not error) // don't overwrite any existing error (or error status) if err != nil && conn.Error == "" { conn.Error = err.Error() } if conn.Status != Status_Error { conn.Status = Status_Disconnected } conn.close_nolock() }) } func getConnInternal(name string) *WslConn { globalLock.Lock() defer globalLock.Unlock() connName := WslName{Distro: name} rtn := clientControllerMap[name] if rtn == nil { ctx, cancelFn := context.WithCancel(context.Background()) rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, Context: ctx, cancelFn: cancelFn} clientControllerMap[name] = rtn } return rtn } func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn { conn := getConnInternal(name) if conn.Client == nil && shouldConnect { conn.Connect(ctx) } return conn } // Convenience function for ensuring a connection is established func EnsureConnection(ctx context.Context, connName string) error { if connName == "" { return nil } conn := GetWslConn(ctx, connName, false) if conn == nil { return fmt.Errorf("connection not found: %s", connName) } connStatus := conn.DeriveConnStatus() switch connStatus.Status { case Status_Connected: return nil case Status_Connecting: return conn.WaitForConnect(ctx) case Status_Init, Status_Disconnected: return conn.Connect(ctx) case Status_Error: return fmt.Errorf("connection error: %s", connStatus.Error) default: return fmt.Errorf("unknown connection status %q", connStatus.Status) } } func DisconnectClient(connName string) error { conn := getConnInternal(connName) if conn == nil { return fmt.Errorf("client %q not found", connName) } err := conn.Close() return err }