diff --git a/pkg/cmdrunner/cmdrunner.go b/pkg/cmdrunner/cmdrunner.go index fe1e7ae81..4e46e7626 100644 --- a/pkg/cmdrunner/cmdrunner.go +++ b/pkg/cmdrunner/cmdrunner.go @@ -180,7 +180,10 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.U } } runPacket.Command = strings.TrimSpace(cmdStr) - cmd, err := remote.RunCommand(ctx, cmdId, ids.Remote.RemotePtr, ids.Remote.RemoteState, runPacket) + cmd, callback, err := remote.RunCommand(ctx, cmdId, ids.Remote.RemotePtr, ids.Remote.RemoteState, runPacket) + if callback != nil { + defer callback() + } if err != nil { return nil, err } diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 64cbf6cfc..f6eeedab3 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -49,9 +49,10 @@ const ( var GlobalStore *Store type Store struct { - Lock *sync.Mutex - Map map[string]*MShellProc // key=remoteid - Log *CircleLog + Lock *sync.Mutex + Map map[string]*MShellProc // key=remoteid + Log *CircleLog + CmdWaitMap map[base.CommandKey][]sstore.UpdatePacket } type MShellProc struct { @@ -123,9 +124,10 @@ func (state RemoteRuntimeState) GetDisplayName(rptr *sstore.RemotePtrType) strin func LoadRemotes(ctx context.Context) error { GlobalStore = &Store{ - Lock: &sync.Mutex{}, - Map: make(map[string]*MShellProc), - Log: MakeCircleLog(100), + Lock: &sync.Mutex{}, + Map: make(map[string]*MShellProc), + Log: MakeCircleLog(100), + CmdWaitMap: make(map[base.CommandKey][]sstore.UpdatePacket), } allRemotes, err := sstore.GetAllRemotes(ctx) if err != nil { @@ -308,7 +310,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState { return state } -func (msh *MShellProc) NotifyUpdate() { +func (msh *MShellProc) NotifyRemoteUpdate() { rstate := msh.GetRemoteRuntimeState() update := &sstore.ModelUpdate{Remotes: []interface{}{rstate}} sstore.MainBus.SendUpdate("", update) @@ -382,7 +384,7 @@ func (msh *MShellProc) setErrorStatus(err error) { defer msh.Lock.Unlock() msh.Status = StatusError msh.Err = err - go msh.NotifyUpdate() + go msh.NotifyRemoteUpdate() } func (msh *MShellProc) getRemoteCopy() sstore.RemoteType { @@ -458,7 +460,7 @@ func (msh *MShellProc) Launch() { msh.WithLock(func() { msh.ServerProc = cproc msh.Status = StatusConnected - go msh.NotifyUpdate() + go msh.NotifyRemoteUpdate() }) go func() { exitErr := cproc.Cmd.Wait() @@ -466,7 +468,7 @@ func (msh *MShellProc) Launch() { msh.WithLock(func() { if msh.Status == StatusConnected { msh.Status = StatusDisconnected - go msh.NotifyUpdate() + go msh.NotifyRemoteUpdate() } }) logf(&remoteCopy, "remote disconnected exitcode=%d", exitCode) @@ -546,39 +548,40 @@ func makeTermOpts(runPk *packet.RunPacketType) sstore.TermOpts { return sstore.TermOpts{Rows: int64(runPk.TermOpts.Rows), Cols: int64(runPk.TermOpts.Cols), FlexRows: true, MaxPtySize: DefaultMaxPtySize} } -func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrType, remoteState *sstore.RemoteState, runPacket *packet.RunPacketType) (*sstore.CmdType, error) { +// returns (cmdtype, allow-updates-callback, err) +func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrType, remoteState *sstore.RemoteState, runPacket *packet.RunPacketType) (*sstore.CmdType, func(), error) { if remotePtr.OwnerId != "" { - return nil, fmt.Errorf("cannot run command against another user's remote '%s'", remotePtr.MakeFullRemoteRef()) + return nil, nil, fmt.Errorf("cannot run command against another user's remote '%s'", remotePtr.MakeFullRemoteRef()) } msh := GetRemoteById(remotePtr.RemoteId) if msh == nil { - return nil, fmt.Errorf("no remote id=%s found", remotePtr.RemoteId) + return nil, nil, fmt.Errorf("no remote id=%s found", remotePtr.RemoteId) } if !msh.IsConnected() { - return nil, fmt.Errorf("remote '%s' is not connected", remotePtr.RemoteId) + return nil, nil, fmt.Errorf("remote '%s' is not connected", remotePtr.RemoteId) } if remoteState == nil { - return nil, fmt.Errorf("no remote state passed to RunCommand") + return nil, nil, fmt.Errorf("no remote state passed to RunCommand") } msh.ServerProc.Output.RegisterRpc(runPacket.ReqId) err := shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket) if err != nil { - return nil, fmt.Errorf("sending run packet to remote: %w", err) + return nil, nil, fmt.Errorf("sending run packet to remote: %w", err) } rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId) if rtnPk == nil { - return nil, ctx.Err() + return nil, nil, ctx.Err() } startPk, ok := rtnPk.(*packet.CmdStartPacketType) if !ok { respPk, ok := rtnPk.(*packet.ResponsePacketType) if !ok { - return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) + return nil, nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) } if respPk.Error != "" { - return nil, errors.New(respPk.Error) + return nil, nil, errors.New(respPk.Error) } - return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) + return nil, nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) } status := sstore.CmdStatusRunning if runPacket.Detached { @@ -598,10 +601,11 @@ func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrTyp } err = sstore.CreateCmdPtyFile(ctx, cmd.SessionId, cmd.CmdId, cmd.TermOpts.MaxPtySize) if err != nil { - return nil, err + // TODO the cmd is running, so this is a tricky error to handle + return nil, nil, fmt.Errorf("cannot create local ptyout file for running command: %v", err) } msh.AddRunningCmd(startPk.CK) - return cmd, nil + return cmd, func() { removeCmdWait(startPk.CK) }, nil } func (msh *MShellProc) AddRunningCmd(ck base.CommandKey) { @@ -651,31 +655,6 @@ func makeDataAckPacket(ck base.CommandKey, fdNum int, ackLen int, err error) *pa return ack } -func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) { - update, err := sstore.UpdateCmdDonePk(context.Background(), donePk) - if err != nil { - fmt.Printf("[error] updating cmddone: %v\n", err) - return - } - if update != nil { - // TODO fix timing issue (this update gets to the FE before run-command returns for short lived commands) - go func() { - time.Sleep(10 * time.Millisecond) - sstore.MainBus.SendUpdate(donePk.CK.GetSessionId(), update) - }() - } - return -} - -func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) { - err := sstore.AppendCmdErrorPk(context.Background(), errPk) - if err != nil { - fmt.Printf("[error] adding cmderr: %v\n", err) - return - } - return -} - func (msh *MShellProc) notifyHangups_nolock() { for _, ck := range msh.RunningCmds { cmd, err := sstore.GetCmdById(context.Background(), ck.GetSessionId(), ck.GetCmdId()) @@ -688,6 +667,59 @@ func (msh *MShellProc) notifyHangups_nolock() { msh.RunningCmds = nil } +func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) { + update, err := sstore.UpdateCmdDonePk(context.Background(), donePk) + if err != nil { + fmt.Printf("[error] updating cmddone: %v\n", err) + return + } + if update != nil { + // TODO fix timing issue (this update gets to the FE before run-command returns for short lived commands) + go func() { + time.Sleep(10 * time.Millisecond) + sendCmdUpdate(donePk.CK, update) + }() + } + return +} + +// TODO notify FE about cmd errors +func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) { + err := sstore.AppendCmdErrorPk(context.Background(), errPk) + if err != nil { + fmt.Printf("[error] adding cmderr: %v\n", err) + return + } + return +} + +func (msh *MShellProc) handleDataPacket(dataPk *packet.DataPacketType, dataPosMap map[base.CommandKey]int64) { + realData, err := base64.StdEncoding.DecodeString(dataPk.Data64) + if err != nil { + ack := makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err) + msh.ServerProc.Input.SendPacket(ack) + return + } + var ack *packet.DataAckPacketType + if len(realData) > 0 { + dataPos := dataPosMap[dataPk.CK] + update, err := sstore.AppendToCmdPtyBlob(context.Background(), dataPk.CK.GetSessionId(), dataPk.CK.GetCmdId(), realData, dataPos) + if err != nil { + ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err) + } else { + ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, len(realData), nil) + } + dataPosMap[dataPk.CK] += int64(len(realData)) + if update != nil { + sendCmdUpdate(dataPk.CK, update) + } + } + if ack != nil { + msh.ServerProc.Input.SendPacket(ack) + } + // fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPk.CK, dataPk.FdNum, len(realData), dataPk.Eof, dataPk.Error) +} + func (msh *MShellProc) ProcessPackets() { defer msh.WithLock(func() { if msh.Status == StatusConnected { @@ -698,33 +730,13 @@ func (msh *MShellProc) ProcessPackets() { logf(msh.Remote, "calling HUP on cmds %v", err) } msh.notifyHangups_nolock() - go msh.NotifyUpdate() + go msh.NotifyRemoteUpdate() }) dataPosMap := make(map[base.CommandKey]int64) for pk := range msh.ServerProc.Output.MainCh { if pk.GetType() == packet.DataPacketStr { dataPk := pk.(*packet.DataPacketType) - realData, err := base64.StdEncoding.DecodeString(dataPk.Data64) - if err != nil { - ack := makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err) - msh.ServerProc.Input.SendPacket(ack) - continue - } - var ack *packet.DataAckPacketType - if len(realData) > 0 { - dataPos := dataPosMap[dataPk.CK] - err = sstore.AppendToCmdPtyBlob(context.Background(), dataPk.CK.GetSessionId(), dataPk.CK.GetCmdId(), realData, dataPos) - if err != nil { - ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err) - } else { - ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, len(realData), nil) - } - dataPosMap[dataPk.CK] += int64(len(realData)) - } - if ack != nil { - msh.ServerProc.Input.SendPacket(ack) - } - // fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPk.CK, dataPk.FdNum, len(realData), dataPk.Eof, dataPk.Error) + msh.handleDataPacket(dataPk, dataPosMap) continue } if pk.GetType() == packet.DataAckPacketStr { diff --git a/pkg/remote/updatequeue.go b/pkg/remote/updatequeue.go new file mode 100644 index 000000000..105bb61fb --- /dev/null +++ b/pkg/remote/updatequeue.go @@ -0,0 +1,66 @@ +package remote + +import ( + "github.com/scripthaus-dev/mshell/pkg/base" + "github.com/scripthaus-dev/sh2-server/pkg/sstore" +) + +func pushCmdWaitIfRequired(ck base.CommandKey, update sstore.UpdatePacket) bool { + GlobalStore.Lock.Lock() + defer GlobalStore.Lock.Unlock() + updates, ok := GlobalStore.CmdWaitMap[ck] + if !ok { + return false + } + updates = append(updates, update) + GlobalStore.CmdWaitMap[ck] = updates + return true +} + +func sendCmdUpdate(ck base.CommandKey, update sstore.UpdatePacket) { + pushed := pushCmdWaitIfRequired(ck, update) + if pushed { + return + } + sstore.MainBus.SendUpdate(ck.GetSessionId(), update) +} + +func runCmdWaitUpdates(ck base.CommandKey) { + for { + update := removeFirstCmdWaitUpdate(ck) + if update == nil { + break + } + sstore.MainBus.SendUpdate(ck.GetSessionId(), update) + } +} + +func removeFirstCmdWaitUpdate(ck base.CommandKey) sstore.UpdatePacket { + GlobalStore.Lock.Lock() + defer GlobalStore.Lock.Unlock() + + updates := GlobalStore.CmdWaitMap[ck] + if len(updates) == 0 { + delete(GlobalStore.CmdWaitMap, ck) + return nil + } + if len(updates) == 1 { + delete(GlobalStore.CmdWaitMap, ck) + return updates[0] + } + update := updates[0] + GlobalStore.CmdWaitMap[ck] = updates[1:] + return update +} + +func removeCmdWait(ck base.CommandKey) { + GlobalStore.Lock.Lock() + defer GlobalStore.Lock.Unlock() + + updates := GlobalStore.CmdWaitMap[ck] + if len(updates) == 0 { + delete(GlobalStore.CmdWaitMap, ck) + return + } + go runCmdWaitUpdates(ck) +} diff --git a/pkg/sstore/fileops.go b/pkg/sstore/fileops.go index f7de7a8b5..0ecf9aa02 100644 --- a/pkg/sstore/fileops.go +++ b/pkg/sstore/fileops.go @@ -21,22 +21,22 @@ func CreateCmdPtyFile(ctx context.Context, sessionId string, cmdId string, maxSi return f.Close() } -func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte, pos int64) error { +func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte, pos int64) (*PtyDataUpdate, error) { if pos < 0 { - return fmt.Errorf("invalid seek pos '%d' in AppendToCmdPtyBlob", pos) + return nil, fmt.Errorf("invalid seek pos '%d' in AppendToCmdPtyBlob", pos) } ptyOutFileName, err := scbase.PtyOutFile(sessionId, cmdId) if err != nil { - return err + return nil, err } f, err := cirfile.OpenCirFile(ptyOutFileName) if err != nil { - return err + return nil, err } defer f.Close() err = f.WriteAt(ctx, data, pos) if err != nil { - return err + return nil, err } data64 := base64.StdEncoding.EncodeToString(data) update := &PtyDataUpdate{ @@ -46,8 +46,7 @@ func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, dat PtyData64: data64, PtyDataLen: int64(len(data)), } - MainBus.SendUpdate(sessionId, update) - return nil + return update, nil } // returns (offset, data, err)