more zsh reinitialization fixes (allow user input during initialization process) (#480)

* fix error logs in scws

* new RpcFollowUpPacketType

* make the rpc/followup handlers generic on the server side -- using new RpcHandlers map and RpcFollowUpPacketType

* rpcinputpacket for passing user input back through to reinit command

* add WAVETERM_DEV env var in dev mode

* remove unused code, ensure mshell and rcfile directory on startup (prevent root clobber with sudo)

* combine all feinput into one function msh.HandleFeInput, and add a new concept of input sinks for special cases (like reinit)

* allow reset to accept user input (to get around interactive initialization problems)

* tone down the selection background highlight color on dark mode.  easier to read selected text

* fix command focus and done focus issues with dynamic (non-run) commands

* add 'module' as a 'rtnstate' command (#478)

* reinitialize shells in parallel, fix timeouts, better error messages
This commit is contained in:
Mike Sawka 2024-03-20 23:38:05 -07:00 committed by GitHub
parent fb59e094e4
commit 0781e6e821
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 384 additions and 225 deletions

View File

@ -27,6 +27,6 @@
--term-cmdtext: #ffffff; --term-cmdtext: #ffffff;
--term-foreground: #d3d7cf; --term-foreground: #d3d7cf;
--term-background: #000000; --term-background: #000000;
--term-selection-background: #ffffff90; --term-selection-background: #ffffff60;
--term-cursor-accent: #000000; --term-cursor-accent: #000000;
} }

View File

@ -135,6 +135,7 @@ func main() {
base.ProcessType = base.ProcessType_WaveShellServer base.ProcessType = base.ProcessType_WaveShellServer
wlog.GlobalSubsystem = base.ProcessType_WaveShellServer wlog.GlobalSubsystem = base.ProcessType_WaveShellServer
base.InitDebugLog("server") base.InitDebugLog("server")
base.EnsureRcFilesDir()
rtnCode, err := server.RunServer() rtnCode, err := server.RunServer()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err) fmt.Fprintf(os.Stderr, "[error] %v\n", err)

View File

