more zsh reinitialization fixes (allow user input during initialization process) (#480)

* fix error logs in scws

* new RpcFollowUpPacketType

* make the rpc/followup handlers generic on the server side -- using new RpcHandlers map and RpcFollowUpPacketType

* rpcinputpacket for passing user input back through to reinit command

* add WAVETERM_DEV env var in dev mode

* remove unused code, ensure mshell and rcfile directory on startup (prevent root clobber with sudo)

* combine all feinput into one function msh.HandleFeInput, and add a new concept of input sinks for special cases (like reinit)

* allow reset to accept user input (to get around interactive initialization problems)

* tone down the selection background highlight color on dark mode.  easier to read selected text

* fix command focus and done focus issues with dynamic (non-run) commands

* add 'module' as a 'rtnstate' command (#478)

* reinitialize shells in parallel, fix timeouts, better error messages
This commit is contained in:
Mike Sawka 2024-03-20 23:38:05 -07:00 committed by GitHub
parent fb59e094e4
commit 0781e6e821
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 384 additions and 225 deletions

View File

@ -27,6 +27,6 @@
--term-cmdtext: #ffffff;
--term-foreground: #d3d7cf;
--term-background: #000000;
--term-selection-background: #ffffff90;
--term-selection-background: #ffffff60;
--term-cursor-accent: #000000;
}

View File

@ -135,6 +135,7 @@ func main() {
base.ProcessType = base.ProcessType_WaveShellServer
wlog.GlobalSubsystem = base.ProcessType_WaveShellServer
base.InitDebugLog("server")
base.EnsureRcFilesDir()
rtnCode, err := server.RunServer()
if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err)

View File

@ -12,7 +12,6 @@ import (
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"sync"
@ -72,11 +71,11 @@ func IsWaveSrv() bool {
return ProcessType == ProcessType_WaveSrv
}
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
if sessionId == "" && cmdId == "" {
func MakeCommandKey(screenId string, lineId string) CommandKey {
if screenId == "" && lineId == "" {
return CommandKey("")
}
return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId))
return CommandKey(fmt.Sprintf("%s/%s", screenId, lineId))
}
func (ckey CommandKey) IsEmpty() bool {
@ -200,51 +199,6 @@ func GetMShellHomeDir() string {
return ExpandHomeDir(DefaultMShellHome)
}
func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
if err := ck.Validate("ck"); err != nil {
return nil, fmt.Errorf("cannot get command files: %w", err)
}
sessionId, cmdId := ck.Split()
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return nil, err
}
base := path.Join(sdir, cmdId)
return &CommandFileNames{
PtyOutFile: base + ".ptyout",
StdinFifo: base + ".stdin",
RunnerOutFile: base + ".runout",
}, nil
}
func CleanUpCmdFiles(sessionId string, cmdId string) error {
if cmdId == "" {
return fmt.Errorf("bad cmdid, cannot clean up")
}
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return err
}
cmdFileGlob := path.Join(sdir, cmdId+".*")
matches, err := filepath.Glob(cmdFileGlob)
if err != nil {
return err
}
for _, file := range matches {
rmErr := os.Remove(file)
if err == nil && rmErr != nil {
err = rmErr
}
}
return err
}
func GetSessionsDir() string {
mhome := GetMShellHomeDir()
sdir := path.Join(mhome, SessionsDirBaseName)
return sdir
}
func EnsureRcFilesDir() (string, error) {
mhome := GetMShellHomeDir()
dirName := path.Join(mhome, RcFilesDirBaseName)
@ -255,19 +209,6 @@ func EnsureRcFilesDir() (string, error) {
return dirName, nil
}
func EnsureSessionDir(sessionId string) (string, error) {
if sessionId == "" {
return "", fmt.Errorf("Bad sessionid, cannot be empty")
}
mhome := GetMShellHomeDir()
sdir := path.Join(mhome, SessionsDirBaseName, sessionId)
err := CacheEnsureDir(sdir, sessionId, 0777, "mshell session dir")
if err != nil {
return "", err
}
return sdir, nil
}
func GetMShellPath() (string, error) {
msPath := os.Getenv(MShellPathVarName) // use MSHELL_PATH
if msPath != "" {
@ -282,11 +223,6 @@ func GetMShellPath() (string, error) {
return exec.LookPath(DefaultMShellName) // standard path lookup for 'mshell'
}
func GetMShellSessionsDir() (string, error) {
mhome := GetMShellHomeDir()
return path.Join(mhome, SessionsDirBaseName), nil
}
func ExpandHomeDir(pathStr string) string {
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
return pathStr
@ -315,22 +251,6 @@ func GoArchOptFile(version string, goos string, goarch string) string {
return fmt.Sprintf(path.Join(installBinDir, binBaseName))
}
func MShellBinaryFromOptDir(version string, goos string, goarch string) (io.ReadCloser, error) {
if !ValidGoArch(goos, goarch) {
return nil, fmt.Errorf("invalid goos/goarch combination: %s/%s", goos, goarch)
}
versionStr := semver.MajorMinor(version)
if versionStr == "" {
return nil, fmt.Errorf("invalid mshell version: %q", version)
}
fileName := GoArchOptFile(version, goos, goarch)
fd, err := os.Open(fileName)
if err != nil {
return nil, fmt.Errorf("cannot open mshell binary %q: %v", fileName, err)
}
return fd, nil
}
func GetRemoteId() (string, error) {
mhome := GetMShellHomeDir()
homeInfo, err := os.Stat(mhome)

View File

@ -12,6 +12,7 @@ import (
"fmt"
"io"
"io/fs"
"log"
"os"
"reflect"
"sync"
@ -63,6 +64,7 @@ const (
FileStatPacketStr = "filestat"
LogPacketStr = "log" // logging packet (sent from waveshell back to server)
ShellStatePacketStr = "shellstate"
RpcInputPacketStr = "rpcinput" // rpc-followup
OpenAIPacketStr = "openai" // other
OpenAICloudReqStr = "openai-cloudreq"
@ -116,6 +118,7 @@ func init() {
TypeStrToFactory[LogPacketStr] = reflect.TypeOf(LogPacketType{})
TypeStrToFactory[ShellStatePacketStr] = reflect.TypeOf(ShellStatePacketType{})
TypeStrToFactory[FileStatPacketStr] = reflect.TypeOf(FileStatPacketType{})
TypeStrToFactory[RpcInputPacketStr] = reflect.TypeOf(RpcInputPacketType{})
var _ RpcPacketType = (*RunPacketType)(nil)
var _ RpcPacketType = (*GetCmdPacketType)(nil)
@ -134,6 +137,9 @@ func init() {
var _ RpcResponsePacketType = (*WriteFileDonePacketType)(nil)
var _ RpcResponsePacketType = (*ShellStatePacketType)(nil)
var _ RpcFollowUpPacketType = (*FileDataPacketType)(nil)
var _ RpcFollowUpPacketType = (*RpcInputPacketType)(nil)
var _ CommandPacketType = (*DataPacketType)(nil)
var _ CommandPacketType = (*DataAckPacketType)(nil)
var _ CommandPacketType = (*CmdDonePacketType)(nil)
@ -166,6 +172,26 @@ func MakePingPacket() *PingPacketType {
return &PingPacketType{Type: PingPacketStr}
}
type RpcInputPacketType struct {
Type string `json:"type"`
ReqId string `json:"reqid"`
Data []byte `json:"data"`
}
func (*RpcInputPacketType) GetType() string {
return RpcInputPacketStr
}
func (p *RpcInputPacketType) GetAssociatedReqId() string {
return p.ReqId
}
func MakeRpcInputPacket(reqId string) *RpcInputPacketType {
return &RpcInputPacketType{Type: RpcInputPacketStr, ReqId: reqId}
}
// these packets can travel either direction
// so it is both a RpcResponsePacketType and an RpcFollowUpPacketType
type FileDataPacketType struct {
Type string `json:"type"`
RespId string `json:"respid"`
@ -185,6 +211,10 @@ func MakeFileDataPacket(reqId string) *FileDataPacketType {
}
}
func (p *FileDataPacketType) GetAssociatedReqId() string {
return p.RespId
}
func (p *FileDataPacketType) GetResponseId() string {
return p.RespId
}
@ -976,6 +1006,12 @@ type CommandPacketType interface {
GetCK() base.CommandKey
}
// RpcPackets initiate an Rpc. these can be part of the data passed back and forth
type RpcFollowUpPacketType interface {
GetType() string
GetAssociatedReqId() string
}
type ModelUpdatePacketType struct {
Type string `json:"type"`
Updates []any `json:"updates"`
@ -1178,6 +1214,14 @@ func (sender *PacketSender) SendPacketCtx(ctx context.Context, pk PacketType) er
}
func (sender *PacketSender) SendPacket(pk PacketType) error {
if pk == nil {
log.Printf("tried to send nil packet\n")
return fmt.Errorf("tried to send nil packet")
}
if pk.GetType() == "" {
log.Printf("tried to send invalid packet: %T\n", pk)
return fmt.Errorf("tried to send packet without a type: %T", pk)
}
err := sender.checkStatus()
if err != nil {
return err

View File

@ -30,6 +30,7 @@ const MaxFileDataPacketSize = 16 * 1024
const WriteFileContextTimeout = 30 * time.Second
const cleanLoopTime = 5 * time.Second
const MaxWriteFileContextData = 100
const InboundRpcErrorTimeoutTime = 30 * time.Second
type shellStateMapKey struct {
ShellType string
@ -52,8 +53,89 @@ type MServer struct {
StateMap *ShellStateMap
WriteErrorCh chan bool // closed if there is a I/O write error
WriteErrorChOnce *sync.Once
WriteFileContextMap map[string]*WriteFileContext
Done bool
InboundRpcHandlers map[string]RpcHandler
InboundRpcErrorSent map[string]time.Time // limits the amount of error messages sent back to the client
}
var _ RpcHandler = (*WriteFileContext)(nil)
type RpcHandler interface {
GetTimeoutTime() time.Time
DispatchPacket(reqId string, pk packet.RpcFollowUpPacketType)
UnRegisterCallback()
}
func (m *MServer) registerRpcHandler(reqId string, handler RpcHandler) error {
if handler == nil {
return errors.New("handler is nil")
}
m.Lock.Lock()
defer m.Lock.Unlock()
if m.InboundRpcHandlers[reqId] != nil {
return errors.New("handler already registered")
}
delete(m.InboundRpcErrorSent, reqId)
m.InboundRpcHandlers[reqId] = handler
return nil
}
func (m *MServer) unregisterRpcHandler(reqId string) {
m.Lock.Lock()
defer m.Lock.Unlock()
handler := m.InboundRpcHandlers[reqId]
delete(m.InboundRpcHandlers, reqId)
if handler != nil {
handler.UnRegisterCallback()
}
}
// limits the number of error responses that can be sent
func (m *MServer) sendInboundRpcError(reqId string, err error) {
m.Lock.Lock()
defer m.Lock.Unlock()
if _, ok := m.InboundRpcErrorSent[reqId]; ok {
return
}
m.InboundRpcErrorSent[reqId] = time.Now().Add(InboundRpcErrorTimeoutTime)
m.Sender.SendErrorResponse(reqId, err)
}
// returns true if dispatched to a waiting RPC handler
func (m *MServer) dispatchRpcFollowUp(pk packet.RpcFollowUpPacketType) bool {
if pk == nil {
return true
}
reqId := pk.GetAssociatedReqId()
m.Lock.Lock()
defer m.Lock.Unlock()
if rh := m.InboundRpcHandlers[reqId]; rh != nil {
rh.DispatchPacket(reqId, pk)
return true
}
return false
}
func (m *MServer) cleanRpcHandlers() {
var staleHandlers []RpcHandler
now := time.Now()
m.Lock.Lock()
for reqId, rh := range m.InboundRpcHandlers {
if now.After(rh.GetTimeoutTime()) {
staleHandlers = append(staleHandlers, rh)
delete(m.InboundRpcHandlers, reqId)
}
}
for reqId, timeoutTime := range m.InboundRpcErrorSent {
if now.After(timeoutTime) {
delete(m.InboundRpcErrorSent, reqId)
}
}
m.Lock.Unlock()
// we do this outside of m.Lock just in case there is some lock contention (UnRegisterCallback might be slow)
for _, rh := range staleHandlers {
rh.UnRegisterCallback()
}
}
type WriteFileContext struct {
@ -78,25 +160,13 @@ func (m *MServer) checkDone() bool {
return m.Done
}
func (m *MServer) getWriteFileContext(reqId string) *WriteFileContext {
m.Lock.Lock()
defer m.Lock.Unlock()
wfc := m.WriteFileContextMap[reqId]
if wfc == nil {
wfc = &WriteFileContext{
CVar: sync.NewCond(&sync.Mutex{}),
LastActive: time.Now(),
}
m.WriteFileContextMap[reqId] = wfc
}
return wfc
func (wfc *WriteFileContext) GetTimeoutTime() time.Time {
return wfc.LastActive.Add(WriteFileContextTimeout)
}
func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) {
m.Lock.Lock()
wfc := m.WriteFileContextMap[pk.RespId]
m.Lock.Unlock()
if wfc == nil {
func (wfc *WriteFileContext) DispatchPacket(reqId string, pkArg packet.RpcFollowUpPacketType) {
dataPk, ok := pkArg.(*packet.FileDataPacketType)
if !ok {
return
}
wfc.CVar.L.Lock()
@ -111,11 +181,11 @@ func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) {
return
}
wfc.LastActive = time.Now()
wfc.Data = append(wfc.Data, pk)
wfc.Data = append(wfc.Data, dataPk)
wfc.CVar.Signal()
}
func (wfc *WriteFileContext) setDone() {
func (wfc *WriteFileContext) UnRegisterCallback() {
wfc.CVar.L.Lock()
defer wfc.CVar.L.Unlock()
wfc.Done = true
@ -123,24 +193,6 @@ func (wfc *WriteFileContext) setDone() {
wfc.CVar.Broadcast()
}
func (m *MServer) cleanWriteFileContexts() {
now := time.Now()
var staleWfcs []*WriteFileContext
m.Lock.Lock()
for reqId, wfc := range m.WriteFileContextMap {
if now.Sub(wfc.LastActive) > WriteFileContextTimeout {
staleWfcs = append(staleWfcs, wfc)
delete(m.WriteFileContextMap, reqId)
}
}
m.Lock.Unlock()
// we do this outside of m.Lock just in case there is some lock contention (end of WriteFile could theoretically be slow)
for _, wfc := range staleWfcs {
wfc.setDone()
}
}
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
ck := pk.GetCK()
if ck == "" {
@ -155,7 +207,6 @@ func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
return
}
cproc.Input.SendPacket(pk)
return
}
func runSingleCompGen(cwd string, compType string, prefix string) ([]string, bool, error) {
@ -226,7 +277,6 @@ func (m *MServer) runMixedCompGen(compPk *packet.CompGenPacketType) {
}
sort.Strings(comps) // resort
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": (hasMoreFiles || hasMoreDirs)})
return
}
func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
@ -246,10 +296,46 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
}
type ReinitRpcHandler struct {
ReqId string
TimeoutTime time.Time
StdinDataCh chan []byte
}
func (rh *ReinitRpcHandler) GetTimeoutTime() time.Time {
return rh.TimeoutTime
}
func (rh *ReinitRpcHandler) DispatchPacket(reqId string, pkArg packet.RpcFollowUpPacketType) {
rpcInput, ok := pkArg.(*packet.RpcInputPacketType)
if !ok {
wlog.Logf("reinit rpc handler: invalid packet type: %T", pkArg)
return
}
// nonblocking send
select {
case rh.StdinDataCh <- rpcInput.Data:
default:
wlog.Logf("reinit rpc handler: stdin data channel full, dropping data")
}
}
func (rh *ReinitRpcHandler) UnRegisterCallback() {
close(rh.StdinDataCh)
}
func (m *MServer) reinit(reqId string, shellType string) {
ssPk, err := m.MakeShellStatePacket(reqId, shellType)
stdinDataCh := make(chan []byte, 10)
rh := &ReinitRpcHandler{
ReqId: reqId,
TimeoutTime: time.Now().Add(30 * time.Second),
StdinDataCh: stdinDataCh,
}
m.registerRpcHandler(reqId, rh)
defer m.unregisterRpcHandler(reqId)
ssPk, err := m.MakeShellStatePacket(reqId, shellType, stdinDataCh)
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error initializing shell: %w", err))
return
}
err = m.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
@ -261,13 +347,13 @@ func (m *MServer) reinit(reqId string, shellType string) {
m.Sender.SendPacket(ssPk)
}
func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) {
func (m *MServer) MakeShellStatePacket(reqId string, shellType string, stdinDataCh chan []byte) (*packet.ShellStatePacketType, error) {
sapi, err := shellapi.MakeShellApi(shellType)
if err != nil {
return nil, err
}
rtnCh := make(chan shellapi.ShellStateOutput, 1)
go sapi.GetShellState(rtnCh)
go sapi.GetShellState(rtnCh, stdinDataCh)
for ssOutput := range rtnCh {
if ssOutput.Error != "" {
return nil, errors.New(ssOutput.Error)
@ -357,7 +443,7 @@ func copyFile(dstName string, srcName string) error {
}
func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContext) {
defer wfc.setDone()
defer m.unregisterRpcHandler(pk.ReqId)
if pk.Path == "" {
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
resp.Error = "invalid write-file request, no path specified"
@ -478,7 +564,6 @@ func (m *MServer) returnStreamFileNewFileResponse(pk *packet.StreamFilePacketTyp
Perm: int(dirInfo.Mode().Perm()),
NotFound: true,
}
return
}
func (m *MServer) streamFile(pk *packet.StreamFilePacketType) {
@ -608,12 +693,19 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
return
}
if writePk, ok := pk.(*packet.WriteFilePacketType); ok {
wfc := m.getWriteFileContext(writePk.ReqId)
wfc := &WriteFileContext{
CVar: sync.NewCond(&sync.Mutex{}),
LastActive: time.Now(),
}
err := m.registerRpcHandler(writePk.ReqId, wfc)
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error registering write-file handler: %w", err))
return
}
go m.writeFile(writePk, wfc)
return
}
m.Sender.SendErrorResponse(reqId, fmt.Errorf("invalid rpc type '%s'", pk.GetType()))
return
}
func (m *MServer) clientPacketCallback(shellType string, pk packet.PacketType) {
@ -726,7 +818,7 @@ func (server *MServer) runReadLoop() {
builder := packet.MakeRunPacketBuilder()
for pk := range server.MainInput.MainCh {
if server.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk))
wlog.Logf("runReadLoop got packet %s\n", packet.AsString(pk))
}
ok, runPacket := builder.ProcessPacket(pk)
if ok {
@ -736,16 +828,19 @@ func (server *MServer) runReadLoop() {
}
continue
}
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
if cmdPk, ok := pk.(packet.CommandPacketType); ok && cmdPk.GetCK() != "" {
server.ProcessCommandPacket(cmdPk)
continue
}
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
if rpcPk, ok := pk.(packet.RpcPacketType); ok && rpcPk.GetReqId() != "" {
server.ProcessRpcPacket(rpcPk)
continue
}
if fileDataPk, ok := pk.(*packet.FileDataPacketType); ok {
server.addFileDataPacket(fileDataPk)
if rpcFollowUp, ok := pk.(packet.RpcFollowUpPacketType); ok && rpcFollowUp.GetAssociatedReqId() != "" {
ok := server.dispatchRpcFollowUp(rpcFollowUp)
if !ok {
server.sendInboundRpcError(rpcFollowUp.GetAssociatedReqId(), fmt.Errorf("no handler for rpc follow-up packet"))
}
continue
}
server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk))
@ -765,7 +860,8 @@ func RunServer() (int, error) {
Debug: debug,
WriteErrorCh: make(chan bool),
WriteErrorChOnce: &sync.Once{},
WriteFileContextMap: make(map[string]*WriteFileContext),
InboundRpcHandlers: make(map[string]RpcHandler),
InboundRpcErrorSent: make(map[string]time.Time),
}
if debug {
packet.GlobalDebug = true
@ -780,7 +876,7 @@ func RunServer() (int, error) {
return
}
time.Sleep(cleanLoopTime)
server.cleanWriteFileContexts()
server.cleanRpcHandlers()
}
}()
var err error

View File

@ -80,8 +80,8 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
}
func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) {
GetBashShellState(outCh)
func (b bashShellApi) GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) {
GetBashShellState(outCh, stdinDataCh)
}
func (b bashShellApi) GetBaseShellOpts() string {
@ -169,7 +169,7 @@ func GetLocalBashMajorVersion() string {
return localBashMajorVersion
}
func GetBashShellState(outCh chan ShellStateOutput) {
func GetBashShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
defer close(outCh)
@ -185,7 +185,7 @@ func GetBashShellState(outCh chan ShellStateOutput) {
outCh <- ShellStateOutput{Output: outputBytes}
}
}()
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
outputBytes, err := StreamCommandWithExtraFd(ctx, ecmd, outputCh, StateOutputFdNum, endBytes, stdinDataCh)
outputWg.Wait()
if err != nil {
outCh <- ShellStateOutput{Error: err.Error()}

View File

@ -25,9 +25,11 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/wlog"
)
const GetStateTimeout = 15 * time.Second
const GetStateTimeout = 10 * time.Second
const ReInitTimeout = GetStateTimeout + 2*time.Second
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
@ -71,7 +73,7 @@ type ShellApi interface {
GetRemoteShellPath() string
MakeRunCommand(cmdStr string, opts RunCommandOpts) string
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
GetShellState(chan ShellStateOutput)
GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte)
GetBaseShellOpts() string
ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error)
MakeRcFileStr(pk *packet.RunPacketType) string
@ -154,7 +156,7 @@ func internalMacUserShell() string {
const FirstExtraFilesFdNum = 3
// returns output(stdout+stderr), extraFdOutput, error
func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte) ([]byte, error) {
func StreamCommandWithExtraFd(ctx context.Context, ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte, stdinDataCh chan []byte) ([]byte, error) {
defer close(outputCh)
ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
@ -202,8 +204,27 @@ func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum i
defer outputWg.Done()
utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes)
}()
if stdinDataCh != nil {
go func() {
// continue this loop even after an error to drain stdinDataCh
hadErr := false
for stdinData := range stdinDataCh {
if hadErr {
continue
}
_, err := cmdPty.Write(stdinData)
if err != nil {
wlog.Logf("error writing to shellstate cmdpty (stdin): %v\n", err)
hadErr = true
}
}
}()
}
exitErr := ecmd.Wait()
if exitErr != nil {
if ctx.Err() != nil {
return nil, fmt.Errorf("%w (%w)", ctx.Err(), exitErr)
}
return nil, exitErr
}
outputWg.Wait()

View File

@ -246,7 +246,7 @@ func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
}
func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) {
func (z zshShellApi) GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
defer close(outCh)
@ -262,7 +262,7 @@ func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) {
outCh <- ShellStateOutput{Output: outputBytes}
}
}()
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
outputBytes, err := StreamCommandWithExtraFd(ctx, ecmd, outputCh, StateOutputFdNum, endBytes, stdinDataCh)
outputWg.Wait()
if err != nil {
outCh <- ShellStateOutput{Error: err.Error()}
@ -726,7 +726,6 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
// sections: see ZshSection_* consts
sections := bytes.Split(outputBytes, sectionSeparator)
if len(sections) != ZshSection_NumFieldsExpected {
base.Logf("invalid -- numfields\n")
return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
}
rtn := &packet.ShellState{}

View File

@ -30,6 +30,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/server"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
@ -1186,9 +1187,17 @@ func deferWriteCmdStatus(ctx context.Context, cmd *sstore.CmdType, startTime tim
err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus)
if err != nil {
// nothing to do
log.Printf("error updating cmddoneinfo (in openai): %v\n", err)
log.Printf("error updating cmddoneinfo: %v\n", err)
return
}
screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, cmd.ScreenId, cmd.LineId)
if err != nil {
log.Printf("error trying to update screen focus type: %v\n", err)
// fall-through (nothing to do)
}
if screen != nil {
update.AddUpdate(*screen)
}
scbus.MainUpdateBus.DoScreenUpdate(cmd.ScreenId, update)
}
@ -3686,7 +3695,7 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
if err != nil {
return nil, err
}
update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil)
update, err := addLineForCmd(ctx, "/reset", true, ids, cmd, "", nil)
if err != nil {
return nil, err
}
@ -3695,7 +3704,7 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
}
func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) {
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancelFn := context.WithTimeout(context.Background(), shellapi.ReInitTimeout)
defer cancelFn()
startTime := time.Now()
var outputPos int64
@ -3712,7 +3721,7 @@ func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verb
writeStringToPty(ctx, cmd, string(data), &outputPos)
}
origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType, dataFn, verbose)
ssPk, err := ids.Remote.MShell.ReInit(ctx, base.MakeCommandKey(cmd.ScreenId, cmd.LineId), shellType, dataFn, verbose)
if err != nil {
rtnErr = err
return
@ -3975,14 +3984,14 @@ func splitLinesForInfo(str string) []string {
}
func resizeRunningCommand(ctx context.Context, cmd *sstore.CmdType, newCols int) error {
siPk := packet.MakeSpecialInputPacket()
siPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
siPk.WinSize = &packet.WinSize{Rows: int(cmd.TermOpts.Rows), Cols: newCols}
feInput := scpacket.MakeFeInputPacket()
feInput.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
feInput.WinSize = &packet.WinSize{Rows: int(cmd.TermOpts.Rows), Cols: newCols}
msh := remote.GetRemoteById(cmd.Remote.RemoteId)
if msh == nil {
return fmt.Errorf("cannot resize, cmd remote not found")
}
err := msh.SendSpecialInput(siPk)
err := msh.HandleFeInput(feInput)
if err != nil {
return err
}
@ -5178,10 +5187,10 @@ func SignalCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus
if !msh.IsConnected() {
return nil, fmt.Errorf("cannot send signal, remote is not connected")
}
siPk := packet.MakeSpecialInputPacket()
siPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
siPk.SigName = sigArg
err = msh.SendSpecialInput(siPk)
inputPk := scpacket.MakeFeInputPacket()
inputPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
inputPk.SigName = sigArg
err = msh.HandleFeInput(inputPk)
if err != nil {
return nil, fmt.Errorf("cannot send signal: %v", err)
}

View File

@ -215,6 +215,7 @@ var literalRtnStateCommands = []string{
"disable",
"function",
"zmodload",
"module",
}
func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string {

View File

@ -54,6 +54,7 @@ const RemoteTermCols = 80
const PtyReadBufSize = 100
const RemoteConnectTimeout = 15 * time.Second
const RpcIterChannelSize = 100
const MaxInputDataSize = 1000
var envVarsToStrip map[string]bool = map[string]bool{
"PROMPT": true,
@ -128,6 +129,7 @@ type pendingStateKey struct {
RemotePtr sstore.RemotePtrType
}
// provides state, acccess, and control for a waveshell server process
type MShellProc struct {
Lock *sync.Mutex
Remote *sstore.RemoteType
@ -135,7 +137,7 @@ type MShellProc struct {
// runtime
RemoteId string // can be read without a lock
Status string
ServerProc *shexec.ClientProc
ServerProc *shexec.ClientProc // the server process
UName string
Err error
ErrNoInitPk bool
@ -154,9 +156,18 @@ type MShellProc struct {
InstallCancelFn context.CancelFunc
InstallErr error
// for synthetic commands (not run through RunCommand), this provides a way for them
// to register to receive input events from the frontend (e.g. ReInit)
CommandInputMap map[base.CommandKey]CommandInputSink
RunningCmds map[base.CommandKey]*RunCmdType
PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name]
Client *ssh.Client
PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name] (in progress commands that might update the state)
Client *ssh.Client
}
type CommandInputSink interface {
HandleInput(feInput *scpacket.FeInputPacketType) error
}
type RunCmdType struct {
@ -169,6 +180,22 @@ type RunCmdType struct {
EphCancled atomic.Bool // only for Ephemeral commands, if true, then the command result should be discarded
}
type ReinitCommandSink struct {
Remote *MShellProc
ReqId string
}
func (rcs *ReinitCommandSink) HandleInput(feInput *scpacket.FeInputPacketType) error {
realData, err := base64.StdEncoding.DecodeString(feInput.InputData64)
if err != nil {
return fmt.Errorf("error decoding input data: %v", err)
}
inputPk := packet.MakeRpcInputPacket(rcs.ReqId)
inputPk.Data = realData
rcs.Remote.ServerProc.Input.SendPacket(inputPk)
return nil
}
type RemoteRuntimeState = sstore.RemoteRuntimeState
func CanComplete(remoteType string) bool {
@ -196,7 +223,7 @@ func (msh *MShellProc) EnsureShellType(ctx context.Context, shellType string) er
return nil
}
// try to reinit the shell
_, err := msh.ReInit(ctx, shellType, nil, false)
_, err := msh.ReInit(ctx, base.CommandKey(""), shellType, nil, false)
if err != nil {
return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err)
}
@ -694,6 +721,7 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
Status: StatusDisconnected,
PtyBuffer: buf,
InstallStatus: StatusDisconnected,
CommandInputMap: make(map[base.CommandKey]CommandInputSink),
RunningCmds: make(map[base.CommandKey]*RunCmdType),
PendingStateCmds: make(map[pendingStateKey]base.CommandKey),
StateMap: server.MakeShellStateMap(),
@ -1401,7 +1429,7 @@ func makeReinitErrorUpdate(shellType string) sstore.ActivityUpdate {
return rtn
}
func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) {
func (msh *MShellProc) ReInit(ctx context.Context, ck base.CommandKey, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) {
if !msh.IsConnected() {
return nil, fmt.Errorf("cannot reinit, remote is not connected")
}
@ -1425,6 +1453,14 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func
return nil, err
}
defer rpcIter.Close()
if ck != "" {
reinitSink := &ReinitCommandSink{
Remote: msh,
ReqId: reinitPk.ReqId,
}
msh.registerInputSink(ck, reinitSink)
defer msh.unregisterInputSink(ck)
}
var ssPk *packet.ShellStatePacketType
for {
resp, err := rpcIter.Next(ctx)
@ -1506,6 +1542,9 @@ func addScVarsToState(state *packet.ShellState) *packet.ShellState {
envMap["WAVETERM_VERSION"] = &shellenv.DeclareDeclType{Name: "WAVETERM_VERSION", Value: scbase.WaveVersion, Args: "x"}
envMap["TERM_PROGRAM"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM", Value: "waveterm", Args: "x"}
envMap["TERM_PROGRAM_VERSION"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM_VERSION", Value: scbase.WaveVersion, Args: "x"}
if scbase.IsDevMode() {
envMap["WAVETERM_DEV"] = &shellenv.DeclareDeclType{Name: "WAVETERM_DEV", Value: "1", Args: "x"}
}
if _, exists := envMap["LANG"]; !exists {
envMap["LANG"] = &shellenv.DeclareDeclType{Name: "LANG", Value: scbase.DetermineLang(), Args: "x"}
}
@ -1727,20 +1766,28 @@ func (msh *MShellProc) Launch(interactive bool) {
}
func (msh *MShellProc) initActiveShells() {
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
gasCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
activeShells, err := msh.getActiveShellTypes(ctx)
activeShells, err := msh.getActiveShellTypes(gasCtx)
if err != nil {
// we're not going to fail the connect for this error (it will be unusable, but technically connected)
msh.WriteToPtyBuffer("*error getting active shells: %v\n", err)
return
}
for _, shellType := range activeShells {
_, err = msh.ReInit(ctx, shellType, nil, false)
if err != nil {
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
}
var wg sync.WaitGroup
for _, shellTypeForVar := range activeShells {
wg.Add(1)
go func(shellType string) {
defer wg.Done()
reinitCtx, cancelFn := context.WithTimeout(context.Background(), shellapi.ReInitTimeout)
defer cancelFn()
_, err = msh.ReInit(reinitCtx, base.CommandKey(""), shellType, nil, false)
if err != nil {
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
}
}(shellTypeForVar)
}
wg.Wait()
}
func (msh *MShellProc) IsConnected() bool {
@ -1775,24 +1822,14 @@ func (msh *MShellProc) IsCmdRunning(ck base.CommandKey) bool {
return ok
}
func (msh *MShellProc) SendInput(dataPk *packet.DataPacketType) error {
if !msh.IsConnected() {
return fmt.Errorf("remote is not connected, cannot send input")
}
if !msh.IsCmdRunning(dataPk.CK) {
return fmt.Errorf("cannot send input, cmd is not running")
}
return msh.ServerProc.Input.SendPacket(dataPk)
}
func (msh *MShellProc) KillRunningCommandAndWait(ctx context.Context, ck base.CommandKey) error {
if !msh.IsCmdRunning(ck) {
return nil
}
siPk := packet.MakeSpecialInputPacket()
siPk.CK = ck
siPk.SigName = "SIGTERM"
err := msh.SendSpecialInput(siPk)
feiPk := scpacket.MakeFeInputPacket()
feiPk.CK = ck
feiPk.SigName = "SIGTERM"
err := msh.HandleFeInput(feiPk)
if err != nil {
return fmt.Errorf("error trying to kill running cmd: %w", err)
}
@ -1809,16 +1846,6 @@ func (msh *MShellProc) KillRunningCommandAndWait(ctx context.Context, ck base.Co
}
}
func (msh *MShellProc) SendSpecialInput(siPk *packet.SpecialInputPacketType) error {
if !msh.IsConnected() {
return fmt.Errorf("remote is not connected, cannot send input")
}
if !msh.IsCmdRunning(siPk.CK) {
return fmt.Errorf("cannot send input, cmd is not running")
}
return msh.ServerProc.Input.SendPacket(siPk)
}
func (msh *MShellProc) SendFileData(dataPk *packet.FileDataPacketType) error {
if !msh.IsConnected() {
return fmt.Errorf("remote is not connected, cannot send input")
@ -2065,6 +2092,62 @@ func makePSCLineError(existingPSC base.CommandKey, line *sstore.LineType, lineEr
return fmt.Errorf("cannot run command while a stateful command (linenum=%d) is still running", line.LineNum)
}
func (msh *MShellProc) registerInputSink(ck base.CommandKey, sink CommandInputSink) {
msh.Lock.Lock()
defer msh.Lock.Unlock()
msh.CommandInputMap[ck] = sink
}
func (msh *MShellProc) unregisterInputSink(ck base.CommandKey) {
msh.Lock.Lock()
defer msh.Lock.Unlock()
delete(msh.CommandInputMap, ck)
}
func (msh *MShellProc) HandleFeInput(inputPk *scpacket.FeInputPacketType) error {
if inputPk == nil {
return nil
}
if !msh.IsConnected() {
return fmt.Errorf("connection is not connected, cannot send input")
}
if msh.IsCmdRunning(inputPk.CK) {
if len(inputPk.InputData64) > 0 {
inputLen := packet.B64DecodedLen(inputPk.InputData64)
if inputLen > MaxInputDataSize {
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
}
dataPk := packet.MakeDataPacket()
dataPk.CK = inputPk.CK
dataPk.FdNum = 0 // stdin
dataPk.Data64 = inputPk.InputData64
err := msh.ServerProc.Input.SendPacket(dataPk)
if err != nil {
return err
}
}
if inputPk.SigName != "" || inputPk.WinSize != nil {
siPk := packet.MakeSpecialInputPacket()
siPk.CK = inputPk.CK
siPk.SigName = inputPk.SigName
siPk.WinSize = inputPk.WinSize
err := msh.ServerProc.Input.SendPacket(siPk)
if err != nil {
return err
}
}
return nil
}
msh.Lock.Lock()
sink := msh.CommandInputMap[inputPk.CK]
msh.Lock.Unlock()
if sink == nil {
// no sink and no running command
return fmt.Errorf("cannot send input, cmd is not running")
}
return sink.HandleInput(inputPk)
}
func (msh *MShellProc) AddRunningCmd(rct *RunCmdType) {
msh.Lock.Lock()
defer msh.Lock.Unlock()

View File

@ -115,6 +115,7 @@ var IgnoreVars = map[string]bool{
"WAVESHELL_VERSION": true,
"WAVETERM": true,
"WAVETERM_VERSION": true,
"WAVETERM_DEV": true,
"TERM_PROGRAM": true,
"TERM_PROGRAM_VERSION": true,
"TERM_SESSION_ID": true,

View File

@ -4,6 +4,7 @@
package scpacket
import (
"encoding/base64"
"fmt"
"reflect"
"regexp"
@ -191,6 +192,14 @@ func MakeFeInputPacket() *FeInputPacketType {
return &FeInputPacketType{Type: FeInputPacketStr}
}
func (pk *FeInputPacketType) DecodeData() ([]byte, error) {
return base64.StdEncoding.DecodeString(pk.InputData64)
}
func (pk *FeInputPacketType) SetData(data []byte) {
pk.InputData64 = base64.StdEncoding.EncodeToString(data)
}
func (*WatchScreenPacketType) GetType() string {
return WatchScreenPacketStr
}

View File

@ -23,7 +23,6 @@ import (
)
const WSStatePacketChSize = 20
const MaxInputDataSize = 1000
const RemoteInputQueueSize = 100
var RemoteInputMapQueue *mapqueue.MapQueue
@ -247,7 +246,7 @@ func (ws *WSState) processMessage(msgBytes []byte) error {
err := RemoteInputMapQueue.Enqueue(feInputPk.Remote.RemoteId, func() {
sendErr := sendCmdInput(feInputPk)
if sendErr != nil {
log.Printf("[scws] sending command input: %v\n", err)
log.Printf("[scws] sending command input: %v\n", sendErr)
}
})
if err != nil {
@ -263,7 +262,7 @@ func (ws *WSState) processMessage(msgBytes []byte) error {
go func() {
sendErr := remote.SendRemoteInput(inputPk)
if sendErr != nil {
log.Printf("[scws] error processing remote input: %v\n", err)
log.Printf("[scws] error processing remote input: %v\n", sendErr)
}
}()
return nil
@ -319,29 +318,5 @@ func sendCmdInput(pk *scpacket.FeInputPacketType) error {
if msh == nil {
return fmt.Errorf("remote %s not found", pk.Remote.RemoteId)
}
if len(pk.InputData64) > 0 {
inputLen := packet.B64DecodedLen(pk.InputData64)
if inputLen > MaxInputDataSize {
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
}
dataPk := packet.MakeDataPacket()
dataPk.CK = pk.CK
dataPk.FdNum = 0 // stdin
dataPk.Data64 = pk.InputData64
err = msh.SendInput(dataPk)
if err != nil {
return err
}
}
if pk.SigName != "" || pk.WinSize != nil {
siPk := packet.MakeSpecialInputPacket()
siPk.CK = pk.CK
siPk.SigName = pk.SigName
siPk.WinSize = pk.WinSize
err = msh.SendSpecialInput(siPk)
if err != nil {
return err
}
}
return nil
return msh.HandleFeInput(pk)
}