diff --git a/frontend/app/block/block.less b/frontend/app/block/block.less index 8d0d16dc3..8bab8ccfb 100644 --- a/frontend/app/block/block.less +++ b/frontend/app/block/block.less @@ -137,6 +137,34 @@ } } + .connection-button { + display: flex; + align-items: center; + gap: 2px; + flex-wrap: nowrap; + overflow: hidden; + text-overflow: ellipsis; + min-width: 0; + font-weight: 400; + color: var(--main-text-color); + border-radius: 2px; + padding-right: 6px; + + &:hover { + background-color: var(--highlight-bg-color); + } + + .connection-icon-box { + flex: 1 1 auto; + overflow: hidden; + } + + .connection-name { + flex: 1 100 auto; + overflow: hidden; + } + } + .block-frame-textelems-wrapper { display: flex; flex: 1 100 auto; diff --git a/frontend/app/block/blockframe.tsx b/frontend/app/block/blockframe.tsx index f84323e64..36e2e3735 100644 --- a/frontend/app/block/blockframe.tsx +++ b/frontend/app/block/blockframe.tsx @@ -1,7 +1,14 @@ // Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 -import { blockViewToIcon, blockViewToName, getBlockHeaderIcon, IconButton, Input } from "@/app/block/blockutil"; +import { + blockViewToIcon, + blockViewToName, + ConnectionButton, + getBlockHeaderIcon, + IconButton, + Input, +} from "@/app/block/blockutil"; import { Button } from "@/app/element/button"; import { ContextMenuModel } from "@/app/store/contextmenu"; import { atoms, globalStore, useBlockAtom, WOS } from "@/app/store/global"; @@ -170,6 +177,8 @@ const HeaderTextElem = React.memo(({ elem }: { elem: HeaderElem }) => { {elem.text} ); + } else if (elem.elemtype == "connectionbutton") { + return ; } else if (elem.elemtype == "div") { return (
{ + const buttonRef = React.useRef(null); + return ( +
+ + {typeof decl.icon === "string" ? ( + + ) : ( + decl.icon + )} + + +
{decl.text}
+
+ ); +}); + export const Input = React.memo(({ decl, className }: { decl: HeaderInput; className: string }) => { const { value, ref, isDisabled, onChange, onKeyDown, onFocus, onBlur } = decl; return ( diff --git a/frontend/app/modals/typeaheadmodal.tsx b/frontend/app/modals/typeaheadmodal.tsx index 8af77ca15..27d7a432f 100644 --- a/frontend/app/modals/typeaheadmodal.tsx +++ b/frontend/app/modals/typeaheadmodal.tsx @@ -104,8 +104,9 @@ interface TypeAheadModalProps { suggestions?: SuggestionType[]; label?: string; className?: string; - onSelect?: (_: string) => void; + value?: string; onChange?: (_: string) => void; + onSelect?: (_: string) => void; onClickBackdrop?: () => void; onKeyDown?: (_) => void; } @@ -115,6 +116,7 @@ const TypeAheadModal = ({ suggestions = dummy, label, anchor, + value, onChange, onSelect, onClickBackdrop, @@ -167,6 +169,7 @@ const TypeAheadModal = ({ (); const Counters = new Map(); +const ConnStatusMap = new Map>(); type GlobalInitOptions = { platform: NodeJS.Platform; @@ -143,9 +145,16 @@ function initGlobalAtoms(initOpts: GlobalInitOptions) { }; } +type WaveEventSubjectContainer = { + id: string; + handler: (event: WaveEvent) => void; + scope: string; +}; + // key is "eventType" or "eventType|oref" const eventSubjects = new Map>(); const fileSubjects = new Map>(); +const waveEventSubjects = new Map(); function getSubjectInternal(subjectKey: string): SubjectWithRef { let subject = eventSubjects.get(subjectKey); @@ -173,6 +182,61 @@ function getEventORefSubject(eventType: string, oref: string): SubjectWithRef void): () => void { + if (handler == null) { + return; + } + const id = crypto.randomUUID(); + const subject = new rxjs.Subject() as any; + const scont: WaveEventSubjectContainer = { id, scope, handler }; + let subjects = waveEventSubjects.get(eventType); + if (subjects == null) { + subjects = []; + waveEventSubjects.set(eventType, subjects); + } + subjects.push(scont); + updateWaveEventSub(eventType); + return () => waveEventUnsubscribe(eventType, id); +} + +function waveEventUnsubscribe(eventType: string, id: string) { + let subjects = waveEventSubjects.get(eventType); + if (subjects == null) { + return; + } + const idx = subjects.findIndex((s) => s.id === id); + if (idx === -1) { + return; + } + subjects.splice(idx, 1); + if (subjects.length === 0) { + waveEventSubjects.delete(eventType); + } + updateWaveEventSub(eventType); +} + function getFileSubject(zoneId: string, fileName: string): SubjectWithRef { const subjectKey = zoneId + "|" + fileName; let subject = fileSubjects.get(subjectKey); @@ -251,6 +315,25 @@ function useBlockDataLoaded(blockId: string): boolean { let globalWS: WSControl = null; +function handleWaveEvent(event: WaveEvent) { + const subjects = waveEventSubjects.get(event.event); + if (subjects == null) { + return; + } + for (const scont of subjects) { + if (util.isBlank(scont.scope)) { + scont.handler(event); + continue; + } + if (event.scopes == null) { + continue; + } + if (event.scopes.includes(scont.scope)) { + scont.handler(event); + } + } +} + function handleWSEventMessage(msg: WSEventType) { if (msg.eventtype == null) { console.log("unsupported event", msg); @@ -275,7 +358,7 @@ function handleWSEventMessage(msg: WSEventType) { } if (msg.eventtype == "rpc") { const rpcMsg: RpcMessage = msg.data; - handleIncomingRpcMessage(rpcMsg); + handleIncomingRpcMessage(rpcMsg, handleWaveEvent); return; } if (msg.eventtype == "layoutaction") { @@ -496,6 +579,38 @@ function countersPrint() { console.log(outStr); } +async function loadConnStatus() { + const connStatusArr = await services.ClientService.GetAllConnStatus(); + if (connStatusArr == null) { + return; + } + for (const connStatus of connStatusArr) { + const curAtom = getConnStatusAtom(connStatus.connection); + globalStore.set(curAtom, connStatus); + } +} + +function subscribeToConnEvents() { + waveEventSubscribe("connchange", null, (event: WaveEvent) => { + const connStatus = event.data as ConnStatus; + if (connStatus == null || util.isBlank(connStatus.connection)) { + return; + } + let curAtom = ConnStatusMap.get(connStatus.connection); + globalStore.set(curAtom, connStatus); + }); +} + +function getConnStatusAtom(conn: string): jotai.PrimitiveAtom { + let rtn = ConnStatusMap.get(conn); + if (rtn == null) { + const connStatus: ConnStatus = { connection: conn, connected: false, error: null }; + rtn = jotai.atom(connStatus); + ConnStatusMap.set(conn, rtn); + } + return rtn; +} + export { atoms, counterInc, @@ -504,6 +619,7 @@ export { createBlock, fetchWaveFile, getApi, + getConnStatusAtom, getEventORefSubject, getEventSubject, getFileSubject, @@ -514,16 +630,20 @@ export { initGlobal, initWS, isDev, + loadConnStatus, openLink, PLATFORM, registerViewModel, sendWSCommand, setBlockFocus, setPlatform, + subscribeToConnEvents, unregisterViewModel, useBlockAtom, useBlockCache, useBlockDataLoaded, useSettingsAtom, + waveEventSubscribe, + waveEventUnsubscribe, WOS, }; diff --git a/frontend/app/store/services.ts b/frontend/app/store/services.ts index c3c63653b..8e58090c4 100644 --- a/frontend/app/store/services.ts +++ b/frontend/app/store/services.ts @@ -32,6 +32,9 @@ class ClientServiceType { FocusWindow(arg2: string): Promise { return WOS.callBackendService("client", "FocusWindow", Array.from(arguments)) } + GetAllConnStatus(): Promise { + return WOS.callBackendService("client", "GetAllConnStatus", Array.from(arguments)) + } GetClientData(): Promise { return WOS.callBackendService("client", "GetClientData", Array.from(arguments)) } diff --git a/frontend/app/store/wshrpc.ts b/frontend/app/store/wshrpc.ts index bdcac511f..91ea2b375 100644 --- a/frontend/app/store/wshrpc.ts +++ b/frontend/app/store/wshrpc.ts @@ -10,7 +10,7 @@ type RpcEntry = { msgFn: (msg: RpcMessage) => void; }; -let openRpcs = new Map(); +const openRpcs = new Map(); async function* rpcResponseGenerator( command: string, @@ -86,10 +86,23 @@ function sendRpcCommand(msg: RpcMessage): AsyncGenerator void) { const isRequest = msg.command != null || msg.reqid != null; if (isRequest) { - console.log("rpc request not supported", msg); + // handle events + if (msg.command == "eventrecv") { + if (eventHandlerFn != null) { + eventHandlerFn(msg.data); + } + return; + } + + console.log("rpc command not supported", msg); return; } if (msg.resid == null) { @@ -122,4 +135,4 @@ if (globalThis.window != null) { globalThis["consumeGenerator"] = consumeGenerator; } -export { handleIncomingRpcMessage, sendRpcCommand }; +export { handleIncomingRpcMessage, sendRawRpcMessage, sendRpcCommand }; diff --git a/frontend/app/store/wshserver.ts b/frontend/app/store/wshserver.ts index df43efb76..cc74e3b6a 100644 --- a/frontend/app/store/wshserver.ts +++ b/frontend/app/store/wshserver.ts @@ -53,7 +53,7 @@ class WshServerType { } // command "eventunsub" [call] - EventUnsubCommand(data: SubscriptionRequest, opts?: RpcOpts): Promise { + EventUnsubCommand(data: string, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("eventunsub", data, opts); } diff --git a/frontend/app/view/term/term.less b/frontend/app/view/term/term.less index 2559d6e11..16dfbef5b 100644 --- a/frontend/app/view/term/term.less +++ b/frontend/app/view/term/term.less @@ -1,6 +1,16 @@ // Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 +.connection-btn { + min-height: 0; + overflow: hidden; + line-height: 1; + display: flex; + background-color: orangered; + justify-content: flex-start; + width: 200px; +} + .view-term { display: flex; flex-direction: column; diff --git a/frontend/app/view/term/term.tsx b/frontend/app/view/term/term.tsx index b39806810..dca1d0a19 100644 --- a/frontend/app/view/term/term.tsx +++ b/frontend/app/view/term/term.tsx @@ -1,9 +1,18 @@ // Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 +import { TypeAheadModal } from "@/app/modals/typeaheadmodal"; import { WshServer } from "@/app/store/wshserver"; import { VDomView } from "@/app/view/term/vdom"; -import { WOS, atoms, getEventORefSubject, globalStore, useBlockAtom, useSettingsAtom } from "@/store/global"; +import { + WOS, + atoms, + getConnStatusAtom, + getEventORefSubject, + globalStore, + useBlockAtom, + useSettingsAtom, +} from "@/store/global"; import * as services from "@/store/services"; import * as keyutil from "@/util/keyutil"; import * as util from "@/util/util"; @@ -109,13 +118,16 @@ function setBlockFocus(blockId: string) { class TermViewModel { viewType: string; + connected: boolean; termRef: React.RefObject; blockAtom: jotai.Atom; termMode: jotai.Atom; + connectedAtom: jotai.Atom; + typeahead: boolean; htmlElemFocusRef: React.RefObject; blockId: string; viewIcon: jotai.Atom; - viewText: jotai.Atom; + viewText: jotai.Atom; viewName: jotai.Atom; blockBg: jotai.Atom; @@ -123,6 +135,14 @@ class TermViewModel { this.viewType = "term"; this.blockId = blockId; this.blockAtom = WOS.getWaveObjectAtom(`block:${blockId}`); + this.connectedAtom = jotai.atom((get) => { + const connectionName = get(this.blockAtom).meta?.connection || ""; + if (connectionName == "") { + return true; + } + const status = get(getConnStatusAtom(connectionName)); + return status.connected; + }); this.termMode = jotai.atom((get) => { const blockData = get(this.blockAtom); return blockData?.meta?.["term:mode"] ?? "term"; @@ -139,7 +159,30 @@ class TermViewModel { }); this.viewText = jotai.atom((get) => { const blockData = get(this.blockAtom); - return blockData?.meta?.title ?? ""; + const titleText: HeaderText = { elemtype: "text", text: blockData?.meta?.title ?? "" }; + const typeAhead = get(atoms.typeAheadModalAtom); + const connectionName = blockData?.meta?.connection || ""; + const isConnected = get(this.connectedAtom); + let iconColor: string; + if (connectionName != "") { + iconColor = "#53b4ea"; + } else { + iconColor = "var(--grey-text-color)"; + } + const connButton: ConnectionButton = { + elemtype: "connectionbutton", + icon: "arrow-right-arrow-left", + iconColor: iconColor, + text: connectionName, + connected: isConnected, + onClick: () => { + globalStore.set(atoms.typeAheadModalAtom, { + ...(typeAhead as TypeAheadModalType), + [blockId]: true, + }); + }, + }; + return [connButton, titleText] as HeaderElem[]; }); this.blockBg = jotai.atom((get) => { const blockData = get(this.blockAtom); @@ -152,6 +195,10 @@ class TermViewModel { }); } + resetConnection() { + WshServer.ControllerRestartCommand({ blockid: this.blockId }); + } + giveFocus(): boolean { let termMode = globalStore.get(this.termMode); if (termMode == "term") { @@ -196,6 +243,9 @@ interface TerminalViewProps { } const TerminalView = ({ blockId, model }: TerminalViewProps) => { + const typeAhead = jotai.useAtomValue(atoms.typeAheadModalAtom); + const viewRef = React.createRef(); + const [connSelected, setConnSelected] = React.useState(""); const connectElemRef = React.useRef(null); const termRef = React.useRef(null); model.termRef = termRef; @@ -371,11 +421,57 @@ const TerminalView = ({ blockId, model }: TerminalViewProps) => { } } + const changeConnection = React.useCallback( + async (connName: string) => { + await WshServer.SetMetaCommand({ oref: WOS.makeORef("block", blockId), meta: { connection: connName } }); + await WshServer.ControllerRestartCommand({ blockid: blockId }); + }, + [blockId] + ); + + const handleTypeAheadKeyDown = React.useCallback( + (waveEvent: WaveKeyboardEvent): boolean => { + if (keyutil.checkKeyPressed(waveEvent, "Enter")) { + changeConnection(connSelected); + globalStore.set(atoms.typeAheadModalAtom, { + ...(typeAhead as TypeAheadModalType), + [blockId]: false, + }); + setConnSelected(""); + return true; + } + if (keyutil.checkKeyPressed(waveEvent, "Escape")) { + globalStore.set(atoms.typeAheadModalAtom, { + ...(typeAhead as TypeAheadModalType), + [blockId]: false, + }); + setConnSelected(""); + model.giveFocus(); + return true; + } + }, + [typeAhead, model, blockId, connSelected] + ); + return (
+ {typeAhead[blockId] && ( + { + changeConnection(selected); + }} + onKeyDown={(e) => keyutil.keydownWrapper(handleTypeAheadKeyDown)(e)} + onChange={(current: string) => setConnSelected(current)} + value={connSelected} + label="Switch Connection" + /> + )}
diff --git a/frontend/types/custom.d.ts b/frontend/types/custom.d.ts index 0917c0e81..1f2ac7a83 100644 --- a/frontend/types/custom.d.ts +++ b/frontend/types/custom.d.ts @@ -140,7 +140,7 @@ declare global { type SubjectWithRef = rxjs.Subject & { refCount: number; release: () => void }; - type HeaderElem = HeaderIconButton | HeaderText | HeaderInput | HeaderDiv | HeaderTextButton; + type HeaderElem = HeaderIconButton | HeaderText | HeaderInput | HeaderDiv | HeaderTextButton | ConnectionButton; type HeaderIconButton = { elemtype: "iconbutton"; @@ -181,6 +181,16 @@ declare global { children: HeaderElem[]; onMouseOver?: (e: React.MouseEvent) => void; onMouseOut?: (e: React.MouseEvent) => void; + onClick?: (e: React.MouseEvent) => void; + }; + + type ConnectionButton = { + elemtype: "connectionbutton"; + icon: string; + text: string; + iconColor: string; + onClick?: (e: React.MouseEvent) => void; + connected: boolean; }; interface ViewModel { diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index df52f7872..b5419adfe 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -158,6 +158,14 @@ declare global { meta: MetaType; }; + // wshrpc.ConnStatus + type ConnStatus = { + status: string; + connection: string; + connected: boolean; + error?: string; + }; + // wshrpc.CpuDataRequest type CpuDataRequest = { id: string; diff --git a/frontend/wave.ts b/frontend/wave.ts index 8f800cc9c..e87468b58 100644 --- a/frontend/wave.ts +++ b/frontend/wave.ts @@ -2,7 +2,18 @@ // SPDX-License-Identifier: Apache-2.0 import { WshServer } from "@/app/store/wshserver"; -import { atoms, countersClear, countersPrint, getApi, globalStore, globalWS, initGlobal, initWS } from "@/store/global"; +import { + atoms, + countersClear, + countersPrint, + getApi, + globalStore, + globalWS, + initGlobal, + initWS, + loadConnStatus, + subscribeToConnEvents, +} from "@/store/global"; import * as services from "@/store/services"; import * as WOS from "@/store/wos"; import * as keyutil from "@/util/keyutil"; @@ -44,6 +55,8 @@ document.addEventListener("DOMContentLoaded", async () => { const initialTab = await WOS.loadAndPinWaveObject(WOS.makeORef("tab", waveWindow.activetabid)); await WOS.loadAndPinWaveObject(WOS.makeORef("layout", initialTab.layoutstate)); initWS(); + await loadConnStatus(); + subscribeToConnEvents(); const settings = await services.FileService.GetSettingsConfig(); console.log("settings", settings); globalStore.set(atoms.settingsConfigAtom, settings); diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 6e2bc222b..ee2c04d97 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -278,12 +278,13 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj if err != nil { return err } - conn, err := conncontroller.GetConn(credentialCtx, opts) - if err != nil { - return err + conn := conncontroller.GetConn(credentialCtx, opts, true) + connStatus := conn.DeriveConnStatus() + if connStatus.Error != "" { + return fmt.Errorf("error connecting to remote: %s", connStatus.Error) } if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { - jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: conn.Opts.String()}, conn.SockName) + jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: conn.Opts.String()}, conn.GetDomainSocketName()) if err != nil { return fmt.Errorf("error making jwt token: %w", err) } @@ -385,10 +386,11 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj log.Printf("[shellproc] shell process wait loop done\n") }() waitErr := shellProc.Cmd.Wait() - shellProc.SetWaitErrorAndSignalDone(waitErr) exitCode := shellexec.ExitCodeFromWaitErr(waitErr) termMsg := fmt.Sprintf("\r\nprocess finished with exit code = %d\r\n\r\n", exitCode) + //HandleAppendBlockFile(bc.BlockId, BlockFile_Term, []byte("\r\n")) HandleAppendBlockFile(bc.BlockId, BlockFile_Term, []byte(termMsg)) + shellProc.SetWaitErrorAndSignalDone(waitErr) }() return nil } @@ -464,8 +466,21 @@ func (bc *BlockController) SendInput(inputUnion *BlockInputUnion) error { } func (bc *BlockController) RestartController() error { - // TODO: if shell command is already running - // we probably want to kill it off, wait, and then restart it + + // kill the command if it's running + bc.Lock.Lock() + if bc.ShellProc != nil { + bc.ShellProc.Close() + } + bc.Lock.Unlock() + + // wait for process to complete + if bc.ShellProc != nil { + doneCh := bc.ShellProc.DoneCh + <-doneCh + } + + // restart controller bdata, err := wstore.DBMustGet[*waveobj.Block](context.Background(), bc.BlockId) if err != nil { return fmt.Errorf("error getting block: %w", err) diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index 90b95b748..870dd4030 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -5,36 +5,111 @@ package conncontroller import ( "context" + "errors" "fmt" "io" + "io/fs" "log" "net" + "os" + "path/filepath" "strings" "sync" + "sync/atomic" + "time" + "github.com/kevinburke/ssh_config" "github.com/wavetermdev/thenextwave/pkg/remote" "github.com/wavetermdev/thenextwave/pkg/userinput" "github.com/wavetermdev/thenextwave/pkg/util/shellutil" "github.com/wavetermdev/thenextwave/pkg/util/utilfn" "github.com/wavetermdev/thenextwave/pkg/wavebase" + "github.com/wavetermdev/thenextwave/pkg/wps" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshutil" "golang.org/x/crypto/ssh" ) +const ( + Status_Init = "init" + Status_Connecting = "connecting" + Status_Connected = "connected" + Status_Disconnected = "disconnected" + Status_Error = "error" +) + var globalLock = &sync.Mutex{} var clientControllerMap = make(map[remote.SSHOpts]*SSHConn) type SSHConn struct { Lock *sync.Mutex + Status string Opts *remote.SSHOpts Client *ssh.Client SockName string DomainSockListener net.Listener ConnController *ssh.Session + Error string + HasWaiter *atomic.Bool +} + +func GetAllConnStatus() []wshrpc.ConnStatus { + globalLock.Lock() + defer globalLock.Unlock() + + var connStatuses []wshrpc.ConnStatus + for _, conn := range clientControllerMap { + connStatuses = append(connStatuses, conn.DeriveConnStatus()) + } + return connStatuses +} + +func (conn *SSHConn) DeriveConnStatus() wshrpc.ConnStatus { + conn.Lock.Lock() + defer conn.Lock.Unlock() + return wshrpc.ConnStatus{ + Status: conn.Status, + Connection: conn.Opts.String(), + Connected: conn.Client != nil, + Error: conn.Error, + } +} + +func (conn *SSHConn) FireConnChangeEvent() { + status := conn.DeriveConnStatus() + event := wshrpc.WaveEvent{ + Event: wshrpc.Event_ConnChange, + Scopes: []string{ + fmt.Sprintf("connection:%s", conn.GetName()), + }, + Data: status, + } + log.Printf("sending event: %+#v", event) + wps.Broker.Publish(event) } func (conn *SSHConn) Close() error { + defer conn.FireConnChangeEvent() + conn.WithLock(func() { + if conn.Status == Status_Connected || conn.Status == Status_Connecting { + // if status is init, disconnected, or error don't change it + conn.Status = Status_Disconnected + } + conn.close_nolock() + }) + // we must wait for the waiter to complete + startTime := time.Now() + for conn.HasWaiter.Load() { + time.Sleep(10 * time.Millisecond) + if time.Since(startTime) > 2*time.Second { + return fmt.Errorf("timeout waiting for waiter to complete") + } + } + return nil +} + +func (conn *SSHConn) close_nolock() { + // does not set status (that should happen at another level) if conn.DomainSockListener != nil { conn.DomainSockListener.Close() conn.DomainSockListener = nil @@ -43,75 +118,113 @@ func (conn *SSHConn) Close() error { conn.ConnController.Close() conn.ConnController = nil } - err := conn.Client.Close() - conn.Client = nil - return err + if conn.Client != nil { + conn.Client.Close() + conn.Client = nil + } +} + +func (conn *SSHConn) GetDomainSocketName() string { + conn.Lock.Lock() + defer conn.Lock.Unlock() + return conn.SockName +} + +func (conn *SSHConn) GetStatus() string { + conn.Lock.Lock() + defer conn.Lock.Unlock() + return conn.Status +} + +func (conn *SSHConn) GetName() string { + // no lock required because opts is immutable + return conn.Opts.String() } func (conn *SSHConn) OpenDomainSocketListener() error { - if conn.DomainSockListener != nil { - return nil + var allowed bool + conn.WithLock(func() { + if conn.Status != Status_Connecting { + allowed = false + } else { + allowed = true + } + }) + if !allowed { + return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus()) } + client := conn.GetClient() randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness if err != nil { return fmt.Errorf("error generating random string: %w", err) } sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr) - log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName) - listener, err := conn.Client.ListenUnix(sockName) + log.Printf("remote domain socket %s %q\n", conn.GetName(), sockName) + listener, err := client.ListenUnix(sockName) if err != nil { return fmt.Errorf("unable to request connection domain socket: %v", err) } - conn.SockName = sockName - conn.DomainSockListener = listener + conn.WithLock(func() { + conn.SockName = sockName + conn.DomainSockListener = listener + }) go func() { - defer func() { - conn.Lock.Lock() - defer conn.Lock.Unlock() + defer conn.WithLock(func() { conn.DomainSockListener = nil - }() + conn.SockName = "" + }) wshutil.RunWshRpcOverListener(listener) }() return nil } func (conn *SSHConn) StartConnServer() error { - conn.Lock.Lock() - defer conn.Lock.Unlock() - if conn.ConnController != nil { - return nil + var allowed bool + conn.WithLock(func() { + if conn.Status != Status_Connecting { + allowed = false + } else { + allowed = true + } + }) + if !allowed { + return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus()) } - wshPath := remote.GetWshPath(conn.Client) + client := conn.GetClient() + wshPath := remote.GetWshPath(client) rpcCtx := wshrpc.RpcContext{ ClientType: wshrpc.ClientType_ConnServer, - Conn: conn.Opts.String(), + Conn: conn.GetName(), } - jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, conn.SockName) + sockName := conn.GetDomainSocketName() + jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName) if err != nil { return fmt.Errorf("unable to create jwt token for conn controller: %w", err) } - sshSession, err := conn.Client.NewSession() + sshSession, err := client.NewSession() if err != nil { return fmt.Errorf("unable to create ssh session for conn controller: %w", err) } pipeRead, pipeWrite := io.Pipe() sshSession.Stdout = pipeWrite sshSession.Stderr = pipeWrite - conn.ConnController = sshSession cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) log.Printf("starting conn controller: %s\n", cmdStr) err = sshSession.Start(cmdStr) if err != nil { return fmt.Errorf("unable to start conn controller: %w", err) } + conn.WithLock(func() { + conn.ConnController = sshSession + }) // service the I/O go func() { // wait for termination, clear the controller + defer conn.WithLock(func() { + conn.ConnController = nil + }) waitErr := sshSession.Wait() - log.Printf("conn controller (%q) terminated: %v", conn.Opts.String(), waitErr) - conn.Lock.Lock() - defer conn.Lock.Unlock() - conn.ConnController = nil + log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr) }() go func() { readErr := wshutil.StreamToLines(pipeRead, func(line []byte) { @@ -119,23 +232,27 @@ func (conn *SSHConn) StartConnServer() error { if !strings.HasSuffix(lineStr, "\n") { lineStr += "\n" } - log.Printf("[conncontroller:%s:output] %s", conn.Opts.String(), lineStr) + log.Printf("[conncontroller:%s:output] %s", conn.GetName(), lineStr) }) if readErr != nil && readErr != io.EOF { - log.Printf("[conncontroller:%s] error reading output: %v\n", conn.Opts.String(), readErr) + log.Printf("[conncontroller:%s] error reading output: %v\n", conn.GetName(), readErr) } }() return nil } func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error { - client := conn.Client + client := conn.GetClient() + if client == nil { + return fmt.Errorf("client is nil") + } // check that correct wsh extensions are installed expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion) clientVersion, err := remote.GetWshVersion(client) if err == nil && clientVersion == expectedVersion { return nil } + // TODO add some progress to SSHConn about install status var queryText string var title string if err != nil { @@ -170,56 +287,189 @@ func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error { if err != nil { return err } - log.Printf("successfully installed wsh on %s\n", conn.Opts.String()) + log.Printf("successfully installed wsh on %s\n", conn.GetName()) return nil } -func GetConn(ctx context.Context, opts *remote.SSHOpts) (*SSHConn, error) { - globalLock.Lock() - defer globalLock.Unlock() +func (conn *SSHConn) GetClient() *ssh.Client { + conn.Lock.Lock() + defer conn.Lock.Unlock() + return conn.Client +} - // attempt to retrieve if already opened - conn, ok := clientControllerMap[*opts] - if ok { - return conn, nil - } - - client, err := remote.ConnectToClient(ctx, opts) //todo specify or remove opts +func (conn *SSHConn) Reconnect(ctx context.Context) error { + err := conn.Close() if err != nil { - return nil, err + return err } - conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client} + return conn.Connect(ctx) +} + +// does not return an error since that error is stored inside of SSHConn +func (conn *SSHConn) Connect(ctx context.Context) error { + var connectAllowed bool + conn.WithLock(func() { + if conn.Status == Status_Connecting || conn.Status == Status_Connected { + connectAllowed = false + } else { + conn.Status = Status_Connecting + conn.Error = "" + connectAllowed = true + } + }) + if !connectAllowed { + return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus()) + } + conn.FireConnChangeEvent() + err := conn.connectInternal(ctx) + conn.WithLock(func() { + if err != nil { + conn.Status = Status_Error + conn.Error = err.Error() + conn.close_nolock() + } else { + conn.Status = Status_Connected + } + }) + conn.FireConnChangeEvent() + return err +} + +func (conn *SSHConn) WithLock(fn func()) { + conn.Lock.Lock() + defer conn.Lock.Unlock() + fn() +} + +func (conn *SSHConn) connectInternal(ctx context.Context) error { + client, err := remote.ConnectToClient(ctx, conn.Opts) //todo specify or remove opts + if err != nil { + return err + } + conn.WithLock(func() { + conn.Client = client + }) err = conn.OpenDomainSocketListener() if err != nil { - conn.Close() - return nil, err + return err } - installErr := conn.checkAndInstallWsh(ctx) if installErr != nil { - conn.Close() - return nil, fmt.Errorf("conncontroller %s wsh install error: %v", conn.Opts.String(), installErr) + return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr) } - csErr := conn.StartConnServer() if csErr != nil { - conn.Close() - return nil, fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.Opts.String(), csErr) + return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr) } + conn.HasWaiter.Store(true) + go conn.waitForDisconnect() + return nil +} - // save successful connection to map - clientControllerMap[*opts] = conn +func (conn *SSHConn) waitForDisconnect() { + defer conn.FireConnChangeEvent() + defer conn.HasWaiter.Store(false) + client := conn.GetClient() + if client == nil { + return + } + err := client.Wait() + conn.WithLock(func() { + if err != nil { + if conn.Status != Status_Disconnected { + // don't set the error if our status is disconnected (because this error was caused by an explicit close) + conn.Status = Status_Error + conn.Error = err.Error() + } + } else { + // not sure if this is possible, because I think Wait() always returns an error (although that's not in the docs) + conn.Status = Status_Disconnected + } + conn.close_nolock() + }) +} - return conn, nil +func getConnInternal(opts *remote.SSHOpts) *SSHConn { + globalLock.Lock() + defer globalLock.Unlock() + rtn := clientControllerMap[*opts] + if rtn == nil { + rtn = &SSHConn{Lock: &sync.Mutex{}, Status: Status_Init, Opts: opts, HasWaiter: &atomic.Bool{}} + clientControllerMap[*opts] = rtn + } + return rtn +} + +func GetConn(ctx context.Context, opts *remote.SSHOpts, shouldConnect bool) *SSHConn { + conn := getConnInternal(opts) + if conn.Client == nil && shouldConnect { + conn.Connect(ctx) + } + return conn } func DisconnectClient(opts *remote.SSHOpts) error { - globalLock.Lock() - defer globalLock.Unlock() - - client, ok := clientControllerMap[*opts] - if ok { - return client.Close() + conn := getConnInternal(opts) + if conn == nil { + return fmt.Errorf("client %q not found", opts.String()) } - return fmt.Errorf("client %v not found", opts) + err := conn.Close() + return err +} + +func resolveSshConfigPatterns(configFiles []string) ([]string, error) { + // using two separate containers to track order and have O(1) lookups + // since go does not have an ordered map primitive + var discoveredPatterns []string + alreadyUsed := make(map[string]bool) + alreadyUsed[""] = true // this excludes the empty string from potential alias + var openedFiles []fs.File + + defer func() { + for _, openedFile := range openedFiles { + openedFile.Close() + } + }() + + var errs []error + for _, configFile := range configFiles { + fd, openErr := os.Open(configFile) + openedFiles = append(openedFiles, fd) + if fd == nil { + errs = append(errs, openErr) + continue + } + + cfg, _ := ssh_config.Decode(fd) + for _, host := range cfg.Hosts { + // for each host, find the first good alias + for _, hostPattern := range host.Patterns { + hostPatternStr := hostPattern.String() + if !strings.Contains(hostPatternStr, "*") || alreadyUsed[hostPatternStr] { + discoveredPatterns = append(discoveredPatterns, hostPatternStr) + alreadyUsed[hostPatternStr] = true + break + } + } + } + } + if len(errs) == len(configFiles) { + errs = append([]error{fmt.Errorf("no ssh config files could be opened:\n")}, errs...) + return nil, errors.Join(errs...) + } + if len(discoveredPatterns) == 0 { + return nil, fmt.Errorf("no compatible hostnames found in ssh config files") + } + + return discoveredPatterns, nil +} + +func GetConnectionsFromConfig() ([]string, error) { + home := wavebase.GetHomeDir() + localConfig := filepath.Join(home, ".ssh", "config") + systemConfig := filepath.Join("/etc", "ssh", "config") + sshConfigFiles := []string{localConfig, systemConfig} + ssh_config.ReloadConfigs() + + return resolveSshConfigPatterns(sshConfigFiles) } diff --git a/pkg/remote/sshclient.go b/pkg/remote/sshclient.go index 8829424c9..160fc65e2 100644 --- a/pkg/remote/sshclient.go +++ b/pkg/remote/sshclient.go @@ -709,8 +709,13 @@ type SSHOpts struct { } func (opts SSHOpts) String() string { - if opts.SSHPort == 0 { - return fmt.Sprintf("%s@%s", opts.SSHUser, opts.SSHHost) + stringRepr := "" + if opts.SSHUser != "" { + stringRepr = opts.SSHUser + "@" } - return fmt.Sprintf("%s@%s:%d", opts.SSHUser, opts.SSHHost, opts.SSHPort) + stringRepr = stringRepr + opts.SSHHost + if opts.SSHPort != 0 { + stringRepr = stringRepr + ":" + fmt.Sprint(opts.SSHPort) + } + return stringRepr } diff --git a/pkg/service/clientservice/clientservice.go b/pkg/service/clientservice/clientservice.go index ffefb2246..86332b3ec 100644 --- a/pkg/service/clientservice/clientservice.go +++ b/pkg/service/clientservice/clientservice.go @@ -10,9 +10,11 @@ import ( "time" "github.com/wavetermdev/thenextwave/pkg/eventbus" + "github.com/wavetermdev/thenextwave/pkg/remote/conncontroller" "github.com/wavetermdev/thenextwave/pkg/service/objectservice" "github.com/wavetermdev/thenextwave/pkg/util/utilfn" "github.com/wavetermdev/thenextwave/pkg/waveobj" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wstore" ) @@ -64,6 +66,10 @@ func (cs *ClientService) MakeWindow(ctx context.Context) (*waveobj.Window, error return wstore.CreateWindow(ctx, nil) } +func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) { + return conncontroller.GetAllConnStatus(), nil +} + // moves the window to the front of the windowId stack func (cs *ClientService) FocusWindow(ctx context.Context, windowId string) error { client, err := cs.GetClientData() diff --git a/pkg/wps/wps.go b/pkg/wps/wps.go index 058235eb7..79797b2ff 100644 --- a/pkg/wps/wps.go +++ b/pkg/wps/wps.go @@ -58,9 +58,14 @@ func (b *BrokerType) GetClient() Client { return b.Client } +// if already subscribed, this will *resubscribe* with the new subscription (remove the old one, and replace with this one) func (b *BrokerType) Subscribe(subRouteId string, sub wshrpc.SubscriptionRequest) { + if sub.Event == "" { + return + } b.Lock.Lock() defer b.Lock.Unlock() + b.unsubscribe_nolock(subRouteId, sub.Event) bs := b.SubMap[sub.Event] if bs == nil { bs = &BrokerSubscription{ @@ -72,6 +77,7 @@ func (b *BrokerType) Subscribe(subRouteId string, sub wshrpc.SubscriptionRequest } if sub.AllScopes { bs.AllSubs = utilfn.AddElemToSliceUniq(bs.AllSubs, subRouteId) + return } for _, scope := range sub.Scopes { starMatch := scopeHasStarMatch(scope) @@ -114,26 +120,26 @@ func addStrToScopeMap(scopeMap map[string][]string, scope string, routeId string scopeMap[scope] = scopeSubs } -func (b *BrokerType) Unsubscribe(subRouteId string, sub wshrpc.SubscriptionRequest) { +func (b *BrokerType) Unsubscribe(subRouteId string, eventName string) { b.Lock.Lock() defer b.Lock.Unlock() - bs := b.SubMap[sub.Event] + b.unsubscribe_nolock(subRouteId, eventName) +} + +func (b *BrokerType) unsubscribe_nolock(subRouteId string, eventName string) { + bs := b.SubMap[eventName] if bs == nil { return } - if sub.AllScopes { - bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, subRouteId) + bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, subRouteId) + for scope := range bs.ScopeSubs { + removeStrFromScopeMap(bs.ScopeSubs, scope, subRouteId) } - for _, scope := range sub.Scopes { - starMatch := scopeHasStarMatch(scope) - if starMatch { - removeStrFromScopeMap(bs.StarSubs, scope, subRouteId) - } else { - removeStrFromScopeMap(bs.ScopeSubs, scope, subRouteId) - } + for scope := range bs.StarSubs { + removeStrFromScopeMap(bs.StarSubs, scope, subRouteId) } if bs.IsEmpty() { - delete(b.SubMap, sub.Event) + delete(b.SubMap, eventName) } } diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 7eea9d7cb..1c06237c7 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -66,7 +66,7 @@ func EventSubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *w } // command "eventunsub", wshserver.EventUnsubCommand -func EventUnsubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *wshrpc.RpcOpts) error { +func EventUnsubCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "eventunsub", data, opts) return err } diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index dcd7d7cb4..349467120 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -26,6 +26,7 @@ const ( const ( Event_BlockClose = "blockclose" + Event_ConnChange = "connchange" ) const ( @@ -83,7 +84,7 @@ type WshRpcInterface interface { FileReadCommand(ctx context.Context, data CommandFileData) (string, error) EventPublishCommand(ctx context.Context, data WaveEvent) error EventSubCommand(ctx context.Context, data SubscriptionRequest) error - EventUnsubCommand(ctx context.Context, data SubscriptionRequest) error + EventUnsubCommand(ctx context.Context, data string) error EventUnsubAllCommand(ctx context.Context) error StreamTestCommand(ctx context.Context) chan RespOrErrorUnion[int] StreamWaveAiCommand(ctx context.Context, request OpenAiStreamRequest) chan RespOrErrorUnion[OpenAIPacketType] @@ -324,3 +325,10 @@ type TimeSeriesData struct { Ts int64 `json:"ts"` Values map[string]float64 `json:"values"` } + +type ConnStatus struct { + Status string `json:"status"` + Connection string `json:"connection"` + Connected bool `json:"connected"` + Error string `json:"error,omitempty"` +} diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 392f41342..6f2deef1b 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -478,7 +478,7 @@ func (ws *WshServer) EventSubCommand(ctx context.Context, data wshrpc.Subscripti return nil } -func (ws *WshServer) EventUnsubCommand(ctx context.Context, data wshrpc.SubscriptionRequest) error { +func (ws *WshServer) EventUnsubCommand(ctx context.Context, data string) error { rpcSource := wshutil.GetRpcSourceFromContext(ctx) if rpcSource == "" { return fmt.Errorf("no rpc source set") diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go index cd0ab4ea6..ccb849451 100644 --- a/pkg/wshutil/wshrouter.go +++ b/pkg/wshutil/wshrouter.go @@ -10,6 +10,7 @@ import ( "log" "sync" + "github.com/wavetermdev/thenextwave/pkg/wps" "github.com/wavetermdev/thenextwave/pkg/wshrpc" ) @@ -18,6 +19,8 @@ const SysRoute = "sys" // this route doesn't exist, just a placeholder for syste // this works like a network switch +// TODO maybe move the wps integration here instead of in wshserver + type routeInfo struct { RpcId string SourceRouteId string @@ -285,6 +288,9 @@ func (router *WshRouter) UnregisterRoute(routeId string) { router.Lock.Lock() defer router.Lock.Unlock() delete(router.RouteMap, routeId) + go func() { + wps.Broker.UnsubscribeAll(routeId) + }() } // this may return nil (returns default only for empty routeId)