diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 4f34e746b..7d03133f3 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -388,9 +388,23 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } - shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) - if err != nil { - return nil, err + if !wslConn.WshEnabled.Load() { + shellProc, err = shellexec.StartWslShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) + if err != nil { + return nil, err + } + } else { + shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) + if err != nil { + wslConn.SetWshError(err) + wslConn.WshEnabled.Store(false) + log.Printf("error starting wsl shell proc with wsh: %v", err) + log.Print("attempting install without wsh") + shellProc, err = shellexec.StartWslShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) + if err != nil { + return nil, err + } + } } } else if remoteName != "" { credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second) diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index 81ce89a36..394f2fbe7 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -150,6 +150,27 @@ func (pp *PipePty) WriteString(s string) (n int, err error) { return pp.Write([]byte(s)) } +func StartWslShellProcNoWsh(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wslconn.WslConn) (*ShellProc, error) { + client := conn.GetClient() + conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProcNoWsh)") + + ecmd := exec.Command("wsl.exe", "~", "-d", client.Name()) + + if termSize.Rows == 0 || termSize.Cols == 0 { + termSize.Rows = shellutil.DefaultTermRows + termSize.Cols = shellutil.DefaultTermCols + } + if termSize.Rows <= 0 || termSize.Cols <= 0 { + return nil, fmt.Errorf("invalid term size: %v", termSize) + } + cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)}) + if err != nil { + return nil, err + } + cmdWrap := MakeCmdWrap(ecmd, cmdPty) + return &ShellProc{Cmd: cmdWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil +} + func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wslconn.WslConn) (*ShellProc, error) { client := conn.GetClient() conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)") diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index e421bec61..75a155d70 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -783,6 +783,16 @@ func (ws *WshServer) WslDefaultDistroCommand(ctx context.Context) (string, error * Dismisses the WshFail Command in runtime memory on the backend */ func (ws *WshServer) DismissWshFailCommand(ctx context.Context, connName string) error { + if strings.HasPrefix(connName, "wsl://") { + distroName := strings.TrimPrefix(connName, "wsl://") + conn := wslconn.GetWslConn(ctx, distroName, false) + if conn == nil { + return fmt.Errorf("connection not found: %s", connName) + } + conn.ClearWshError() + conn.FireConnChangeEvent() + return nil + } opts, err := remote.ParseOpts(connName) if err != nil { return err diff --git a/pkg/wslconn/wslconn.go b/pkg/wslconn/wslconn.go index 1f561c826..42ec61118 100644 --- a/pkg/wslconn/wslconn.go +++ b/pkg/wslconn/wslconn.go @@ -101,11 +101,14 @@ func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus { return wshrpc.ConnStatus{ Status: conn.Status, Connected: conn.Status == Status_Connected, - WshEnabled: true, // always use wsh for wsl connections (temporary) + WshEnabled: conn.WshEnabled.Load(), Connection: conn.GetName(), HasConnected: (conn.LastConnectTime > 0), ActiveConnNum: conn.ActiveConnNum, Error: conn.Error, + WshError: conn.WshError, + NoWshReason: conn.NoWshReason, + WshVersion: conn.WshVersion, } } @@ -702,6 +705,9 @@ func (conn *WslConn) waitForDisconnect() { log.Printf("wait for disconnect in %+#v", conn) defer conn.FireConnChangeEvent() defer conn.HasWaiter.Store(false) + if conn.ConnController == nil { + return + } err := conn.ConnController.Wait() conn.WithLock(func() { // disconnects happen for a variety of reasons (like network, etc. and are typically transient)