diff --git a/frontend/app/block/blockframe.tsx b/frontend/app/block/blockframe.tsx index 22a8fc2ed..a8c88251e 100644 --- a/frontend/app/block/blockframe.tsx +++ b/frontend/app/block/blockframe.tsx @@ -353,12 +353,21 @@ const ChangeConnectionBlockModal = React.memo( }) => { const [connSelected, setConnSelected] = React.useState(""); const changeConnModalOpen = jotai.useAtomValue(changeConnModalAtom); + const [blockData] = WOS.useWaveObjectValue(WOS.makeORef("block", blockId)); const changeConnection = React.useCallback( async (connName: string) => { + const oldCwd = blockData?.meta?.file ?? ""; + let newCwd: string; + if (oldCwd == "") { + newCwd = ""; + } else { + newCwd = "~"; + } await WshServer.SetMetaCommand({ oref: WOS.makeORef("block", blockId), - meta: { connection: connName }, + meta: { connection: connName, file: newCwd }, }); + await services.BlockService.EnsureConnection(blockId).catch((e) => console.log(e)); await WshServer.ControllerRestartCommand({ blockid: blockId }); }, [blockId] diff --git a/frontend/app/store/services.ts b/frontend/app/store/services.ts index 78e07d166..915db790c 100644 --- a/frontend/app/store/services.ts +++ b/frontend/app/store/services.ts @@ -7,6 +7,9 @@ import * as WOS from "./wos"; // blockservice.BlockService (block) class BlockServiceType { + EnsureConnection(arg2: string): Promise { + return WOS.callBackendService("block", "EnsureConnection", Array.from(arguments)) + } GetControllerStatus(arg2: string): Promise { return WOS.callBackendService("block", "GetControllerStatus", Array.from(arguments)) } diff --git a/frontend/app/view/preview/preview.tsx b/frontend/app/view/preview/preview.tsx index f15949337..e1c6614bf 100644 --- a/frontend/app/view/preview/preview.tsx +++ b/frontend/app/view/preview/preview.tsx @@ -52,6 +52,7 @@ export class PreviewModel implements ViewModel { previewTextRef: React.RefObject; editMode: jotai.Atom; canPreview: jotai.PrimitiveAtom; + manageConnection: jotai.Atom; fileName: jotai.Atom; connection: jotai.Atom; @@ -81,6 +82,7 @@ export class PreviewModel implements ViewModel { this.ceReadOnly = jotai.atom(true); this.canPreview = jotai.atom(false); this.openFileModal = jotai.atom(false); + this.manageConnection = jotai.atom(true); this.blockAtom = WOS.getWaveObjectAtom(`block:${blockId}`); this.viewIcon = jotai.atom((get) => { let blockData = get(this.blockAtom); diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index ab2685999..154d51ce9 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -19,11 +19,13 @@ import ( "time" "github.com/kevinburke/ssh_config" + "github.com/skeema/knownhosts" "github.com/wavetermdev/thenextwave/pkg/remote" "github.com/wavetermdev/thenextwave/pkg/userinput" "github.com/wavetermdev/thenextwave/pkg/util/shellutil" "github.com/wavetermdev/thenextwave/pkg/util/utilfn" "github.com/wavetermdev/thenextwave/pkg/wavebase" + "github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/wps" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshutil" @@ -84,7 +86,7 @@ func (conn *SSHConn) FireConnChangeEvent() { }, Data: status, } - log.Printf("connstatus change %q => %s\n", conn.GetName(), status.Status) + log.Printf("sending event: %+#v", event) wps.Broker.Publish(event) } @@ -241,7 +243,7 @@ func (conn *SSHConn) StartConnServer() error { return nil } -func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error { +func (conn *SSHConn) checkAndInstallWsh(ctx context.Context, clientDisplayName string) error { client := conn.GetClient() if client == nil { return fmt.Errorf("client is nil") @@ -252,27 +254,33 @@ func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error { if err == nil && clientVersion == expectedVersion { return nil } - // TODO add some progress to SSHConn about install status var queryText string var title string if err != nil { - queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?" - title = "Install Wsh Shell Extensions" + queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+ + "installed on `%s` \n"+ + "to ensure a seamless experience. \n\n"+ + "Would you like to install them?", clientDisplayName) + title = "Install Wave Shell Extensions" } else { - queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion) - title = "Update Wsh Shell Extensions" + queryText = fmt.Sprintf("Wave requires the Wave Shell Extensions \n"+ + "installed on `%s` \n"+ + "to be updated from %s to %s. \n\n"+ + "Would you like to update?", clientDisplayName, clientVersion, expectedVersion) + title = "Update Wave Shell Extensions" } request := &userinput.UserInputRequest{ ResponseType: "confirm", QueryText: queryText, Title: title, + Markdown: true, CheckBoxMsg: "Don't show me this again", } response, err := userinput.GetUserInput(ctx, request) if err != nil || !response.Confirm { return err } - log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String()) + log.Printf("attempting to install wsh to `%s`", clientDisplayName) clientOs, err := remote.GetClientOs(client) if err != nil { return err @@ -346,6 +354,8 @@ func (conn *SSHConn) connectInternal(ctx context.Context) error { if err != nil { return err } + fmtAddr := knownhosts.Normalize(fmt.Sprintf("%s@%s", client.User(), client.RemoteAddr().String())) + clientDisplayName := fmt.Sprintf("%s (%s)", conn.GetName(), fmtAddr) conn.WithLock(func() { conn.Client = client }) @@ -353,7 +363,7 @@ func (conn *SSHConn) connectInternal(ctx context.Context) error { if err != nil { return err } - installErr := conn.checkAndInstallWsh(ctx) + installErr := conn.checkAndInstallWsh(ctx, clientDisplayName) if installErr != nil { return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr) } @@ -408,6 +418,48 @@ func GetConn(ctx context.Context, opts *remote.SSHOpts, shouldConnect bool) *SSH return conn } +// Convenience function for ensuring a connection is established +func EnsureConnection(ctx context.Context, blockData *waveobj.Block) error { + connectionName := blockData.Meta.GetString(waveobj.MetaKey_Connection, "") + if connectionName == "" { + return nil + } + credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second) + defer cancelFunc() + + opts, err := remote.ParseOpts(connectionName) + if err != nil { + return err + } + conn := GetConn(credentialCtx, opts, true) + statusChan := make(chan string, 1) + go func() { + // we need to wait for connected/disconnected/error + // to ensure the connection has been established before + // continuing in the original thread + for { + // GetStatus has a lock which makes this reasonable to loop over + status := conn.GetStatus() + if credentialCtx.Err() != nil { + // prevent infinite loop from context + statusChan <- Status_Error + return + } + if status == Status_Connected || status == Status_Disconnected || status == Status_Error { + statusChan <- status + return + } + } + }() + status := <-statusChan + if status == Status_Error { + return fmt.Errorf("connection error: %v", conn.Error) + } else if status == Status_Disconnected { + return fmt.Errorf("disconnected: %v", conn.Error) + } + return nil +} + func DisconnectClient(opts *remote.SSHOpts) error { conn := getConnInternal(opts) if conn == nil { diff --git a/pkg/service/blockservice/blockservice.go b/pkg/service/blockservice/blockservice.go index 73acbeea6..0555e0c34 100644 --- a/pkg/service/blockservice/blockservice.go +++ b/pkg/service/blockservice/blockservice.go @@ -11,6 +11,7 @@ import ( "github.com/wavetermdev/thenextwave/pkg/blockcontroller" "github.com/wavetermdev/thenextwave/pkg/filestore" + "github.com/wavetermdev/thenextwave/pkg/remote/conncontroller" "github.com/wavetermdev/thenextwave/pkg/tsgen/tsgenmeta" "github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/wshrpc" @@ -83,3 +84,11 @@ func (bs *BlockService) SaveWaveAiData(ctx context.Context, blockId string, hist } return nil } + +func (bs *BlockService) EnsureConnection(ctx context.Context, blockId string) error { + block, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId) + if err != nil { + return err + } + return conncontroller.EnsureConnection(ctx, block) +} diff --git a/pkg/service/objectservice/objectservice.go b/pkg/service/objectservice/objectservice.go index 3188cad91..365f623d7 100644 --- a/pkg/service/objectservice/objectservice.go +++ b/pkg/service/objectservice/objectservice.go @@ -11,6 +11,7 @@ import ( "time" "github.com/wavetermdev/thenextwave/pkg/blockcontroller" + "github.com/wavetermdev/thenextwave/pkg/remote/conncontroller" "github.com/wavetermdev/thenextwave/pkg/tsgen/tsgenmeta" "github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/wcore" @@ -133,6 +134,14 @@ func (svc *ObjectService) SetActiveTab(uiContext waveobj.UIContext, tabId string return nil, fmt.Errorf("error getting tab: %w", err) } for _, blockId := range tab.BlockIds { + blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId) + if err != nil { + return nil, fmt.Errorf("error getting block: %w", err) + } + err = conncontroller.EnsureConnection(ctx, blockData) + if err != nil { + return nil, fmt.Errorf("unable to ensure connection: %v", err) + } blockErr := blockcontroller.StartBlockController(ctx, tabId, blockId) if blockErr != nil { // we don't want to fail the set active tab operation if a block controller fails to start diff --git a/pkg/wcore/wcore.go b/pkg/wcore/wcore.go index adc542966..cb4d9e881 100644 --- a/pkg/wcore/wcore.go +++ b/pkg/wcore/wcore.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "github.com/wavetermdev/thenextwave/pkg/blockcontroller" + "github.com/wavetermdev/thenextwave/pkg/remote/conncontroller" "github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/wps" "github.com/wavetermdev/thenextwave/pkg/wshrpc" @@ -173,6 +174,10 @@ func CreateBlock(ctx context.Context, tabId string, blockDef *waveobj.BlockDef, if err != nil { return nil, fmt.Errorf("error creating block: %w", err) } + err = conncontroller.EnsureConnection(ctx, blockData) + if err != nil { + return nil, fmt.Errorf("unable to ensure connection: %v", err) + } controllerName := blockData.Meta.GetString(waveobj.MetaKey_Controller, "") if controllerName != "" { err = blockcontroller.StartBlockController(ctx, tabId, blockData.OID)