diff --git a/frontend/app/block/block.less b/frontend/app/block/block.less
index 8d0d16dc3..8bab8ccfb 100644
--- a/frontend/app/block/block.less
+++ b/frontend/app/block/block.less
@@ -137,6 +137,34 @@
}
}
+ .connection-button {
+ display: flex;
+ align-items: center;
+ gap: 2px;
+ flex-wrap: nowrap;
+ overflow: hidden;
+ text-overflow: ellipsis;
+ min-width: 0;
+ font-weight: 400;
+ color: var(--main-text-color);
+ border-radius: 2px;
+ padding-right: 6px;
+
+ &:hover {
+ background-color: var(--highlight-bg-color);
+ }
+
+ .connection-icon-box {
+ flex: 1 1 auto;
+ overflow: hidden;
+ }
+
+ .connection-name {
+ flex: 1 100 auto;
+ overflow: hidden;
+ }
+ }
+
.block-frame-textelems-wrapper {
display: flex;
flex: 1 100 auto;
diff --git a/frontend/app/block/blockframe.tsx b/frontend/app/block/blockframe.tsx
index f84323e64..36e2e3735 100644
--- a/frontend/app/block/blockframe.tsx
+++ b/frontend/app/block/blockframe.tsx
@@ -1,7 +1,14 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
-import { blockViewToIcon, blockViewToName, getBlockHeaderIcon, IconButton, Input } from "@/app/block/blockutil";
+import {
+ blockViewToIcon,
+ blockViewToName,
+ ConnectionButton,
+ getBlockHeaderIcon,
+ IconButton,
+ Input,
+} from "@/app/block/blockutil";
import { Button } from "@/app/element/button";
import { ContextMenuModel } from "@/app/store/contextmenu";
import { atoms, globalStore, useBlockAtom, WOS } from "@/app/store/global";
@@ -170,6 +177,8 @@ const HeaderTextElem = React.memo(({ elem }: { elem: HeaderElem }) => {
{elem.text}
);
+ } else if (elem.elemtype == "connectionbutton") {
+ return ;
} else if (elem.elemtype == "div") {
return (
{
+ const buttonRef = React.useRef
(null);
+ return (
+
+
+ {typeof decl.icon === "string" ? (
+
+ ) : (
+ decl.icon
+ )}
+
+
+
{decl.text}
+
+ );
+});
+
export const Input = React.memo(({ decl, className }: { decl: HeaderInput; className: string }) => {
const { value, ref, isDisabled, onChange, onKeyDown, onFocus, onBlur } = decl;
return (
diff --git a/frontend/app/modals/typeaheadmodal.tsx b/frontend/app/modals/typeaheadmodal.tsx
index 8af77ca15..27d7a432f 100644
--- a/frontend/app/modals/typeaheadmodal.tsx
+++ b/frontend/app/modals/typeaheadmodal.tsx
@@ -104,8 +104,9 @@ interface TypeAheadModalProps {
suggestions?: SuggestionType[];
label?: string;
className?: string;
- onSelect?: (_: string) => void;
+ value?: string;
onChange?: (_: string) => void;
+ onSelect?: (_: string) => void;
onClickBackdrop?: () => void;
onKeyDown?: (_) => void;
}
@@ -115,6 +116,7 @@ const TypeAheadModal = ({
suggestions = dummy,
label,
anchor,
+ value,
onChange,
onSelect,
onClickBackdrop,
@@ -167,6 +169,7 @@ const TypeAheadModal = ({
();
const Counters = new Map();
+const ConnStatusMap = new Map>();
type GlobalInitOptions = {
platform: NodeJS.Platform;
@@ -143,9 +145,16 @@ function initGlobalAtoms(initOpts: GlobalInitOptions) {
};
}
+type WaveEventSubjectContainer = {
+ id: string;
+ handler: (event: WaveEvent) => void;
+ scope: string;
+};
+
// key is "eventType" or "eventType|oref"
const eventSubjects = new Map>();
const fileSubjects = new Map>();
+const waveEventSubjects = new Map();
function getSubjectInternal(subjectKey: string): SubjectWithRef {
let subject = eventSubjects.get(subjectKey);
@@ -173,6 +182,61 @@ function getEventORefSubject(eventType: string, oref: string): SubjectWithRef void): () => void {
+ if (handler == null) {
+ return;
+ }
+ const id = crypto.randomUUID();
+ const subject = new rxjs.Subject() as any;
+ const scont: WaveEventSubjectContainer = { id, scope, handler };
+ let subjects = waveEventSubjects.get(eventType);
+ if (subjects == null) {
+ subjects = [];
+ waveEventSubjects.set(eventType, subjects);
+ }
+ subjects.push(scont);
+ updateWaveEventSub(eventType);
+ return () => waveEventUnsubscribe(eventType, id);
+}
+
+function waveEventUnsubscribe(eventType: string, id: string) {
+ let subjects = waveEventSubjects.get(eventType);
+ if (subjects == null) {
+ return;
+ }
+ const idx = subjects.findIndex((s) => s.id === id);
+ if (idx === -1) {
+ return;
+ }
+ subjects.splice(idx, 1);
+ if (subjects.length === 0) {
+ waveEventSubjects.delete(eventType);
+ }
+ updateWaveEventSub(eventType);
+}
+
function getFileSubject(zoneId: string, fileName: string): SubjectWithRef {
const subjectKey = zoneId + "|" + fileName;
let subject = fileSubjects.get(subjectKey);
@@ -251,6 +315,25 @@ function useBlockDataLoaded(blockId: string): boolean {
let globalWS: WSControl = null;
+function handleWaveEvent(event: WaveEvent) {
+ const subjects = waveEventSubjects.get(event.event);
+ if (subjects == null) {
+ return;
+ }
+ for (const scont of subjects) {
+ if (util.isBlank(scont.scope)) {
+ scont.handler(event);
+ continue;
+ }
+ if (event.scopes == null) {
+ continue;
+ }
+ if (event.scopes.includes(scont.scope)) {
+ scont.handler(event);
+ }
+ }
+}
+
function handleWSEventMessage(msg: WSEventType) {
if (msg.eventtype == null) {
console.log("unsupported event", msg);
@@ -275,7 +358,7 @@ function handleWSEventMessage(msg: WSEventType) {
}
if (msg.eventtype == "rpc") {
const rpcMsg: RpcMessage = msg.data;
- handleIncomingRpcMessage(rpcMsg);
+ handleIncomingRpcMessage(rpcMsg, handleWaveEvent);
return;
}
if (msg.eventtype == "layoutaction") {
@@ -496,6 +579,38 @@ function countersPrint() {
console.log(outStr);
}
+async function loadConnStatus() {
+ const connStatusArr = await services.ClientService.GetAllConnStatus();
+ if (connStatusArr == null) {
+ return;
+ }
+ for (const connStatus of connStatusArr) {
+ const curAtom = getConnStatusAtom(connStatus.connection);
+ globalStore.set(curAtom, connStatus);
+ }
+}
+
+function subscribeToConnEvents() {
+ waveEventSubscribe("connchange", null, (event: WaveEvent) => {
+ const connStatus = event.data as ConnStatus;
+ if (connStatus == null || util.isBlank(connStatus.connection)) {
+ return;
+ }
+ let curAtom = ConnStatusMap.get(connStatus.connection);
+ globalStore.set(curAtom, connStatus);
+ });
+}
+
+function getConnStatusAtom(conn: string): jotai.PrimitiveAtom {
+ let rtn = ConnStatusMap.get(conn);
+ if (rtn == null) {
+ const connStatus: ConnStatus = { connection: conn, connected: false, error: null };
+ rtn = jotai.atom(connStatus);
+ ConnStatusMap.set(conn, rtn);
+ }
+ return rtn;
+}
+
export {
atoms,
counterInc,
@@ -504,6 +619,7 @@ export {
createBlock,
fetchWaveFile,
getApi,
+ getConnStatusAtom,
getEventORefSubject,
getEventSubject,
getFileSubject,
@@ -514,16 +630,20 @@ export {
initGlobal,
initWS,
isDev,
+ loadConnStatus,
openLink,
PLATFORM,
registerViewModel,
sendWSCommand,
setBlockFocus,
setPlatform,
+ subscribeToConnEvents,
unregisterViewModel,
useBlockAtom,
useBlockCache,
useBlockDataLoaded,
useSettingsAtom,
+ waveEventSubscribe,
+ waveEventUnsubscribe,
WOS,
};
diff --git a/frontend/app/store/services.ts b/frontend/app/store/services.ts
index c3c63653b..8e58090c4 100644
--- a/frontend/app/store/services.ts
+++ b/frontend/app/store/services.ts
@@ -32,6 +32,9 @@ class ClientServiceType {
FocusWindow(arg2: string): Promise {
return WOS.callBackendService("client", "FocusWindow", Array.from(arguments))
}
+ GetAllConnStatus(): Promise {
+ return WOS.callBackendService("client", "GetAllConnStatus", Array.from(arguments))
+ }
GetClientData(): Promise {
return WOS.callBackendService("client", "GetClientData", Array.from(arguments))
}
diff --git a/frontend/app/store/wshrpc.ts b/frontend/app/store/wshrpc.ts
index bdcac511f..91ea2b375 100644
--- a/frontend/app/store/wshrpc.ts
+++ b/frontend/app/store/wshrpc.ts
@@ -10,7 +10,7 @@ type RpcEntry = {
msgFn: (msg: RpcMessage) => void;
};
-let openRpcs = new Map();
+const openRpcs = new Map();
async function* rpcResponseGenerator(
command: string,
@@ -86,10 +86,23 @@ function sendRpcCommand(msg: RpcMessage): AsyncGenerator void) {
const isRequest = msg.command != null || msg.reqid != null;
if (isRequest) {
- console.log("rpc request not supported", msg);
+ // handle events
+ if (msg.command == "eventrecv") {
+ if (eventHandlerFn != null) {
+ eventHandlerFn(msg.data);
+ }
+ return;
+ }
+
+ console.log("rpc command not supported", msg);
return;
}
if (msg.resid == null) {
@@ -122,4 +135,4 @@ if (globalThis.window != null) {
globalThis["consumeGenerator"] = consumeGenerator;
}
-export { handleIncomingRpcMessage, sendRpcCommand };
+export { handleIncomingRpcMessage, sendRawRpcMessage, sendRpcCommand };
diff --git a/frontend/app/store/wshserver.ts b/frontend/app/store/wshserver.ts
index df43efb76..cc74e3b6a 100644
--- a/frontend/app/store/wshserver.ts
+++ b/frontend/app/store/wshserver.ts
@@ -53,7 +53,7 @@ class WshServerType {
}
// command "eventunsub" [call]
- EventUnsubCommand(data: SubscriptionRequest, opts?: RpcOpts): Promise {
+ EventUnsubCommand(data: string, opts?: RpcOpts): Promise {
return WOS.wshServerRpcHelper_call("eventunsub", data, opts);
}
diff --git a/frontend/app/view/term/term.less b/frontend/app/view/term/term.less
index 2559d6e11..16dfbef5b 100644
--- a/frontend/app/view/term/term.less
+++ b/frontend/app/view/term/term.less
@@ -1,6 +1,16 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
+.connection-btn {
+ min-height: 0;
+ overflow: hidden;
+ line-height: 1;
+ display: flex;
+ background-color: orangered;
+ justify-content: flex-start;
+ width: 200px;
+}
+
.view-term {
display: flex;
flex-direction: column;
diff --git a/frontend/app/view/term/term.tsx b/frontend/app/view/term/term.tsx
index b39806810..dca1d0a19 100644
--- a/frontend/app/view/term/term.tsx
+++ b/frontend/app/view/term/term.tsx
@@ -1,9 +1,18 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
+import { TypeAheadModal } from "@/app/modals/typeaheadmodal";
import { WshServer } from "@/app/store/wshserver";
import { VDomView } from "@/app/view/term/vdom";
-import { WOS, atoms, getEventORefSubject, globalStore, useBlockAtom, useSettingsAtom } from "@/store/global";
+import {
+ WOS,
+ atoms,
+ getConnStatusAtom,
+ getEventORefSubject,
+ globalStore,
+ useBlockAtom,
+ useSettingsAtom,
+} from "@/store/global";
import * as services from "@/store/services";
import * as keyutil from "@/util/keyutil";
import * as util from "@/util/util";
@@ -109,13 +118,16 @@ function setBlockFocus(blockId: string) {
class TermViewModel {
viewType: string;
+ connected: boolean;
termRef: React.RefObject;
blockAtom: jotai.Atom;
termMode: jotai.Atom;
+ connectedAtom: jotai.Atom;
+ typeahead: boolean;
htmlElemFocusRef: React.RefObject;
blockId: string;
viewIcon: jotai.Atom;
- viewText: jotai.Atom;
+ viewText: jotai.Atom;
viewName: jotai.Atom;
blockBg: jotai.Atom;
@@ -123,6 +135,14 @@ class TermViewModel {
this.viewType = "term";
this.blockId = blockId;
this.blockAtom = WOS.getWaveObjectAtom(`block:${blockId}`);
+ this.connectedAtom = jotai.atom((get) => {
+ const connectionName = get(this.blockAtom).meta?.connection || "";
+ if (connectionName == "") {
+ return true;
+ }
+ const status = get(getConnStatusAtom(connectionName));
+ return status.connected;
+ });
this.termMode = jotai.atom((get) => {
const blockData = get(this.blockAtom);
return blockData?.meta?.["term:mode"] ?? "term";
@@ -139,7 +159,30 @@ class TermViewModel {
});
this.viewText = jotai.atom((get) => {
const blockData = get(this.blockAtom);
- return blockData?.meta?.title ?? "";
+ const titleText: HeaderText = { elemtype: "text", text: blockData?.meta?.title ?? "" };
+ const typeAhead = get(atoms.typeAheadModalAtom);
+ const connectionName = blockData?.meta?.connection || "";
+ const isConnected = get(this.connectedAtom);
+ let iconColor: string;
+ if (connectionName != "") {
+ iconColor = "#53b4ea";
+ } else {
+ iconColor = "var(--grey-text-color)";
+ }
+ const connButton: ConnectionButton = {
+ elemtype: "connectionbutton",
+ icon: "arrow-right-arrow-left",
+ iconColor: iconColor,
+ text: connectionName,
+ connected: isConnected,
+ onClick: () => {
+ globalStore.set(atoms.typeAheadModalAtom, {
+ ...(typeAhead as TypeAheadModalType),
+ [blockId]: true,
+ });
+ },
+ };
+ return [connButton, titleText] as HeaderElem[];
});
this.blockBg = jotai.atom((get) => {
const blockData = get(this.blockAtom);
@@ -152,6 +195,10 @@ class TermViewModel {
});
}
+ resetConnection() {
+ WshServer.ControllerRestartCommand({ blockid: this.blockId });
+ }
+
giveFocus(): boolean {
let termMode = globalStore.get(this.termMode);
if (termMode == "term") {
@@ -196,6 +243,9 @@ interface TerminalViewProps {
}
const TerminalView = ({ blockId, model }: TerminalViewProps) => {
+ const typeAhead = jotai.useAtomValue(atoms.typeAheadModalAtom);
+ const viewRef = React.createRef();
+ const [connSelected, setConnSelected] = React.useState("");
const connectElemRef = React.useRef(null);
const termRef = React.useRef(null);
model.termRef = termRef;
@@ -371,11 +421,57 @@ const TerminalView = ({ blockId, model }: TerminalViewProps) => {
}
}
+ const changeConnection = React.useCallback(
+ async (connName: string) => {
+ await WshServer.SetMetaCommand({ oref: WOS.makeORef("block", blockId), meta: { connection: connName } });
+ await WshServer.ControllerRestartCommand({ blockid: blockId });
+ },
+ [blockId]
+ );
+
+ const handleTypeAheadKeyDown = React.useCallback(
+ (waveEvent: WaveKeyboardEvent): boolean => {
+ if (keyutil.checkKeyPressed(waveEvent, "Enter")) {
+ changeConnection(connSelected);
+ globalStore.set(atoms.typeAheadModalAtom, {
+ ...(typeAhead as TypeAheadModalType),
+ [blockId]: false,
+ });
+ setConnSelected("");
+ return true;
+ }
+ if (keyutil.checkKeyPressed(waveEvent, "Escape")) {
+ globalStore.set(atoms.typeAheadModalAtom, {
+ ...(typeAhead as TypeAheadModalType),
+ [blockId]: false,
+ });
+ setConnSelected("");
+ model.giveFocus();
+ return true;
+ }
+ },
+ [typeAhead, model, blockId, connSelected]
+ );
+
return (
+ {typeAhead[blockId] && (
+
{
+ changeConnection(selected);
+ }}
+ onKeyDown={(e) => keyutil.keydownWrapper(handleTypeAheadKeyDown)(e)}
+ onChange={(current: string) => setConnSelected(current)}
+ value={connSelected}
+ label="Switch Connection"
+ />
+ )}
diff --git a/frontend/types/custom.d.ts b/frontend/types/custom.d.ts
index 0917c0e81..1f2ac7a83 100644
--- a/frontend/types/custom.d.ts
+++ b/frontend/types/custom.d.ts
@@ -140,7 +140,7 @@ declare global {
type SubjectWithRef = rxjs.Subject & { refCount: number; release: () => void };
- type HeaderElem = HeaderIconButton | HeaderText | HeaderInput | HeaderDiv | HeaderTextButton;
+ type HeaderElem = HeaderIconButton | HeaderText | HeaderInput | HeaderDiv | HeaderTextButton | ConnectionButton;
type HeaderIconButton = {
elemtype: "iconbutton";
@@ -181,6 +181,16 @@ declare global {
children: HeaderElem[];
onMouseOver?: (e: React.MouseEvent) => void;
onMouseOut?: (e: React.MouseEvent) => void;
+ onClick?: (e: React.MouseEvent) => void;
+ };
+
+ type ConnectionButton = {
+ elemtype: "connectionbutton";
+ icon: string;
+ text: string;
+ iconColor: string;
+ onClick?: (e: React.MouseEvent) => void;
+ connected: boolean;
};
interface ViewModel {
diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts
index df52f7872..b5419adfe 100644
--- a/frontend/types/gotypes.d.ts
+++ b/frontend/types/gotypes.d.ts
@@ -158,6 +158,14 @@ declare global {
meta: MetaType;
};
+ // wshrpc.ConnStatus
+ type ConnStatus = {
+ status: string;
+ connection: string;
+ connected: boolean;
+ error?: string;
+ };
+
// wshrpc.CpuDataRequest
type CpuDataRequest = {
id: string;
diff --git a/frontend/wave.ts b/frontend/wave.ts
index 8f800cc9c..e87468b58 100644
--- a/frontend/wave.ts
+++ b/frontend/wave.ts
@@ -2,7 +2,18 @@
// SPDX-License-Identifier: Apache-2.0
import { WshServer } from "@/app/store/wshserver";
-import { atoms, countersClear, countersPrint, getApi, globalStore, globalWS, initGlobal, initWS } from "@/store/global";
+import {
+ atoms,
+ countersClear,
+ countersPrint,
+ getApi,
+ globalStore,
+ globalWS,
+ initGlobal,
+ initWS,
+ loadConnStatus,
+ subscribeToConnEvents,
+} from "@/store/global";
import * as services from "@/store/services";
import * as WOS from "@/store/wos";
import * as keyutil from "@/util/keyutil";
@@ -44,6 +55,8 @@ document.addEventListener("DOMContentLoaded", async () => {
const initialTab = await WOS.loadAndPinWaveObject(WOS.makeORef("tab", waveWindow.activetabid));
await WOS.loadAndPinWaveObject(WOS.makeORef("layout", initialTab.layoutstate));
initWS();
+ await loadConnStatus();
+ subscribeToConnEvents();
const settings = await services.FileService.GetSettingsConfig();
console.log("settings", settings);
globalStore.set(atoms.settingsConfigAtom, settings);
diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go
index 6e2bc222b..ee2c04d97 100644
--- a/pkg/blockcontroller/blockcontroller.go
+++ b/pkg/blockcontroller/blockcontroller.go
@@ -278,12 +278,13 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
if err != nil {
return err
}
- conn, err := conncontroller.GetConn(credentialCtx, opts)
- if err != nil {
- return err
+ conn := conncontroller.GetConn(credentialCtx, opts, true)
+ connStatus := conn.DeriveConnStatus()
+ if connStatus.Error != "" {
+ return fmt.Errorf("error connecting to remote: %s", connStatus.Error)
}
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
- jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: conn.Opts.String()}, conn.SockName)
+ jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: conn.Opts.String()}, conn.GetDomainSocketName())
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
}
@@ -385,10 +386,11 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
log.Printf("[shellproc] shell process wait loop done\n")
}()
waitErr := shellProc.Cmd.Wait()
- shellProc.SetWaitErrorAndSignalDone(waitErr)
exitCode := shellexec.ExitCodeFromWaitErr(waitErr)
termMsg := fmt.Sprintf("\r\nprocess finished with exit code = %d\r\n\r\n", exitCode)
+ //HandleAppendBlockFile(bc.BlockId, BlockFile_Term, []byte("\r\n"))
HandleAppendBlockFile(bc.BlockId, BlockFile_Term, []byte(termMsg))
+ shellProc.SetWaitErrorAndSignalDone(waitErr)
}()
return nil
}
@@ -464,8 +466,21 @@ func (bc *BlockController) SendInput(inputUnion *BlockInputUnion) error {
}
func (bc *BlockController) RestartController() error {
- // TODO: if shell command is already running
- // we probably want to kill it off, wait, and then restart it
+
+ // kill the command if it's running
+ bc.Lock.Lock()
+ if bc.ShellProc != nil {
+ bc.ShellProc.Close()
+ }
+ bc.Lock.Unlock()
+
+ // wait for process to complete
+ if bc.ShellProc != nil {
+ doneCh := bc.ShellProc.DoneCh
+ <-doneCh
+ }
+
+ // restart controller
bdata, err := wstore.DBMustGet[*waveobj.Block](context.Background(), bc.BlockId)
if err != nil {
return fmt.Errorf("error getting block: %w", err)
diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go
index 90b95b748..870dd4030 100644
--- a/pkg/remote/conncontroller/conncontroller.go
+++ b/pkg/remote/conncontroller/conncontroller.go
@@ -5,36 +5,111 @@ package conncontroller
import (
"context"
+ "errors"
"fmt"
"io"
+ "io/fs"
"log"
"net"
+ "os"
+ "path/filepath"
"strings"
"sync"
+ "sync/atomic"
+ "time"
+ "github.com/kevinburke/ssh_config"
"github.com/wavetermdev/thenextwave/pkg/remote"
"github.com/wavetermdev/thenextwave/pkg/userinput"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
+ "github.com/wavetermdev/thenextwave/pkg/wps"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"golang.org/x/crypto/ssh"
)
+const (
+ Status_Init = "init"
+ Status_Connecting = "connecting"
+ Status_Connected = "connected"
+ Status_Disconnected = "disconnected"
+ Status_Error = "error"
+)
+
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[remote.SSHOpts]*SSHConn)
type SSHConn struct {
Lock *sync.Mutex
+ Status string
Opts *remote.SSHOpts
Client *ssh.Client
SockName string
DomainSockListener net.Listener
ConnController *ssh.Session
+ Error string
+ HasWaiter *atomic.Bool
+}
+
+func GetAllConnStatus() []wshrpc.ConnStatus {
+ globalLock.Lock()
+ defer globalLock.Unlock()
+
+ var connStatuses []wshrpc.ConnStatus
+ for _, conn := range clientControllerMap {
+ connStatuses = append(connStatuses, conn.DeriveConnStatus())
+ }
+ return connStatuses
+}
+
+func (conn *SSHConn) DeriveConnStatus() wshrpc.ConnStatus {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ return wshrpc.ConnStatus{
+ Status: conn.Status,
+ Connection: conn.Opts.String(),
+ Connected: conn.Client != nil,
+ Error: conn.Error,
+ }
+}
+
+func (conn *SSHConn) FireConnChangeEvent() {
+ status := conn.DeriveConnStatus()
+ event := wshrpc.WaveEvent{
+ Event: wshrpc.Event_ConnChange,
+ Scopes: []string{
+ fmt.Sprintf("connection:%s", conn.GetName()),
+ },
+ Data: status,
+ }
+ log.Printf("sending event: %+#v", event)
+ wps.Broker.Publish(event)
}
func (conn *SSHConn) Close() error {
+ defer conn.FireConnChangeEvent()
+ conn.WithLock(func() {
+ if conn.Status == Status_Connected || conn.Status == Status_Connecting {
+ // if status is init, disconnected, or error don't change it
+ conn.Status = Status_Disconnected
+ }
+ conn.close_nolock()
+ })
+ // we must wait for the waiter to complete
+ startTime := time.Now()
+ for conn.HasWaiter.Load() {
+ time.Sleep(10 * time.Millisecond)
+ if time.Since(startTime) > 2*time.Second {
+ return fmt.Errorf("timeout waiting for waiter to complete")
+ }
+ }
+ return nil
+}
+
+func (conn *SSHConn) close_nolock() {
+ // does not set status (that should happen at another level)
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
conn.DomainSockListener = nil
@@ -43,75 +118,113 @@ func (conn *SSHConn) Close() error {
conn.ConnController.Close()
conn.ConnController = nil
}
- err := conn.Client.Close()
- conn.Client = nil
- return err
+ if conn.Client != nil {
+ conn.Client.Close()
+ conn.Client = nil
+ }
+}
+
+func (conn *SSHConn) GetDomainSocketName() string {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ return conn.SockName
+}
+
+func (conn *SSHConn) GetStatus() string {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ return conn.Status
+}
+
+func (conn *SSHConn) GetName() string {
+ // no lock required because opts is immutable
+ return conn.Opts.String()
}
func (conn *SSHConn) OpenDomainSocketListener() error {
- if conn.DomainSockListener != nil {
- return nil
+ var allowed bool
+ conn.WithLock(func() {
+ if conn.Status != Status_Connecting {
+ allowed = false
+ } else {
+ allowed = true
+ }
+ })
+ if !allowed {
+ return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
}
+ client := conn.GetClient()
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
if err != nil {
return fmt.Errorf("error generating random string: %w", err)
}
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
- log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
- listener, err := conn.Client.ListenUnix(sockName)
+ log.Printf("remote domain socket %s %q\n", conn.GetName(), sockName)
+ listener, err := client.ListenUnix(sockName)
if err != nil {
return fmt.Errorf("unable to request connection domain socket: %v", err)
}
- conn.SockName = sockName
- conn.DomainSockListener = listener
+ conn.WithLock(func() {
+ conn.SockName = sockName
+ conn.DomainSockListener = listener
+ })
go func() {
- defer func() {
- conn.Lock.Lock()
- defer conn.Lock.Unlock()
+ defer conn.WithLock(func() {
conn.DomainSockListener = nil
- }()
+ conn.SockName = ""
+ })
wshutil.RunWshRpcOverListener(listener)
}()
return nil
}
func (conn *SSHConn) StartConnServer() error {
- conn.Lock.Lock()
- defer conn.Lock.Unlock()
- if conn.ConnController != nil {
- return nil
+ var allowed bool
+ conn.WithLock(func() {
+ if conn.Status != Status_Connecting {
+ allowed = false
+ } else {
+ allowed = true
+ }
+ })
+ if !allowed {
+ return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
}
- wshPath := remote.GetWshPath(conn.Client)
+ client := conn.GetClient()
+ wshPath := remote.GetWshPath(client)
rpcCtx := wshrpc.RpcContext{
ClientType: wshrpc.ClientType_ConnServer,
- Conn: conn.Opts.String(),
+ Conn: conn.GetName(),
}
- jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, conn.SockName)
+ sockName := conn.GetDomainSocketName()
+ jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
if err != nil {
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
}
- sshSession, err := conn.Client.NewSession()
+ sshSession, err := client.NewSession()
if err != nil {
return fmt.Errorf("unable to create ssh session for conn controller: %w", err)
}
pipeRead, pipeWrite := io.Pipe()
sshSession.Stdout = pipeWrite
sshSession.Stderr = pipeWrite
- conn.ConnController = sshSession
cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
log.Printf("starting conn controller: %s\n", cmdStr)
err = sshSession.Start(cmdStr)
if err != nil {
return fmt.Errorf("unable to start conn controller: %w", err)
}
+ conn.WithLock(func() {
+ conn.ConnController = sshSession
+ })
// service the I/O
go func() {
// wait for termination, clear the controller
+ defer conn.WithLock(func() {
+ conn.ConnController = nil
+ })
waitErr := sshSession.Wait()
- log.Printf("conn controller (%q) terminated: %v", conn.Opts.String(), waitErr)
- conn.Lock.Lock()
- defer conn.Lock.Unlock()
- conn.ConnController = nil
+ log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
}()
go func() {
readErr := wshutil.StreamToLines(pipeRead, func(line []byte) {
@@ -119,23 +232,27 @@ func (conn *SSHConn) StartConnServer() error {
if !strings.HasSuffix(lineStr, "\n") {
lineStr += "\n"
}
- log.Printf("[conncontroller:%s:output] %s", conn.Opts.String(), lineStr)
+ log.Printf("[conncontroller:%s:output] %s", conn.GetName(), lineStr)
})
if readErr != nil && readErr != io.EOF {
- log.Printf("[conncontroller:%s] error reading output: %v\n", conn.Opts.String(), readErr)
+ log.Printf("[conncontroller:%s] error reading output: %v\n", conn.GetName(), readErr)
}
}()
return nil
}
func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error {
- client := conn.Client
+ client := conn.GetClient()
+ if client == nil {
+ return fmt.Errorf("client is nil")
+ }
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := remote.GetWshVersion(client)
if err == nil && clientVersion == expectedVersion {
return nil
}
+ // TODO add some progress to SSHConn about install status
var queryText string
var title string
if err != nil {
@@ -170,56 +287,189 @@ func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error {
if err != nil {
return err
}
- log.Printf("successfully installed wsh on %s\n", conn.Opts.String())
+ log.Printf("successfully installed wsh on %s\n", conn.GetName())
return nil
}
-func GetConn(ctx context.Context, opts *remote.SSHOpts) (*SSHConn, error) {
- globalLock.Lock()
- defer globalLock.Unlock()
+func (conn *SSHConn) GetClient() *ssh.Client {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ return conn.Client
+}
- // attempt to retrieve if already opened
- conn, ok := clientControllerMap[*opts]
- if ok {
- return conn, nil
- }
-
- client, err := remote.ConnectToClient(ctx, opts) //todo specify or remove opts
+func (conn *SSHConn) Reconnect(ctx context.Context) error {
+ err := conn.Close()
if err != nil {
- return nil, err
+ return err
}
- conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
+ return conn.Connect(ctx)
+}
+
+// does not return an error since that error is stored inside of SSHConn
+func (conn *SSHConn) Connect(ctx context.Context) error {
+ var connectAllowed bool
+ conn.WithLock(func() {
+ if conn.Status == Status_Connecting || conn.Status == Status_Connected {
+ connectAllowed = false
+ } else {
+ conn.Status = Status_Connecting
+ conn.Error = ""
+ connectAllowed = true
+ }
+ })
+ if !connectAllowed {
+ return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
+ }
+ conn.FireConnChangeEvent()
+ err := conn.connectInternal(ctx)
+ conn.WithLock(func() {
+ if err != nil {
+ conn.Status = Status_Error
+ conn.Error = err.Error()
+ conn.close_nolock()
+ } else {
+ conn.Status = Status_Connected
+ }
+ })
+ conn.FireConnChangeEvent()
+ return err
+}
+
+func (conn *SSHConn) WithLock(fn func()) {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ fn()
+}
+
+func (conn *SSHConn) connectInternal(ctx context.Context) error {
+ client, err := remote.ConnectToClient(ctx, conn.Opts) //todo specify or remove opts
+ if err != nil {
+ return err
+ }
+ conn.WithLock(func() {
+ conn.Client = client
+ })
err = conn.OpenDomainSocketListener()
if err != nil {
- conn.Close()
- return nil, err
+ return err
}
-
installErr := conn.checkAndInstallWsh(ctx)
if installErr != nil {
- conn.Close()
- return nil, fmt.Errorf("conncontroller %s wsh install error: %v", conn.Opts.String(), installErr)
+ return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
}
-
csErr := conn.StartConnServer()
if csErr != nil {
- conn.Close()
- return nil, fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.Opts.String(), csErr)
+ return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
}
+ conn.HasWaiter.Store(true)
+ go conn.waitForDisconnect()
+ return nil
+}
- // save successful connection to map
- clientControllerMap[*opts] = conn
+func (conn *SSHConn) waitForDisconnect() {
+ defer conn.FireConnChangeEvent()
+ defer conn.HasWaiter.Store(false)
+ client := conn.GetClient()
+ if client == nil {
+ return
+ }
+ err := client.Wait()
+ conn.WithLock(func() {
+ if err != nil {
+ if conn.Status != Status_Disconnected {
+ // don't set the error if our status is disconnected (because this error was caused by an explicit close)
+ conn.Status = Status_Error
+ conn.Error = err.Error()
+ }
+ } else {
+ // not sure if this is possible, because I think Wait() always returns an error (although that's not in the docs)
+ conn.Status = Status_Disconnected
+ }
+ conn.close_nolock()
+ })
+}
- return conn, nil
+func getConnInternal(opts *remote.SSHOpts) *SSHConn {
+ globalLock.Lock()
+ defer globalLock.Unlock()
+ rtn := clientControllerMap[*opts]
+ if rtn == nil {
+ rtn = &SSHConn{Lock: &sync.Mutex{}, Status: Status_Init, Opts: opts, HasWaiter: &atomic.Bool{}}
+ clientControllerMap[*opts] = rtn
+ }
+ return rtn
+}
+
+func GetConn(ctx context.Context, opts *remote.SSHOpts, shouldConnect bool) *SSHConn {
+ conn := getConnInternal(opts)
+ if conn.Client == nil && shouldConnect {
+ conn.Connect(ctx)
+ }
+ return conn
}
func DisconnectClient(opts *remote.SSHOpts) error {
- globalLock.Lock()
- defer globalLock.Unlock()
-
- client, ok := clientControllerMap[*opts]
- if ok {
- return client.Close()
+ conn := getConnInternal(opts)
+ if conn == nil {
+ return fmt.Errorf("client %q not found", opts.String())
}
- return fmt.Errorf("client %v not found", opts)
+ err := conn.Close()
+ return err
+}
+
+func resolveSshConfigPatterns(configFiles []string) ([]string, error) {
+ // using two separate containers to track order and have O(1) lookups
+ // since go does not have an ordered map primitive
+ var discoveredPatterns []string
+ alreadyUsed := make(map[string]bool)
+ alreadyUsed[""] = true // this excludes the empty string from potential alias
+ var openedFiles []fs.File
+
+ defer func() {
+ for _, openedFile := range openedFiles {
+ openedFile.Close()
+ }
+ }()
+
+ var errs []error
+ for _, configFile := range configFiles {
+ fd, openErr := os.Open(configFile)
+ openedFiles = append(openedFiles, fd)
+ if fd == nil {
+ errs = append(errs, openErr)
+ continue
+ }
+
+ cfg, _ := ssh_config.Decode(fd)
+ for _, host := range cfg.Hosts {
+ // for each host, find the first good alias
+ for _, hostPattern := range host.Patterns {
+ hostPatternStr := hostPattern.String()
+ if !strings.Contains(hostPatternStr, "*") || alreadyUsed[hostPatternStr] {
+ discoveredPatterns = append(discoveredPatterns, hostPatternStr)
+ alreadyUsed[hostPatternStr] = true
+ break
+ }
+ }
+ }
+ }
+ if len(errs) == len(configFiles) {
+ errs = append([]error{fmt.Errorf("no ssh config files could be opened:\n")}, errs...)
+ return nil, errors.Join(errs...)
+ }
+ if len(discoveredPatterns) == 0 {
+ return nil, fmt.Errorf("no compatible hostnames found in ssh config files")
+ }
+
+ return discoveredPatterns, nil
+}
+
+func GetConnectionsFromConfig() ([]string, error) {
+ home := wavebase.GetHomeDir()
+ localConfig := filepath.Join(home, ".ssh", "config")
+ systemConfig := filepath.Join("/etc", "ssh", "config")
+ sshConfigFiles := []string{localConfig, systemConfig}
+ ssh_config.ReloadConfigs()
+
+ return resolveSshConfigPatterns(sshConfigFiles)
}
diff --git a/pkg/remote/sshclient.go b/pkg/remote/sshclient.go
index 8829424c9..160fc65e2 100644
--- a/pkg/remote/sshclient.go
+++ b/pkg/remote/sshclient.go
@@ -709,8 +709,13 @@ type SSHOpts struct {
}
func (opts SSHOpts) String() string {
- if opts.SSHPort == 0 {
- return fmt.Sprintf("%s@%s", opts.SSHUser, opts.SSHHost)
+ stringRepr := ""
+ if opts.SSHUser != "" {
+ stringRepr = opts.SSHUser + "@"
}
- return fmt.Sprintf("%s@%s:%d", opts.SSHUser, opts.SSHHost, opts.SSHPort)
+ stringRepr = stringRepr + opts.SSHHost
+ if opts.SSHPort != 0 {
+ stringRepr = stringRepr + ":" + fmt.Sprint(opts.SSHPort)
+ }
+ return stringRepr
}
diff --git a/pkg/service/clientservice/clientservice.go b/pkg/service/clientservice/clientservice.go
index ffefb2246..86332b3ec 100644
--- a/pkg/service/clientservice/clientservice.go
+++ b/pkg/service/clientservice/clientservice.go
@@ -10,9 +10,11 @@ import (
"time"
"github.com/wavetermdev/thenextwave/pkg/eventbus"
+ "github.com/wavetermdev/thenextwave/pkg/remote/conncontroller"
"github.com/wavetermdev/thenextwave/pkg/service/objectservice"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/waveobj"
+ "github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wstore"
)
@@ -64,6 +66,10 @@ func (cs *ClientService) MakeWindow(ctx context.Context) (*waveobj.Window, error
return wstore.CreateWindow(ctx, nil)
}
+func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
+ return conncontroller.GetAllConnStatus(), nil
+}
+
// moves the window to the front of the windowId stack
func (cs *ClientService) FocusWindow(ctx context.Context, windowId string) error {
client, err := cs.GetClientData()
diff --git a/pkg/wps/wps.go b/pkg/wps/wps.go
index 058235eb7..79797b2ff 100644
--- a/pkg/wps/wps.go
+++ b/pkg/wps/wps.go
@@ -58,9 +58,14 @@ func (b *BrokerType) GetClient() Client {
return b.Client
}
+// if already subscribed, this will *resubscribe* with the new subscription (remove the old one, and replace with this one)
func (b *BrokerType) Subscribe(subRouteId string, sub wshrpc.SubscriptionRequest) {
+ if sub.Event == "" {
+ return
+ }
b.Lock.Lock()
defer b.Lock.Unlock()
+ b.unsubscribe_nolock(subRouteId, sub.Event)
bs := b.SubMap[sub.Event]
if bs == nil {
bs = &BrokerSubscription{
@@ -72,6 +77,7 @@ func (b *BrokerType) Subscribe(subRouteId string, sub wshrpc.SubscriptionRequest
}
if sub.AllScopes {
bs.AllSubs = utilfn.AddElemToSliceUniq(bs.AllSubs, subRouteId)
+ return
}
for _, scope := range sub.Scopes {
starMatch := scopeHasStarMatch(scope)
@@ -114,26 +120,26 @@ func addStrToScopeMap(scopeMap map[string][]string, scope string, routeId string
scopeMap[scope] = scopeSubs
}
-func (b *BrokerType) Unsubscribe(subRouteId string, sub wshrpc.SubscriptionRequest) {
+func (b *BrokerType) Unsubscribe(subRouteId string, eventName string) {
b.Lock.Lock()
defer b.Lock.Unlock()
- bs := b.SubMap[sub.Event]
+ b.unsubscribe_nolock(subRouteId, eventName)
+}
+
+func (b *BrokerType) unsubscribe_nolock(subRouteId string, eventName string) {
+ bs := b.SubMap[eventName]
if bs == nil {
return
}
- if sub.AllScopes {
- bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, subRouteId)
+ bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, subRouteId)
+ for scope := range bs.ScopeSubs {
+ removeStrFromScopeMap(bs.ScopeSubs, scope, subRouteId)
}
- for _, scope := range sub.Scopes {
- starMatch := scopeHasStarMatch(scope)
- if starMatch {
- removeStrFromScopeMap(bs.StarSubs, scope, subRouteId)
- } else {
- removeStrFromScopeMap(bs.ScopeSubs, scope, subRouteId)
- }
+ for scope := range bs.StarSubs {
+ removeStrFromScopeMap(bs.StarSubs, scope, subRouteId)
}
if bs.IsEmpty() {
- delete(b.SubMap, sub.Event)
+ delete(b.SubMap, eventName)
}
}
diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go
index 7eea9d7cb..1c06237c7 100644
--- a/pkg/wshrpc/wshclient/wshclient.go
+++ b/pkg/wshrpc/wshclient/wshclient.go
@@ -66,7 +66,7 @@ func EventSubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *w
}
// command "eventunsub", wshserver.EventUnsubCommand
-func EventUnsubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *wshrpc.RpcOpts) error {
+func EventUnsubCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "eventunsub", data, opts)
return err
}
diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go
index dcd7d7cb4..349467120 100644
--- a/pkg/wshrpc/wshrpctypes.go
+++ b/pkg/wshrpc/wshrpctypes.go
@@ -26,6 +26,7 @@ const (
const (
Event_BlockClose = "blockclose"
+ Event_ConnChange = "connchange"
)
const (
@@ -83,7 +84,7 @@ type WshRpcInterface interface {
FileReadCommand(ctx context.Context, data CommandFileData) (string, error)
EventPublishCommand(ctx context.Context, data WaveEvent) error
EventSubCommand(ctx context.Context, data SubscriptionRequest) error
- EventUnsubCommand(ctx context.Context, data SubscriptionRequest) error
+ EventUnsubCommand(ctx context.Context, data string) error
EventUnsubAllCommand(ctx context.Context) error
StreamTestCommand(ctx context.Context) chan RespOrErrorUnion[int]
StreamWaveAiCommand(ctx context.Context, request OpenAiStreamRequest) chan RespOrErrorUnion[OpenAIPacketType]
@@ -324,3 +325,10 @@ type TimeSeriesData struct {
Ts int64 `json:"ts"`
Values map[string]float64 `json:"values"`
}
+
+type ConnStatus struct {
+ Status string `json:"status"`
+ Connection string `json:"connection"`
+ Connected bool `json:"connected"`
+ Error string `json:"error,omitempty"`
+}
diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go
index 392f41342..6f2deef1b 100644
--- a/pkg/wshrpc/wshserver/wshserver.go
+++ b/pkg/wshrpc/wshserver/wshserver.go
@@ -478,7 +478,7 @@ func (ws *WshServer) EventSubCommand(ctx context.Context, data wshrpc.Subscripti
return nil
}
-func (ws *WshServer) EventUnsubCommand(ctx context.Context, data wshrpc.SubscriptionRequest) error {
+func (ws *WshServer) EventUnsubCommand(ctx context.Context, data string) error {
rpcSource := wshutil.GetRpcSourceFromContext(ctx)
if rpcSource == "" {
return fmt.Errorf("no rpc source set")
diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go
index cd0ab4ea6..ccb849451 100644
--- a/pkg/wshutil/wshrouter.go
+++ b/pkg/wshutil/wshrouter.go
@@ -10,6 +10,7 @@ import (
"log"
"sync"
+ "github.com/wavetermdev/thenextwave/pkg/wps"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
)
@@ -18,6 +19,8 @@ const SysRoute = "sys" // this route doesn't exist, just a placeholder for syste
// this works like a network switch
+// TODO maybe move the wps integration here instead of in wshserver
+
type routeInfo struct {
RpcId string
SourceRouteId string
@@ -285,6 +288,9 @@ func (router *WshRouter) UnregisterRoute(routeId string) {
router.Lock.Lock()
defer router.Lock.Unlock()
delete(router.RouteMap, routeId)
+ go func() {
+ wps.Broker.UnsubscribeAll(routeId)
+ }()
}
// this may return nil (returns default only for empty routeId)