@ -12,7 +12,6 @@ import (
"os" "os"
"os/exec" "os/exec"
"path" "path"
"path/filepath"
"strings" "strings"
"sync" "sync"
@ -72,11 +71,11 @@ func IsWaveSrv() bool {
return ProcessType == ProcessType_WaveSrv return ProcessType == ProcessType_WaveSrv
} }
func MakeCommandKey(sessionId string, cmdId string) CommandKey { func MakeCommandKey(screenId string, lineId string) CommandKey {
if sessionId == "" && cmdId == "" { if screenId == "" && lineId == "" {
return CommandKey("") return CommandKey("")
} }
return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId)) return CommandKey(fmt.Sprintf("%s/%s", screenId, lineId))
} }
func (ckey CommandKey) IsEmpty() bool { func (ckey CommandKey) IsEmpty() bool {
@ -200,51 +199,6 @@ func GetMShellHomeDir() string {
return ExpandHomeDir(DefaultMShellHome) return ExpandHomeDir(DefaultMShellHome)
} }
func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
if err := ck.Validate("ck"); err != nil {
return nil, fmt.Errorf("cannot get command files: %w", err)
}
sessionId, cmdId := ck.Split()
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return nil, err
}
base := path.Join(sdir, cmdId)
return &CommandFileNames{
PtyOutFile: base + ".ptyout",
StdinFifo: base + ".stdin",
RunnerOutFile: base + ".runout",
}, nil
}
func CleanUpCmdFiles(sessionId string, cmdId string) error {
if cmdId == "" {
return fmt.Errorf("bad cmdid, cannot clean up")
}
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return err
}
cmdFileGlob := path.Join(sdir, cmdId+".*")
matches, err := filepath.Glob(cmdFileGlob)
if err != nil {
return err
}
for _, file := range matches {
rmErr := os.Remove(file)
if err == nil && rmErr != nil {
err = rmErr
}
}
return err
}
func GetSessionsDir() string {
mhome := GetMShellHomeDir()
sdir := path.Join(mhome, SessionsDirBaseName)
return sdir
}
func EnsureRcFilesDir() (string, error) { func EnsureRcFilesDir() (string, error) {
mhome := GetMShellHomeDir() mhome := GetMShellHomeDir()
dirName := path.Join(mhome, RcFilesDirBaseName) dirName := path.Join(mhome, RcFilesDirBaseName)
@ -255,19 +209,6 @@ func EnsureRcFilesDir() (string, error) {
return dirName, nil return dirName, nil
} }
func EnsureSessionDir(sessionId string) (string, error) {
if sessionId == "" {
return "", fmt.Errorf("Bad sessionid, cannot be empty")
}
mhome := GetMShellHomeDir()
sdir := path.Join(mhome, SessionsDirBaseName, sessionId)
err := CacheEnsureDir(sdir, sessionId, 0777, "mshell session dir")
if err != nil {
return "", err
}
return sdir, nil
}
func GetMShellPath() (string, error) { func GetMShellPath() (string, error) {
msPath := os.Getenv(MShellPathVarName) // use MSHELL_PATH msPath := os.Getenv(MShellPathVarName) // use MSHELL_PATH
if msPath != "" { if msPath != "" {
@ -282,11 +223,6 @@ func GetMShellPath() (string, error) {
return exec.LookPath(DefaultMShellName) // standard path lookup for 'mshell' return exec.LookPath(DefaultMShellName) // standard path lookup for 'mshell'
} }
func GetMShellSessionsDir() (string, error) {
mhome := GetMShellHomeDir()
return path.Join(mhome, SessionsDirBaseName), nil
}
func ExpandHomeDir(pathStr string) string { func ExpandHomeDir(pathStr string) string {
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") { if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
return pathStr return pathStr
@ -315,22 +251,6 @@ func GoArchOptFile(version string, goos string, goarch string) string {
return fmt.Sprintf(path.Join(installBinDir, binBaseName)) return fmt.Sprintf(path.Join(installBinDir, binBaseName))
} }
func MShellBinaryFromOptDir(version string, goos string, goarch string) (io.ReadCloser, error) {
if !ValidGoArch(goos, goarch) {
return nil, fmt.Errorf("invalid goos/goarch combination: %s/%s", goos, goarch)
}
versionStr := semver.MajorMinor(version)
if versionStr == "" {
return nil, fmt.Errorf("invalid mshell version: %q", version)
}
fileName := GoArchOptFile(version, goos, goarch)
fd, err := os.Open(fileName)
if err != nil {
return nil, fmt.Errorf("cannot open mshell binary %q: %v", fileName, err)
}
return fd, nil
}
func GetRemoteId() (string, error) { func GetRemoteId() (string, error) {
mhome := GetMShellHomeDir() mhome := GetMShellHomeDir()
homeInfo, err := os.Stat(mhome) homeInfo, err := os.Stat(mhome)

View File

@ -12,6 +12,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"log"
"os" "os"
"reflect" "reflect"
"sync" "sync"
@ -63,6 +64,7 @@ const (
FileStatPacketStr = "filestat" FileStatPacketStr = "filestat"
LogPacketStr = "log" // logging packet (sent from waveshell back to server) LogPacketStr = "log" // logging packet (sent from waveshell back to server)
ShellStatePacketStr = "shellstate" ShellStatePacketStr = "shellstate"
RpcInputPacketStr = "rpcinput" // rpc-followup
OpenAIPacketStr = "openai" // other OpenAIPacketStr = "openai" // other
OpenAICloudReqStr = "openai-cloudreq" OpenAICloudReqStr = "openai-cloudreq"
@ -116,6 +118,7 @@ func init() {
TypeStrToFactory[LogPacketStr] = reflect.TypeOf(LogPacketType{}) TypeStrToFactory[LogPacketStr] = reflect.TypeOf(LogPacketType{})
TypeStrToFactory[ShellStatePacketStr] = reflect.TypeOf(ShellStatePacketType{}) TypeStrToFactory[ShellStatePacketStr] = reflect.TypeOf(ShellStatePacketType{})
TypeStrToFactory[FileStatPacketStr] = reflect.TypeOf(FileStatPacketType{}) TypeStrToFactory[FileStatPacketStr] = reflect.TypeOf(FileStatPacketType{})
TypeStrToFactory[RpcInputPacketStr] = reflect.TypeOf(RpcInputPacketType{})
var _ RpcPacketType = (*RunPacketType)(nil) var _ RpcPacketType = (*RunPacketType)(nil)
var _ RpcPacketType = (*GetCmdPacketType)(nil) var _ RpcPacketType = (*GetCmdPacketType)(nil)
@ -134,6 +137,9 @@ func init() {
var _ RpcResponsePacketType = (*WriteFileDonePacketType)(nil) var _ RpcResponsePacketType = (*WriteFileDonePacketType)(nil)
var _ RpcResponsePacketType = (*ShellStatePacketType)(nil) var _ RpcResponsePacketType = (*ShellStatePacketType)(nil)
var _ RpcFollowUpPacketType = (*FileDataPacketType)(nil)
var _ RpcFollowUpPacketType = (*RpcInputPacketType)(nil)
var _ CommandPacketType = (*DataPacketType)(nil) var _ CommandPacketType = (*DataPacketType)(nil)
var _ CommandPacketType = (*DataAckPacketType)(nil) var _ CommandPacketType = (*DataAckPacketType)(nil)
var _ CommandPacketType = (*CmdDonePacketType)(nil) var _ CommandPacketType = (*CmdDonePacketType)(nil)
@ -166,6 +172,26 @@ func MakePingPacket() *PingPacketType {
return &PingPacketType{Type: PingPacketStr} return &PingPacketType{Type: PingPacketStr}
} }
type RpcInputPacketType struct {
Type string `json:"type"`
ReqId string `json:"reqid"`
Data []byte `json:"data"`
}
func (*RpcInputPacketType) GetType() string {
return RpcInputPacketStr
}
func (p *RpcInputPacketType) GetAssociatedReqId() string {
return p.ReqId
}
func MakeRpcInputPacket(reqId string) *RpcInputPacketType {
return &RpcInputPacketType{Type: RpcInputPacketStr, ReqId: reqId}
}
// these packets can travel either direction
// so it is both a RpcResponsePacketType and an RpcFollowUpPacketType
type FileDataPacketType struct { type FileDataPacketType struct {
Type string `json:"type"` Type string `json:"type"`
RespId string `json:"respid"` RespId string `json:"respid"`
@ -185,6 +211,10 @@ func MakeFileDataPacket(reqId string) *FileDataPacketType {
} }
} }
func (p *FileDataPacketType) GetAssociatedReqId() string {
return p.RespId
}
func (p *FileDataPacketType) GetResponseId() string { func (p *FileDataPacketType) GetResponseId() string {
return p.RespId return p.RespId
} }
@ -976,6 +1006,12 @@ type CommandPacketType interface {
GetCK() base.CommandKey GetCK() base.CommandKey
} }
// RpcPackets initiate an Rpc. these can be part of the data passed back and forth
type RpcFollowUpPacketType interface {
GetType() string
GetAssociatedReqId() string
}
type ModelUpdatePacketType struct { type ModelUpdatePacketType struct {
Type string `json:"type"` Type string `json:"type"`
Updates []any `json:"updates"` Updates []any `json:"updates"`
@ -1178,6 +1214,14 @@ func (sender *PacketSender) SendPacketCtx(ctx context.Context, pk PacketType) er
} }
func (sender *PacketSender) SendPacket(pk PacketType) error { func (sender *PacketSender) SendPacket(pk PacketType) error {
if pk == nil {
log.Printf("tried to send nil packet\n")
return fmt.Errorf("tried to send nil packet")
}
if pk.GetType() == "" {
log.Printf("tried to send invalid packet: %T\n", pk)
return fmt.Errorf("tried to send packet without a type: %T", pk)
}
err := sender.checkStatus() err := sender.checkStatus()
if err != nil { if err != nil {
return err return err

View File

@ -30,6 +30,7 @@ const MaxFileDataPacketSize = 16 * 1024
const WriteFileContextTimeout = 30 * time.Second const WriteFileContextTimeout = 30 * time.Second
const cleanLoopTime = 5 * time.Second const cleanLoopTime = 5 * time.Second
const MaxWriteFileContextData = 100 const MaxWriteFileContextData = 100
const InboundRpcErrorTimeoutTime = 30 * time.Second
type shellStateMapKey struct { type shellStateMapKey struct {
ShellType string ShellType string
@ -52,8 +53,89 @@ type MServer struct {
StateMap *ShellStateMap StateMap *ShellStateMap
WriteErrorCh chan bool // closed if there is a I/O write error WriteErrorCh chan bool // closed if there is a I/O write error
WriteErrorChOnce *sync.Once WriteErrorChOnce *sync.Once
WriteFileContextMap map[string]*WriteFileContext
Done bool Done bool
InboundRpcHandlers map[string]RpcHandler
InboundRpcErrorSent map[string]time.Time // limits the amount of error messages sent back to the client
}
var _ RpcHandler = (*WriteFileContext)(nil)
type RpcHandler interface {
GetTimeoutTime() time.Time
DispatchPacket(reqId string, pk packet.RpcFollowUpPacketType)
UnRegisterCallback()
}
func (m *MServer) registerRpcHandler(reqId string, handler RpcHandler) error {
if handler == nil {
return errors.New("handler is nil")
}
m.Lock.Lock()
defer m.Lock.Unlock()
if m.InboundRpcHandlers[reqId] != nil {
return errors.New("handler already registered")
}
delete(m.InboundRpcErrorSent, reqId)
m.InboundRpcHandlers[reqId] = handler
return nil
}
func (m *MServer) unregisterRpcHandler(reqId string) {
m.Lock.Lock()
defer m.Lock.Unlock()
handler := m.InboundRpcHandlers[reqId]
delete(m.InboundRpcHandlers, reqId)
if handler != nil {
handler.UnRegisterCallback()
}
}
// limits the number of error responses that can be sent
func (m *MServer) sendInboundRpcError(reqId string, err error) {
m.Lock.Lock()
defer m.Lock.Unlock()
if _, ok := m.InboundRpcErrorSent[reqId]; ok {
return
}
m.InboundRpcErrorSent[reqId] = time.Now().Add(InboundRpcErrorTimeoutTime)
m.Sender.SendErrorResponse(reqId, err)
}
// returns true if dispatched to a waiting RPC handler
func (m *MServer) dispatchRpcFollowUp(pk packet.RpcFollowUpPacketType) bool {
if pk == nil {
return true
}
reqId := pk.GetAssociatedReqId()
m.Lock.Lock()
defer m.Lock.Unlock()
if rh := m.InboundRpcHandlers[reqId]; rh != nil {
rh.DispatchPacket(reqId, pk)
return true
}
return false
}
func (m *MServer) cleanRpcHandlers() {
var staleHandlers []RpcHandler
now := time.Now()
m.Lock.Lock()
for reqId, rh := range m.InboundRpcHandlers {
if now.After(rh.GetTimeoutTime()) {
staleHandlers = append(staleHandlers, rh)
delete(m.InboundRpcHandlers, reqId)
}
}
for reqId, timeoutTime := range m.InboundRpcErrorSent {
if now.After(timeoutTime) {
delete(m.InboundRpcErrorSent, reqId)
}
}
m.Lock.Unlock()
// we do this outside of m.Lock just in case there is some lock contention (UnRegisterCallback might be slow)
for _, rh := range staleHandlers {
rh.UnRegisterCallback()
}
} }
type WriteFileContext struct { type WriteFileContext struct {
@ -78,25 +160,13 @@ func (m *MServer) checkDone() bool {
return m.Done return m.Done
} }
func (m *MServer) getWriteFileContext(reqId string) *WriteFileContext { func (wfc *WriteFileContext) GetTimeoutTime() time.Time {
m.Lock.Lock() return wfc.LastActive.Add(WriteFileContextTimeout)
defer m.Lock.Unlock()
wfc := m.WriteFileContextMap[reqId]
if wfc == nil {
wfc = &WriteFileContext{
CVar: sync.NewCond(&sync.Mutex{}),
LastActive: time.Now(),
}
m.WriteFileContextMap[reqId] = wfc
}
return wfc
} }
func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) { func (wfc *WriteFileContext) DispatchPacket(reqId string, pkArg packet.RpcFollowUpPacketType) {
m.Lock.Lock() dataPk, ok := pkArg.(*packet.FileDataPacketType)
wfc := m.WriteFileContextMap[pk.RespId] if !ok {
m.Lock.Unlock()
if wfc == nil {
return return
} }
wfc.CVar.L.Lock() wfc.CVar.L.Lock()
@ -111,11 +181,11 @@ func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) {
return return
} }
wfc.LastActive = time.Now() wfc.LastActive = time.Now()
wfc.Data = append(wfc.Data, pk) wfc.Data = append(wfc.Data, dataPk)
wfc.CVar.Signal() wfc.CVar.Signal()
} }
func (wfc *WriteFileContext) setDone() { func (wfc *WriteFileContext) UnRegisterCallback() {
wfc.CVar.L.Lock() wfc.CVar.L.Lock()
defer wfc.CVar.L.Unlock() defer wfc.CVar.L.Unlock()
wfc.Done = true wfc.Done = true
@ -123,24 +193,6 @@ func (wfc *WriteFileContext) setDone() {
wfc.CVar.Broadcast() wfc.CVar.Broadcast()
} }
func (m *MServer) cleanWriteFileContexts() {
now := time.Now()
var staleWfcs []*WriteFileContext
m.Lock.Lock()
for reqId, wfc := range m.WriteFileContextMap {
if now.Sub(wfc.LastActive) > WriteFileContextTimeout {
staleWfcs = append(staleWfcs, wfc)
delete(m.WriteFileContextMap, reqId)
}
}
m.Lock.Unlock()
// we do this outside of m.Lock just in case there is some lock contention (end of WriteFile could theoretically be slow)
for _, wfc := range staleWfcs {
wfc.setDone()
}
}
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) { func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
ck := pk.GetCK() ck := pk.GetCK()
if ck == "" { if ck == "" {
@ -155,7 +207,6 @@ func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
return return
} }
cproc.Input.SendPacket(pk) cproc.Input.SendPacket(pk)
return
} }
func runSingleCompGen(cwd string, compType string, prefix string) ([]string, bool, error) { func runSingleCompGen(cwd string, compType string, prefix string) ([]string, bool, error) {
@ -226,7 +277,6 @@ func (m *MServer) runMixedCompGen(compPk *packet.CompGenPacketType) {
} }
sort.Strings(comps) // resort sort.Strings(comps) // resort
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": (hasMoreFiles || hasMoreDirs)}) m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": (hasMoreFiles || hasMoreDirs)})
return
} }
func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) { func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
@ -246,10 +296,46 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore}) m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
} }
type ReinitRpcHandler struct {
ReqId string
TimeoutTime time.Time
StdinDataCh chan []byte
}
func (rh *ReinitRpcHandler) GetTimeoutTime() time.Time {
return rh.TimeoutTime
}
func (rh *ReinitRpcHandler) DispatchPacket(reqId string, pkArg packet.RpcFollowUpPacketType) {
rpcInput, ok := pkArg.(*packet.RpcInputPacketType)
if !ok {
wlog.Logf("reinit rpc handler: invalid packet type: %T", pkArg)
return
}
// nonblocking send
select {
case rh.StdinDataCh <- rpcInput.Data:
default:
wlog.Logf("reinit rpc handler: stdin data channel full, dropping data")
}
}
func (rh *ReinitRpcHandler) UnRegisterCallback() {
close(rh.StdinDataCh)
}
func (m *MServer) reinit(reqId string, shellType string) { func (m *MServer) reinit(reqId string, shellType string) {
ssPk, err := m.MakeShellStatePacket(reqId, shellType) stdinDataCh := make(chan []byte, 10)
rh := &ReinitRpcHandler{
ReqId: reqId,
TimeoutTime: time.Now().Add(30 * time.Second),
StdinDataCh: stdinDataCh,
}
m.registerRpcHandler(reqId, rh)
defer m.unregisterRpcHandler(reqId)
ssPk, err := m.MakeShellStatePacket(reqId, shellType, stdinDataCh)
if err != nil { if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err)) m.Sender.SendErrorResponse(reqId, fmt.Errorf("error initializing shell: %w", err))
return return
} }
err = m.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State) err = m.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
@ -261,13 +347,13 @@ func (m *MServer) reinit(reqId string, shellType string) {
m.Sender.SendPacket(ssPk) m.Sender.SendPacket(ssPk)
} }
func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) { func (m *MServer) MakeShellStatePacket(reqId string, shellType string, stdinDataCh chan []byte) (*packet.ShellStatePacketType, error) {
sapi, err := shellapi.MakeShellApi(shellType) sapi, err := shellapi.MakeShellApi(shellType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rtnCh := make(chan shellapi.ShellStateOutput, 1) rtnCh := make(chan shellapi.ShellStateOutput, 1)
go sapi.GetShellState(rtnCh) go sapi.GetShellState(rtnCh, stdinDataCh)
for ssOutput := range rtnCh { for ssOutput := range rtnCh {
if ssOutput.Error != "" { if ssOutput.Error != "" {
return nil, errors.New(ssOutput.Error) return nil, errors.New(ssOutput.Error)
@ -357,7 +443,7 @@ func copyFile(dstName string, srcName string) error {
} }
func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContext) { func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContext) {
defer wfc.setDone() defer m.unregisterRpcHandler(pk.ReqId)
if pk.Path == "" { if pk.Path == "" {
resp := packet.MakeWriteFileReadyPacket(pk.ReqId) resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
resp.Error = "invalid write-file request, no path specified" resp.Error = "invalid write-file request, no path specified"
@ -478,7 +564,6 @@ func (m *MServer) returnStreamFileNewFileResponse(pk *packet.StreamFilePacketTyp
Perm: int(dirInfo.Mode().Perm()), Perm: int(dirInfo.Mode().Perm()),
NotFound: true, NotFound: true,
} }
return
} }
func (m *MServer) streamFile(pk *packet.StreamFilePacketType) { func (m *MServer) streamFile(pk *packet.StreamFilePacketType) {
@ -608,12 +693,19 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
return return
} }
if writePk, ok := pk.(*packet.WriteFilePacketType); ok { if writePk, ok := pk.(*packet.WriteFilePacketType); ok {
wfc := m.getWriteFileContext(writePk.ReqId) wfc := &WriteFileContext{
CVar: sync.NewCond(&sync.Mutex{}),
LastActive: time.Now(),
}
err := m.registerRpcHandler(writePk.ReqId, wfc)
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error registering write-file handler: %w", err))
return
}
go m.writeFile(writePk, wfc) go m.writeFile(writePk, wfc)
return return
} }
m.Sender.SendErrorResponse(reqId, fmt.Errorf("invalid rpc type '%s'", pk.GetType())) m.Sender.SendErrorResponse(reqId, fmt.Errorf("invalid rpc type '%s'", pk.GetType()))
return
} }
func (m *MServer) clientPacketCallback(shellType string, pk packet.PacketType) { func (m *MServer) clientPacketCallback(shellType string, pk packet.PacketType) {
@ -726,7 +818,7 @@ func (server *MServer) runReadLoop() {
builder := packet.MakeRunPacketBuilder() builder := packet.MakeRunPacketBuilder()
for pk := range server.MainInput.MainCh { for pk := range server.MainInput.MainCh {
if server.Debug { if server.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk)) wlog.Logf("runReadLoop got packet %s\n", packet.AsString(pk))
} }
ok, runPacket := builder.ProcessPacket(pk) ok, runPacket := builder.ProcessPacket(pk)
if ok { if ok {
@ -736,16 +828,19 @@ func (server *MServer) runReadLoop() {
} }
continue continue
} }
if cmdPk, ok := pk.(packet.CommandPacketType); ok { if cmdPk, ok := pk.(packet.CommandPacketType); ok && cmdPk.GetCK() != "" {
server.ProcessCommandPacket(cmdPk) server.ProcessCommandPacket(cmdPk)
continue continue
} }
if rpcPk, ok := pk.(packet.RpcPacketType); ok { if rpcPk, ok := pk.(packet.RpcPacketType); ok && rpcPk.GetReqId() != "" {
server.ProcessRpcPacket(rpcPk) server.ProcessRpcPacket(rpcPk)
continue continue
} }
if fileDataPk, ok := pk.(*packet.FileDataPacketType); ok { if rpcFollowUp, ok := pk.(packet.RpcFollowUpPacketType); ok && rpcFollowUp.GetAssociatedReqId() != "" {
server.addFileDataPacket(fileDataPk) ok := server.dispatchRpcFollowUp(rpcFollowUp)
if !ok {
server.sendInboundRpcError(rpcFollowUp.GetAssociatedReqId(), fmt.Errorf("no handler for rpc follow-up packet"))
}
continue continue
} }
server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk)) server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk))
@ -765,7 +860,8 @@ func RunServer() (int, error) {
Debug: debug, Debug: debug,
WriteErrorCh: make(chan bool), WriteErrorCh: make(chan bool),
WriteErrorChOnce: &sync.Once{}, WriteErrorChOnce: &sync.Once{},
WriteFileContextMap: make(map[string]*WriteFileContext), InboundRpcHandlers: make(map[string]RpcHandler),
InboundRpcErrorSent: make(map[string]time.Time),
} }
if debug { if debug {
packet.GlobalDebug = true packet.GlobalDebug = true
@ -780,7 +876,7 @@ func RunServer() (int, error) {
return return
} }
time.Sleep(cleanLoopTime) time.Sleep(cleanLoopTime)
server.cleanWriteFileContexts() server.cleanRpcHandlers()
} }
}() }()
var err error var err error

