Break model update code out of sstore (#290)

* Break update code out of sstore

* add license disclaimers

* missed one

* add another

* fix regression in openai updates, remove unnecessary functions

* another copyright

* update casts

* fix issue with variadic updates

* remove logs

* remove log

* remove unnecessary log

* save work

* moved a bunch of stuff to scbus

* make modelupdate an object

* fix new screen not updating active screen

* add comment

* make updates into packet types

* different cast

* update comments, remove unused methods

* add one more comment

* add an IsEmpty() on model updates to prevent sending empty updates to client
This commit is contained in:
Evan Simkowitz 2024-02-15 16:45:47 -08:00 committed by GitHub
parent 158378a7ad
commit 8acda3525b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1072 additions and 836 deletions

View File

@ -715,7 +715,7 @@ class Model {
return this.ws.open.get();
}
runUpdate(genUpdate: UpdateMessage, interactive: boolean) {
runUpdate(genUpdate: UpdatePacket, interactive: boolean) {
mobx.action(() => {
const oldContext = this.getUIContext();
try {
@ -727,8 +727,9 @@ class Model {
const newContext = this.getUIContext();
if (oldContext.sessionid != newContext.sessionid || oldContext.screenid != newContext.screenid) {
this.inputModel.resetInput();
if (!("ptydata64" in genUpdate)) {
const reversedGenUpdate = genUpdate.slice().reverse();
if (genUpdate.type == "model") {
const modelUpdate = genUpdate as ModelUpdatePacket;
const reversedGenUpdate = modelUpdate.data.slice().reverse();
const lastCmdLine = reversedGenUpdate.find((update) => "cmdline" in update);
if (lastCmdLine) {
// TODO a bit of a hack since this update gets applied in runUpdate_internal.
@ -768,20 +769,12 @@ class Model {
}
updateActiveSession(sessionId: string): void {
const [oldActiveSessionId, oldActiveScreenId] = this.getActiveIds();
if (sessionId != null) {
const newSessionId = sessionId;
if (this.activeSessionId.get() != newSessionId) {
this.activeSessionId.set(newSessionId);
}
}
const [newActiveSessionId, newActiveScreenId] = this.getActiveIds();
if (oldActiveSessionId != newActiveSessionId || oldActiveScreenId != newActiveScreenId) {
this.activeMainView.set("session");
this.deactivateScreenLines();
this.ws.watchScreen(newActiveSessionId, newActiveScreenId);
}
}
updateScreenNumRunningCommands(numRunningCommandUpdates: ScreenNumRunningCommandsUpdateType[]) {
@ -796,9 +789,9 @@ class Model {
}
}
runUpdate_internal(genUpdate: UpdateMessage, uiContext: UIContextType, interactive: boolean) {
if ("ptydata64" in genUpdate) {
const ptyMsg: PtyDataUpdateType = genUpdate;
runUpdate_internal(genUpdate: UpdatePacket, uiContext: UIContextType, interactive: boolean) {
if (genUpdate.type == "pty") {
const ptyMsg = genUpdate.data as PtyDataUpdateType;
if (isBlank(ptyMsg.remoteid)) {
// regular update
this.updatePtyData(ptyMsg);
@ -807,125 +800,138 @@ class Model {
const ptyData = base64ToArray(ptyMsg.ptydata64);
this.remotesModel.receiveData(ptyMsg.remoteid, ptyMsg.ptypos, ptyData);
}
return;
}
let showedRemotesModal = false;
genUpdate.forEach((update) => {
if (update.connect != null) {
if (update.connect.screens != null) {
this.screenMap.clear();
this.updateScreens(update.connect.screens);
}
if (update.connect.sessions != null) {
this.sessionList.clear();
this.updateSessions(update.connect.sessions);
}
if (update.connect.remotes != null) {
this.remotes.clear();
this.updateRemotes(update.connect.remotes);
}
if (update.connect.activesessionid != null) {
this.updateActiveSession(update.connect.activesessionid);
}
if (update.connect.screennumrunningcommands != null) {
this.updateScreenNumRunningCommands(update.connect.screennumrunningcommands);
}
if (update.connect.screenstatusindicators != null) {
this.updateScreenStatusIndicators(update.connect.screenstatusindicators);
}
} else if (genUpdate.type == "model") {
const modelUpdateItems = genUpdate.data as ModelUpdateItemType[];
this.sessionListLoaded.set(true);
this.remotesLoaded.set(true);
} else if (update.screen != null) {
this.updateScreens([update.screen]);
} else if (update.session != null) {
this.updateSessions([update.session]);
} else if (update.activesessionid != null) {
this.updateActiveSession(update.activesessionid);
} else if (update.line != null) {
this.addLineCmd(update.line.line, update.line.cmd, interactive);
} else if (update.cmd != null) {
this.updateCmd(update.cmd);
} else if (update.screenlines != null) {
this.updateScreenLines(update.screenlines, false);
} else if (update.remote != null) {
this.updateRemotes([update.remote]);
// This code's purpose is to show view remote connection modal when a new connection is added
if (!showedRemotesModal && this.remotesModel.recentConnAddedState.get()) {
showedRemotesModal = true;
this.remotesModel.openReadModal(update.remote.remoteid);
}
} else if (update.mainview != null) {
switch (update.mainview.mainview) {
case "session":
this.activeMainView.set("session");
break;
case "history":
if (update.mainview.historyview != null) {
this.historyViewModel.showHistoryView(update.mainview.historyview);
} else {
console.warn("invalid historyview in update:", update.mainview);
}
break;
case "bookmarks":
if (update.mainview.bookmarksview != null) {
this.bookmarksModel.showBookmarksView(
update.mainview.bookmarksview?.bookmarks ?? [],
update.mainview.bookmarksview?.selectedbookmark
);
} else {
console.warn("invalid bookmarksview in update:", update.mainview);
}
break;
case "plugins":
this.pluginsModel.showPluginsView();
break;
default:
console.warn("invalid mainview in update:", update.mainview);
}
} else if (update.bookmarks != null) {
if (update.bookmarks.bookmarks != null) {
this.bookmarksModel.mergeBookmarks(update.bookmarks.bookmarks);
}
} else if (update.clientdata != null) {
this.setClientData(update.clientdata);
} else if (update.cmdline != null) {
this.inputModel.updateCmdLine(update.cmdline);
} else if (update.openaicmdinfochat != null) {
this.inputModel.setOpenAICmdInfoChat(update.openaicmdinfochat);
} else if (update.screenstatusindicator != null) {
this.updateScreenStatusIndicators([update.screenstatusindicator]);
} else if (update.screennumrunningcommands != null) {
this.updateScreenNumRunningCommands([update.screennumrunningcommands]);
} else if (update.userinputrequest != null) {
let userInputRequest: UserInputRequest = update.userinputrequest;
this.modalsModel.pushModal(appconst.USER_INPUT, userInputRequest);
} else if (interactive) {
if (update.info != null) {
const info: InfoType = update.info;
this.inputModel.flashInfoMsg(info, info.timeoutms);
} else if (update.remoteview != null) {
const rview: RemoteViewType = update.remoteview;
if (rview.remoteedit != null) {
this.remotesModel.openEditModal({ ...rview.remoteedit });
let showedRemotesModal = false;
const [oldActiveSessionId, oldActiveScreenId] = this.getActiveIds();
modelUpdateItems.forEach((update) => {
if (update.connect != null) {
if (update.connect.screens != null) {
this.screenMap.clear();
this.updateScreens(update.connect.screens);
}
} else if (update.alertmessage != null) {
const alertMessage: AlertMessageType = update.alertmessage;
this.showAlert(alertMessage);
} else if (update.history != null) {
if (
uiContext.sessionid == update.history.sessionid &&
uiContext.screenid == update.history.screenid
) {
this.inputModel.setHistoryInfo(update.history);
if (update.connect.sessions != null) {
this.sessionList.clear();
this.updateSessions(update.connect.sessions);
}
if (update.connect.remotes != null) {
this.remotes.clear();
this.updateRemotes(update.connect.remotes);
}
if (update.connect.activesessionid != null) {
this.updateActiveSession(update.connect.activesessionid);
}
if (update.connect.screennumrunningcommands != null) {
this.updateScreenNumRunningCommands(update.connect.screennumrunningcommands);
}
if (update.connect.screenstatusindicators != null) {
this.updateScreenStatusIndicators(update.connect.screenstatusindicators);
}
this.sessionListLoaded.set(true);
this.remotesLoaded.set(true);
} else if (update.screen != null) {
this.updateScreens([update.screen]);
} else if (update.session != null) {
this.updateSessions([update.session]);
} else if (update.activesessionid != null) {
this.updateActiveSession(update.activesessionid);
} else if (update.line != null) {
this.addLineCmd(update.line.line, update.line.cmd, interactive);
} else if (update.cmd != null) {
this.updateCmd(update.cmd);
} else if (update.screenlines != null) {
this.updateScreenLines(update.screenlines, false);
} else if (update.remote != null) {
this.updateRemotes([update.remote]);
// This code's purpose is to show view remote connection modal when a new connection is added
if (!showedRemotesModal && this.remotesModel.recentConnAddedState.get()) {
showedRemotesModal = true;
this.remotesModel.openReadModal(update.remote.remoteid);
}
} else if (update.mainview != null) {
switch (update.mainview.mainview) {
case "session":
this.activeMainView.set("session");
break;
case "history":
if (update.mainview.historyview != null) {
this.historyViewModel.showHistoryView(update.mainview.historyview);
} else {
console.warn("invalid historyview in update:", update.mainview);
}
break;
case "bookmarks":
if (update.mainview.bookmarksview != null) {
this.bookmarksModel.showBookmarksView(
update.mainview.bookmarksview?.bookmarks ?? [],
update.mainview.bookmarksview?.selectedbookmark
);
} else {
console.warn("invalid bookmarksview in update:", update.mainview);
}
break;
case "plugins":
this.pluginsModel.showPluginsView();
break;
default:
console.warn("invalid mainview in update:", update.mainview);
}
} else if (update.bookmarks != null) {
if (update.bookmarks.bookmarks != null) {
this.bookmarksModel.mergeBookmarks(update.bookmarks.bookmarks);
}
} else if (update.clientdata != null) {
this.setClientData(update.clientdata);
} else if (update.cmdline != null) {
this.inputModel.updateCmdLine(update.cmdline);
} else if (update.openaicmdinfochat != null) {
this.inputModel.setOpenAICmdInfoChat(update.openaicmdinfochat);
} else if (update.screenstatusindicator != null) {
this.updateScreenStatusIndicators([update.screenstatusindicator]);
} else if (update.screennumrunningcommands != null) {
this.updateScreenNumRunningCommands([update.screennumrunningcommands]);
} else if (update.userinputrequest != null) {
const userInputRequest: UserInputRequest = update.userinputrequest;
this.modalsModel.pushModal(appconst.USER_INPUT, userInputRequest);
} else if (interactive) {
if (update.info != null) {
const info: InfoType = update.info;
this.inputModel.flashInfoMsg(info, info.timeoutms);
} else if (update.remoteview != null) {
const rview: RemoteViewType = update.remoteview;
if (rview.remoteedit != null) {
this.remotesModel.openEditModal({ ...rview.remoteedit });
}
} else if (update.alertmessage != null) {
const alertMessage: AlertMessageType = update.alertmessage;
this.showAlert(alertMessage);
} else if (update.history != null) {
if (
uiContext.sessionid == update.history.sessionid &&
uiContext.screenid == update.history.screenid
) {
this.inputModel.setHistoryInfo(update.history);
}
} else if (this.isDev) {
console.log("did not match update", update);
}
} else if (this.isDev) {
console.log("did not match update", update);
}
} else if (this.isDev) {
console.log("did not match update", update);
});
// Check if the active session or screen has changed, and if so, watch the new screen
const [newActiveSessionId, newActiveScreenId] = this.getActiveIds();
if (oldActiveSessionId != newActiveSessionId || oldActiveScreenId != newActiveScreenId) {
this.activeMainView.set("session");
this.deactivateScreenLines();
this.ws.watchScreen(newActiveSessionId, newActiveScreenId);
}
});
} else {
console.warn("unknown update", genUpdate);
}
}
updateRemotes(remotes: RemoteType[]): void {
@ -1064,11 +1070,13 @@ class Model {
this.handleCmdRestart(cmd);
}
isInfoUpdate(update: UpdateMessage): boolean {
if (update == null || "ptydata64" in update) {
isInfoUpdate(update: UpdatePacket): boolean {
if (update.type == "model") {
const modelUpdate = update as ModelUpdatePacket;
return modelUpdate.data.some((u) => u.info != null || u.history != null);
} else {
return false;
}
return update.some((u) => u.info != null || u.history != null);
}
getClientDataLoop(loopNum: number): void {

16
src/types/custom.d.ts vendored
View File

@ -338,6 +338,10 @@ declare global {
};
type ModelUpdateType = {
items?: ModelUpdateItemType[];
};
type ModelUpdateItemType = {
interactive: boolean;
session?: SessionDataType;
activesessionid?: string;
@ -440,7 +444,17 @@ declare global {
showCut?: boolean;
};
type UpdateMessage = PtyDataUpdateType | ModelUpdateType[];
type ModelUpdatePacket = {
type: "model";
data: ModelUpdateItemType[];
};
type PtyDataUpdatePacket = {
type: "pty";
data: PtyDataUpdateType;
};
type UpdatePacket = ModelUpdatePacket | PtyDataUpdatePacket;
type RendererContext = {
screenId: string;

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,6 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package releasechecker
import (
@ -8,6 +11,7 @@ import (
"golang.org/x/mod/semver"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
)
@ -66,9 +70,9 @@ func CheckNewRelease(ctx context.Context, force bool) (ReleaseCheckResult, error
return Failure, fmt.Errorf("error getting updated client data: %w", err)
}
update := &sstore.ModelUpdate{}
sstore.AddUpdate(update, *clientData)
sstore.MainBus.SendUpdate(update)
update := scbus.MakeUpdatePacket()
update.AddUpdate(clientData)
scbus.MainUpdateBus.DoUpdate(update)
return Success, nil
}

View File

@ -34,8 +34,10 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"golang.org/x/crypto/ssh"
"golang.org/x/mod/semver"
)
@ -681,9 +683,9 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
func (msh *MShellProc) NotifyRemoteUpdate() {
rstate := msh.GetRemoteRuntimeState()
update := &sstore.ModelUpdate{}
sstore.AddUpdate(update, rstate)
sstore.MainBus.SendUpdate(update)
update := scbus.MakeUpdatePacket()
update.AddUpdate(rstate)
scbus.MainUpdateBus.DoUpdate(update)
}
func GetAllRemoteRuntimeState() []*RemoteRuntimeState {
@ -943,13 +945,13 @@ func (msh *MShellProc) writeToPtyBuffer_nolock(strFmt string, args ...interface{
func sendRemotePtyUpdate(remoteId string, dataOffset int64, data []byte) {
data64 := base64.StdEncoding.EncodeToString(data)
update := &sstore.PtyDataUpdate{
update := scbus.MakePtyDataUpdate(&scbus.PtyDataUpdate{
RemoteId: remoteId,
PtyPos: dataOffset,
PtyData64: data64,
PtyDataLen: int64(len(data)),
}
sstore.MainBus.SendUpdate(update)
})
scbus.MainUpdateBus.DoUpdate(update)
}
func (msh *MShellProc) isWaitingForPassword_nolock() bool {
@ -2016,9 +2018,9 @@ func (msh *MShellProc) notifyHangups_nolock() {
if err != nil {
continue
}
update := &sstore.ModelUpdate{}
sstore.AddUpdate(update, *cmd)
sstore.MainBus.SendScreenUpdate(ck.GetGroupId(), update)
update := scbus.MakeUpdatePacket()
update.AddUpdate(*cmd)
scbus.MainUpdateBus.DoScreenUpdate(ck.GetGroupId(), update)
go pushNumRunningCmdsUpdate(&ck, -1)
}
msh.RunningCmds = make(map[base.CommandKey]RunCmdType)
@ -2047,7 +2049,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
// fall-through (nothing to do)
}
if screen != nil {
sstore.AddUpdate(update, *screen)
update.AddUpdate(*screen)
}
rct := msh.GetRunningCmd(donePk.CK)
var statePtr *sstore.ShellStatePtr
@ -2059,7 +2061,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
// fall-through (nothing to do)
}
if remoteInst != nil {
sstore.AddUpdate(update, sstore.MakeSessionUpdateForRemote(rct.SessionId, remoteInst))
update.AddUpdate(sstore.MakeSessionUpdateForRemote(rct.SessionId, remoteInst))
}
statePtr = &sstore.ShellStatePtr{BaseHash: donePk.FinalState.GetHashVal(false)}
} else if donePk.FinalStateDiff != nil && rct != nil {
@ -2079,7 +2081,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
// fall-through (nothing to do)
}
if remoteInst != nil {
sstore.AddUpdate(update, sstore.MakeSessionUpdateForRemote(rct.SessionId, remoteInst))
update.AddUpdate(sstore.MakeSessionUpdateForRemote(rct.SessionId, remoteInst))
}
diffHashArr := append(([]string)(nil), donePk.FinalStateDiff.DiffHashArr...)
diffHashArr = append(diffHashArr, donePk.FinalStateDiff.GetHashVal(false))
@ -2093,7 +2095,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
// fall-through (nothing to do)
}
}
sstore.MainBus.SendUpdate(update)
scbus.MainUpdateBus.DoUpdate(update)
return
}
@ -2122,13 +2124,13 @@ func (msh *MShellProc) handleCmdFinalPacket(finalPk *packet.CmdFinalPacketType)
log.Printf("error getting cmd(2) in handleCmdFinalPacket (not found)\n")
return
}
update := &sstore.ModelUpdate{}
sstore.AddUpdate(update, *rtnCmd)
update := scbus.MakeUpdatePacket()
update.AddUpdate(*rtnCmd)
if screen != nil {
sstore.AddUpdate(update, *screen)
update.AddUpdate(*screen)
}
go pushNumRunningCmdsUpdate(&finalPk.CK, -1)
sstore.MainBus.SendUpdate(update)
scbus.MainUpdateBus.DoUpdate(update)
}
// TODO notify FE about cmd errors
@ -2164,7 +2166,7 @@ func (msh *MShellProc) handleDataPacket(dataPk *packet.DataPacketType, dataPosMa
}
utilfn.IncSyncMap(dataPosMap, dataPk.CK, int64(len(realData)))
if update != nil {
sstore.MainBus.SendScreenUpdate(dataPk.CK.GetGroupId(), update)
scbus.MainUpdateBus.DoScreenUpdate(dataPk.CK.GetGroupId(), update)
}
}
if ack != nil {
@ -2193,9 +2195,9 @@ func (msh *MShellProc) makeHandleCmdFinalPacketClosure(finalPk *packet.CmdFinalP
func sendScreenUpdates(screens []*sstore.ScreenType) {
for _, screen := range screens {
update := &sstore.ModelUpdate{}
sstore.AddUpdate(update, *screen)
sstore.MainBus.SendUpdate(update)
update := scbus.MakeUpdatePacket()
update.AddUpdate(*screen)
scbus.MainUpdateBus.DoUpdate(update)
}
}

View File

@ -22,8 +22,9 @@ import (
"github.com/kevinburke/ssh_config"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/userinput"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
@ -104,14 +105,13 @@ func createPublicKeyCallback(sshKeywords *SshKeywords, passphrase string) func()
return createDummySigner()
}
request := &sstore.UserInputRequestType{
request := &userinput.UserInputRequestType{
ResponseType: "text",
QueryText: fmt.Sprintf("Enter passphrase for the SSH key: %s", identityFile),
Title: "Publickey Auth + Passphrase",
}
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
response, err := sstore.MainBus.GetUserInput(ctx, request)
ctx, _ := context.WithTimeout(context.Background(), 60*time.Second)
response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
if err != nil {
// this is an error where we actually do want to stop
// trying keys
@ -141,12 +141,12 @@ func createInteractivePasswordCallbackPrompt() func() (secret string, err error)
// in the future
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
request := &sstore.UserInputRequestType{
request := &userinput.UserInputRequestType{
ResponseType: "text",
QueryText: "Password:",
Title: "Password Authentication",
}
response, err := sstore.MainBus.GetUserInput(ctx, request)
response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
if err != nil {
return "", err
}
@ -201,12 +201,12 @@ func promptChallengeQuestion(question string, echo bool) (answer string, err err
// in the future
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
request := &sstore.UserInputRequestType{
request := &userinput.UserInputRequestType{
ResponseType: "text",
QueryText: question,
Title: "Keyboard Interactive Authentication",
}
response, err := sstore.MainBus.GetUserInput(ctx, request)
response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
if err != nil {
return "", err
}
@ -234,10 +234,10 @@ func openKnownHostsForEdit(knownHostsFilename string) (*os.File, error) {
return os.OpenFile(knownHostsFilename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
}
func writeToKnownHosts(knownHostsFile string, newLine string, getUserVerification func() (*scpacket.UserInputResponsePacketType, error)) error {
func writeToKnownHosts(knownHostsFile string, newLine string, getUserVerification func() (*userinput.UserInputResponsePacketType, error)) error {
if getUserVerification == nil {
getUserVerification = func() (*scpacket.UserInputResponsePacketType, error) {
return &scpacket.UserInputResponsePacketType{
getUserVerification = func() (*userinput.UserInputResponsePacketType, error) {
return &userinput.UserInputResponsePacketType{
Type: "confirm",
Confirm: true,
}, nil
@ -270,7 +270,7 @@ func writeToKnownHosts(knownHostsFile string, newLine string, getUserVerificatio
return f.Close()
}
func createUnknownKeyVerifier(knownHostsFile string, hostname string, remote string, key ssh.PublicKey) func() (*scpacket.UserInputResponsePacketType, error) {
func createUnknownKeyVerifier(knownHostsFile string, hostname string, remote string, key ssh.PublicKey) func() (*userinput.UserInputResponsePacketType, error) {
base64Key := base64.StdEncoding.EncodeToString(key.Marshal())
queryText := fmt.Sprintf(
"The authenticity of host '%s (%s)' can't be established "+
@ -280,20 +280,20 @@ func createUnknownKeyVerifier(knownHostsFile string, hostname string, remote str
"**Would you like to continue connecting?** If so, the key will be permanently "+
"added to the file %s "+
"to protect from future man-in-the-middle attacks.", hostname, remote, key.Type(), base64Key, knownHostsFile)
request := &sstore.UserInputRequestType{
request := &userinput.UserInputRequestType{
ResponseType: "confirm",
QueryText: queryText,
Markdown: true,
Title: "Known Hosts Key Missing",
}
return func() (*scpacket.UserInputResponsePacketType, error) {
return func() (*userinput.UserInputResponsePacketType, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
return sstore.MainBus.GetUserInput(ctx, request)
return userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
}
}
func createMissingKnownHostsVerifier(knownHostsFile string, hostname string, remote string, key ssh.PublicKey) func() (*scpacket.UserInputResponsePacketType, error) {
func createMissingKnownHostsVerifier(knownHostsFile string, hostname string, remote string, key ssh.PublicKey) func() (*userinput.UserInputResponsePacketType, error) {
base64Key := base64.StdEncoding.EncodeToString(key.Marshal())
queryText := fmt.Sprintf(
"The authenticity of host '%s (%s)' can't be established "+
@ -304,16 +304,16 @@ func createMissingKnownHostsVerifier(knownHostsFile string, hostname string, rem
"- %s will be created \n"+
"- the key will be added to %s\n\n"+
"This will protect from future man-in-the-middle attacks.", hostname, remote, key.Type(), base64Key, knownHostsFile, knownHostsFile)
request := &sstore.UserInputRequestType{
request := &userinput.UserInputRequestType{
ResponseType: "confirm",
QueryText: queryText,
Markdown: true,
Title: "Known Hosts File Missing",
}
return func() (*scpacket.UserInputResponsePacketType, error) {
return func() (*userinput.UserInputResponsePacketType, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFn()
return sstore.MainBus.GetUserInput(ctx, request)
return userinput.GetUserInput(ctx, scbus.MainRpcBus, request)
}
}
@ -444,13 +444,13 @@ func createHostKeyCallback(opts *sstore.SSHOpts) (ssh.HostKeyCallback, error) {
"%s\n\n"+
"**Offending Keys** \n"+
"%s", key.Type(), correctKeyFingerprint, strings.Join(bulletListKnownHosts, " \n"), strings.Join(offendingKeysFmt, " \n"))
update := &sstore.ModelUpdate{}
sstore.AddUpdate(update, sstore.AlertMessageType{
update := scbus.MakeUpdatePacket()
update.AddUpdate(sstore.AlertMessageType{
Markdown: true,
Title: "Known Hosts Key Changed",
Message: alertText,
})
sstore.MainBus.SendUpdate(update)
scbus.MainUpdateBus.DoUpdate(update)
return fmt.Errorf("remote host identification has changed")
}

View File

@ -0,0 +1,126 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package scbus
import (
"encoding/json"
"reflect"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
)
const ModelUpdateStr = "model"
// A channel for sending model updates to the client
type ModelUpdateChannel[J any] struct {
ScreenId string
ClientId string
ch chan J
}
func (uch *ModelUpdateChannel[J]) GetChannel() chan J {
return uch.ch
}
func (uch *ModelUpdateChannel[J]) SetChannel(ch chan J) {
uch.ch = ch
}
// Match the screenId to the channel
func (sch *ModelUpdateChannel[J]) Match(screenId string) bool {
if screenId == "" {
return true
}
return screenId == sch.ScreenId
}
// An interface for all model updates
type ModelUpdateItem interface {
// The key to use when marshalling to JSON and interpreting in the client
GetType() string
}
// An inner data type for the ModelUpdatePacketType. Stores a collection of model updates to be sent to the client.
type ModelUpdate []ModelUpdateItem
func (mu *ModelUpdate) IsEmpty() bool {
if mu == nil {
return true
}
muArr := []ModelUpdateItem(*mu)
return len(muArr) == 0
}
func (mu *ModelUpdate) MarshalJSON() ([]byte, error) {
rtn := make([]map[string]any, 0)
for _, u := range *mu {
m := make(map[string]any)
m[(u).GetType()] = u
rtn = append(rtn, m)
}
return json.Marshal(rtn)
}
// An UpdatePacket for sending model updates to the client
type ModelUpdatePacketType struct {
Type string `json:"type"`
Data *ModelUpdate `json:"data"`
}
func (*ModelUpdatePacketType) GetType() string {
return ModelUpdateStr
}
func (mu *ModelUpdatePacketType) IsEmpty() bool {
if mu == nil || mu.Data == nil {
return true
}
return mu.Data.IsEmpty()
}
// Clean the ClientData in an update, if present
func (upk *ModelUpdatePacketType) Clean() {
if upk == nil || upk.Data == nil {
return
}
for _, item := range *(upk.Data) {
if i, ok := (item).(CleanableUpdateItem); ok {
i.Clean()
}
}
}
// Add a collection of model updates to the update
func (upk *ModelUpdatePacketType) AddUpdate(items ...ModelUpdateItem) {
*(upk.Data) = append(*(upk.Data), items...)
}
// Create a new model update packet
func MakeUpdatePacket() *ModelUpdatePacketType {
return &ModelUpdatePacketType{
Type: ModelUpdateStr,
Data: &ModelUpdate{},
}
}
// Returns the items in the update that are of type I
func GetUpdateItems[I ModelUpdateItem](upk *ModelUpdatePacketType) []*I {
ret := make([]*I, 0)
for _, item := range *(upk.Data) {
if i, ok := (item).(I); ok {
ret = append(ret, &i)
}
}
return ret
}
// An interface for model updates that can be cleaned
type CleanableUpdateItem interface {
Clean()
}
func init() {
// Register the model update packet type
packet.RegisterPacketType(ModelUpdateStr, reflect.TypeOf(ModelUpdatePacketType{}))
}

View File

@ -0,0 +1,50 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package scbus
import (
"reflect"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
)
const PtyDataUpdateStr = "pty"
// The inner data type for the PtyDataUpdatePacketType. Stores the pty data to be sent to the client.
type PtyDataUpdate struct {
ScreenId string `json:"screenid,omitempty"`
LineId string `json:"lineid,omitempty"`
RemoteId string `json:"remoteid,omitempty"`
PtyPos int64 `json:"ptypos"`
PtyData64 string `json:"ptydata64"`
PtyDataLen int64 `json:"ptydatalen"`
}
// An UpdatePacket for sending pty data to the client
type PtyDataUpdatePacketType struct {
Type string `json:"type"`
Data *PtyDataUpdate `json:"data"`
}
func (*PtyDataUpdatePacketType) GetType() string {
return PtyDataUpdateStr
}
func (pdu *PtyDataUpdatePacketType) Clean() {
// This is a no-op for PtyDataUpdatePacketType, but it is required to satisfy the UpdatePacket interface
}
func (pdu *PtyDataUpdatePacketType) IsEmpty() bool {
return pdu == nil || pdu.Data == nil || pdu.Data.PtyDataLen == 0
}
// Create a new PtyDataUpdatePacketType
func MakePtyDataUpdate(update *PtyDataUpdate) *PtyDataUpdatePacketType {
return &PtyDataUpdatePacketType{Type: PtyDataUpdateStr, Data: update}
}
func init() {
// Register the PtyDataUpdatePacketType with the packet package
packet.RegisterPacketType(PtyDataUpdateStr, reflect.TypeOf(PtyDataUpdatePacketType{}))
}

240
wavesrv/pkg/scbus/scbus.go Normal file
View File

@ -0,0 +1,240 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
// Defines interfaces for creating communciation channels between server and clients
package scbus
import (
"context"
"fmt"
"log"
"reflect"
"sync"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
)
var MainUpdateBus *UpdateBus = MakeUpdateBus()
var MainRpcBus *RpcBus = MakeRpcBus()
// The default channel size
const ChSize = 100
type Channel[I packet.PacketType] interface {
GetChannel() chan I
SetChannel(chan I)
Match(string) bool
}
// A concurrent bus for registering and managing channels
type Bus[I packet.PacketType] struct {
Lock *sync.Mutex
Channels map[string]Channel[I]
}
// Opens new channel and registers it with the bus. If a channel exists, it is closed and replaced.
func (bus *Bus[I]) RegisterChannel(key string, channelEntry Channel[I]) chan I {
bus.Lock.Lock()
defer bus.Lock.Unlock()
uch, found := bus.Channels[key]
ch := make(chan I, ChSize)
log.Printf("registering channel key=%s ch=%v\n", key, ch)
channelEntry.SetChannel(ch)
if found {
close(uch.GetChannel())
}
bus.Channels[key] = channelEntry
return channelEntry.GetChannel()
}
// Closes the channel matching the provided key and removes it from the bus
func (bus *Bus[I]) UnregisterChannel(key string) {
bus.Lock.Lock()
defer bus.Lock.Unlock()
uch, found := bus.Channels[key]
if found {
close(uch.GetChannel())
delete(bus.Channels, key)
}
}
// An interface for updates to be sent over an UpdateChannel
type UpdatePacket interface {
// The key to use when marshalling to JSON and interpreting in the client
GetType() string
Clean()
IsEmpty() bool
}
// A channel for sending model updates to the client
type UpdateChannel struct {
ScreenId string
ch chan UpdatePacket
}
func (uch *UpdateChannel) GetChannel() chan UpdatePacket {
return uch.ch
}
func (uch *UpdateChannel) SetChannel(ch chan UpdatePacket) {
uch.ch = ch
}
// Match the screenId to the channel
func (sch *UpdateChannel) Match(screenId string) bool {
if screenId == "" {
return true
}
return screenId == sch.ScreenId
}
// A collection of channels that can transmit updates
type UpdateBus struct {
Bus[UpdatePacket]
}
func (bus *UpdateBus) GetLock() *sync.Mutex {
return bus.Lock
}
// Create a new UpdateBus
func MakeUpdateBus() *UpdateBus {
return &UpdateBus{
Bus[UpdatePacket]{
Lock: &sync.Mutex{},
Channels: make(map[string]Channel[UpdatePacket]),
},
}
}
// Send an update to all channels in the collection
func (bus *UpdateBus) DoUpdate(update UpdatePacket) {
if update == nil || update.IsEmpty() {
return
}
update.Clean()
bus.Lock.Lock()
defer bus.Lock.Unlock()
for key, uch := range bus.Channels {
select {
case uch.GetChannel() <- update:
default:
log.Printf("[error] dropped update on %s updatebus uch key=%s\n", reflect.TypeOf(uch), key)
}
}
}
// Send a model update to only clients that are subscribed to the given screenId
func (bus *UpdateBus) DoScreenUpdate(screenId string, update UpdatePacket) {
if update == nil {
return
}
update.Clean()
bus.Lock.Lock()
defer bus.Lock.Unlock()
for id, uch := range bus.Channels {
if uch.Match(screenId) {
select {
case uch.GetChannel() <- update:
default:
log.Printf("[error] dropped update on updatebus uch id=%s\n", id)
}
}
}
}
// An interface for rpc requests
// This is separate from the RpcPacketType defined in the waveshell/pkg/packet package, as that one is intended for use communicating between wavesrv and waveshell. It is has a different set of required methods.
type RpcPacket interface {
SetReqId(string)
SetTimeoutMs(int)
GetType() string
}
// An interface for rpc responses
// This is separate from the RpcResponsePacketType defined in the waveshell/pkg/packet package, as that one is intended for use communicating between wavesrv and waveshell. It is has a different set of required methods.
type RpcResponse interface {
SetError(string)
GetError() string
GetType() string
}
// A collection of channels that can receive rpc responses
type RpcBus struct {
Bus[RpcResponse]
}
// Create a new RpcBus
func MakeRpcBus() *RpcBus {
return &RpcBus{
Bus[RpcResponse]{
Lock: &sync.Mutex{},
Channels: make(map[string]Channel[RpcResponse]),
},
}
}
// Get the user input channel for the given request id
func (bus *RpcBus) GetRpcChannel(id string) (chan RpcResponse, bool) {
bus.Lock.Lock()
defer bus.Lock.Unlock()
if ch, ok := bus.Channels[id]; ok {
return ch.GetChannel(), ok
}
return nil, false
}
// Implements the Channel interface to allow receiving rpc responses
type RpcChannel struct {
ch chan RpcResponse
}
func (ch *RpcChannel) GetChannel() chan RpcResponse {
return ch.ch
}
func (ch *RpcChannel) SetChannel(newCh chan RpcResponse) {
ch.ch = newCh
}
// This is a no-op, only used to satisfy the Channel interface
func (ch *RpcChannel) Match(string) bool {
return true
}
// Send a user input request to the frontend and wait for a response
func (bus *RpcBus) DoRpc(ctx context.Context, pk RpcPacket) (RpcResponse, error) {
id := uuid.New().String()
ch := bus.RegisterChannel(id, &RpcChannel{})
pk.SetReqId(id)
defer bus.UnregisterChannel(id)
deadline, _ := ctx.Deadline()
pk.SetTimeoutMs(int(time.Until(deadline).Milliseconds()) - 500)
// Send the request to the frontend
mu := MakeUpdatePacket()
mu.AddUpdate(pk)
MainUpdateBus.DoUpdate(mu)
var response RpcResponse
var err error
// prepare to receive response
select {
case resp := <-ch:
response = resp
case <-ctx.Done():
return nil, fmt.Errorf("timed out waiting for rpc response")
}
if response.GetError() != "" {
err = fmt.Errorf(response.GetError())
}
return response, err
}

View File

@ -83,7 +83,6 @@ const WatchScreenPacketStr = "watchscreen"
const FeInputPacketStr = "feinput"
const RemoteInputPacketStr = "remoteinput"
const CmdInputTextPacketStr = "cmdinputtext"
const UserInputResponsePacketStr = "userinputresp"
type FeCommandPacketType struct {
Type string `json:"type"`
@ -156,21 +155,16 @@ type CmdInputTextPacketType struct {
Text utilfn.StrWithPos `json:"text"`
}
type UserInputResponsePacketType struct {
Type string `json:"type"`
RequestId string `json:"requestid"`
Text string `json:"text,omitempty"`
Confirm bool `json:"confirm,omitempty"`
ErrorMsg string `json:"errormsg,omitempty"`
}
func init() {
packet.RegisterPacketType(FeCommandPacketStr, reflect.TypeOf(FeCommandPacketType{}))
packet.RegisterPacketType(WatchScreenPacketStr, reflect.TypeOf(WatchScreenPacketType{}))
packet.RegisterPacketType(FeInputPacketStr, reflect.TypeOf(FeInputPacketType{}))
packet.RegisterPacketType(RemoteInputPacketStr, reflect.TypeOf(RemoteInputPacketType{}))
packet.RegisterPacketType(CmdInputTextPacketStr, reflect.TypeOf(CmdInputTextPacketType{}))
packet.RegisterPacketType(UserInputResponsePacketStr, reflect.TypeOf(UserInputResponsePacketType{}))
}
type PacketType interface {
GetType() string
}
func (*CmdInputTextPacketType) GetType() string {
@ -212,7 +206,3 @@ func MakeRemoteInputPacket() *RemoteInputPacketType {
func (*RemoteInputPacketType) GetType() string {
return RemoteInputPacketStr
}
func (*UserInputResponsePacketType) GetType() string {
return UserInputResponsePacketStr
}

View File

@ -15,8 +15,10 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/wavesrv/pkg/mapqueue"
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/userinput"
"github.com/wavetermdev/waveterm/wavesrv/pkg/wsshell"
)
@ -35,8 +37,8 @@ type WSState struct {
ClientId string
ConnectTime time.Time
Shell *wsshell.WSShell
UpdateCh chan interface{}
UpdateQueue []interface{}
UpdateCh chan scbus.UpdatePacket
UpdateQueue []any
Authenticated bool
AuthKey string
@ -71,7 +73,7 @@ func (ws *WSState) GetShell() *wsshell.WSShell {
return ws.Shell
}
func (ws *WSState) WriteUpdate(update interface{}) error {
func (ws *WSState) WriteUpdate(update any) error {
shell := ws.GetShell()
if shell == nil {
return fmt.Errorf("cannot write update, empty shell")
@ -103,26 +105,21 @@ func (ws *WSState) WatchScreen(sessionId string, screenId string) {
}
ws.SessionId = sessionId
ws.ScreenId = screenId
ws.UpdateCh = sstore.MainBus.RegisterChannel(ws.ClientId, ws.ScreenId)
ws.UpdateCh = scbus.MainUpdateBus.RegisterChannel(ws.ClientId, &scbus.UpdateChannel{ScreenId: ws.ScreenId})
log.Printf("[ws] watch screen clientid=%s sessionid=%s screenid=%s, updateCh=%v\n", ws.ClientId, sessionId, screenId, ws.UpdateCh)
go ws.RunUpdates(ws.UpdateCh)
}
func (ws *WSState) UnWatchScreen() {
ws.Lock.Lock()
defer ws.Lock.Unlock()
sstore.MainBus.UnregisterChannel(ws.ClientId)
scbus.MainUpdateBus.UnregisterChannel(ws.ClientId)
ws.SessionId = ""
ws.ScreenId = ""
log.Printf("[ws] unwatch screen clientid=%s\n", ws.ClientId)
}
func (ws *WSState) getUpdateCh() chan interface{} {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.UpdateCh
}
func (ws *WSState) RunUpdates(updateCh chan interface{}) {
func (ws *WSState) RunUpdates(updateCh chan scbus.UpdatePacket) {
if updateCh == nil {
panic("invalid nil updateCh passed to RunUpdates")
}
@ -141,7 +138,6 @@ func writeJsonProtected(shell *wsshell.WSShell, update any) {
return
}
log.Printf("[error] in scws RunUpdates WriteJson: %v\n", r)
return
}()
shell.WriteJson(update)
}
@ -155,7 +151,6 @@ func (ws *WSState) ReplaceShell(shell *wsshell.WSShell) {
}
ws.Shell.Conn.Close()
ws.Shell = shell
return
}
// returns all state required to display current UI
@ -170,8 +165,8 @@ func (ws *WSState) handleConnection() error {
connectUpdate.Remotes = remotes
// restore status indicators
connectUpdate.ScreenStatusIndicators, connectUpdate.ScreenNumRunningCommands = sstore.GetCurrentIndicatorState()
mu := &sstore.ModelUpdate{}
sstore.AddUpdate(mu, *connectUpdate)
mu := scbus.MakeUpdatePacket()
mu.AddUpdate(*connectUpdate)
err = ws.Shell.WriteJson(mu)
if err != nil {
return err
@ -282,11 +277,11 @@ func (ws *WSState) processMessage(msgBytes []byte) error {
sstore.ScreenMemSetCmdInputText(cmdInputPk.ScreenId, cmdInputPk.Text, cmdInputPk.SeqNum)
return nil
}
if pk.GetType() == scpacket.UserInputResponsePacketStr {
userInputRespPk := pk.(*scpacket.UserInputResponsePacketType)
uich, ok := sstore.MainBus.GetUserInputChannel(userInputRespPk.RequestId)
if pk.GetType() == userinput.UserInputResponsePacketStr {
userInputRespPk := pk.(*userinput.UserInputResponsePacketType)
uich, ok := scbus.MainRpcBus.GetRpcChannel(userInputRespPk.RequestId)
if !ok {
return fmt.Errorf("received User Input Response with invalid Id (%s): %v\n", userInputRespPk.RequestId, err)
return fmt.Errorf("received User Input Response with invalid Id (%s): %v", userInputRespPk.RequestId, err)
}
select {
case uich <- userInputRespPk:
@ -302,7 +297,7 @@ func (ws *WSState) RunWSRead() {
if shell == nil {
return
}
shell.WriteJson(map[string]interface{}{"type": "hello"}) // let client know we accepted this connection, ignore error
shell.WriteJson(map[string]any{"type": "hello"}) // let client know we accepted this connection, ignore error
for msgBytes := range shell.ReadChan {
err := ws.processMessage(msgBytes)
if err != nil {

View File

@ -22,6 +22,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/dbutil"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
)
const HistoryCols = "h.historyid, h.ts, h.userid, h.sessionid, h.screenid, h.lineid, h.haderror, h.cmdstr, h.remoteownerid, h.remoteid, h.remotename, h.ismetacmd, h.linenum, h.exitcode, h.durationms, h.festate, h.tags, h.status"
@ -563,7 +564,7 @@ func GetSessionByName(ctx context.Context, name string) (*SessionType, error) {
// returns sessionId
// if sessionName == "", it will be generated
func InsertSessionWithName(ctx context.Context, sessionName string, activate bool) (*ModelUpdate, error) {
func InsertSessionWithName(ctx context.Context, sessionName string, activate bool) (*scbus.ModelUpdatePacketType, error) {
var newScreen *ScreenType
newSessionId := scbase.GenWaveUUID()
txErr := WithTx(ctx, func(tx *TxWrap) error {
@ -577,7 +578,7 @@ func InsertSessionWithName(ctx context.Context, sessionName string, activate boo
if err != nil {
return err
}
screenUpdateItems := GetUpdateItems[ScreenType](screenUpdate)
screenUpdateItems := scbus.GetUpdateItems[ScreenType](screenUpdate)
if len(screenUpdateItems) < 1 {
return fmt.Errorf("no screen update items")
}
@ -595,11 +596,11 @@ func InsertSessionWithName(ctx context.Context, sessionName string, activate boo
if err != nil {
return nil, err
}
update := &ModelUpdate{}
AddUpdate(update, *session)
AddUpdate(update, *newScreen)
update := scbus.MakeUpdatePacket()
update.AddUpdate(*session)
update.AddUpdate(*newScreen)
if activate {
AddUpdate(update, ActiveSessionIdUpdate(newSessionId))
update.AddUpdate(ActiveSessionIdUpdate(newSessionId))
}
return update, nil
}
@ -687,7 +688,7 @@ func fmtUniqueName(name string, defaultFmtStr string, startIdx int, strs []strin
}
}
func InsertScreen(ctx context.Context, sessionId string, origScreenName string, opts ScreenCreateOpts, activate bool) (*ModelUpdate, error) {
func InsertScreen(ctx context.Context, sessionId string, origScreenName string, opts ScreenCreateOpts, activate bool) (*scbus.ModelUpdatePacketType, error) {
var newScreenId string
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT sessionid FROM session WHERE sessionid = ? AND NOT archived`
@ -753,14 +754,14 @@ func InsertScreen(ctx context.Context, sessionId string, origScreenName string,
if err != nil {
return nil, err
}
update := &ModelUpdate{}
AddUpdate(update, *newScreen)
update := scbus.MakeUpdatePacket()
update.AddUpdate(*newScreen)
if activate {
bareSession, err := GetBareSessionById(ctx, sessionId)
if err != nil {
return nil, txErr
}
AddUpdate(update, *bareSession)
update.AddUpdate(*bareSession)
UpdateWithCurrentOpenAICmdInfoChat(newScreenId, update)
}
return update, nil
@ -875,28 +876,30 @@ func GetCmdByScreenId(ctx context.Context, screenId string, lineId string) (*Cmd
})
}
func UpdateWithClearOpenAICmdInfo(screenId string) (*ModelUpdate, error) {
func UpdateWithClearOpenAICmdInfo(screenId string) *scbus.ModelUpdatePacketType {
ScreenMemClearCmdInfoChat(screenId)
return UpdateWithCurrentOpenAICmdInfoChat(screenId, nil)
}
func UpdateWithAddNewOpenAICmdInfoPacket(ctx context.Context, screenId string, pk *packet.OpenAICmdInfoChatMessage) (*ModelUpdate, error) {
func UpdateWithAddNewOpenAICmdInfoPacket(ctx context.Context, screenId string, pk *packet.OpenAICmdInfoChatMessage) *scbus.ModelUpdatePacketType {
ScreenMemAddCmdInfoChatMessage(screenId, pk)
return UpdateWithCurrentOpenAICmdInfoChat(screenId, nil)
}
func UpdateWithCurrentOpenAICmdInfoChat(screenId string, update *ModelUpdate) (*ModelUpdate, error) {
ret := &ModelUpdate{}
AddOpenAICmdInfoChatUpdate(ret, ScreenMemGetCmdInfoChat(screenId).Messages)
return ret, nil
func UpdateWithCurrentOpenAICmdInfoChat(screenId string, update *scbus.ModelUpdatePacketType) *scbus.ModelUpdatePacketType {
if update == nil {
update = scbus.MakeUpdatePacket()
}
update.AddUpdate(OpenAICmdInfoChatUpdate(ScreenMemGetCmdInfoChat(screenId).Messages))
return update
}
func UpdateWithUpdateOpenAICmdInfoPacket(ctx context.Context, screenId string, messageID int, pk *packet.OpenAICmdInfoChatMessage) (*ModelUpdate, error) {
func UpdateWithUpdateOpenAICmdInfoPacket(ctx context.Context, screenId string, messageID int, pk *packet.OpenAICmdInfoChatMessage) (*scbus.ModelUpdatePacketType, error) {
err := ScreenMemUpdateCmdInfoChatMessage(screenId, messageID, pk)
if err != nil {
return nil, err
}
return UpdateWithCurrentOpenAICmdInfoChat(screenId, nil)
return UpdateWithCurrentOpenAICmdInfoChat(screenId, nil), nil
}
func UpdateCmdForRestart(ctx context.Context, ck base.CommandKey, ts int64, cmdPid int, remotePid int, termOpts *TermOpts) error {
@ -913,7 +916,7 @@ func UpdateCmdForRestart(ctx context.Context, ck base.CommandKey, ts int64, cmdP
})
}
func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, donePk *packet.CmdDonePacketType, status string) (*ModelUpdate, error) {
func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, donePk *packet.CmdDonePacketType, status string) (*scbus.ModelUpdatePacketType, error) {
if donePk == nil {
return nil, fmt.Errorf("invalid cmddone packet")
}
@ -947,8 +950,8 @@ func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, donePk *packet.C
return nil, fmt.Errorf("cmd data not found for ck[%s]", ck)
}
update := &ModelUpdate{}
AddUpdate(update, *rtnCmd)
update := scbus.MakeUpdatePacket()
update.AddUpdate(*rtnCmd)
// Update in-memory screen indicator status
var indicator StatusIndicatorLevel
@ -1096,7 +1099,7 @@ func getNextId(ids []string, delId string) string {
return ids[0]
}
func SwitchScreenById(ctx context.Context, sessionId string, screenId string) (*ModelUpdate, error) {
func SwitchScreenById(ctx context.Context, sessionId string, screenId string) (*scbus.ModelUpdatePacketType, error) {
SetActiveSessionId(ctx, sessionId)
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT screenid FROM screen WHERE sessionid = ? AND screenid = ?`
@ -1114,12 +1117,12 @@ func SwitchScreenById(ctx context.Context, sessionId string, screenId string) (*
if err != nil {
return nil, err
}
update := &ModelUpdate{}
AddUpdate(update, (ActiveSessionIdUpdate)(sessionId))
AddUpdate(update, *bareSession)
update := scbus.MakeUpdatePacket()
update.AddUpdate(ActiveSessionIdUpdate(sessionId))
update.AddUpdate(*bareSession)
memState := GetScreenMemState(screenId)
if memState != nil {
AddCmdLineUpdate(update, memState.CmdInputText)
update.AddUpdate(CmdLineUpdate(memState.CmdInputText))
UpdateWithCurrentOpenAICmdInfoChat(screenId, update)
// Clear any previous status indicator for this screen
@ -1151,7 +1154,7 @@ func cleanScreenCmds(ctx context.Context, screenId string) error {
return nil
}
func ArchiveScreen(ctx context.Context, sessionId string, screenId string) (UpdatePacket, error) {
func ArchiveScreen(ctx context.Context, sessionId string, screenId string) (scbus.UpdatePacket, error) {
var isActive bool
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT screenid FROM screen WHERE sessionid = ? AND screenid = ?`
@ -1188,14 +1191,14 @@ func ArchiveScreen(ctx context.Context, sessionId string, screenId string) (Upda
if err != nil {
return nil, fmt.Errorf("cannot retrive archived screen: %w", err)
}
update := &ModelUpdate{}
AddUpdate(update, *newScreen)
update := scbus.MakeUpdatePacket()
update.AddUpdate(*newScreen)
if isActive {
bareSession, err := GetBareSessionById(ctx, sessionId)
if err != nil {
return nil, err
}
AddUpdate(update, *bareSession)
update.AddUpdate(*bareSession)
}
return update, nil
}
@ -1215,7 +1218,7 @@ func UnArchiveScreen(ctx context.Context, sessionId string, screenId string) err
}
// if sessionDel is passed, we do *not* delete the screen directory (session delete will handle that)
func DeleteScreen(ctx context.Context, screenId string, sessionDel bool, update *ModelUpdate) (*ModelUpdate, error) {
func DeleteScreen(ctx context.Context, screenId string, sessionDel bool, update *scbus.ModelUpdatePacketType) (*scbus.ModelUpdatePacketType, error) {
var sessionId string
var isActive bool
var screenTombstone *ScreenTombstoneType
@ -1276,16 +1279,16 @@ func DeleteScreen(ctx context.Context, screenId string, sessionDel bool, update
GoDeleteScreenDirs(screenId)
}
if update == nil {
update = &ModelUpdate{}
update = scbus.MakeUpdatePacket()
}
AddUpdate(update, *screenTombstone)
AddUpdate(update, ScreenType{SessionId: sessionId, ScreenId: screenId, Remove: true})
update.AddUpdate(*screenTombstone)
update.AddUpdate(ScreenType{SessionId: sessionId, ScreenId: screenId, Remove: true})
if isActive {
bareSession, err := GetBareSessionById(ctx, sessionId)
if err != nil {
return nil, err
}
AddUpdate(update, *bareSession)
update.AddUpdate(*bareSession)
}
return update, nil
}
@ -1516,7 +1519,7 @@ func SetScreenName(ctx context.Context, sessionId string, screenId string, name
return txErr
}
func ArchiveScreenLines(ctx context.Context, screenId string) (*ModelUpdate, error) {
func ArchiveScreenLines(ctx context.Context, screenId string) (*scbus.ModelUpdatePacketType, error) {
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT screenid FROM screen WHERE screenid = ?`
if !tx.Exists(query, screenId) {
@ -1535,12 +1538,12 @@ func ArchiveScreenLines(ctx context.Context, screenId string) (*ModelUpdate, err
if err != nil {
return nil, err
}
ret := &ModelUpdate{}
AddUpdate(ret, *screenLines)
ret := scbus.MakeUpdatePacket()
ret.AddUpdate(*screenLines)
return ret, nil
}
func DeleteScreenLines(ctx context.Context, screenId string) (*ModelUpdate, error) {
func DeleteScreenLines(ctx context.Context, screenId string) (*scbus.ModelUpdatePacketType, error) {
var lineIds []string
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT lineid FROM line
@ -1579,9 +1582,9 @@ func DeleteScreenLines(ctx context.Context, screenId string) (*ModelUpdate, erro
}
screenLines.Lines = append(screenLines.Lines, line)
}
ret := &ModelUpdate{}
AddUpdate(ret, *screen)
AddUpdate(ret, *screenLines)
ret := scbus.MakeUpdatePacket()
ret.AddUpdate(*screen)
ret.AddUpdate(*screenLines)
return ret, nil
}
@ -1629,11 +1632,11 @@ func ScreenReset(ctx context.Context, screenId string) ([]*RemoteInstance, error
})
}
func DeleteSession(ctx context.Context, sessionId string) (UpdatePacket, error) {
func DeleteSession(ctx context.Context, sessionId string) (scbus.UpdatePacket, error) {
var newActiveSessionId string
var screenIds []string
var sessionTombstone *SessionTombstoneType
update := &ModelUpdate{}
update := scbus.MakeUpdatePacket()
txErr := WithTx(ctx, func(tx *TxWrap) error {
bareSession, err := GetBareSessionById(tx.Context(), sessionId)
if err != nil {
@ -1668,11 +1671,11 @@ func DeleteSession(ctx context.Context, sessionId string) (UpdatePacket, error)
}
GoDeleteScreenDirs(screenIds...)
if newActiveSessionId != "" {
AddUpdate(update, (ActiveSessionIdUpdate)(newActiveSessionId))
update.AddUpdate(ActiveSessionIdUpdate(newActiveSessionId))
}
AddUpdate(update, SessionType{SessionId: sessionId, Remove: true})
update.AddUpdate(SessionType{SessionId: sessionId, Remove: true})
if sessionTombstone != nil {
AddUpdate(update, *sessionTombstone)
update.AddUpdate(*sessionTombstone)
}
return update, nil
}
@ -1699,7 +1702,7 @@ func fixActiveSessionId(ctx context.Context) (string, error) {
return newActiveSessionId, nil
}
func ArchiveSession(ctx context.Context, sessionId string) (*ModelUpdate, error) {
func ArchiveSession(ctx context.Context, sessionId string) (*scbus.ModelUpdatePacketType, error) {
if sessionId == "" {
return nil, fmt.Errorf("invalid blank sessionid")
}
@ -1723,17 +1726,17 @@ func ArchiveSession(ctx context.Context, sessionId string) (*ModelUpdate, error)
return nil, txErr
}
bareSession, _ := GetBareSessionById(ctx, sessionId)
update := &ModelUpdate{}
update := scbus.MakeUpdatePacket()
if bareSession != nil {
AddUpdate(update, *bareSession)
update.AddUpdate(*bareSession)
}
if newActiveSessionId != "" {
AddUpdate(update, (ActiveSessionIdUpdate)(newActiveSessionId))
update.AddUpdate(ActiveSessionIdUpdate(newActiveSessionId))
}
return update, nil
}
func UnArchiveSession(ctx context.Context, sessionId string, activate bool) (*ModelUpdate, error) {
func UnArchiveSession(ctx context.Context, sessionId string, activate bool) (*scbus.ModelUpdatePacketType, error) {
if sessionId == "" {
return nil, fmt.Errorf("invalid blank sessionid")
}
@ -1759,13 +1762,13 @@ func UnArchiveSession(ctx context.Context, sessionId string, activate bool) (*Mo
return nil, txErr
}
bareSession, _ := GetBareSessionById(ctx, sessionId)
update := &ModelUpdate{}
update := scbus.MakeUpdatePacket()
if bareSession != nil {
AddUpdate(update, *bareSession)
update.AddUpdate(*bareSession)
}
if activate {
AddUpdate(update, (ActiveSessionIdUpdate)(sessionId))
update.AddUpdate(ActiveSessionIdUpdate(sessionId))
}
return update, nil
}

View File

@ -18,6 +18,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/cirfile"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
)
func CreateCmdPtyFile(ctx context.Context, screenId string, lineId string, maxSize int64) error {
@ -61,7 +62,7 @@ func ClearCmdPtyFile(ctx context.Context, screenId string, lineId string) error
return nil
}
func AppendToCmdPtyBlob(ctx context.Context, screenId string, lineId string, data []byte, pos int64) (*PtyDataUpdate, error) {
func AppendToCmdPtyBlob(ctx context.Context, screenId string, lineId string, data []byte, pos int64) (*scbus.PtyDataUpdatePacketType, error) {
if screenId == "" {
return nil, fmt.Errorf("cannot append to PtyBlob, screenid is not set")
}
@ -82,13 +83,13 @@ func AppendToCmdPtyBlob(ctx context.Context, screenId string, lineId string, dat
return nil, err
}
data64 := base64.StdEncoding.EncodeToString(data)
update := &PtyDataUpdate{
update := scbus.MakePtyDataUpdate(&scbus.PtyDataUpdate{
ScreenId: screenId,
LineId: lineId,
PtyPos: pos,
PtyData64: data64,
PtyDataLen: int64(len(data)),
}
})
err = MaybeInsertPtyPosUpdate(ctx, screenId, lineId)
if err != nil {
// just log

View File

@ -27,6 +27,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/wavesrv/pkg/dbutil"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
_ "github.com/mattn/go-sqlite3"
@ -338,7 +339,7 @@ func (cdata *ClientData) Clean() *ClientData {
return &rtn
}
func (ClientData) UpdateType() string {
func (ClientData) GetType() string {
return "clientdata"
}
@ -357,7 +358,7 @@ type SessionType struct {
Remove bool `json:"remove,omitempty"`
}
func (SessionType) UpdateType() string {
func (SessionType) GetType() string {
return "session"
}
@ -376,7 +377,7 @@ type SessionTombstoneType struct {
func (SessionTombstoneType) UseDBMap() {}
func (SessionTombstoneType) UpdateType() string {
func (SessionTombstoneType) GetType() string {
return "sessiontombstone"
}
@ -449,7 +450,7 @@ type ScreenLinesType struct {
func (ScreenLinesType) UseDBMap() {}
func (ScreenLinesType) UpdateType() string {
func (ScreenLinesType) GetType() string {
return "screenlines"
}
@ -548,22 +549,22 @@ func (s *ScreenType) FromMap(m map[string]interface{}) bool {
return true
}
func (ScreenType) UpdateType() string {
func (ScreenType) GetType() string {
return "screen"
}
func AddScreenUpdate(update *ModelUpdate, newScreen *ScreenType) {
func AddScreenUpdate(update *scbus.ModelUpdatePacketType, newScreen *ScreenType) {
if newScreen == nil {
return
}
screenUpdates := GetUpdateItems[ScreenType](update)
screenUpdates := scbus.GetUpdateItems[ScreenType](update)
for _, screenUpdate := range screenUpdates {
if screenUpdate.ScreenId == newScreen.ScreenId {
screenUpdate = newScreen
return
}
}
AddUpdate(update, newScreen)
update.AddUpdate(newScreen)
}
type ScreenTombstoneType struct {
@ -576,7 +577,7 @@ type ScreenTombstoneType struct {
func (ScreenTombstoneType) UseDBMap() {}
func (ScreenTombstoneType) UpdateType() string {
func (ScreenTombstoneType) GetType() string {
return "screentombstone"
}
@ -1060,7 +1061,7 @@ func (state RemoteRuntimeState) ExpandHomeDir(pathStr string) (string, error) {
return path.Join(homeDir, pathStr[2:]), nil
}
func (RemoteRuntimeState) UpdateType() string {
func (RemoteRuntimeState) GetType() string {
return "remote"
}
@ -1128,7 +1129,7 @@ type CmdType struct {
Restarted bool `json:"restarted,omitempty"` // not persisted to DB
}
func (CmdType) UpdateType() string {
func (CmdType) GetType() string {
return "cmd"
}
@ -1479,7 +1480,7 @@ func SetReleaseInfo(ctx context.Context, releaseInfo ReleaseInfoType) error {
}
// Sets the in-memory status indicator for the given screenId to the given value and adds it to the ModelUpdate. By default, the active screen will be ignored when updating status. To force a status update for the active screen, set force=true.
func SetStatusIndicatorLevel_Update(ctx context.Context, update *ModelUpdate, screenId string, level StatusIndicatorLevel, force bool) error {
func SetStatusIndicatorLevel_Update(ctx context.Context, update *scbus.ModelUpdatePacketType, screenId string, level StatusIndicatorLevel, force bool) error {
var newStatus StatusIndicatorLevel
if force {
// Force the update and set the new status to the given level, regardless of the current status or the active screen
@ -1509,7 +1510,7 @@ func SetStatusIndicatorLevel_Update(ctx context.Context, update *ModelUpdate, sc
}
}
AddUpdate(update, ScreenStatusIndicatorType{
update.AddUpdate(ScreenStatusIndicatorType{
ScreenId: screenId,
Status: newStatus,
})
@ -1518,17 +1519,17 @@ func SetStatusIndicatorLevel_Update(ctx context.Context, update *ModelUpdate, sc
// Sets the in-memory status indicator for the given screenId to the given value and pushes the new value to the FE
func SetStatusIndicatorLevel(ctx context.Context, screenId string, level StatusIndicatorLevel, force bool) error {
update := &ModelUpdate{}
update := scbus.MakeUpdatePacket()
err := SetStatusIndicatorLevel_Update(ctx, update, screenId, level, false)
if err != nil {
return err
}
MainBus.SendUpdate(update)
scbus.MainUpdateBus.DoUpdate(update)
return nil
}
// Resets the in-memory status indicator for the given screenId to StatusIndicatorLevel_None and adds it to the ModelUpdate
func ResetStatusIndicator_Update(update *ModelUpdate, screenId string) error {
func ResetStatusIndicator_Update(update *scbus.ModelUpdatePacketType, screenId string) error {
// We do not need to set context when resetting the status indicator because we will not need to call the DB
return SetStatusIndicatorLevel_Update(context.TODO(), update, screenId, StatusIndicatorLevel_None, true)
}
@ -1539,9 +1540,9 @@ func ResetStatusIndicator(screenId string) error {
return SetStatusIndicatorLevel(context.TODO(), screenId, StatusIndicatorLevel_None, true)
}
func IncrementNumRunningCmds_Update(update *ModelUpdate, screenId string, delta int) {
func IncrementNumRunningCmds_Update(update *scbus.ModelUpdatePacketType, screenId string, delta int) {
newNum := ScreenMemIncrementNumRunningCommands(screenId, delta)
AddUpdate(update, ScreenNumRunningCommandsType{
update.AddUpdate(ScreenNumRunningCommandsType{
ScreenId: screenId,
Num: newNum,
})
@ -1549,7 +1550,7 @@ func IncrementNumRunningCmds_Update(update *ModelUpdate, screenId string, delta
}
func IncrementNumRunningCmds(screenId string, delta int) {
update := &ModelUpdate{}
update := scbus.MakeUpdatePacket()
IncrementNumRunningCmds_Update(update, screenId, delta)
MainBus.SendUpdate(update)
scbus.MainUpdateBus.DoUpdate(update)
}

View File

@ -1,247 +0,0 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package sstore
import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
)
var MainBus *UpdateBus = MakeUpdateBus()
const PtyDataUpdateStr = "pty"
const ModelUpdateStr = "model"
const UpdateChSize = 100
type UpdatePacket interface {
// The key to use when marshalling to JSON and interpreting in the client
UpdateType() string
Clean()
}
type PtyDataUpdate struct {
ScreenId string `json:"screenid,omitempty"`
LineId string `json:"lineid,omitempty"`
RemoteId string `json:"remoteid,omitempty"`
PtyPos int64 `json:"ptypos"`
PtyData64 string `json:"ptydata64"`
PtyDataLen int64 `json:"ptydatalen"`
}
func (*PtyDataUpdate) UpdateType() string {
return PtyDataUpdateStr
}
func (pdu *PtyDataUpdate) Clean() {}
// A collection of independent model updates to be sent to the client. Will be evaluated in order on the client.
type ModelUpdate []*ModelUpdateItem
func (*ModelUpdate) UpdateType() string {
return ModelUpdateStr
}
func (mu *ModelUpdate) MarshalJSON() ([]byte, error) {
rtn := make([]map[string]any, 0)
for _, u := range *mu {
m := make(map[string]any)
m[(*u).UpdateType()] = u
rtn = append(rtn, m)
}
return json.Marshal(rtn)
}
// An interface for all model updates
type ModelUpdateItem interface {
// The key to use when marshalling to JSON and interpreting in the client
UpdateType() string
}
// Clean the ClientData in an update, if present
func (update *ModelUpdate) Clean() {
if update == nil {
return
}
clientDataUpdates := GetUpdateItems[ClientData](update)
if len(clientDataUpdates) > 0 {
lastUpdate := clientDataUpdates[len(clientDataUpdates)-1]
lastUpdate.Clean()
}
}
func (update *ModelUpdate) append(item *ModelUpdateItem) {
*update = append(*update, item)
}
// Add a collection of model updates to the update
func AddUpdate(update *ModelUpdate, item ...ModelUpdateItem) {
for _, i := range item {
update.append(&i)
}
}
// Returns the items in the update that are of type I
func GetUpdateItems[I ModelUpdateItem](update *ModelUpdate) []*I {
ret := make([]*I, 0)
for _, item := range *update {
if i, ok := (*item).(I); ok {
ret = append(ret, &i)
}
}
return ret
}
type UpdateChannel struct {
ScreenId string
ClientId string
Ch chan interface{}
}
func (uch UpdateChannel) Match(screenId string) bool {
if screenId == "" {
return true
}
return screenId == uch.ScreenId
}
type UpdateBus struct {
Lock *sync.Mutex
Channels map[string]UpdateChannel
UserInputCh map[string](chan *scpacket.UserInputResponsePacketType)
}
func MakeUpdateBus() *UpdateBus {
return &UpdateBus{
Lock: &sync.Mutex{},
Channels: make(map[string]UpdateChannel),
UserInputCh: make(map[string](chan *scpacket.UserInputResponsePacketType)),
}
}
// always returns a new channel
func (bus *UpdateBus) RegisterChannel(clientId string, screenId string) chan interface{} {
bus.Lock.Lock()
defer bus.Lock.Unlock()
uch, found := bus.Channels[clientId]
if found {
close(uch.Ch)
uch.ScreenId = screenId
uch.Ch = make(chan interface{}, UpdateChSize)
} else {
uch = UpdateChannel{
ClientId: clientId,
ScreenId: screenId,
Ch: make(chan interface{}, UpdateChSize),
}
}
bus.Channels[clientId] = uch
return uch.Ch
}
func (bus *UpdateBus) UnregisterChannel(clientId string) {
bus.Lock.Lock()
defer bus.Lock.Unlock()
uch, found := bus.Channels[clientId]
if found {
close(uch.Ch)
delete(bus.Channels, clientId)
}
}
func (bus *UpdateBus) SendUpdate(update UpdatePacket) {
if update == nil {
return
}
update.Clean()
bus.Lock.Lock()
defer bus.Lock.Unlock()
for _, uch := range bus.Channels {
select {
case uch.Ch <- update:
default:
log.Printf("[error] dropped update on updatebus uch clientid=%s\n", uch.ClientId)
}
}
}
func (bus *UpdateBus) SendScreenUpdate(screenId string, update UpdatePacket) {
if update == nil {
return
}
update.Clean()
bus.Lock.Lock()
defer bus.Lock.Unlock()
for _, uch := range bus.Channels {
if uch.Match(screenId) {
select {
case uch.Ch <- update:
default:
log.Printf("[error] dropped update on updatebus uch clientid=%s\n", uch.ClientId)
}
}
}
}
func (bus *UpdateBus) registerUserInputChannel() (string, chan *scpacket.UserInputResponsePacketType) {
bus.Lock.Lock()
defer bus.Lock.Unlock()
id := uuid.New().String()
uich := make(chan *scpacket.UserInputResponsePacketType, 1)
bus.UserInputCh[id] = uich
return id, uich
}
func (bus *UpdateBus) unregisterUserInputChannel(id string) {
bus.Lock.Lock()
defer bus.Lock.Unlock()
delete(bus.UserInputCh, id)
}
func (bus *UpdateBus) GetUserInputChannel(id string) (chan *scpacket.UserInputResponsePacketType, bool) {
bus.Lock.Lock()
defer bus.Lock.Unlock()
uich, ok := bus.UserInputCh[id]
return uich, ok
}
func (bus *UpdateBus) GetUserInput(ctx context.Context, userInputRequest *UserInputRequestType) (*scpacket.UserInputResponsePacketType, error) {
id, uich := bus.registerUserInputChannel()
defer bus.unregisterUserInputChannel(id)
userInputRequest.RequestId = id
deadline, _ := ctx.Deadline()
userInputRequest.TimeoutMs = int(time.Until(deadline).Milliseconds()) - 500
update := &ModelUpdate{}
AddUpdate(update, *userInputRequest)
bus.SendUpdate(update)
var response *scpacket.UserInputResponsePacketType
var err error
// prepare to receive response
select {
case resp := <-uich:
response = resp
case <-ctx.Done():
return nil, fmt.Errorf("Timed out waiting for user input")
}
if response.ErrorMsg != "" {
err = fmt.Errorf(response.ErrorMsg)
}
return response, err
}

View File

@ -1,3 +1,6 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package sstore
import (
@ -5,11 +8,12 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
)
type ActiveSessionIdUpdate string
func (ActiveSessionIdUpdate) UpdateType() string {
func (ActiveSessionIdUpdate) GetType() string {
return "activesessionid"
}
@ -18,11 +22,11 @@ type LineUpdate struct {
Cmd CmdType `json:"cmd,omitempty"`
}
func (LineUpdate) UpdateType() string {
func (LineUpdate) GetType() string {
return "line"
}
func AddLineUpdate(update *ModelUpdate, newLine *LineType, newCmd *CmdType) {
func AddLineUpdate(update *scbus.ModelUpdatePacketType, newLine *LineType, newCmd *CmdType) {
if newLine == nil {
return
}
@ -32,19 +36,15 @@ func AddLineUpdate(update *ModelUpdate, newLine *LineType, newCmd *CmdType) {
if newCmd != nil {
newLineUpdate.Cmd = *newCmd
}
AddUpdate(update, newLineUpdate)
update.AddUpdate(newLineUpdate)
}
type CmdLineUpdate utilfn.StrWithPos
func (CmdLineUpdate) UpdateType() string {
func (CmdLineUpdate) GetType() string {
return "cmdline"
}
func AddCmdLineUpdate(update *ModelUpdate, cmdLine utilfn.StrWithPos) {
AddUpdate(update, CmdLineUpdate(cmdLine))
}
type InfoMsgType struct {
InfoTitle string `json:"infotitle"`
InfoError string `json:"infoerror,omitempty"`
@ -57,21 +57,21 @@ type InfoMsgType struct {
TimeoutMs int64 `json:"timeoutms,omitempty"`
}
func (InfoMsgType) UpdateType() string {
func (InfoMsgType) GetType() string {
return "info"
}
func InfoMsgUpdate(infoMsgFmt string, args ...interface{}) *ModelUpdate {
func InfoMsgUpdate(infoMsgFmt string, args ...interface{}) *scbus.ModelUpdatePacketType {
msg := fmt.Sprintf(infoMsgFmt, args...)
ret := &ModelUpdate{}
ret := scbus.MakeUpdatePacket()
newInfoUpdate := InfoMsgType{InfoMsg: msg}
AddUpdate(ret, newInfoUpdate)
ret.AddUpdate(newInfoUpdate)
return ret
}
// only sets InfoError if InfoError is not already set
func AddInfoMsgUpdateError(update *ModelUpdate, errStr string) {
infoUpdates := GetUpdateItems[InfoMsgType](update)
func AddInfoMsgUpdateError(update *scbus.ModelUpdatePacketType, errStr string) {
infoUpdates := scbus.GetUpdateItems[InfoMsgType](update)
if len(infoUpdates) > 0 {
lastUpdate := infoUpdates[len(infoUpdates)-1]
@ -80,13 +80,13 @@ func AddInfoMsgUpdateError(update *ModelUpdate, errStr string) {
return
}
} else {
AddUpdate(update, InfoMsgType{InfoError: errStr})
update.AddUpdate(InfoMsgType{InfoError: errStr})
}
}
type ClearInfoUpdate bool
func (ClearInfoUpdate) UpdateType() string {
func (ClearInfoUpdate) GetType() string {
return "clearinfo"
}
@ -98,20 +98,16 @@ type HistoryInfoType struct {
Show bool `json:"show"`
}
func (HistoryInfoType) UpdateType() string {
func (HistoryInfoType) GetType() string {
return "history"
}
type InteractiveUpdate bool
func (InteractiveUpdate) UpdateType() string {
func (InteractiveUpdate) GetType() string {
return "interactive"
}
func AddInteractiveUpdate(update *ModelUpdate, interactive bool) {
AddUpdate(update, InteractiveUpdate(interactive))
}
type ConnectUpdate struct {
Sessions []*SessionType `json:"sessions,omitempty"`
Screens []*ScreenType `json:"screens,omitempty"`
@ -121,7 +117,7 @@ type ConnectUpdate struct {
ActiveSessionId string `json:"activesessionid,omitempty"`
}
func (ConnectUpdate) UpdateType() string {
func (ConnectUpdate) GetType() string {
return "connect"
}
@ -131,7 +127,7 @@ type MainViewUpdate struct {
BookmarksView *BookmarksUpdate `json:"bookmarksview,omitempty"`
}
func (MainViewUpdate) UpdateType() string {
func (MainViewUpdate) GetType() string {
return "mainview"
}
@ -140,15 +136,15 @@ type BookmarksUpdate struct {
SelectedBookmark string `json:"selectedbookmark,omitempty"`
}
func (BookmarksUpdate) UpdateType() string {
func (BookmarksUpdate) GetType() string {
return "bookmarks"
}
func AddBookmarksUpdate(update *ModelUpdate, bookmarks []*BookmarkType, selectedBookmark *string) {
func AddBookmarksUpdate(update *scbus.ModelUpdatePacketType, bookmarks []*BookmarkType, selectedBookmark *string) {
if selectedBookmark == nil {
AddUpdate(update, BookmarksUpdate{Bookmarks: bookmarks})
update.AddUpdate(BookmarksUpdate{Bookmarks: bookmarks})
} else {
AddUpdate(update, BookmarksUpdate{Bookmarks: bookmarks, SelectedBookmark: *selectedBookmark})
update.AddUpdate(BookmarksUpdate{Bookmarks: bookmarks, SelectedBookmark: *selectedBookmark})
}
}
@ -177,20 +173,16 @@ type RemoteViewType struct {
RemoteEdit *RemoteEditType `json:"remoteedit,omitempty"`
}
func (RemoteViewType) UpdateType() string {
func (RemoteViewType) GetType() string {
return "remoteview"
}
type OpenAICmdInfoChatUpdate []*packet.OpenAICmdInfoChatMessage
func (OpenAICmdInfoChatUpdate) UpdateType() string {
func (OpenAICmdInfoChatUpdate) GetType() string {
return "openaicmdinfochat"
}
func AddOpenAICmdInfoChatUpdate(update *ModelUpdate, chatMessages []*packet.OpenAICmdInfoChatMessage) {
AddUpdate(update, OpenAICmdInfoChatUpdate(chatMessages))
}
type AlertMessageType struct {
Title string `json:"title,omitempty"`
Message string `json:"message"`
@ -198,7 +190,7 @@ type AlertMessageType struct {
Markdown bool `json:"markdown,omitempty"`
}
func (AlertMessageType) UpdateType() string {
func (AlertMessageType) GetType() string {
return "alertmessage"
}
@ -207,7 +199,7 @@ type ScreenStatusIndicatorType struct {
Status StatusIndicatorLevel `json:"status"`
}
func (ScreenStatusIndicatorType) UpdateType() string {
func (ScreenStatusIndicatorType) GetType() string {
return "screenstatusindicator"
}
@ -216,19 +208,6 @@ type ScreenNumRunningCommandsType struct {
Num int `json:"num"`
}
func (ScreenNumRunningCommandsType) UpdateType() string {
func (ScreenNumRunningCommandsType) GetType() string {
return "screennumrunningcommands"
}
type UserInputRequestType struct {
RequestId string `json:"requestid"`
QueryText string `json:"querytext"`
ResponseType string `json:"responsetype"`
Title string `json:"title"`
Markdown bool `json:"markdown"`
TimeoutMs int `json:"timeoutms"`
}
func (UserInputRequestType) UpdateType() string {
return "userinputrequest"
}

View File

@ -0,0 +1,77 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
// Provides a mechanism for the backend to request user input from the frontend.
package userinput
import (
"context"
"fmt"
"reflect"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
)
// An RpcPacket for requesting user input from the client
type UserInputRequestType struct {
RequestId string `json:"requestid"`
QueryText string `json:"querytext"`
ResponseType string `json:"responsetype"`
Title string `json:"title"`
Markdown bool `json:"markdown"`
TimeoutMs int `json:"timeoutms"`
}
func (*UserInputRequestType) GetType() string {
return "userinputrequest"
}
func (req *UserInputRequestType) SetReqId(reqId string) {
req.RequestId = reqId
}
func (req *UserInputRequestType) SetTimeoutMs(timeoutMs int) {
req.TimeoutMs = timeoutMs
}
const UserInputResponsePacketStr = "userinputresp"
// An RpcResponse for user input requests
type UserInputResponsePacketType struct {
Type string `json:"type"`
RequestId string `json:"requestid"`
Text string `json:"text,omitempty"`
Confirm bool `json:"confirm,omitempty"`
ErrorMsg string `json:"errormsg,omitempty"`
}
func (*UserInputResponsePacketType) GetType() string {
return UserInputResponsePacketStr
}
func (pk *UserInputResponsePacketType) GetError() string {
return pk.ErrorMsg
}
func (pk *UserInputResponsePacketType) SetError(err string) {
pk.ErrorMsg = err
}
// Send a user input request to the frontend and wait for a response
func GetUserInput(ctx context.Context, bus *scbus.RpcBus, userInputRequest *UserInputRequestType) (*UserInputResponsePacketType, error) {
resp, err := scbus.MainRpcBus.DoRpc(ctx, userInputRequest)
if err != nil {
return nil, err
}
if ret, ok := resp.(*UserInputResponsePacketType); !ok {
return nil, fmt.Errorf("unexpected response type: %v", reflect.TypeOf(resp))
} else {
return ret, nil
}
}
func init() {
// Register the user input request packet type
packet.RegisterPacketType(UserInputResponsePacketStr, reflect.TypeOf(UserInputResponsePacketType{}))
}