diff --git a/public/themes/term-default.css b/public/themes/term-default.css index 9e3fbc31d..6f2f990cc 100644 --- a/public/themes/term-default.css +++ b/public/themes/term-default.css @@ -27,6 +27,6 @@ --term-cmdtext: #ffffff; --term-foreground: #d3d7cf; --term-background: #000000; - --term-selection-background: #ffffff90; + --term-selection-background: #ffffff60; --term-cursor-accent: #000000; } diff --git a/waveshell/main-waveshell.go b/waveshell/main-waveshell.go index d032b194a..8e1228852 100644 --- a/waveshell/main-waveshell.go +++ b/waveshell/main-waveshell.go @@ -135,6 +135,7 @@ func main() { base.ProcessType = base.ProcessType_WaveShellServer wlog.GlobalSubsystem = base.ProcessType_WaveShellServer base.InitDebugLog("server") + base.EnsureRcFilesDir() rtnCode, err := server.RunServer() if err != nil { fmt.Fprintf(os.Stderr, "[error] %v\n", err) diff --git a/waveshell/pkg/base/base.go b/waveshell/pkg/base/base.go index a571f189e..a66d4c4ab 100644 --- a/waveshell/pkg/base/base.go +++ b/waveshell/pkg/base/base.go @@ -12,7 +12,6 @@ import ( "os" "os/exec" "path" - "path/filepath" "strings" "sync" @@ -72,11 +71,11 @@ func IsWaveSrv() bool { return ProcessType == ProcessType_WaveSrv } -func MakeCommandKey(sessionId string, cmdId string) CommandKey { - if sessionId == "" && cmdId == "" { +func MakeCommandKey(screenId string, lineId string) CommandKey { + if screenId == "" && lineId == "" { return CommandKey("") } - return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId)) + return CommandKey(fmt.Sprintf("%s/%s", screenId, lineId)) } func (ckey CommandKey) IsEmpty() bool { @@ -200,51 +199,6 @@ func GetMShellHomeDir() string { return ExpandHomeDir(DefaultMShellHome) } -func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) { - if err := ck.Validate("ck"); err != nil { - return nil, fmt.Errorf("cannot get command files: %w", err) - } - sessionId, cmdId := ck.Split() - sdir, err := EnsureSessionDir(sessionId) - if err != nil { - return nil, err - } - base := path.Join(sdir, cmdId) - return &CommandFileNames{ - PtyOutFile: base + ".ptyout", - StdinFifo: base + ".stdin", - RunnerOutFile: base + ".runout", - }, nil -} - -func CleanUpCmdFiles(sessionId string, cmdId string) error { - if cmdId == "" { - return fmt.Errorf("bad cmdid, cannot clean up") - } - sdir, err := EnsureSessionDir(sessionId) - if err != nil { - return err - } - cmdFileGlob := path.Join(sdir, cmdId+".*") - matches, err := filepath.Glob(cmdFileGlob) - if err != nil { - return err - } - for _, file := range matches { - rmErr := os.Remove(file) - if err == nil && rmErr != nil { - err = rmErr - } - } - return err -} - -func GetSessionsDir() string { - mhome := GetMShellHomeDir() - sdir := path.Join(mhome, SessionsDirBaseName) - return sdir -} - func EnsureRcFilesDir() (string, error) { mhome := GetMShellHomeDir() dirName := path.Join(mhome, RcFilesDirBaseName) @@ -255,19 +209,6 @@ func EnsureRcFilesDir() (string, error) { return dirName, nil } -func EnsureSessionDir(sessionId string) (string, error) { - if sessionId == "" { - return "", fmt.Errorf("Bad sessionid, cannot be empty") - } - mhome := GetMShellHomeDir() - sdir := path.Join(mhome, SessionsDirBaseName, sessionId) - err := CacheEnsureDir(sdir, sessionId, 0777, "mshell session dir") - if err != nil { - return "", err - } - return sdir, nil -} - func GetMShellPath() (string, error) { msPath := os.Getenv(MShellPathVarName) // use MSHELL_PATH if msPath != "" { @@ -282,11 +223,6 @@ func GetMShellPath() (string, error) { return exec.LookPath(DefaultMShellName) // standard path lookup for 'mshell' } -func GetMShellSessionsDir() (string, error) { - mhome := GetMShellHomeDir() - return path.Join(mhome, SessionsDirBaseName), nil -} - func ExpandHomeDir(pathStr string) string { if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") { return pathStr @@ -315,22 +251,6 @@ func GoArchOptFile(version string, goos string, goarch string) string { return fmt.Sprintf(path.Join(installBinDir, binBaseName)) } -func MShellBinaryFromOptDir(version string, goos string, goarch string) (io.ReadCloser, error) { - if !ValidGoArch(goos, goarch) { - return nil, fmt.Errorf("invalid goos/goarch combination: %s/%s", goos, goarch) - } - versionStr := semver.MajorMinor(version) - if versionStr == "" { - return nil, fmt.Errorf("invalid mshell version: %q", version) - } - fileName := GoArchOptFile(version, goos, goarch) - fd, err := os.Open(fileName) - if err != nil { - return nil, fmt.Errorf("cannot open mshell binary %q: %v", fileName, err) - } - return fd, nil -} - func GetRemoteId() (string, error) { mhome := GetMShellHomeDir() homeInfo, err := os.Stat(mhome) diff --git a/waveshell/pkg/packet/packet.go b/waveshell/pkg/packet/packet.go index d204075eb..b081368f9 100644 --- a/waveshell/pkg/packet/packet.go +++ b/waveshell/pkg/packet/packet.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "io/fs" + "log" "os" "reflect" "sync" @@ -63,6 +64,7 @@ const ( FileStatPacketStr = "filestat" LogPacketStr = "log" // logging packet (sent from waveshell back to server) ShellStatePacketStr = "shellstate" + RpcInputPacketStr = "rpcinput" // rpc-followup OpenAIPacketStr = "openai" // other OpenAICloudReqStr = "openai-cloudreq" @@ -116,6 +118,7 @@ func init() { TypeStrToFactory[LogPacketStr] = reflect.TypeOf(LogPacketType{}) TypeStrToFactory[ShellStatePacketStr] = reflect.TypeOf(ShellStatePacketType{}) TypeStrToFactory[FileStatPacketStr] = reflect.TypeOf(FileStatPacketType{}) + TypeStrToFactory[RpcInputPacketStr] = reflect.TypeOf(RpcInputPacketType{}) var _ RpcPacketType = (*RunPacketType)(nil) var _ RpcPacketType = (*GetCmdPacketType)(nil) @@ -134,6 +137,9 @@ func init() { var _ RpcResponsePacketType = (*WriteFileDonePacketType)(nil) var _ RpcResponsePacketType = (*ShellStatePacketType)(nil) + var _ RpcFollowUpPacketType = (*FileDataPacketType)(nil) + var _ RpcFollowUpPacketType = (*RpcInputPacketType)(nil) + var _ CommandPacketType = (*DataPacketType)(nil) var _ CommandPacketType = (*DataAckPacketType)(nil) var _ CommandPacketType = (*CmdDonePacketType)(nil) @@ -166,6 +172,26 @@ func MakePingPacket() *PingPacketType { return &PingPacketType{Type: PingPacketStr} } +type RpcInputPacketType struct { + Type string `json:"type"` + ReqId string `json:"reqid"` + Data []byte `json:"data"` +} + +func (*RpcInputPacketType) GetType() string { + return RpcInputPacketStr +} + +func (p *RpcInputPacketType) GetAssociatedReqId() string { + return p.ReqId +} + +func MakeRpcInputPacket(reqId string) *RpcInputPacketType { + return &RpcInputPacketType{Type: RpcInputPacketStr, ReqId: reqId} +} + +// these packets can travel either direction +// so it is both a RpcResponsePacketType and an RpcFollowUpPacketType type FileDataPacketType struct { Type string `json:"type"` RespId string `json:"respid"` @@ -185,6 +211,10 @@ func MakeFileDataPacket(reqId string) *FileDataPacketType { } } +func (p *FileDataPacketType) GetAssociatedReqId() string { + return p.RespId +} + func (p *FileDataPacketType) GetResponseId() string { return p.RespId } @@ -976,6 +1006,12 @@ type CommandPacketType interface { GetCK() base.CommandKey } +// RpcPackets initiate an Rpc. these can be part of the data passed back and forth +type RpcFollowUpPacketType interface { + GetType() string + GetAssociatedReqId() string +} + type ModelUpdatePacketType struct { Type string `json:"type"` Updates []any `json:"updates"` @@ -1178,6 +1214,14 @@ func (sender *PacketSender) SendPacketCtx(ctx context.Context, pk PacketType) er } func (sender *PacketSender) SendPacket(pk PacketType) error { + if pk == nil { + log.Printf("tried to send nil packet\n") + return fmt.Errorf("tried to send nil packet") + } + if pk.GetType() == "" { + log.Printf("tried to send invalid packet: %T\n", pk) + return fmt.Errorf("tried to send packet without a type: %T", pk) + } err := sender.checkStatus() if err != nil { return err diff --git a/waveshell/pkg/server/server.go b/waveshell/pkg/server/server.go index 6c956e32a..39ee26b4e 100644 --- a/waveshell/pkg/server/server.go +++ b/waveshell/pkg/server/server.go @@ -30,6 +30,7 @@ const MaxFileDataPacketSize = 16 * 1024 const WriteFileContextTimeout = 30 * time.Second const cleanLoopTime = 5 * time.Second const MaxWriteFileContextData = 100 +const InboundRpcErrorTimeoutTime = 30 * time.Second type shellStateMapKey struct { ShellType string @@ -52,8 +53,89 @@ type MServer struct { StateMap *ShellStateMap WriteErrorCh chan bool // closed if there is a I/O write error WriteErrorChOnce *sync.Once - WriteFileContextMap map[string]*WriteFileContext Done bool + InboundRpcHandlers map[string]RpcHandler + InboundRpcErrorSent map[string]time.Time // limits the amount of error messages sent back to the client +} + +var _ RpcHandler = (*WriteFileContext)(nil) + +type RpcHandler interface { + GetTimeoutTime() time.Time + DispatchPacket(reqId string, pk packet.RpcFollowUpPacketType) + UnRegisterCallback() +} + +func (m *MServer) registerRpcHandler(reqId string, handler RpcHandler) error { + if handler == nil { + return errors.New("handler is nil") + } + m.Lock.Lock() + defer m.Lock.Unlock() + if m.InboundRpcHandlers[reqId] != nil { + return errors.New("handler already registered") + } + delete(m.InboundRpcErrorSent, reqId) + m.InboundRpcHandlers[reqId] = handler + return nil +} + +func (m *MServer) unregisterRpcHandler(reqId string) { + m.Lock.Lock() + defer m.Lock.Unlock() + handler := m.InboundRpcHandlers[reqId] + delete(m.InboundRpcHandlers, reqId) + if handler != nil { + handler.UnRegisterCallback() + } +} + +// limits the number of error responses that can be sent +func (m *MServer) sendInboundRpcError(reqId string, err error) { + m.Lock.Lock() + defer m.Lock.Unlock() + if _, ok := m.InboundRpcErrorSent[reqId]; ok { + return + } + m.InboundRpcErrorSent[reqId] = time.Now().Add(InboundRpcErrorTimeoutTime) + m.Sender.SendErrorResponse(reqId, err) +} + +// returns true if dispatched to a waiting RPC handler +func (m *MServer) dispatchRpcFollowUp(pk packet.RpcFollowUpPacketType) bool { + if pk == nil { + return true + } + reqId := pk.GetAssociatedReqId() + m.Lock.Lock() + defer m.Lock.Unlock() + if rh := m.InboundRpcHandlers[reqId]; rh != nil { + rh.DispatchPacket(reqId, pk) + return true + } + return false +} + +func (m *MServer) cleanRpcHandlers() { + var staleHandlers []RpcHandler + now := time.Now() + m.Lock.Lock() + for reqId, rh := range m.InboundRpcHandlers { + if now.After(rh.GetTimeoutTime()) { + staleHandlers = append(staleHandlers, rh) + delete(m.InboundRpcHandlers, reqId) + } + } + for reqId, timeoutTime := range m.InboundRpcErrorSent { + if now.After(timeoutTime) { + delete(m.InboundRpcErrorSent, reqId) + } + } + m.Lock.Unlock() + // we do this outside of m.Lock just in case there is some lock contention (UnRegisterCallback might be slow) + for _, rh := range staleHandlers { + rh.UnRegisterCallback() + } } type WriteFileContext struct { @@ -78,25 +160,13 @@ func (m *MServer) checkDone() bool { return m.Done } -func (m *MServer) getWriteFileContext(reqId string) *WriteFileContext { - m.Lock.Lock() - defer m.Lock.Unlock() - wfc := m.WriteFileContextMap[reqId] - if wfc == nil { - wfc = &WriteFileContext{ - CVar: sync.NewCond(&sync.Mutex{}), - LastActive: time.Now(), - } - m.WriteFileContextMap[reqId] = wfc - } - return wfc +func (wfc *WriteFileContext) GetTimeoutTime() time.Time { + return wfc.LastActive.Add(WriteFileContextTimeout) } -func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) { - m.Lock.Lock() - wfc := m.WriteFileContextMap[pk.RespId] - m.Lock.Unlock() - if wfc == nil { +func (wfc *WriteFileContext) DispatchPacket(reqId string, pkArg packet.RpcFollowUpPacketType) { + dataPk, ok := pkArg.(*packet.FileDataPacketType) + if !ok { return } wfc.CVar.L.Lock() @@ -111,11 +181,11 @@ func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) { return } wfc.LastActive = time.Now() - wfc.Data = append(wfc.Data, pk) + wfc.Data = append(wfc.Data, dataPk) wfc.CVar.Signal() } -func (wfc *WriteFileContext) setDone() { +func (wfc *WriteFileContext) UnRegisterCallback() { wfc.CVar.L.Lock() defer wfc.CVar.L.Unlock() wfc.Done = true @@ -123,24 +193,6 @@ func (wfc *WriteFileContext) setDone() { wfc.CVar.Broadcast() } -func (m *MServer) cleanWriteFileContexts() { - now := time.Now() - var staleWfcs []*WriteFileContext - m.Lock.Lock() - for reqId, wfc := range m.WriteFileContextMap { - if now.Sub(wfc.LastActive) > WriteFileContextTimeout { - staleWfcs = append(staleWfcs, wfc) - delete(m.WriteFileContextMap, reqId) - } - } - m.Lock.Unlock() - - // we do this outside of m.Lock just in case there is some lock contention (end of WriteFile could theoretically be slow) - for _, wfc := range staleWfcs { - wfc.setDone() - } -} - func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) { ck := pk.GetCK() if ck == "" { @@ -155,7 +207,6 @@ func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) { return } cproc.Input.SendPacket(pk) - return } func runSingleCompGen(cwd string, compType string, prefix string) ([]string, bool, error) { @@ -226,7 +277,6 @@ func (m *MServer) runMixedCompGen(compPk *packet.CompGenPacketType) { } sort.Strings(comps) // resort m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": (hasMoreFiles || hasMoreDirs)}) - return } func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) { @@ -246,10 +296,46 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) { m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore}) } +type ReinitRpcHandler struct { + ReqId string + TimeoutTime time.Time + StdinDataCh chan []byte +} + +func (rh *ReinitRpcHandler) GetTimeoutTime() time.Time { + return rh.TimeoutTime +} + +func (rh *ReinitRpcHandler) DispatchPacket(reqId string, pkArg packet.RpcFollowUpPacketType) { + rpcInput, ok := pkArg.(*packet.RpcInputPacketType) + if !ok { + wlog.Logf("reinit rpc handler: invalid packet type: %T", pkArg) + return + } + // nonblocking send + select { + case rh.StdinDataCh <- rpcInput.Data: + default: + wlog.Logf("reinit rpc handler: stdin data channel full, dropping data") + } +} + +func (rh *ReinitRpcHandler) UnRegisterCallback() { + close(rh.StdinDataCh) +} + func (m *MServer) reinit(reqId string, shellType string) { - ssPk, err := m.MakeShellStatePacket(reqId, shellType) + stdinDataCh := make(chan []byte, 10) + rh := &ReinitRpcHandler{ + ReqId: reqId, + TimeoutTime: time.Now().Add(30 * time.Second), + StdinDataCh: stdinDataCh, + } + m.registerRpcHandler(reqId, rh) + defer m.unregisterRpcHandler(reqId) + ssPk, err := m.MakeShellStatePacket(reqId, shellType, stdinDataCh) if err != nil { - m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err)) + m.Sender.SendErrorResponse(reqId, fmt.Errorf("error initializing shell: %w", err)) return } err = m.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State) @@ -261,13 +347,13 @@ func (m *MServer) reinit(reqId string, shellType string) { m.Sender.SendPacket(ssPk) } -func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) { +func (m *MServer) MakeShellStatePacket(reqId string, shellType string, stdinDataCh chan []byte) (*packet.ShellStatePacketType, error) { sapi, err := shellapi.MakeShellApi(shellType) if err != nil { return nil, err } rtnCh := make(chan shellapi.ShellStateOutput, 1) - go sapi.GetShellState(rtnCh) + go sapi.GetShellState(rtnCh, stdinDataCh) for ssOutput := range rtnCh { if ssOutput.Error != "" { return nil, errors.New(ssOutput.Error) @@ -357,7 +443,7 @@ func copyFile(dstName string, srcName string) error { } func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContext) { - defer wfc.setDone() + defer m.unregisterRpcHandler(pk.ReqId) if pk.Path == "" { resp := packet.MakeWriteFileReadyPacket(pk.ReqId) resp.Error = "invalid write-file request, no path specified" @@ -478,7 +564,6 @@ func (m *MServer) returnStreamFileNewFileResponse(pk *packet.StreamFilePacketTyp Perm: int(dirInfo.Mode().Perm()), NotFound: true, } - return } func (m *MServer) streamFile(pk *packet.StreamFilePacketType) { @@ -608,12 +693,19 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) { return } if writePk, ok := pk.(*packet.WriteFilePacketType); ok { - wfc := m.getWriteFileContext(writePk.ReqId) + wfc := &WriteFileContext{ + CVar: sync.NewCond(&sync.Mutex{}), + LastActive: time.Now(), + } + err := m.registerRpcHandler(writePk.ReqId, wfc) + if err != nil { + m.Sender.SendErrorResponse(reqId, fmt.Errorf("error registering write-file handler: %w", err)) + return + } go m.writeFile(writePk, wfc) return } m.Sender.SendErrorResponse(reqId, fmt.Errorf("invalid rpc type '%s'", pk.GetType())) - return } func (m *MServer) clientPacketCallback(shellType string, pk packet.PacketType) { @@ -726,7 +818,7 @@ func (server *MServer) runReadLoop() { builder := packet.MakeRunPacketBuilder() for pk := range server.MainInput.MainCh { if server.Debug { - fmt.Printf("PK> %s\n", packet.AsString(pk)) + wlog.Logf("runReadLoop got packet %s\n", packet.AsString(pk)) } ok, runPacket := builder.ProcessPacket(pk) if ok { @@ -736,16 +828,19 @@ func (server *MServer) runReadLoop() { } continue } - if cmdPk, ok := pk.(packet.CommandPacketType); ok { + if cmdPk, ok := pk.(packet.CommandPacketType); ok && cmdPk.GetCK() != "" { server.ProcessCommandPacket(cmdPk) continue } - if rpcPk, ok := pk.(packet.RpcPacketType); ok { + if rpcPk, ok := pk.(packet.RpcPacketType); ok && rpcPk.GetReqId() != "" { server.ProcessRpcPacket(rpcPk) continue } - if fileDataPk, ok := pk.(*packet.FileDataPacketType); ok { - server.addFileDataPacket(fileDataPk) + if rpcFollowUp, ok := pk.(packet.RpcFollowUpPacketType); ok && rpcFollowUp.GetAssociatedReqId() != "" { + ok := server.dispatchRpcFollowUp(rpcFollowUp) + if !ok { + server.sendInboundRpcError(rpcFollowUp.GetAssociatedReqId(), fmt.Errorf("no handler for rpc follow-up packet")) + } continue } server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk)) @@ -765,7 +860,8 @@ func RunServer() (int, error) { Debug: debug, WriteErrorCh: make(chan bool), WriteErrorChOnce: &sync.Once{}, - WriteFileContextMap: make(map[string]*WriteFileContext), + InboundRpcHandlers: make(map[string]RpcHandler), + InboundRpcErrorSent: make(map[string]time.Time), } if debug { packet.GlobalDebug = true @@ -780,7 +876,7 @@ func RunServer() (int, error) { return } time.Sleep(cleanLoopTime) - server.cleanWriteFileContexts() + server.cleanRpcHandlers() } }() var err error diff --git a/waveshell/pkg/shellapi/bashapi.go b/waveshell/pkg/shellapi/bashapi.go index 567d3fb79..96efb7948 100644 --- a/waveshell/pkg/shellapi/bashapi.go +++ b/waveshell/pkg/shellapi/bashapi.go @@ -80,8 +80,8 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty return MakeBashShExecCommand(cmdStr, rcFileName, usePty) } -func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) { - GetBashShellState(outCh) +func (b bashShellApi) GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) { + GetBashShellState(outCh, stdinDataCh) } func (b bashShellApi) GetBaseShellOpts() string { @@ -169,7 +169,7 @@ func GetLocalBashMajorVersion() string { return localBashMajorVersion } -func GetBashShellState(outCh chan ShellStateOutput) { +func GetBashShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) { ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout) defer cancelFn() defer close(outCh) @@ -185,7 +185,7 @@ func GetBashShellState(outCh chan ShellStateOutput) { outCh <- ShellStateOutput{Output: outputBytes} } }() - outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes) + outputBytes, err := StreamCommandWithExtraFd(ctx, ecmd, outputCh, StateOutputFdNum, endBytes, stdinDataCh) outputWg.Wait() if err != nil { outCh <- ShellStateOutput{Error: err.Error()} diff --git a/waveshell/pkg/shellapi/shellapi.go b/waveshell/pkg/shellapi/shellapi.go index c9e7fc018..3b28a7ee2 100644 --- a/waveshell/pkg/shellapi/shellapi.go +++ b/waveshell/pkg/shellapi/shellapi.go @@ -25,9 +25,11 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" "github.com/wavetermdev/waveterm/waveshell/pkg/utilfn" + "github.com/wavetermdev/waveterm/waveshell/pkg/wlog" ) -const GetStateTimeout = 15 * time.Second +const GetStateTimeout = 10 * time.Second +const ReInitTimeout = GetStateTimeout + 2*time.Second const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"` const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"` const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"` @@ -71,7 +73,7 @@ type ShellApi interface { GetRemoteShellPath() string MakeRunCommand(cmdStr string, opts RunCommandOpts) string MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd - GetShellState(chan ShellStateOutput) + GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) GetBaseShellOpts() string ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error) MakeRcFileStr(pk *packet.RunPacketType) string @@ -154,7 +156,7 @@ func internalMacUserShell() string { const FirstExtraFilesFdNum = 3 // returns output(stdout+stderr), extraFdOutput, error -func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte) ([]byte, error) { +func StreamCommandWithExtraFd(ctx context.Context, ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte, stdinDataCh chan []byte) ([]byte, error) { defer close(outputCh) ecmd.Env = os.Environ() shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType)) @@ -202,8 +204,27 @@ func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum i defer outputWg.Done() utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes) }() + if stdinDataCh != nil { + go func() { + // continue this loop even after an error to drain stdinDataCh + hadErr := false + for stdinData := range stdinDataCh { + if hadErr { + continue + } + _, err := cmdPty.Write(stdinData) + if err != nil { + wlog.Logf("error writing to shellstate cmdpty (stdin): %v\n", err) + hadErr = true + } + } + }() + } exitErr := ecmd.Wait() if exitErr != nil { + if ctx.Err() != nil { + return nil, fmt.Errorf("%w (%w)", ctx.Err(), exitErr) + } return nil, exitErr } outputWg.Wait() diff --git a/waveshell/pkg/shellapi/zshapi.go b/waveshell/pkg/shellapi/zshapi.go index aa6ad1c70..43bfeed0f 100644 --- a/waveshell/pkg/shellapi/zshapi.go +++ b/waveshell/pkg/shellapi/zshapi.go @@ -246,7 +246,7 @@ func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr) } -func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) { +func (z zshShellApi) GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) { ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout) defer cancelFn() defer close(outCh) @@ -262,7 +262,7 @@ func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) { outCh <- ShellStateOutput{Output: outputBytes} } }() - outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes) + outputBytes, err := StreamCommandWithExtraFd(ctx, ecmd, outputCh, StateOutputFdNum, endBytes, stdinDataCh) outputWg.Wait() if err != nil { outCh <- ShellStateOutput{Error: err.Error()} @@ -726,7 +726,6 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta // sections: see ZshSection_* consts sections := bytes.Split(outputBytes, sectionSeparator) if len(sections) != ZshSection_NumFieldsExpected { - base.Logf("invalid -- numfields\n") return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections)) } rtn := &packet.ShellState{} diff --git a/wavesrv/pkg/cmdrunner/cmdrunner.go b/wavesrv/pkg/cmdrunner/cmdrunner.go index 5b45c2166..a2c1152b2 100644 --- a/wavesrv/pkg/cmdrunner/cmdrunner.go +++ b/wavesrv/pkg/cmdrunner/cmdrunner.go @@ -30,6 +30,7 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/waveshell/pkg/server" + "github.com/wavetermdev/waveterm/waveshell/pkg/shellapi" "github.com/wavetermdev/waveterm/waveshell/pkg/shellenv" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" "github.com/wavetermdev/waveterm/waveshell/pkg/shexec" @@ -1186,9 +1187,17 @@ func deferWriteCmdStatus(ctx context.Context, cmd *sstore.CmdType, startTime tim err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus) if err != nil { // nothing to do - log.Printf("error updating cmddoneinfo (in openai): %v\n", err) + log.Printf("error updating cmddoneinfo: %v\n", err) return } + screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, cmd.ScreenId, cmd.LineId) + if err != nil { + log.Printf("error trying to update screen focus type: %v\n", err) + // fall-through (nothing to do) + } + if screen != nil { + update.AddUpdate(*screen) + } scbus.MainUpdateBus.DoScreenUpdate(cmd.ScreenId, update) } @@ -3686,7 +3695,7 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) ( if err != nil { return nil, err } - update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil) + update, err := addLineForCmd(ctx, "/reset", true, ids, cmd, "", nil) if err != nil { return nil, err } @@ -3695,7 +3704,7 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) ( } func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) { - ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancelFn := context.WithTimeout(context.Background(), shellapi.ReInitTimeout) defer cancelFn() startTime := time.Now() var outputPos int64 @@ -3712,7 +3721,7 @@ func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verb writeStringToPty(ctx, cmd, string(data), &outputPos) } origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType) - ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType, dataFn, verbose) + ssPk, err := ids.Remote.MShell.ReInit(ctx, base.MakeCommandKey(cmd.ScreenId, cmd.LineId), shellType, dataFn, verbose) if err != nil { rtnErr = err return @@ -3975,14 +3984,14 @@ func splitLinesForInfo(str string) []string { } func resizeRunningCommand(ctx context.Context, cmd *sstore.CmdType, newCols int) error { - siPk := packet.MakeSpecialInputPacket() - siPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId) - siPk.WinSize = &packet.WinSize{Rows: int(cmd.TermOpts.Rows), Cols: newCols} + feInput := scpacket.MakeFeInputPacket() + feInput.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId) + feInput.WinSize = &packet.WinSize{Rows: int(cmd.TermOpts.Rows), Cols: newCols} msh := remote.GetRemoteById(cmd.Remote.RemoteId) if msh == nil { return fmt.Errorf("cannot resize, cmd remote not found") } - err := msh.SendSpecialInput(siPk) + err := msh.HandleFeInput(feInput) if err != nil { return err } @@ -5178,10 +5187,10 @@ func SignalCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus if !msh.IsConnected() { return nil, fmt.Errorf("cannot send signal, remote is not connected") } - siPk := packet.MakeSpecialInputPacket() - siPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId) - siPk.SigName = sigArg - err = msh.SendSpecialInput(siPk) + inputPk := scpacket.MakeFeInputPacket() + inputPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId) + inputPk.SigName = sigArg + err = msh.HandleFeInput(inputPk) if err != nil { return nil, fmt.Errorf("cannot send signal: %v", err) } diff --git a/wavesrv/pkg/cmdrunner/shparse.go b/wavesrv/pkg/cmdrunner/shparse.go index f28ee28da..1aa816bfd 100644 --- a/wavesrv/pkg/cmdrunner/shparse.go +++ b/wavesrv/pkg/cmdrunner/shparse.go @@ -215,6 +215,7 @@ var literalRtnStateCommands = []string{ "disable", "function", "zmodload", + "module", } func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string { diff --git a/wavesrv/pkg/remote/remote.go b/wavesrv/pkg/remote/remote.go index a1c658c7f..8956a198c 100644 --- a/wavesrv/pkg/remote/remote.go +++ b/wavesrv/pkg/remote/remote.go @@ -54,6 +54,7 @@ const RemoteTermCols = 80 const PtyReadBufSize = 100 const RemoteConnectTimeout = 15 * time.Second const RpcIterChannelSize = 100 +const MaxInputDataSize = 1000 var envVarsToStrip map[string]bool = map[string]bool{ "PROMPT": true, @@ -128,6 +129,7 @@ type pendingStateKey struct { RemotePtr sstore.RemotePtrType } +// provides state, acccess, and control for a waveshell server process type MShellProc struct { Lock *sync.Mutex Remote *sstore.RemoteType @@ -135,7 +137,7 @@ type MShellProc struct { // runtime RemoteId string // can be read without a lock Status string - ServerProc *shexec.ClientProc + ServerProc *shexec.ClientProc // the server process UName string Err error ErrNoInitPk bool @@ -154,9 +156,18 @@ type MShellProc struct { InstallCancelFn context.CancelFunc InstallErr error + // for synthetic commands (not run through RunCommand), this provides a way for them + // to register to receive input events from the frontend (e.g. ReInit) + CommandInputMap map[base.CommandKey]CommandInputSink + RunningCmds map[base.CommandKey]*RunCmdType - PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name] - Client *ssh.Client + PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name] (in progress commands that might update the state) + + Client *ssh.Client +} + +type CommandInputSink interface { + HandleInput(feInput *scpacket.FeInputPacketType) error } type RunCmdType struct { @@ -169,6 +180,22 @@ type RunCmdType struct { EphCancled atomic.Bool // only for Ephemeral commands, if true, then the command result should be discarded } +type ReinitCommandSink struct { + Remote *MShellProc + ReqId string +} + +func (rcs *ReinitCommandSink) HandleInput(feInput *scpacket.FeInputPacketType) error { + realData, err := base64.StdEncoding.DecodeString(feInput.InputData64) + if err != nil { + return fmt.Errorf("error decoding input data: %v", err) + } + inputPk := packet.MakeRpcInputPacket(rcs.ReqId) + inputPk.Data = realData + rcs.Remote.ServerProc.Input.SendPacket(inputPk) + return nil +} + type RemoteRuntimeState = sstore.RemoteRuntimeState func CanComplete(remoteType string) bool { @@ -196,7 +223,7 @@ func (msh *MShellProc) EnsureShellType(ctx context.Context, shellType string) er return nil } // try to reinit the shell - _, err := msh.ReInit(ctx, shellType, nil, false) + _, err := msh.ReInit(ctx, base.CommandKey(""), shellType, nil, false) if err != nil { return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err) } @@ -694,6 +721,7 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc { Status: StatusDisconnected, PtyBuffer: buf, InstallStatus: StatusDisconnected, + CommandInputMap: make(map[base.CommandKey]CommandInputSink), RunningCmds: make(map[base.CommandKey]*RunCmdType), PendingStateCmds: make(map[pendingStateKey]base.CommandKey), StateMap: server.MakeShellStateMap(), @@ -1401,7 +1429,7 @@ func makeReinitErrorUpdate(shellType string) sstore.ActivityUpdate { return rtn } -func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) { +func (msh *MShellProc) ReInit(ctx context.Context, ck base.CommandKey, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) { if !msh.IsConnected() { return nil, fmt.Errorf("cannot reinit, remote is not connected") } @@ -1425,6 +1453,14 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func return nil, err } defer rpcIter.Close() + if ck != "" { + reinitSink := &ReinitCommandSink{ + Remote: msh, + ReqId: reinitPk.ReqId, + } + msh.registerInputSink(ck, reinitSink) + defer msh.unregisterInputSink(ck) + } var ssPk *packet.ShellStatePacketType for { resp, err := rpcIter.Next(ctx) @@ -1506,6 +1542,9 @@ func addScVarsToState(state *packet.ShellState) *packet.ShellState { envMap["WAVETERM_VERSION"] = &shellenv.DeclareDeclType{Name: "WAVETERM_VERSION", Value: scbase.WaveVersion, Args: "x"} envMap["TERM_PROGRAM"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM", Value: "waveterm", Args: "x"} envMap["TERM_PROGRAM_VERSION"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM_VERSION", Value: scbase.WaveVersion, Args: "x"} + if scbase.IsDevMode() { + envMap["WAVETERM_DEV"] = &shellenv.DeclareDeclType{Name: "WAVETERM_DEV", Value: "1", Args: "x"} + } if _, exists := envMap["LANG"]; !exists { envMap["LANG"] = &shellenv.DeclareDeclType{Name: "LANG", Value: scbase.DetermineLang(), Args: "x"} } @@ -1727,20 +1766,28 @@ func (msh *MShellProc) Launch(interactive bool) { } func (msh *MShellProc) initActiveShells() { - ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + gasCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) defer cancelFn() - activeShells, err := msh.getActiveShellTypes(ctx) + activeShells, err := msh.getActiveShellTypes(gasCtx) if err != nil { // we're not going to fail the connect for this error (it will be unusable, but technically connected) msh.WriteToPtyBuffer("*error getting active shells: %v\n", err) return } - for _, shellType := range activeShells { - _, err = msh.ReInit(ctx, shellType, nil, false) - if err != nil { - msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err) - } + var wg sync.WaitGroup + for _, shellTypeForVar := range activeShells { + wg.Add(1) + go func(shellType string) { + defer wg.Done() + reinitCtx, cancelFn := context.WithTimeout(context.Background(), shellapi.ReInitTimeout) + defer cancelFn() + _, err = msh.ReInit(reinitCtx, base.CommandKey(""), shellType, nil, false) + if err != nil { + msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err) + } + }(shellTypeForVar) } + wg.Wait() } func (msh *MShellProc) IsConnected() bool { @@ -1775,24 +1822,14 @@ func (msh *MShellProc) IsCmdRunning(ck base.CommandKey) bool { return ok } -func (msh *MShellProc) SendInput(dataPk *packet.DataPacketType) error { - if !msh.IsConnected() { - return fmt.Errorf("remote is not connected, cannot send input") - } - if !msh.IsCmdRunning(dataPk.CK) { - return fmt.Errorf("cannot send input, cmd is not running") - } - return msh.ServerProc.Input.SendPacket(dataPk) -} - func (msh *MShellProc) KillRunningCommandAndWait(ctx context.Context, ck base.CommandKey) error { if !msh.IsCmdRunning(ck) { return nil } - siPk := packet.MakeSpecialInputPacket() - siPk.CK = ck - siPk.SigName = "SIGTERM" - err := msh.SendSpecialInput(siPk) + feiPk := scpacket.MakeFeInputPacket() + feiPk.CK = ck + feiPk.SigName = "SIGTERM" + err := msh.HandleFeInput(feiPk) if err != nil { return fmt.Errorf("error trying to kill running cmd: %w", err) } @@ -1809,16 +1846,6 @@ func (msh *MShellProc) KillRunningCommandAndWait(ctx context.Context, ck base.Co } } -func (msh *MShellProc) SendSpecialInput(siPk *packet.SpecialInputPacketType) error { - if !msh.IsConnected() { - return fmt.Errorf("remote is not connected, cannot send input") - } - if !msh.IsCmdRunning(siPk.CK) { - return fmt.Errorf("cannot send input, cmd is not running") - } - return msh.ServerProc.Input.SendPacket(siPk) -} - func (msh *MShellProc) SendFileData(dataPk *packet.FileDataPacketType) error { if !msh.IsConnected() { return fmt.Errorf("remote is not connected, cannot send input") @@ -2065,6 +2092,62 @@ func makePSCLineError(existingPSC base.CommandKey, line *sstore.LineType, lineEr return fmt.Errorf("cannot run command while a stateful command (linenum=%d) is still running", line.LineNum) } +func (msh *MShellProc) registerInputSink(ck base.CommandKey, sink CommandInputSink) { + msh.Lock.Lock() + defer msh.Lock.Unlock() + msh.CommandInputMap[ck] = sink +} + +func (msh *MShellProc) unregisterInputSink(ck base.CommandKey) { + msh.Lock.Lock() + defer msh.Lock.Unlock() + delete(msh.CommandInputMap, ck) +} + +func (msh *MShellProc) HandleFeInput(inputPk *scpacket.FeInputPacketType) error { + if inputPk == nil { + return nil + } + if !msh.IsConnected() { + return fmt.Errorf("connection is not connected, cannot send input") + } + if msh.IsCmdRunning(inputPk.CK) { + if len(inputPk.InputData64) > 0 { + inputLen := packet.B64DecodedLen(inputPk.InputData64) + if inputLen > MaxInputDataSize { + return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize) + } + dataPk := packet.MakeDataPacket() + dataPk.CK = inputPk.CK + dataPk.FdNum = 0 // stdin + dataPk.Data64 = inputPk.InputData64 + err := msh.ServerProc.Input.SendPacket(dataPk) + if err != nil { + return err + } + } + if inputPk.SigName != "" || inputPk.WinSize != nil { + siPk := packet.MakeSpecialInputPacket() + siPk.CK = inputPk.CK + siPk.SigName = inputPk.SigName + siPk.WinSize = inputPk.WinSize + err := msh.ServerProc.Input.SendPacket(siPk) + if err != nil { + return err + } + } + return nil + } + msh.Lock.Lock() + sink := msh.CommandInputMap[inputPk.CK] + msh.Lock.Unlock() + if sink == nil { + // no sink and no running command + return fmt.Errorf("cannot send input, cmd is not running") + } + return sink.HandleInput(inputPk) +} + func (msh *MShellProc) AddRunningCmd(rct *RunCmdType) { msh.Lock.Lock() defer msh.Lock.Unlock() diff --git a/wavesrv/pkg/rtnstate/rtnstate.go b/wavesrv/pkg/rtnstate/rtnstate.go index 76492bc47..447cc047c 100644 --- a/wavesrv/pkg/rtnstate/rtnstate.go +++ b/wavesrv/pkg/rtnstate/rtnstate.go @@ -115,6 +115,7 @@ var IgnoreVars = map[string]bool{ "WAVESHELL_VERSION": true, "WAVETERM": true, "WAVETERM_VERSION": true, + "WAVETERM_DEV": true, "TERM_PROGRAM": true, "TERM_PROGRAM_VERSION": true, "TERM_SESSION_ID": true, diff --git a/wavesrv/pkg/scpacket/scpacket.go b/wavesrv/pkg/scpacket/scpacket.go index 52c8a9ffa..24dad7c22 100644 --- a/wavesrv/pkg/scpacket/scpacket.go +++ b/wavesrv/pkg/scpacket/scpacket.go @@ -4,6 +4,7 @@ package scpacket import ( + "encoding/base64" "fmt" "reflect" "regexp" @@ -191,6 +192,14 @@ func MakeFeInputPacket() *FeInputPacketType { return &FeInputPacketType{Type: FeInputPacketStr} } +func (pk *FeInputPacketType) DecodeData() ([]byte, error) { + return base64.StdEncoding.DecodeString(pk.InputData64) +} + +func (pk *FeInputPacketType) SetData(data []byte) { + pk.InputData64 = base64.StdEncoding.EncodeToString(data) +} + func (*WatchScreenPacketType) GetType() string { return WatchScreenPacketStr } diff --git a/wavesrv/pkg/scws/scws.go b/wavesrv/pkg/scws/scws.go index e8e1c2269..aedd54f81 100644 --- a/wavesrv/pkg/scws/scws.go +++ b/wavesrv/pkg/scws/scws.go @@ -23,7 +23,6 @@ import ( ) const WSStatePacketChSize = 20 -const MaxInputDataSize = 1000 const RemoteInputQueueSize = 100 var RemoteInputMapQueue *mapqueue.MapQueue @@ -247,7 +246,7 @@ func (ws *WSState) processMessage(msgBytes []byte) error { err := RemoteInputMapQueue.Enqueue(feInputPk.Remote.RemoteId, func() { sendErr := sendCmdInput(feInputPk) if sendErr != nil { - log.Printf("[scws] sending command input: %v\n", err) + log.Printf("[scws] sending command input: %v\n", sendErr) } }) if err != nil { @@ -263,7 +262,7 @@ func (ws *WSState) processMessage(msgBytes []byte) error { go func() { sendErr := remote.SendRemoteInput(inputPk) if sendErr != nil { - log.Printf("[scws] error processing remote input: %v\n", err) + log.Printf("[scws] error processing remote input: %v\n", sendErr) } }() return nil @@ -319,29 +318,5 @@ func sendCmdInput(pk *scpacket.FeInputPacketType) error { if msh == nil { return fmt.Errorf("remote %s not found", pk.Remote.RemoteId) } - if len(pk.InputData64) > 0 { - inputLen := packet.B64DecodedLen(pk.InputData64) - if inputLen > MaxInputDataSize { - return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize) - } - dataPk := packet.MakeDataPacket() - dataPk.CK = pk.CK - dataPk.FdNum = 0 // stdin - dataPk.Data64 = pk.InputData64 - err = msh.SendInput(dataPk) - if err != nil { - return err - } - } - if pk.SigName != "" || pk.WinSize != nil { - siPk := packet.MakeSpecialInputPacket() - siPk.CK = pk.CK - siPk.SigName = pk.SigName - siPk.WinSize = pk.WinSize - err = msh.SendSpecialInput(siPk) - if err != nil { - return err - } - } - return nil + return msh.HandleFeInput(pk) }