View File

@ -80,8 +80,8 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
return MakeBashShExecCommand(cmdStr, rcFileName, usePty) return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
} }
func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) { func (b bashShellApi) GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) {
GetBashShellState(outCh) GetBashShellState(outCh, stdinDataCh)
} }
func (b bashShellApi) GetBaseShellOpts() string { func (b bashShellApi) GetBaseShellOpts() string {
@ -169,7 +169,7 @@ func GetLocalBashMajorVersion() string {
return localBashMajorVersion return localBashMajorVersion
} }
func GetBashShellState(outCh chan ShellStateOutput) { func GetBashShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout) ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn() defer cancelFn()
defer close(outCh) defer close(outCh)
@ -185,7 +185,7 @@ func GetBashShellState(outCh chan ShellStateOutput) {
outCh <- ShellStateOutput{Output: outputBytes} outCh <- ShellStateOutput{Output: outputBytes}
} }
}() }()
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes) outputBytes, err := StreamCommandWithExtraFd(ctx, ecmd, outputCh, StateOutputFdNum, endBytes, stdinDataCh)
outputWg.Wait() outputWg.Wait()
if err != nil { if err != nil {
outCh <- ShellStateOutput{Error: err.Error()} outCh <- ShellStateOutput{Error: err.Error()}

View File

@ -25,9 +25,11 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn" "github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/wlog"
) )
const GetStateTimeout = 15 * time.Second const GetStateTimeout = 10 * time.Second
const ReInitTimeout = GetStateTimeout + 2*time.Second
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"` const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"` const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"` const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
@ -71,7 +73,7 @@ type ShellApi interface {
GetRemoteShellPath() string GetRemoteShellPath() string
MakeRunCommand(cmdStr string, opts RunCommandOpts) string MakeRunCommand(cmdStr string, opts RunCommandOpts) string
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
GetShellState(chan ShellStateOutput) GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte)
GetBaseShellOpts() string GetBaseShellOpts() string
ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error) ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error)
MakeRcFileStr(pk *packet.RunPacketType) string MakeRcFileStr(pk *packet.RunPacketType) string
@ -154,7 +156,7 @@ func internalMacUserShell() string {
const FirstExtraFilesFdNum = 3 const FirstExtraFilesFdNum = 3
// returns output(stdout+stderr), extraFdOutput, error // returns output(stdout+stderr), extraFdOutput, error
func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte) ([]byte, error) { func StreamCommandWithExtraFd(ctx context.Context, ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte, stdinDataCh chan []byte) ([]byte, error) {
defer close(outputCh) defer close(outputCh)
ecmd.Env = os.Environ() ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType)) shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
@ -202,8 +204,27 @@ func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum i
defer outputWg.Done() defer outputWg.Done()
utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes) utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes)
}() }()
if stdinDataCh != nil {
go func() {
// continue this loop even after an error to drain stdinDataCh
hadErr := false
for stdinData := range stdinDataCh {
if hadErr {
continue
}
_, err := cmdPty.Write(stdinData)
if err != nil {
wlog.Logf("error writing to shellstate cmdpty (stdin): %v\n", err)
hadErr = true
}
}
}()
}
exitErr := ecmd.Wait() exitErr := ecmd.Wait()
if exitErr != nil { if exitErr != nil {
if ctx.Err() != nil {
return nil, fmt.Errorf("%w (%w)", ctx.Err(), exitErr)
}
return nil, exitErr return nil, exitErr
} }
outputWg.Wait() outputWg.Wait()

