mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
more zsh reinitialization fixes (allow user input during initialization process) (#480)
* fix error logs in scws * new RpcFollowUpPacketType * make the rpc/followup handlers generic on the server side -- using new RpcHandlers map and RpcFollowUpPacketType * rpcinputpacket for passing user input back through to reinit command * add WAVETERM_DEV env var in dev mode * remove unused code, ensure mshell and rcfile directory on startup (prevent root clobber with sudo) * combine all feinput into one function msh.HandleFeInput, and add a new concept of input sinks for special cases (like reinit) * allow reset to accept user input (to get around interactive initialization problems) * tone down the selection background highlight color on dark mode. easier to read selected text * fix command focus and done focus issues with dynamic (non-run) commands * add 'module' as a 'rtnstate' command (#478) * reinitialize shells in parallel, fix timeouts, better error messages
This commit is contained in:
parent
fb59e094e4
commit
0781e6e821
@ -27,6 +27,6 @@
|
||||
--term-cmdtext: #ffffff;
|
||||
--term-foreground: #d3d7cf;
|
||||
--term-background: #000000;
|
||||
--term-selection-background: #ffffff90;
|
||||
--term-selection-background: #ffffff60;
|
||||
--term-cursor-accent: #000000;
|
||||
}
|
||||
|
@ -135,6 +135,7 @@ func main() {
|
||||
base.ProcessType = base.ProcessType_WaveShellServer
|
||||
wlog.GlobalSubsystem = base.ProcessType_WaveShellServer
|
||||
base.InitDebugLog("server")
|
||||
base.EnsureRcFilesDir()
|
||||
rtnCode, err := server.RunServer()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
|
||||
|
@ -12,7 +12,6 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@ -72,11 +71,11 @@ func IsWaveSrv() bool {
|
||||
return ProcessType == ProcessType_WaveSrv
|
||||
}
|
||||
|
||||
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
|
||||
if sessionId == "" && cmdId == "" {
|
||||
func MakeCommandKey(screenId string, lineId string) CommandKey {
|
||||
if screenId == "" && lineId == "" {
|
||||
return CommandKey("")
|
||||
}
|
||||
return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId))
|
||||
return CommandKey(fmt.Sprintf("%s/%s", screenId, lineId))
|
||||
}
|
||||
|
||||
func (ckey CommandKey) IsEmpty() bool {
|
||||
@ -200,51 +199,6 @@ func GetMShellHomeDir() string {
|
||||
return ExpandHomeDir(DefaultMShellHome)
|
||||
}
|
||||
|
||||
func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
|
||||
if err := ck.Validate("ck"); err != nil {
|
||||
return nil, fmt.Errorf("cannot get command files: %w", err)
|
||||
}
|
||||
sessionId, cmdId := ck.Split()
|
||||
sdir, err := EnsureSessionDir(sessionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
base := path.Join(sdir, cmdId)
|
||||
return &CommandFileNames{
|
||||
PtyOutFile: base + ".ptyout",
|
||||
StdinFifo: base + ".stdin",
|
||||
RunnerOutFile: base + ".runout",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CleanUpCmdFiles(sessionId string, cmdId string) error {
|
||||
if cmdId == "" {
|
||||
return fmt.Errorf("bad cmdid, cannot clean up")
|
||||
}
|
||||
sdir, err := EnsureSessionDir(sessionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmdFileGlob := path.Join(sdir, cmdId+".*")
|
||||
matches, err := filepath.Glob(cmdFileGlob)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, file := range matches {
|
||||
rmErr := os.Remove(file)
|
||||
if err == nil && rmErr != nil {
|
||||
err = rmErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func GetSessionsDir() string {
|
||||
mhome := GetMShellHomeDir()
|
||||
sdir := path.Join(mhome, SessionsDirBaseName)
|
||||
return sdir
|
||||
}
|
||||
|
||||
func EnsureRcFilesDir() (string, error) {
|
||||
mhome := GetMShellHomeDir()
|
||||
dirName := path.Join(mhome, RcFilesDirBaseName)
|
||||
@ -255,19 +209,6 @@ func EnsureRcFilesDir() (string, error) {
|
||||
return dirName, nil
|
||||
}
|
||||
|
||||
func EnsureSessionDir(sessionId string) (string, error) {
|
||||
if sessionId == "" {
|
||||
return "", fmt.Errorf("Bad sessionid, cannot be empty")
|
||||
}
|
||||
mhome := GetMShellHomeDir()
|
||||
sdir := path.Join(mhome, SessionsDirBaseName, sessionId)
|
||||
err := CacheEnsureDir(sdir, sessionId, 0777, "mshell session dir")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sdir, nil
|
||||
}
|
||||
|
||||
func GetMShellPath() (string, error) {
|
||||
msPath := os.Getenv(MShellPathVarName) // use MSHELL_PATH
|
||||
if msPath != "" {
|
||||
@ -282,11 +223,6 @@ func GetMShellPath() (string, error) {
|
||||
return exec.LookPath(DefaultMShellName) // standard path lookup for 'mshell'
|
||||
}
|
||||
|
||||
func GetMShellSessionsDir() (string, error) {
|
||||
mhome := GetMShellHomeDir()
|
||||
return path.Join(mhome, SessionsDirBaseName), nil
|
||||
}
|
||||
|
||||
func ExpandHomeDir(pathStr string) string {
|
||||
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
|
||||
return pathStr
|
||||
@ -315,22 +251,6 @@ func GoArchOptFile(version string, goos string, goarch string) string {
|
||||
return fmt.Sprintf(path.Join(installBinDir, binBaseName))
|
||||
}
|
||||
|
||||
func MShellBinaryFromOptDir(version string, goos string, goarch string) (io.ReadCloser, error) {
|
||||
if !ValidGoArch(goos, goarch) {
|
||||
return nil, fmt.Errorf("invalid goos/goarch combination: %s/%s", goos, goarch)
|
||||
}
|
||||
versionStr := semver.MajorMinor(version)
|
||||
if versionStr == "" {
|
||||
return nil, fmt.Errorf("invalid mshell version: %q", version)
|
||||
}
|
||||
fileName := GoArchOptFile(version, goos, goarch)
|
||||
fd, err := os.Open(fileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot open mshell binary %q: %v", fileName, err)
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
func GetRemoteId() (string, error) {
|
||||
mhome := GetMShellHomeDir()
|
||||
homeInfo, err := os.Stat(mhome)
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
@ -63,6 +64,7 @@ const (
|
||||
FileStatPacketStr = "filestat"
|
||||
LogPacketStr = "log" // logging packet (sent from waveshell back to server)
|
||||
ShellStatePacketStr = "shellstate"
|
||||
RpcInputPacketStr = "rpcinput" // rpc-followup
|
||||
|
||||
OpenAIPacketStr = "openai" // other
|
||||
OpenAICloudReqStr = "openai-cloudreq"
|
||||
@ -116,6 +118,7 @@ func init() {
|
||||
TypeStrToFactory[LogPacketStr] = reflect.TypeOf(LogPacketType{})
|
||||
TypeStrToFactory[ShellStatePacketStr] = reflect.TypeOf(ShellStatePacketType{})
|
||||
TypeStrToFactory[FileStatPacketStr] = reflect.TypeOf(FileStatPacketType{})
|
||||
TypeStrToFactory[RpcInputPacketStr] = reflect.TypeOf(RpcInputPacketType{})
|
||||
|
||||
var _ RpcPacketType = (*RunPacketType)(nil)
|
||||
var _ RpcPacketType = (*GetCmdPacketType)(nil)
|
||||
@ -134,6 +137,9 @@ func init() {
|
||||
var _ RpcResponsePacketType = (*WriteFileDonePacketType)(nil)
|
||||
var _ RpcResponsePacketType = (*ShellStatePacketType)(nil)
|
||||
|
||||
var _ RpcFollowUpPacketType = (*FileDataPacketType)(nil)
|
||||
var _ RpcFollowUpPacketType = (*RpcInputPacketType)(nil)
|
||||
|
||||
var _ CommandPacketType = (*DataPacketType)(nil)
|
||||
var _ CommandPacketType = (*DataAckPacketType)(nil)
|
||||
var _ CommandPacketType = (*CmdDonePacketType)(nil)
|
||||
@ -166,6 +172,26 @@ func MakePingPacket() *PingPacketType {
|
||||
return &PingPacketType{Type: PingPacketStr}
|
||||
}
|
||||
|
||||
type RpcInputPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ReqId string `json:"reqid"`
|
||||
Data []byte `json:"data"`
|
||||
}
|
||||
|
||||
func (*RpcInputPacketType) GetType() string {
|
||||
return RpcInputPacketStr
|
||||
}
|
||||
|
||||
func (p *RpcInputPacketType) GetAssociatedReqId() string {
|
||||
return p.ReqId
|
||||
}
|
||||
|
||||
func MakeRpcInputPacket(reqId string) *RpcInputPacketType {
|
||||
return &RpcInputPacketType{Type: RpcInputPacketStr, ReqId: reqId}
|
||||
}
|
||||
|
||||
// these packets can travel either direction
|
||||
// so it is both a RpcResponsePacketType and an RpcFollowUpPacketType
|
||||
type FileDataPacketType struct {
|
||||
Type string `json:"type"`
|
||||
RespId string `json:"respid"`
|
||||
@ -185,6 +211,10 @@ func MakeFileDataPacket(reqId string) *FileDataPacketType {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *FileDataPacketType) GetAssociatedReqId() string {
|
||||
return p.RespId
|
||||
}
|
||||
|
||||
func (p *FileDataPacketType) GetResponseId() string {
|
||||
return p.RespId
|
||||
}
|
||||
@ -976,6 +1006,12 @@ type CommandPacketType interface {
|
||||
GetCK() base.CommandKey
|
||||
}
|
||||
|
||||
// RpcPackets initiate an Rpc. these can be part of the data passed back and forth
|
||||
type RpcFollowUpPacketType interface {
|
||||
GetType() string
|
||||
GetAssociatedReqId() string
|
||||
}
|
||||
|
||||
type ModelUpdatePacketType struct {
|
||||
Type string `json:"type"`
|
||||
Updates []any `json:"updates"`
|
||||
@ -1178,6 +1214,14 @@ func (sender *PacketSender) SendPacketCtx(ctx context.Context, pk PacketType) er
|
||||
}
|
||||
|
||||
func (sender *PacketSender) SendPacket(pk PacketType) error {
|
||||
if pk == nil {
|
||||
log.Printf("tried to send nil packet\n")
|
||||
return fmt.Errorf("tried to send nil packet")
|
||||
}
|
||||
if pk.GetType() == "" {
|
||||
log.Printf("tried to send invalid packet: %T\n", pk)
|
||||
return fmt.Errorf("tried to send packet without a type: %T", pk)
|
||||
}
|
||||
err := sender.checkStatus()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -30,6 +30,7 @@ const MaxFileDataPacketSize = 16 * 1024
|
||||
const WriteFileContextTimeout = 30 * time.Second
|
||||
const cleanLoopTime = 5 * time.Second
|
||||
const MaxWriteFileContextData = 100
|
||||
const InboundRpcErrorTimeoutTime = 30 * time.Second
|
||||
|
||||
type shellStateMapKey struct {
|
||||
ShellType string
|
||||
@ -52,8 +53,89 @@ type MServer struct {
|
||||
StateMap *ShellStateMap
|
||||
WriteErrorCh chan bool // closed if there is a I/O write error
|
||||
WriteErrorChOnce *sync.Once
|
||||
WriteFileContextMap map[string]*WriteFileContext
|
||||
Done bool
|
||||
InboundRpcHandlers map[string]RpcHandler
|
||||
InboundRpcErrorSent map[string]time.Time // limits the amount of error messages sent back to the client
|
||||
}
|
||||
|
||||
var _ RpcHandler = (*WriteFileContext)(nil)
|
||||
|
||||
type RpcHandler interface {
|
||||
GetTimeoutTime() time.Time
|
||||
DispatchPacket(reqId string, pk packet.RpcFollowUpPacketType)
|
||||
UnRegisterCallback()
|
||||
}
|
||||
|
||||
func (m *MServer) registerRpcHandler(reqId string, handler RpcHandler) error {
|
||||
if handler == nil {
|
||||
return errors.New("handler is nil")
|
||||
}
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
if m.InboundRpcHandlers[reqId] != nil {
|
||||
return errors.New("handler already registered")
|
||||
}
|
||||
delete(m.InboundRpcErrorSent, reqId)
|
||||
m.InboundRpcHandlers[reqId] = handler
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MServer) unregisterRpcHandler(reqId string) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
handler := m.InboundRpcHandlers[reqId]
|
||||
delete(m.InboundRpcHandlers, reqId)
|
||||
if handler != nil {
|
||||
handler.UnRegisterCallback()
|
||||
}
|
||||
}
|
||||
|
||||
// limits the number of error responses that can be sent
|
||||
func (m *MServer) sendInboundRpcError(reqId string, err error) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
if _, ok := m.InboundRpcErrorSent[reqId]; ok {
|
||||
return
|
||||
}
|
||||
m.InboundRpcErrorSent[reqId] = time.Now().Add(InboundRpcErrorTimeoutTime)
|
||||
m.Sender.SendErrorResponse(reqId, err)
|
||||
}
|
||||
|
||||
// returns true if dispatched to a waiting RPC handler
|
||||
func (m *MServer) dispatchRpcFollowUp(pk packet.RpcFollowUpPacketType) bool {
|
||||
if pk == nil {
|
||||
return true
|
||||
}
|
||||
reqId := pk.GetAssociatedReqId()
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
if rh := m.InboundRpcHandlers[reqId]; rh != nil {
|
||||
rh.DispatchPacket(reqId, pk)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MServer) cleanRpcHandlers() {
|
||||
var staleHandlers []RpcHandler
|
||||
now := time.Now()
|
||||
m.Lock.Lock()
|
||||
for reqId, rh := range m.InboundRpcHandlers {
|
||||
if now.After(rh.GetTimeoutTime()) {
|
||||
staleHandlers = append(staleHandlers, rh)
|
||||
delete(m.InboundRpcHandlers, reqId)
|
||||
}
|
||||
}
|
||||
for reqId, timeoutTime := range m.InboundRpcErrorSent {
|
||||
if now.After(timeoutTime) {
|
||||
delete(m.InboundRpcErrorSent, reqId)
|
||||
}
|
||||
}
|
||||
m.Lock.Unlock()
|
||||
// we do this outside of m.Lock just in case there is some lock contention (UnRegisterCallback might be slow)
|
||||
for _, rh := range staleHandlers {
|
||||
rh.UnRegisterCallback()
|
||||
}
|
||||
}
|
||||
|
||||
type WriteFileContext struct {
|
||||
@ -78,25 +160,13 @@ func (m *MServer) checkDone() bool {
|
||||
return m.Done
|
||||
}
|
||||
|
||||
func (m *MServer) getWriteFileContext(reqId string) *WriteFileContext {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
wfc := m.WriteFileContextMap[reqId]
|
||||
if wfc == nil {
|
||||
wfc = &WriteFileContext{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
LastActive: time.Now(),
|
||||
}
|
||||
m.WriteFileContextMap[reqId] = wfc
|
||||
}
|
||||
return wfc
|
||||
func (wfc *WriteFileContext) GetTimeoutTime() time.Time {
|
||||
return wfc.LastActive.Add(WriteFileContextTimeout)
|
||||
}
|
||||
|
||||
func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) {
|
||||
m.Lock.Lock()
|
||||
wfc := m.WriteFileContextMap[pk.RespId]
|
||||
m.Lock.Unlock()
|
||||
if wfc == nil {
|
||||
func (wfc *WriteFileContext) DispatchPacket(reqId string, pkArg packet.RpcFollowUpPacketType) {
|
||||
dataPk, ok := pkArg.(*packet.FileDataPacketType)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
wfc.CVar.L.Lock()
|
||||
@ -111,11 +181,11 @@ func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) {
|
||||
return
|
||||
}
|
||||
wfc.LastActive = time.Now()
|
||||
wfc.Data = append(wfc.Data, pk)
|
||||
wfc.Data = append(wfc.Data, dataPk)
|
||||
wfc.CVar.Signal()
|
||||
}
|
||||
|
||||
func (wfc *WriteFileContext) setDone() {
|
||||
func (wfc *WriteFileContext) UnRegisterCallback() {
|
||||
wfc.CVar.L.Lock()
|
||||
defer wfc.CVar.L.Unlock()
|
||||
wfc.Done = true
|
||||
@ -123,24 +193,6 @@ func (wfc *WriteFileContext) setDone() {
|
||||
wfc.CVar.Broadcast()
|
||||
}
|
||||
|
||||
func (m *MServer) cleanWriteFileContexts() {
|
||||
now := time.Now()
|
||||
var staleWfcs []*WriteFileContext
|
||||
m.Lock.Lock()
|
||||
for reqId, wfc := range m.WriteFileContextMap {
|
||||
if now.Sub(wfc.LastActive) > WriteFileContextTimeout {
|
||||
staleWfcs = append(staleWfcs, wfc)
|
||||
delete(m.WriteFileContextMap, reqId)
|
||||
}
|
||||
}
|
||||
m.Lock.Unlock()
|
||||
|
||||
// we do this outside of m.Lock just in case there is some lock contention (end of WriteFile could theoretically be slow)
|
||||
for _, wfc := range staleWfcs {
|
||||
wfc.setDone()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
|
||||
ck := pk.GetCK()
|
||||
if ck == "" {
|
||||
@ -155,7 +207,6 @@ func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
|
||||
return
|
||||
}
|
||||
cproc.Input.SendPacket(pk)
|
||||
return
|
||||
}
|
||||
|
||||
func runSingleCompGen(cwd string, compType string, prefix string) ([]string, bool, error) {
|
||||
@ -226,7 +277,6 @@ func (m *MServer) runMixedCompGen(compPk *packet.CompGenPacketType) {
|
||||
}
|
||||
sort.Strings(comps) // resort
|
||||
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": (hasMoreFiles || hasMoreDirs)})
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
|
||||
@ -246,10 +296,46 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
|
||||
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
|
||||
}
|
||||
|
||||
type ReinitRpcHandler struct {
|
||||
ReqId string
|
||||
TimeoutTime time.Time
|
||||
StdinDataCh chan []byte
|
||||
}
|
||||
|
||||
func (rh *ReinitRpcHandler) GetTimeoutTime() time.Time {
|
||||
return rh.TimeoutTime
|
||||
}
|
||||
|
||||
func (rh *ReinitRpcHandler) DispatchPacket(reqId string, pkArg packet.RpcFollowUpPacketType) {
|
||||
rpcInput, ok := pkArg.(*packet.RpcInputPacketType)
|
||||
if !ok {
|
||||
wlog.Logf("reinit rpc handler: invalid packet type: %T", pkArg)
|
||||
return
|
||||
}
|
||||
// nonblocking send
|
||||
select {
|
||||
case rh.StdinDataCh <- rpcInput.Data:
|
||||
default:
|
||||
wlog.Logf("reinit rpc handler: stdin data channel full, dropping data")
|
||||
}
|
||||
}
|
||||
|
||||
func (rh *ReinitRpcHandler) UnRegisterCallback() {
|
||||
close(rh.StdinDataCh)
|
||||
}
|
||||
|
||||
func (m *MServer) reinit(reqId string, shellType string) {
|
||||
ssPk, err := m.MakeShellStatePacket(reqId, shellType)
|
||||
stdinDataCh := make(chan []byte, 10)
|
||||
rh := &ReinitRpcHandler{
|
||||
ReqId: reqId,
|
||||
TimeoutTime: time.Now().Add(30 * time.Second),
|
||||
StdinDataCh: stdinDataCh,
|
||||
}
|
||||
m.registerRpcHandler(reqId, rh)
|
||||
defer m.unregisterRpcHandler(reqId)
|
||||
ssPk, err := m.MakeShellStatePacket(reqId, shellType, stdinDataCh)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error initializing shell: %w", err))
|
||||
return
|
||||
}
|
||||
err = m.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
|
||||
@ -261,13 +347,13 @@ func (m *MServer) reinit(reqId string, shellType string) {
|
||||
m.Sender.SendPacket(ssPk)
|
||||
}
|
||||
|
||||
func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) {
|
||||
func (m *MServer) MakeShellStatePacket(reqId string, shellType string, stdinDataCh chan []byte) (*packet.ShellStatePacketType, error) {
|
||||
sapi, err := shellapi.MakeShellApi(shellType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtnCh := make(chan shellapi.ShellStateOutput, 1)
|
||||
go sapi.GetShellState(rtnCh)
|
||||
go sapi.GetShellState(rtnCh, stdinDataCh)
|
||||
for ssOutput := range rtnCh {
|
||||
if ssOutput.Error != "" {
|
||||
return nil, errors.New(ssOutput.Error)
|
||||
@ -357,7 +443,7 @@ func copyFile(dstName string, srcName string) error {
|
||||
}
|
||||
|
||||
func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContext) {
|
||||
defer wfc.setDone()
|
||||
defer m.unregisterRpcHandler(pk.ReqId)
|
||||
if pk.Path == "" {
|
||||
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
|
||||
resp.Error = "invalid write-file request, no path specified"
|
||||
@ -478,7 +564,6 @@ func (m *MServer) returnStreamFileNewFileResponse(pk *packet.StreamFilePacketTyp
|
||||
Perm: int(dirInfo.Mode().Perm()),
|
||||
NotFound: true,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) streamFile(pk *packet.StreamFilePacketType) {
|
||||
@ -608,12 +693,19 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
|
||||
return
|
||||
}
|
||||
if writePk, ok := pk.(*packet.WriteFilePacketType); ok {
|
||||
wfc := m.getWriteFileContext(writePk.ReqId)
|
||||
wfc := &WriteFileContext{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
LastActive: time.Now(),
|
||||
}
|
||||
err := m.registerRpcHandler(writePk.ReqId, wfc)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error registering write-file handler: %w", err))
|
||||
return
|
||||
}
|
||||
go m.writeFile(writePk, wfc)
|
||||
return
|
||||
}
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("invalid rpc type '%s'", pk.GetType()))
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) clientPacketCallback(shellType string, pk packet.PacketType) {
|
||||
@ -726,7 +818,7 @@ func (server *MServer) runReadLoop() {
|
||||
builder := packet.MakeRunPacketBuilder()
|
||||
for pk := range server.MainInput.MainCh {
|
||||
if server.Debug {
|
||||
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
||||
wlog.Logf("runReadLoop got packet %s\n", packet.AsString(pk))
|
||||
}
|
||||
ok, runPacket := builder.ProcessPacket(pk)
|
||||
if ok {
|
||||
@ -736,16 +828,19 @@ func (server *MServer) runReadLoop() {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
|
||||
if cmdPk, ok := pk.(packet.CommandPacketType); ok && cmdPk.GetCK() != "" {
|
||||
server.ProcessCommandPacket(cmdPk)
|
||||
continue
|
||||
}
|
||||
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
|
||||
if rpcPk, ok := pk.(packet.RpcPacketType); ok && rpcPk.GetReqId() != "" {
|
||||
server.ProcessRpcPacket(rpcPk)
|
||||
continue
|
||||
}
|
||||
if fileDataPk, ok := pk.(*packet.FileDataPacketType); ok {
|
||||
server.addFileDataPacket(fileDataPk)
|
||||
if rpcFollowUp, ok := pk.(packet.RpcFollowUpPacketType); ok && rpcFollowUp.GetAssociatedReqId() != "" {
|
||||
ok := server.dispatchRpcFollowUp(rpcFollowUp)
|
||||
if !ok {
|
||||
server.sendInboundRpcError(rpcFollowUp.GetAssociatedReqId(), fmt.Errorf("no handler for rpc follow-up packet"))
|
||||
}
|
||||
continue
|
||||
}
|
||||
server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk))
|
||||
@ -765,7 +860,8 @@ func RunServer() (int, error) {
|
||||
Debug: debug,
|
||||
WriteErrorCh: make(chan bool),
|
||||
WriteErrorChOnce: &sync.Once{},
|
||||
WriteFileContextMap: make(map[string]*WriteFileContext),
|
||||
InboundRpcHandlers: make(map[string]RpcHandler),
|
||||
InboundRpcErrorSent: make(map[string]time.Time),
|
||||
}
|
||||
if debug {
|
||||
packet.GlobalDebug = true
|
||||
@ -780,7 +876,7 @@ func RunServer() (int, error) {
|
||||
return
|
||||
}
|
||||
time.Sleep(cleanLoopTime)
|
||||
server.cleanWriteFileContexts()
|
||||
server.cleanRpcHandlers()
|
||||
}
|
||||
}()
|
||||
var err error
|
||||
|
@ -80,8 +80,8 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
|
||||
return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
|
||||
}
|
||||
|
||||
func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) {
|
||||
GetBashShellState(outCh)
|
||||
func (b bashShellApi) GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) {
|
||||
GetBashShellState(outCh, stdinDataCh)
|
||||
}
|
||||
|
||||
func (b bashShellApi) GetBaseShellOpts() string {
|
||||
@ -169,7 +169,7 @@ func GetLocalBashMajorVersion() string {
|
||||
return localBashMajorVersion
|
||||
}
|
||||
|
||||
func GetBashShellState(outCh chan ShellStateOutput) {
|
||||
func GetBashShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
defer close(outCh)
|
||||
@ -185,7 +185,7 @@ func GetBashShellState(outCh chan ShellStateOutput) {
|
||||
outCh <- ShellStateOutput{Output: outputBytes}
|
||||
}
|
||||
}()
|
||||
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
|
||||
outputBytes, err := StreamCommandWithExtraFd(ctx, ecmd, outputCh, StateOutputFdNum, endBytes, stdinDataCh)
|
||||
outputWg.Wait()
|
||||
if err != nil {
|
||||
outCh <- ShellStateOutput{Error: err.Error()}
|
||||
|
@ -25,9 +25,11 @@ import (
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/wlog"
|
||||
)
|
||||
|
||||
const GetStateTimeout = 15 * time.Second
|
||||
const GetStateTimeout = 10 * time.Second
|
||||
const ReInitTimeout = GetStateTimeout + 2*time.Second
|
||||
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
|
||||
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
|
||||
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
|
||||
@ -71,7 +73,7 @@ type ShellApi interface {
|
||||
GetRemoteShellPath() string
|
||||
MakeRunCommand(cmdStr string, opts RunCommandOpts) string
|
||||
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
|
||||
GetShellState(chan ShellStateOutput)
|
||||
GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte)
|
||||
GetBaseShellOpts() string
|
||||
ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error)
|
||||
MakeRcFileStr(pk *packet.RunPacketType) string
|
||||
@ -154,7 +156,7 @@ func internalMacUserShell() string {
|
||||
const FirstExtraFilesFdNum = 3
|
||||
|
||||
// returns output(stdout+stderr), extraFdOutput, error
|
||||
func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte) ([]byte, error) {
|
||||
func StreamCommandWithExtraFd(ctx context.Context, ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte, stdinDataCh chan []byte) ([]byte, error) {
|
||||
defer close(outputCh)
|
||||
ecmd.Env = os.Environ()
|
||||
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
|
||||
@ -202,8 +204,27 @@ func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum i
|
||||
defer outputWg.Done()
|
||||
utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes)
|
||||
}()
|
||||
if stdinDataCh != nil {
|
||||
go func() {
|
||||
// continue this loop even after an error to drain stdinDataCh
|
||||
hadErr := false
|
||||
for stdinData := range stdinDataCh {
|
||||
if hadErr {
|
||||
continue
|
||||
}
|
||||
_, err := cmdPty.Write(stdinData)
|
||||
if err != nil {
|
||||
wlog.Logf("error writing to shellstate cmdpty (stdin): %v\n", err)
|
||||
hadErr = true
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
exitErr := ecmd.Wait()
|
||||
if exitErr != nil {
|
||||
if ctx.Err() != nil {
|
||||
return nil, fmt.Errorf("%w (%w)", ctx.Err(), exitErr)
|
||||
}
|
||||
return nil, exitErr
|
||||
}
|
||||
outputWg.Wait()
|
||||
|
@ -246,7 +246,7 @@ func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
|
||||
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) {
|
||||
func (z zshShellApi) GetShellState(outCh chan ShellStateOutput, stdinDataCh chan []byte) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
defer close(outCh)
|
||||
@ -262,7 +262,7 @@ func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) {
|
||||
outCh <- ShellStateOutput{Output: outputBytes}
|
||||
}
|
||||
}()
|
||||
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
|
||||
outputBytes, err := StreamCommandWithExtraFd(ctx, ecmd, outputCh, StateOutputFdNum, endBytes, stdinDataCh)
|
||||
outputWg.Wait()
|
||||
if err != nil {
|
||||
outCh <- ShellStateOutput{Error: err.Error()}
|
||||
@ -726,7 +726,6 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
|
||||
// sections: see ZshSection_* consts
|
||||
sections := bytes.Split(outputBytes, sectionSeparator)
|
||||
if len(sections) != ZshSection_NumFieldsExpected {
|
||||
base.Logf("invalid -- numfields\n")
|
||||
return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
|
||||
}
|
||||
rtn := &packet.ShellState{}
|
||||
|
@ -30,6 +30,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/server"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
@ -1186,9 +1187,17 @@ func deferWriteCmdStatus(ctx context.Context, cmd *sstore.CmdType, startTime tim
|
||||
err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus)
|
||||
if err != nil {
|
||||
// nothing to do
|
||||
log.Printf("error updating cmddoneinfo (in openai): %v\n", err)
|
||||
log.Printf("error updating cmddoneinfo: %v\n", err)
|
||||
return
|
||||
}
|
||||
screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, cmd.ScreenId, cmd.LineId)
|
||||
if err != nil {
|
||||
log.Printf("error trying to update screen focus type: %v\n", err)
|
||||
// fall-through (nothing to do)
|
||||
}
|
||||
if screen != nil {
|
||||
update.AddUpdate(*screen)
|
||||
}
|
||||
scbus.MainUpdateBus.DoScreenUpdate(cmd.ScreenId, update)
|
||||
}
|
||||
|
||||
@ -3686,7 +3695,7 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil)
|
||||
update, err := addLineForCmd(ctx, "/reset", true, ids, cmd, "", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -3695,7 +3704,7 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
|
||||
}
|
||||
|
||||
func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), shellapi.ReInitTimeout)
|
||||
defer cancelFn()
|
||||
startTime := time.Now()
|
||||
var outputPos int64
|
||||
@ -3712,7 +3721,7 @@ func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verb
|
||||
writeStringToPty(ctx, cmd, string(data), &outputPos)
|
||||
}
|
||||
origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
|
||||
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType, dataFn, verbose)
|
||||
ssPk, err := ids.Remote.MShell.ReInit(ctx, base.MakeCommandKey(cmd.ScreenId, cmd.LineId), shellType, dataFn, verbose)
|
||||
if err != nil {
|
||||
rtnErr = err
|
||||
return
|
||||
@ -3975,14 +3984,14 @@ func splitLinesForInfo(str string) []string {
|
||||
}
|
||||
|
||||
func resizeRunningCommand(ctx context.Context, cmd *sstore.CmdType, newCols int) error {
|
||||
siPk := packet.MakeSpecialInputPacket()
|
||||
siPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
|
||||
siPk.WinSize = &packet.WinSize{Rows: int(cmd.TermOpts.Rows), Cols: newCols}
|
||||
feInput := scpacket.MakeFeInputPacket()
|
||||
feInput.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
|
||||
feInput.WinSize = &packet.WinSize{Rows: int(cmd.TermOpts.Rows), Cols: newCols}
|
||||
msh := remote.GetRemoteById(cmd.Remote.RemoteId)
|
||||
if msh == nil {
|
||||
return fmt.Errorf("cannot resize, cmd remote not found")
|
||||
}
|
||||
err := msh.SendSpecialInput(siPk)
|
||||
err := msh.HandleFeInput(feInput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -5178,10 +5187,10 @@ func SignalCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus
|
||||
if !msh.IsConnected() {
|
||||
return nil, fmt.Errorf("cannot send signal, remote is not connected")
|
||||
}
|
||||
siPk := packet.MakeSpecialInputPacket()
|
||||
siPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
|
||||
siPk.SigName = sigArg
|
||||
err = msh.SendSpecialInput(siPk)
|
||||
inputPk := scpacket.MakeFeInputPacket()
|
||||
inputPk.CK = base.MakeCommandKey(cmd.ScreenId, cmd.LineId)
|
||||
inputPk.SigName = sigArg
|
||||
err = msh.HandleFeInput(inputPk)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot send signal: %v", err)
|
||||
}
|
||||
|
@ -215,6 +215,7 @@ var literalRtnStateCommands = []string{
|
||||
"disable",
|
||||
"function",
|
||||
"zmodload",
|
||||
"module",
|
||||
}
|
||||
|
||||
func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string {
|
||||
|
@ -54,6 +54,7 @@ const RemoteTermCols = 80
|
||||
const PtyReadBufSize = 100
|
||||
const RemoteConnectTimeout = 15 * time.Second
|
||||
const RpcIterChannelSize = 100
|
||||
const MaxInputDataSize = 1000
|
||||
|
||||
var envVarsToStrip map[string]bool = map[string]bool{
|
||||
"PROMPT": true,
|
||||
@ -128,6 +129,7 @@ type pendingStateKey struct {
|
||||
RemotePtr sstore.RemotePtrType
|
||||
}
|
||||
|
||||
// provides state, acccess, and control for a waveshell server process
|
||||
type MShellProc struct {
|
||||
Lock *sync.Mutex
|
||||
Remote *sstore.RemoteType
|
||||
@ -135,7 +137,7 @@ type MShellProc struct {
|
||||
// runtime
|
||||
RemoteId string // can be read without a lock
|
||||
Status string
|
||||
ServerProc *shexec.ClientProc
|
||||
ServerProc *shexec.ClientProc // the server process
|
||||
UName string
|
||||
Err error
|
||||
ErrNoInitPk bool
|
||||
@ -154,9 +156,18 @@ type MShellProc struct {
|
||||
InstallCancelFn context.CancelFunc
|
||||
InstallErr error
|
||||
|
||||
// for synthetic commands (not run through RunCommand), this provides a way for them
|
||||
// to register to receive input events from the frontend (e.g. ReInit)
|
||||
CommandInputMap map[base.CommandKey]CommandInputSink
|
||||
|
||||
RunningCmds map[base.CommandKey]*RunCmdType
|
||||
PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name]
|
||||
Client *ssh.Client
|
||||
PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name] (in progress commands that might update the state)
|
||||
|
||||
Client *ssh.Client
|
||||
}
|
||||
|
||||
type CommandInputSink interface {
|
||||
HandleInput(feInput *scpacket.FeInputPacketType) error
|
||||
}
|
||||
|
||||
type RunCmdType struct {
|
||||
@ -169,6 +180,22 @@ type RunCmdType struct {
|
||||
EphCancled atomic.Bool // only for Ephemeral commands, if true, then the command result should be discarded
|
||||
}
|
||||
|
||||
type ReinitCommandSink struct {
|
||||
Remote *MShellProc
|
||||
ReqId string
|
||||
}
|
||||
|
||||
func (rcs *ReinitCommandSink) HandleInput(feInput *scpacket.FeInputPacketType) error {
|
||||
realData, err := base64.StdEncoding.DecodeString(feInput.InputData64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding input data: %v", err)
|
||||
}
|
||||
inputPk := packet.MakeRpcInputPacket(rcs.ReqId)
|
||||
inputPk.Data = realData
|
||||
rcs.Remote.ServerProc.Input.SendPacket(inputPk)
|
||||
return nil
|
||||
}
|
||||
|
||||
type RemoteRuntimeState = sstore.RemoteRuntimeState
|
||||
|
||||
func CanComplete(remoteType string) bool {
|
||||
@ -196,7 +223,7 @@ func (msh *MShellProc) EnsureShellType(ctx context.Context, shellType string) er
|
||||
return nil
|
||||
}
|
||||
// try to reinit the shell
|
||||
_, err := msh.ReInit(ctx, shellType, nil, false)
|
||||
_, err := msh.ReInit(ctx, base.CommandKey(""), shellType, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err)
|
||||
}
|
||||
@ -694,6 +721,7 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
|
||||
Status: StatusDisconnected,
|
||||
PtyBuffer: buf,
|
||||
InstallStatus: StatusDisconnected,
|
||||
CommandInputMap: make(map[base.CommandKey]CommandInputSink),
|
||||
RunningCmds: make(map[base.CommandKey]*RunCmdType),
|
||||
PendingStateCmds: make(map[pendingStateKey]base.CommandKey),
|
||||
StateMap: server.MakeShellStateMap(),
|
||||
@ -1401,7 +1429,7 @@ func makeReinitErrorUpdate(shellType string) sstore.ActivityUpdate {
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) {
|
||||
func (msh *MShellProc) ReInit(ctx context.Context, ck base.CommandKey, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) {
|
||||
if !msh.IsConnected() {
|
||||
return nil, fmt.Errorf("cannot reinit, remote is not connected")
|
||||
}
|
||||
@ -1425,6 +1453,14 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func
|
||||
return nil, err
|
||||
}
|
||||
defer rpcIter.Close()
|
||||
if ck != "" {
|
||||
reinitSink := &ReinitCommandSink{
|
||||
Remote: msh,
|
||||
ReqId: reinitPk.ReqId,
|
||||
}
|
||||
msh.registerInputSink(ck, reinitSink)
|
||||
defer msh.unregisterInputSink(ck)
|
||||
}
|
||||
var ssPk *packet.ShellStatePacketType
|
||||
for {
|
||||
resp, err := rpcIter.Next(ctx)
|
||||
@ -1506,6 +1542,9 @@ func addScVarsToState(state *packet.ShellState) *packet.ShellState {
|
||||
envMap["WAVETERM_VERSION"] = &shellenv.DeclareDeclType{Name: "WAVETERM_VERSION", Value: scbase.WaveVersion, Args: "x"}
|
||||
envMap["TERM_PROGRAM"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM", Value: "waveterm", Args: "x"}
|
||||
envMap["TERM_PROGRAM_VERSION"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM_VERSION", Value: scbase.WaveVersion, Args: "x"}
|
||||
if scbase.IsDevMode() {
|
||||
envMap["WAVETERM_DEV"] = &shellenv.DeclareDeclType{Name: "WAVETERM_DEV", Value: "1", Args: "x"}
|
||||
}
|
||||
if _, exists := envMap["LANG"]; !exists {
|
||||
envMap["LANG"] = &shellenv.DeclareDeclType{Name: "LANG", Value: scbase.DetermineLang(), Args: "x"}
|
||||
}
|
||||
@ -1727,20 +1766,28 @@ func (msh *MShellProc) Launch(interactive bool) {
|
||||
}
|
||||
|
||||
func (msh *MShellProc) initActiveShells() {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
gasCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
activeShells, err := msh.getActiveShellTypes(ctx)
|
||||
activeShells, err := msh.getActiveShellTypes(gasCtx)
|
||||
if err != nil {
|
||||
// we're not going to fail the connect for this error (it will be unusable, but technically connected)
|
||||
msh.WriteToPtyBuffer("*error getting active shells: %v\n", err)
|
||||
return
|
||||
}
|
||||
for _, shellType := range activeShells {
|
||||
_, err = msh.ReInit(ctx, shellType, nil, false)
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
for _, shellTypeForVar := range activeShells {
|
||||
wg.Add(1)
|
||||
go func(shellType string) {
|
||||
defer wg.Done()
|
||||
reinitCtx, cancelFn := context.WithTimeout(context.Background(), shellapi.ReInitTimeout)
|
||||
defer cancelFn()
|
||||
_, err = msh.ReInit(reinitCtx, base.CommandKey(""), shellType, nil, false)
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
|
||||
}
|
||||
}(shellTypeForVar)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (msh *MShellProc) IsConnected() bool {
|
||||
@ -1775,24 +1822,14 @@ func (msh *MShellProc) IsCmdRunning(ck base.CommandKey) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
func (msh *MShellProc) SendInput(dataPk *packet.DataPacketType) error {
|
||||
if !msh.IsConnected() {
|
||||
return fmt.Errorf("remote is not connected, cannot send input")
|
||||
}
|
||||
if !msh.IsCmdRunning(dataPk.CK) {
|
||||
return fmt.Errorf("cannot send input, cmd is not running")
|
||||
}
|
||||
return msh.ServerProc.Input.SendPacket(dataPk)
|
||||
}
|
||||
|
||||
func (msh *MShellProc) KillRunningCommandAndWait(ctx context.Context, ck base.CommandKey) error {
|
||||
if !msh.IsCmdRunning(ck) {
|
||||
return nil
|
||||
}
|
||||
siPk := packet.MakeSpecialInputPacket()
|
||||
siPk.CK = ck
|
||||
siPk.SigName = "SIGTERM"
|
||||
err := msh.SendSpecialInput(siPk)
|
||||
feiPk := scpacket.MakeFeInputPacket()
|
||||
feiPk.CK = ck
|
||||
feiPk.SigName = "SIGTERM"
|
||||
err := msh.HandleFeInput(feiPk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error trying to kill running cmd: %w", err)
|
||||
}
|
||||
@ -1809,16 +1846,6 @@ func (msh *MShellProc) KillRunningCommandAndWait(ctx context.Context, ck base.Co
|
||||
}
|
||||
}
|
||||
|
||||
func (msh *MShellProc) SendSpecialInput(siPk *packet.SpecialInputPacketType) error {
|
||||
if !msh.IsConnected() {
|
||||
return fmt.Errorf("remote is not connected, cannot send input")
|
||||
}
|
||||
if !msh.IsCmdRunning(siPk.CK) {
|
||||
return fmt.Errorf("cannot send input, cmd is not running")
|
||||
}
|
||||
return msh.ServerProc.Input.SendPacket(siPk)
|
||||
}
|
||||
|
||||
func (msh *MShellProc) SendFileData(dataPk *packet.FileDataPacketType) error {
|
||||
if !msh.IsConnected() {
|
||||
return fmt.Errorf("remote is not connected, cannot send input")
|
||||
@ -2065,6 +2092,62 @@ func makePSCLineError(existingPSC base.CommandKey, line *sstore.LineType, lineEr
|
||||
return fmt.Errorf("cannot run command while a stateful command (linenum=%d) is still running", line.LineNum)
|
||||
}
|
||||
|
||||
func (msh *MShellProc) registerInputSink(ck base.CommandKey, sink CommandInputSink) {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
msh.CommandInputMap[ck] = sink
|
||||
}
|
||||
|
||||
func (msh *MShellProc) unregisterInputSink(ck base.CommandKey) {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
delete(msh.CommandInputMap, ck)
|
||||
}
|
||||
|
||||
func (msh *MShellProc) HandleFeInput(inputPk *scpacket.FeInputPacketType) error {
|
||||
if inputPk == nil {
|
||||
return nil
|
||||
}
|
||||
if !msh.IsConnected() {
|
||||
return fmt.Errorf("connection is not connected, cannot send input")
|
||||
}
|
||||
if msh.IsCmdRunning(inputPk.CK) {
|
||||
if len(inputPk.InputData64) > 0 {
|
||||
inputLen := packet.B64DecodedLen(inputPk.InputData64)
|
||||
if inputLen > MaxInputDataSize {
|
||||
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
|
||||
}
|
||||
dataPk := packet.MakeDataPacket()
|
||||
dataPk.CK = inputPk.CK
|
||||
dataPk.FdNum = 0 // stdin
|
||||
dataPk.Data64 = inputPk.InputData64
|
||||
err := msh.ServerProc.Input.SendPacket(dataPk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if inputPk.SigName != "" || inputPk.WinSize != nil {
|
||||
siPk := packet.MakeSpecialInputPacket()
|
||||
siPk.CK = inputPk.CK
|
||||
siPk.SigName = inputPk.SigName
|
||||
siPk.WinSize = inputPk.WinSize
|
||||
err := msh.ServerProc.Input.SendPacket(siPk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
msh.Lock.Lock()
|
||||
sink := msh.CommandInputMap[inputPk.CK]
|
||||
msh.Lock.Unlock()
|
||||
if sink == nil {
|
||||
// no sink and no running command
|
||||
return fmt.Errorf("cannot send input, cmd is not running")
|
||||
}
|
||||
return sink.HandleInput(inputPk)
|
||||
}
|
||||
|
||||
func (msh *MShellProc) AddRunningCmd(rct *RunCmdType) {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
|
@ -115,6 +115,7 @@ var IgnoreVars = map[string]bool{
|
||||
"WAVESHELL_VERSION": true,
|
||||
"WAVETERM": true,
|
||||
"WAVETERM_VERSION": true,
|
||||
"WAVETERM_DEV": true,
|
||||
"TERM_PROGRAM": true,
|
||||
"TERM_PROGRAM_VERSION": true,
|
||||
"TERM_SESSION_ID": true,
|
||||
|
@ -4,6 +4,7 @@
|
||||
package scpacket
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
@ -191,6 +192,14 @@ func MakeFeInputPacket() *FeInputPacketType {
|
||||
return &FeInputPacketType{Type: FeInputPacketStr}
|
||||
}
|
||||
|
||||
func (pk *FeInputPacketType) DecodeData() ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(pk.InputData64)
|
||||
}
|
||||
|
||||
func (pk *FeInputPacketType) SetData(data []byte) {
|
||||
pk.InputData64 = base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
func (*WatchScreenPacketType) GetType() string {
|
||||
return WatchScreenPacketStr
|
||||
}
|
||||
|
@ -23,7 +23,6 @@ import (
|
||||
)
|
||||
|
||||
const WSStatePacketChSize = 20
|
||||
const MaxInputDataSize = 1000
|
||||
const RemoteInputQueueSize = 100
|
||||
|
||||
var RemoteInputMapQueue *mapqueue.MapQueue
|
||||
@ -247,7 +246,7 @@ func (ws *WSState) processMessage(msgBytes []byte) error {
|
||||
err := RemoteInputMapQueue.Enqueue(feInputPk.Remote.RemoteId, func() {
|
||||
sendErr := sendCmdInput(feInputPk)
|
||||
if sendErr != nil {
|
||||
log.Printf("[scws] sending command input: %v\n", err)
|
||||
log.Printf("[scws] sending command input: %v\n", sendErr)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
@ -263,7 +262,7 @@ func (ws *WSState) processMessage(msgBytes []byte) error {
|
||||
go func() {
|
||||
sendErr := remote.SendRemoteInput(inputPk)
|
||||
if sendErr != nil {
|
||||
log.Printf("[scws] error processing remote input: %v\n", err)
|
||||
log.Printf("[scws] error processing remote input: %v\n", sendErr)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
@ -319,29 +318,5 @@ func sendCmdInput(pk *scpacket.FeInputPacketType) error {
|
||||
if msh == nil {
|
||||
return fmt.Errorf("remote %s not found", pk.Remote.RemoteId)
|
||||
}
|
||||
if len(pk.InputData64) > 0 {
|
||||
inputLen := packet.B64DecodedLen(pk.InputData64)
|
||||
if inputLen > MaxInputDataSize {
|
||||
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
|
||||
}
|
||||
dataPk := packet.MakeDataPacket()
|
||||
dataPk.CK = pk.CK
|
||||
dataPk.FdNum = 0 // stdin
|
||||
dataPk.Data64 = pk.InputData64
|
||||
err = msh.SendInput(dataPk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if pk.SigName != "" || pk.WinSize != nil {
|
||||
siPk := packet.MakeSpecialInputPacket()
|
||||
siPk.CK = pk.CK
|
||||
siPk.SigName = pk.SigName
|
||||
siPk.WinSize = pk.WinSize
|
||||
err = msh.SendSpecialInput(siPk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return msh.HandleFeInput(pk)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user