View File

@ -246,7 +246,7 @@ func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr) return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
} }
func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) { func (z zshShellApi) GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout) ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn() defer cancelFn()
defer close(outCh) defer close(outCh)
@ -262,7 +262,7 @@ func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) {
outCh <- ShellStateOutput{Output: outputBytes} outCh <- ShellStateOutput{Output: outputBytes}
} }
}() }()
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes) outputBytes, err := StreamCommandWithExtraFd(ctx, ecmd, outputCh, StateOutputFdNum, endBytes, stdinDataCh)
outputWg.Wait() outputWg.Wait()
if err != nil { if err != nil {
outCh <- ShellStateOutput{Error: err.Error()} outCh <- ShellStateOutput{Error: err.Error()}
@ -726,7 +726,6 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
// sections: see ZshSection_* consts // sections: see ZshSection_* consts
sections := bytes.Split(outputBytes, sectionSeparator) sections := bytes.Split(outputBytes, sectionSeparator)
if len(sections) != ZshSection_NumFieldsExpected { if len(sections) != ZshSection_NumFieldsExpected {
base.Logf("invalid -- numfields\n")
return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections)) return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
} }
rtn := &packet.ShellState{} rtn := &packet.ShellState{}

View File

@ -30,6 +30,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/server" "github.com/wavetermdev/waveterm/waveshell/pkg/server"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv" "github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec" "github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
@ -1186,9 +1187,17 @@ func deferWriteCmdStatus(ctx context.Context, cmd *sstore.CmdType, startTime tim
err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus) err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus)
if err != nil { if err != nil {
// nothing to do // nothing to do
log.Printf("error updating cmddoneinfo (in openai): %v\n", err) log.Printf("error updating cmddoneinfo: %v\n", err)
return return
} }
screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, cmd.ScreenId, cmd.LineId)
if err != nil {
log.Printf("error trying to update screen focus type: %v\n", err)
// fall-through (nothing to do)
}
if screen != nil {
update.AddUpdate(*screen)
}
scbus.MainUpdateBus.DoScreenUpdate(cmd.ScreenId, update) scbus.MainUpdateBus.DoScreenUpdate(cmd.ScreenId, update)
} }
@ -3686,7 +3695,7 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
if err != nil { if err != nil {
return nil, err return nil, err
} }
update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil) update, err := addLineForCmd(ctx, "/reset", true, ids, cmd, "", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -3695,7 +3704,7 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
} }
func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) { func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) {
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancelFn := context.WithTimeout(context.Background(), shellapi.ReInitTimeout)
defer cancelFn() defer cancelFn()
startTime := time.Now() startTime := time.Now()
var outputPos int64 var outputPos int64
@ -3712,7 +3721,7 @@ func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verb
writeStringToPty(ctx, cmd, string(data), &outputPos) writeStringToPty(ctx, cmd, string(data), &outputPos)
} }
origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType) origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType, dataFn, verbose) ssPk, err := ids.Remote.MShell.ReInit(ctx, base.MakeCommandKey(cmd.ScreenId, cmd.LineId), shellType, dataFn, verbose)
if err != nil { if err != nil {
rtnErr = err rtnErr = err
return return
@ -3975,14 +3984,14 @@ func splitLinesForInfo(str string) []string {
} }
func resizeRunningCommand(ctx context.Context, cmd *sstore.CmdType, newCols int) error { func resizeRunningCommand(ctx context.Context, cmd *sstore.CmdType, newCols int) error {
siPk := packet.MakeSpecialInputPacket() feInput := scpacket.MakeFeInputPacket()
siPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId) feInput.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
siPk.WinSize = &packet.WinSize{Rows: int(cmd.TermOpts.Rows), Cols: newCols} feInput.WinSize = &packet.WinSize{Rows: int(cmd.TermOpts.Rows), Cols: newCols}
msh := remote.GetRemoteById(cmd.Remote.RemoteId) msh := remote.GetRemoteById(cmd.Remote.RemoteId)
if msh == nil { if msh == nil {
return fmt.Errorf("cannot resize, cmd remote not found") return fmt.Errorf("cannot resize, cmd remote not found")
} }
err := msh.SendSpecialInput(siPk) err := msh.HandleFeInput(feInput)
if err != nil { if err != nil {
return err return err
} }
@ -5178,10 +5187,10 @@ func SignalCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus
if !msh.IsConnected() { if !msh.IsConnected() {
return nil, fmt.Errorf("cannot send signal, remote is not connected") return nil, fmt.Errorf("cannot send signal, remote is not connected")
} }
siPk := packet.MakeSpecialInputPacket() inputPk := scpacket.MakeFeInputPacket()
siPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId) inputPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
siPk.SigName = sigArg inputPk.SigName = sigArg
err = msh.SendSpecialInput(siPk) err = msh.HandleFeInput(inputPk)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot send signal: %v", err) return nil, fmt.Errorf("cannot send signal: %v", err)
} }

View File

@ -215,6 +215,7 @@ var literalRtnStateCommands = []string{
"disable", "disable",
"function", "function",
"zmodload", "zmodload",
"module",
} }
func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string { func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string {

View File

@ -54,6 +54,7 @@ const RemoteTermCols = 80
const PtyReadBufSize = 100 const PtyReadBufSize = 100
const RemoteConnectTimeout = 15 * time.Second const RemoteConnectTimeout = 15 * time.Second
const RpcIterChannelSize = 100 const RpcIterChannelSize = 100
const MaxInputDataSize = 1000
var envVarsToStrip map[string]bool = map[string]bool{ var envVarsToStrip map[string]bool = map[string]bool{
"PROMPT": true, "PROMPT": true,
@ -128,6 +129,7 @@ type pendingStateKey struct {
RemotePtr sstore.RemotePtrType RemotePtr sstore.RemotePtrType
} }
// provides state, acccess, and control for a waveshell server process
type MShellProc struct { type MShellProc struct {
Lock *sync.Mutex Lock *sync.Mutex
Remote *sstore.RemoteType Remote *sstore.RemoteType
@ -135,7 +137,7 @@ type MShellProc struct {
// runtime // runtime
RemoteId string // can be read without a lock RemoteId string // can be read without a lock
Status string Status string
ServerProc *shexec.ClientProc ServerProc *shexec.ClientProc // the server process
UName string UName string
Err error Err error
ErrNoInitPk bool ErrNoInitPk bool
@ -154,9 +156,18 @@ type MShellProc struct {
InstallCancelFn context.CancelFunc InstallCancelFn context.CancelFunc
InstallErr error InstallErr error
// for synthetic commands (not run through RunCommand), this provides a way for them
// to register to receive input events from the frontend (e.g. ReInit)
CommandInputMap map[base.CommandKey]CommandInputSink
RunningCmds map[base.CommandKey]*RunCmdType RunningCmds map[base.CommandKey]*RunCmdType
PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name] PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name] (in progress commands that might update the state)
Client *ssh.Client
Client *ssh.Client
}
type CommandInputSink interface {
HandleInput(feInput *scpacket.FeInputPacketType) error
} }
type RunCmdType struct { type RunCmdType struct {
@ -169,6 +180,22 @@ type RunCmdType struct {
EphCancled atomic.Bool // only for Ephemeral commands, if true, then the command result should be discarded EphCancled atomic.Bool // only for Ephemeral commands, if true, then the command result should be discarded
} }
type ReinitCommandSink struct {
Remote *MShellProc
ReqId string
}
func (rcs *ReinitCommandSink) HandleInput(feInput *scpacket.FeInputPacketType) error {
realData, err := base64.StdEncoding.DecodeString(feInput.InputData64)
if err != nil {
return fmt.Errorf("error decoding input data: %v", err)
}
inputPk := packet.MakeRpcInputPacket(rcs.ReqId)
inputPk.Data = realData
rcs.Remote.ServerProc.Input.SendPacket(inputPk)
return nil
}
type RemoteRuntimeState = sstore.RemoteRuntimeState type RemoteRuntimeState = sstore.RemoteRuntimeState
func CanComplete(remoteType string) bool { func CanComplete(remoteType string) bool {
@ -196,7 +223,7 @@ func (msh *MShellProc) EnsureShellType(ctx context.Context, shellType string) er
return nil return nil
} }
// try to reinit the shell // try to reinit the shell
_, err := msh.ReInit(ctx, shellType, nil, false) _, err := msh.ReInit(ctx, base.CommandKey(""), shellType, nil, false)
if err != nil { if err != nil {
return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err) return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err)
} }
@ -694,6 +721,7 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
Status: StatusDisconnected, Status: StatusDisconnected,
PtyBuffer: buf, PtyBuffer: buf,
InstallStatus: StatusDisconnected, InstallStatus: StatusDisconnected,
CommandInputMap: make(map[base.CommandKey]CommandInputSink),
RunningCmds: make(map[base.CommandKey]*RunCmdType), RunningCmds: make(map[base.CommandKey]*RunCmdType),
PendingStateCmds: make(map[pendingStateKey]base.CommandKey), PendingStateCmds: make(map[pendingStateKey]base.CommandKey),
StateMap: server.MakeShellStateMap(), StateMap: server.MakeShellStateMap(),
@ -1401,7 +1429,7 @@ func makeReinitErrorUpdate(shellType string) sstore.ActivityUpdate {
return rtn return rtn
} }
func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) { func (msh *MShellProc) ReInit(ctx context.Context, ck base.CommandKey, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) {
if !msh.IsConnected() { if !msh.IsConnected() {
return nil, fmt.Errorf("cannot reinit, remote is not connected") return nil, fmt.Errorf("cannot reinit, remote is not connected")
} }
@ -1425,6 +1453,14 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func
return nil, err return nil, err
} }
defer rpcIter.Close() defer rpcIter.Close()
if ck != "" {
reinitSink := &ReinitCommandSink{
Remote: msh,
ReqId: reinitPk.ReqId,
}
msh.registerInputSink(ck, reinitSink)
defer msh.unregisterInputSink(ck)
}
var ssPk *packet.ShellStatePacketType var ssPk *packet.ShellStatePacketType
for { for {
resp, err := rpcIter.Next(ctx) resp, err := rpcIter.Next(ctx)
@ -1506,6 +1542,9 @@ func addScVarsToState(state *packet.ShellState) *packet.ShellState {
envMap["WAVETERM_VERSION"] = &shellenv.DeclareDeclType{Name: "WAVETERM_VERSION", Value: scbase.WaveVersion, Args: "x"} envMap["WAVETERM_VERSION"] = &shellenv.DeclareDeclType{Name: "WAVETERM_VERSION", Value: scbase.WaveVersion, Args: "x"}
envMap["TERM_PROGRAM"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM", Value: "waveterm", Args: "x"} envMap["TERM_PROGRAM"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM", Value: "waveterm", Args: "x"}
envMap["TERM_PROGRAM_VERSION"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM_VERSION", Value: scbase.WaveVersion, Args: "x"} envMap["TERM_PROGRAM_VERSION"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM_VERSION", Value: scbase.WaveVersion, Args: "x"}
if scbase.IsDevMode() {
envMap["WAVETERM_DEV"] = &shellenv.DeclareDeclType{Name: "WAVETERM_DEV", Value: "1", Args: "x"}
}
if _, exists := envMap["LANG"]; !exists { if _, exists := envMap["LANG"]; !exists {
envMap["LANG"] = &shellenv.DeclareDeclType{Name: "LANG", Value: scbase.DetermineLang(), Args: "x"} envMap["LANG"] = &shellenv.DeclareDeclType{Name: "LANG", Value: scbase.DetermineLang(), Args: "x"}
} }
@ -1727,20 +1766,28 @@ func (msh *MShellProc) Launch(interactive bool) {
} }
func (msh *MShellProc) initActiveShells() { func (msh *MShellProc) initActiveShells() {
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) gasCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn() defer cancelFn()
activeShells, err := msh.getActiveShellTypes(ctx) activeShells, err := msh.getActiveShellTypes(gasCtx)
if err != nil { if err != nil {
// we're not going to fail the connect for this error (it will be unusable, but technically connected) // we're not going to fail the connect for this error (it will be unusable, but technically connected)
msh.WriteToPtyBuffer("*error getting active shells: %v\n", err) msh.WriteToPtyBuffer("*error getting active shells: %v\n", err)
return return
} }
for _, shellType := range activeShells { var wg sync.WaitGroup
_, err = msh.ReInit(ctx, shellType, nil, false) for _, shellTypeForVar := range activeShells {
if err != nil { wg.Add(1)
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err) go func(shellType string) {
} defer wg.Done()
reinitCtx, cancelFn := context.WithTimeout(context.Background(), shellapi.ReInitTimeout)
defer cancelFn()
_, err = msh.ReInit(reinitCtx, base.CommandKey(""), shellType, nil, false)
if err != nil {
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
}
}(shellTypeForVar)
} }
wg.Wait()
} }
func (msh *MShellProc) IsConnected() bool { func (msh *MShellProc) IsConnected() bool {
@ -1775,24 +1822,14 @@ func (msh *MShellProc) IsCmdRunning(ck base.CommandKey) bool {
return ok return ok
} }
func (msh *MShellProc) SendInput(dataPk *packet.DataPacketType) error {
if !msh.IsConnected() {
return fmt.Errorf("remote is not connected, cannot send input")
}
if !msh.IsCmdRunning(dataPk.CK) {
return fmt.Errorf("cannot send input, cmd is not running")
}
return msh.ServerProc.Input.SendPacket(dataPk)
}
func (msh *MShellProc) KillRunningCommandAndWait(ctx context.Context, ck base.CommandKey) error { func (msh *MShellProc) KillRunningCommandAndWait(ctx context.Context, ck base.CommandKey) error {
if !msh.IsCmdRunning(ck) { if !msh.IsCmdRunning(ck) {
return nil return nil
} }
siPk := packet.MakeSpecialInputPacket() feiPk := scpacket.MakeFeInputPacket()
siPk.CK = ck feiPk.CK = ck
siPk.SigName = "SIGTERM" feiPk.SigName = "SIGTERM"
err := msh.SendSpecialInput(siPk) err := msh.HandleFeInput(feiPk)
if err != nil { if err != nil {
return fmt.Errorf("error trying to kill running cmd: %w", err) return fmt.Errorf("error trying to kill running cmd: %w", err)
} }
@ -1809,16 +1846,6 @@ func (msh *MShellProc) KillRunningCommandAndWait(ctx context.Context, ck base.Co
} }
} }
func (msh *MShellProc) SendSpecialInput(siPk *packet.SpecialInputPacketType) error {
if !msh.IsConnected() {
return fmt.Errorf("remote is not connected, cannot send input")
}
if !msh.IsCmdRunning(siPk.CK) {
return fmt.Errorf("cannot send input, cmd is not running")
}
return msh.ServerProc.Input.SendPacket(siPk)
}
func (msh *MShellProc) SendFileData(dataPk *packet.FileDataPacketType) error { func (msh *MShellProc) SendFileData(dataPk *packet.FileDataPacketType) error {
if !msh.IsConnected() { if !msh.IsConnected() {
return fmt.Errorf("remote is not connected, cannot send input") return fmt.Errorf("remote is not connected, cannot send input")
@ -2065,6 +2092,62 @@ func makePSCLineError(existingPSC base.CommandKey, line *sstore.LineType, lineEr
return fmt.Errorf("cannot run command while a stateful command (linenum=%d) is still running", line.LineNum) return fmt.Errorf("cannot run command while a stateful command (linenum=%d) is still running", line.LineNum)
} }
func (msh *MShellProc) registerInputSink(ck base.CommandKey, sink CommandInputSink) {
msh.Lock.Lock()
defer msh.Lock.Unlock()
msh.CommandInputMap[ck] = sink
}
func (msh *MShellProc) unregisterInputSink(ck base.CommandKey) {
msh.Lock.Lock()
defer msh.Lock.Unlock()
delete(msh.CommandInputMap, ck)
}
func (msh *MShellProc) HandleFeInput(inputPk *scpacket.FeInputPacketType) error {
if inputPk == nil {
return nil
}
if !msh.IsConnected() {
return fmt.Errorf("connection is not connected, cannot send input")
}
if msh.IsCmdRunning(inputPk.CK) {
if len(inputPk.InputData64) > 0 {
inputLen := packet.B64DecodedLen(inputPk.InputData64)
if inputLen > MaxInputDataSize {
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
}
dataPk := packet.MakeDataPacket()
dataPk.CK = inputPk.CK
dataPk.FdNum = 0 // stdin
dataPk.Data64 = inputPk.InputData64
err := msh.ServerProc.Input.SendPacket(dataPk)
if err != nil {
return err
}
}
if inputPk.SigName != "" || inputPk.WinSize != nil {
siPk := packet.MakeSpecialInputPacket()
siPk.CK = inputPk.CK
siPk.SigName = inputPk.SigName
siPk.WinSize = inputPk.WinSize
err := msh.ServerProc.Input.SendPacket(siPk)
if err != nil {
return err
}
}
return nil
}
msh.Lock.Lock()
sink := msh.CommandInputMap[inputPk.CK]
msh.Lock.Unlock()
if sink == nil {
// no sink and no running command
return fmt.Errorf("cannot send input, cmd is not running")
}
return sink.HandleInput(inputPk)
}
func (msh *MShellProc) AddRunningCmd(rct *RunCmdType) { func (msh *MShellProc) AddRunningCmd(rct *RunCmdType) {
msh.Lock.Lock() msh.Lock.Lock()
defer msh.Lock.Unlock() defer msh.Lock.Unlock()

View File

@ -115,6 +115,7 @@ var IgnoreVars = map[string]bool{
"WAVESHELL_VERSION": true, "WAVESHELL_VERSION": true,
"WAVETERM": true, "WAVETERM": true,
"WAVETERM_VERSION": true, "WAVETERM_VERSION": true,
"WAVETERM_DEV": true,
"TERM_PROGRAM": true, "TERM_PROGRAM": true,
"TERM_PROGRAM_VERSION": true, "TERM_PROGRAM_VERSION": true,
"TERM_SESSION_ID": true, "TERM_SESSION_ID": true,

View File

@ -4,6 +4,7 @@
package scpacket package scpacket
import ( import (
"encoding/base64"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
@ -191,6 +192,14 @@ func MakeFeInputPacket() *FeInputPacketType {
return &FeInputPacketType{Type: FeInputPacketStr} return &FeInputPacketType{Type: FeInputPacketStr}
} }
func (pk *FeInputPacketType) DecodeData() ([]byte, error) {
return base64.StdEncoding.DecodeString(pk.InputData64)
}
func (pk *FeInputPacketType) SetData(data []byte) {
pk.InputData64 = base64.StdEncoding.EncodeToString(data)
}
func (*WatchScreenPacketType) GetType() string { func (*WatchScreenPacketType) GetType() string {
return WatchScreenPacketStr return WatchScreenPacketStr
} }

View File

@ -23,7 +23,6 @@ import (
) )
const WSStatePacketChSize = 20 const WSStatePacketChSize = 20
const MaxInputDataSize = 1000
const RemoteInputQueueSize = 100 const RemoteInputQueueSize = 100
var RemoteInputMapQueue *mapqueue.MapQueue var RemoteInputMapQueue *mapqueue.MapQueue
@ -247,7 +246,7 @@ func (ws *WSState) processMessage(msgBytes []byte) error {
err := RemoteInputMapQueue.Enqueue(feInputPk.Remote.RemoteId, func() { err := RemoteInputMapQueue.Enqueue(feInputPk.Remote.RemoteId, func() {
sendErr := sendCmdInput(feInputPk) sendErr := sendCmdInput(feInputPk)
if sendErr != nil { if sendErr != nil {
log.Printf("[scws] sending command input: %v\n", err) log.Printf("[scws] sending command input: %v\n", sendErr)
} }
}) })
if err != nil { if err != nil {
@ -263,7 +262,7 @@ func (ws *WSState) processMessage(msgBytes []byte) error {
go func() { go func() {
sendErr := remote.SendRemoteInput(inputPk) sendErr := remote.SendRemoteInput(inputPk)
if sendErr != nil { if sendErr != nil {
log.Printf("[scws] error processing remote input: %v\n", err) log.Printf("[scws] error processing remote input: %v\n", sendErr)
} }
}() }()
return nil return nil
@ -319,29 +318,5 @@ func sendCmdInput(pk *scpacket.FeInputPacketType) error {
if msh == nil { if msh == nil {
return fmt.Errorf("remote %s not found", pk.Remote.RemoteId) return fmt.Errorf("remote %s not found", pk.Remote.RemoteId)
} }
if len(pk.InputData64) > 0 { return msh.HandleFeInput(pk)
inputLen := packet.B64DecodedLen(pk.InputData64)
if inputLen > MaxInputDataSize {
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
}
dataPk := packet.MakeDataPacket()
dataPk.CK = pk.CK
dataPk.FdNum = 0 // stdin
dataPk.Data64 = pk.InputData64
err = msh.SendInput(dataPk)
if err != nil {
return err
}
}
if pk.SigName != "" || pk.WinSize != nil {
siPk := packet.MakeSpecialInputPacket()
siPk.CK = pk.CK
siPk.SigName = pk.SigName
siPk.WinSize = pk.WinSize
err = msh.SendSpecialInput(siPk)
if err != nil {
return err
}
}
return nil
} }