create a remote update queue to ensure that we send the line update before we send cmd updates

This commit is contained in:
sawka 2022-09-05 14:49:23 -07:00
parent b980fd6b74
commit 54e0ecffe1
4 changed files with 157 additions and 77 deletions

View File

@ -180,7 +180,10 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.U
} }
} }
runPacket.Command = strings.TrimSpace(cmdStr) runPacket.Command = strings.TrimSpace(cmdStr)
cmd, err := remote.RunCommand(ctx, cmdId, ids.Remote.RemotePtr, ids.Remote.RemoteState, runPacket) cmd, callback, err := remote.RunCommand(ctx, cmdId, ids.Remote.RemotePtr, ids.Remote.RemoteState, runPacket)
if callback != nil {
defer callback()
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -49,9 +49,10 @@ const (
var GlobalStore *Store var GlobalStore *Store
type Store struct { type Store struct {
Lock *sync.Mutex Lock *sync.Mutex
Map map[string]*MShellProc // key=remoteid Map map[string]*MShellProc // key=remoteid
Log *CircleLog Log *CircleLog
CmdWaitMap map[base.CommandKey][]sstore.UpdatePacket
} }
type MShellProc struct { type MShellProc struct {
@ -123,9 +124,10 @@ func (state RemoteRuntimeState) GetDisplayName(rptr *sstore.RemotePtrType) strin
func LoadRemotes(ctx context.Context) error { func LoadRemotes(ctx context.Context) error {
GlobalStore = &Store{ GlobalStore = &Store{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
Map: make(map[string]*MShellProc), Map: make(map[string]*MShellProc),
Log: MakeCircleLog(100), Log: MakeCircleLog(100),
CmdWaitMap: make(map[base.CommandKey][]sstore.UpdatePacket),
} }
allRemotes, err := sstore.GetAllRemotes(ctx) allRemotes, err := sstore.GetAllRemotes(ctx)
if err != nil { if err != nil {
@ -308,7 +310,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
return state return state
} }
func (msh *MShellProc) NotifyUpdate() { func (msh *MShellProc) NotifyRemoteUpdate() {
rstate := msh.GetRemoteRuntimeState() rstate := msh.GetRemoteRuntimeState()
update := &sstore.ModelUpdate{Remotes: []interface{}{rstate}} update := &sstore.ModelUpdate{Remotes: []interface{}{rstate}}
sstore.MainBus.SendUpdate("", update) sstore.MainBus.SendUpdate("", update)
@ -382,7 +384,7 @@ func (msh *MShellProc) setErrorStatus(err error) {
defer msh.Lock.Unlock() defer msh.Lock.Unlock()
msh.Status = StatusError msh.Status = StatusError
msh.Err = err msh.Err = err
go msh.NotifyUpdate() go msh.NotifyRemoteUpdate()
} }
func (msh *MShellProc) getRemoteCopy() sstore.RemoteType { func (msh *MShellProc) getRemoteCopy() sstore.RemoteType {
@ -458,7 +460,7 @@ func (msh *MShellProc) Launch() {
msh.WithLock(func() { msh.WithLock(func() {
msh.ServerProc = cproc msh.ServerProc = cproc
msh.Status = StatusConnected msh.Status = StatusConnected
go msh.NotifyUpdate() go msh.NotifyRemoteUpdate()
}) })
go func() { go func() {
exitErr := cproc.Cmd.Wait() exitErr := cproc.Cmd.Wait()
@ -466,7 +468,7 @@ func (msh *MShellProc) Launch() {
msh.WithLock(func() { msh.WithLock(func() {
if msh.Status == StatusConnected { if msh.Status == StatusConnected {
msh.Status = StatusDisconnected msh.Status = StatusDisconnected
go msh.NotifyUpdate() go msh.NotifyRemoteUpdate()
} }
}) })
logf(&remoteCopy, "remote disconnected exitcode=%d", exitCode) logf(&remoteCopy, "remote disconnected exitcode=%d", exitCode)
@ -546,39 +548,40 @@ func makeTermOpts(runPk *packet.RunPacketType) sstore.TermOpts {
return sstore.TermOpts{Rows: int64(runPk.TermOpts.Rows), Cols: int64(runPk.TermOpts.Cols), FlexRows: true, MaxPtySize: DefaultMaxPtySize} return sstore.TermOpts{Rows: int64(runPk.TermOpts.Rows), Cols: int64(runPk.TermOpts.Cols), FlexRows: true, MaxPtySize: DefaultMaxPtySize}
} }
func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrType, remoteState *sstore.RemoteState, runPacket *packet.RunPacketType) (*sstore.CmdType, error) { // returns (cmdtype, allow-updates-callback, err)
func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrType, remoteState *sstore.RemoteState, runPacket *packet.RunPacketType) (*sstore.CmdType, func(), error) {
if remotePtr.OwnerId != "" { if remotePtr.OwnerId != "" {
return nil, fmt.Errorf("cannot run command against another user's remote '%s'", remotePtr.MakeFullRemoteRef()) return nil, nil, fmt.Errorf("cannot run command against another user's remote '%s'", remotePtr.MakeFullRemoteRef())
} }
msh := GetRemoteById(remotePtr.RemoteId) msh := GetRemoteById(remotePtr.RemoteId)
if msh == nil { if msh == nil {
return nil, fmt.Errorf("no remote id=%s found", remotePtr.RemoteId) return nil, nil, fmt.Errorf("no remote id=%s found", remotePtr.RemoteId)
} }
if !msh.IsConnected() { if !msh.IsConnected() {
return nil, fmt.Errorf("remote '%s' is not connected", remotePtr.RemoteId) return nil, nil, fmt.Errorf("remote '%s' is not connected", remotePtr.RemoteId)
} }
if remoteState == nil { if remoteState == nil {
return nil, fmt.Errorf("no remote state passed to RunCommand") return nil, nil, fmt.Errorf("no remote state passed to RunCommand")
} }
msh.ServerProc.Output.RegisterRpc(runPacket.ReqId) msh.ServerProc.Output.RegisterRpc(runPacket.ReqId)
err := shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket) err := shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket)
if err != nil { if err != nil {
return nil, fmt.Errorf("sending run packet to remote: %w", err) return nil, nil, fmt.Errorf("sending run packet to remote: %w", err)
} }
rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId) rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId)
if rtnPk == nil { if rtnPk == nil {
return nil, ctx.Err() return nil, nil, ctx.Err()
} }
startPk, ok := rtnPk.(*packet.CmdStartPacketType) startPk, ok := rtnPk.(*packet.CmdStartPacketType)
if !ok { if !ok {
respPk, ok := rtnPk.(*packet.ResponsePacketType) respPk, ok := rtnPk.(*packet.ResponsePacketType)
if !ok { if !ok {
return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) return nil, nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
} }
if respPk.Error != "" { if respPk.Error != "" {
return nil, errors.New(respPk.Error) return nil, nil, errors.New(respPk.Error)
} }
return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) return nil, nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
} }
status := sstore.CmdStatusRunning status := sstore.CmdStatusRunning
if runPacket.Detached { if runPacket.Detached {
@ -598,10 +601,11 @@ func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrTyp
} }
err = sstore.CreateCmdPtyFile(ctx, cmd.SessionId, cmd.CmdId, cmd.TermOpts.MaxPtySize) err = sstore.CreateCmdPtyFile(ctx, cmd.SessionId, cmd.CmdId, cmd.TermOpts.MaxPtySize)
if err != nil { if err != nil {
return nil, err // TODO the cmd is running, so this is a tricky error to handle
return nil, nil, fmt.Errorf("cannot create local ptyout file for running command: %v", err)
} }
msh.AddRunningCmd(startPk.CK) msh.AddRunningCmd(startPk.CK)
return cmd, nil return cmd, func() { removeCmdWait(startPk.CK) }, nil
} }
func (msh *MShellProc) AddRunningCmd(ck base.CommandKey) { func (msh *MShellProc) AddRunningCmd(ck base.CommandKey) {
@ -651,31 +655,6 @@ func makeDataAckPacket(ck base.CommandKey, fdNum int, ackLen int, err error) *pa
return ack return ack
} }
func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
update, err := sstore.UpdateCmdDonePk(context.Background(), donePk)
if err != nil {
fmt.Printf("[error] updating cmddone: %v\n", err)
return
}
if update != nil {
// TODO fix timing issue (this update gets to the FE before run-command returns for short lived commands)
go func() {
time.Sleep(10 * time.Millisecond)
sstore.MainBus.SendUpdate(donePk.CK.GetSessionId(), update)
}()
}
return
}
func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) {
err := sstore.AppendCmdErrorPk(context.Background(), errPk)
if err != nil {
fmt.Printf("[error] adding cmderr: %v\n", err)
return
}
return
}
func (msh *MShellProc) notifyHangups_nolock() { func (msh *MShellProc) notifyHangups_nolock() {
for _, ck := range msh.RunningCmds { for _, ck := range msh.RunningCmds {
cmd, err := sstore.GetCmdById(context.Background(), ck.GetSessionId(), ck.GetCmdId()) cmd, err := sstore.GetCmdById(context.Background(), ck.GetSessionId(), ck.GetCmdId())
@ -688,6 +667,59 @@ func (msh *MShellProc) notifyHangups_nolock() {
msh.RunningCmds = nil msh.RunningCmds = nil
} }
func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
update, err := sstore.UpdateCmdDonePk(context.Background(), donePk)
if err != nil {
fmt.Printf("[error] updating cmddone: %v\n", err)
return
}
if update != nil {
// TODO fix timing issue (this update gets to the FE before run-command returns for short lived commands)
go func() {
time.Sleep(10 * time.Millisecond)
sendCmdUpdate(donePk.CK, update)
}()
}
return
}
// TODO notify FE about cmd errors
func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) {
err := sstore.AppendCmdErrorPk(context.Background(), errPk)
if err != nil {
fmt.Printf("[error] adding cmderr: %v\n", err)
return
}
return
}
func (msh *MShellProc) handleDataPacket(dataPk *packet.DataPacketType, dataPosMap map[base.CommandKey]int64) {
realData, err := base64.StdEncoding.DecodeString(dataPk.Data64)
if err != nil {
ack := makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err)
msh.ServerProc.Input.SendPacket(ack)
return
}
var ack *packet.DataAckPacketType
if len(realData) > 0 {
dataPos := dataPosMap[dataPk.CK]
update, err := sstore.AppendToCmdPtyBlob(context.Background(), dataPk.CK.GetSessionId(), dataPk.CK.GetCmdId(), realData, dataPos)
if err != nil {
ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err)
} else {
ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, len(realData), nil)
}
dataPosMap[dataPk.CK] += int64(len(realData))
if update != nil {
sendCmdUpdate(dataPk.CK, update)
}
}
if ack != nil {
msh.ServerProc.Input.SendPacket(ack)
}
// fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPk.CK, dataPk.FdNum, len(realData), dataPk.Eof, dataPk.Error)
}
func (msh *MShellProc) ProcessPackets() { func (msh *MShellProc) ProcessPackets() {
defer msh.WithLock(func() { defer msh.WithLock(func() {
if msh.Status == StatusConnected { if msh.Status == StatusConnected {
@ -698,33 +730,13 @@ func (msh *MShellProc) ProcessPackets() {
logf(msh.Remote, "calling HUP on cmds %v", err) logf(msh.Remote, "calling HUP on cmds %v", err)
} }
msh.notifyHangups_nolock() msh.notifyHangups_nolock()
go msh.NotifyUpdate() go msh.NotifyRemoteUpdate()
}) })
dataPosMap := make(map[base.CommandKey]int64) dataPosMap := make(map[base.CommandKey]int64)
for pk := range msh.ServerProc.Output.MainCh { for pk := range msh.ServerProc.Output.MainCh {
if pk.GetType() == packet.DataPacketStr { if pk.GetType() == packet.DataPacketStr {
dataPk := pk.(*packet.DataPacketType) dataPk := pk.(*packet.DataPacketType)
realData, err := base64.StdEncoding.DecodeString(dataPk.Data64) msh.handleDataPacket(dataPk, dataPosMap)
if err != nil {
ack := makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err)
msh.ServerProc.Input.SendPacket(ack)
continue
}
var ack *packet.DataAckPacketType
if len(realData) > 0 {
dataPos := dataPosMap[dataPk.CK]
err = sstore.AppendToCmdPtyBlob(context.Background(), dataPk.CK.GetSessionId(), dataPk.CK.GetCmdId(), realData, dataPos)
if err != nil {
ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err)
} else {
ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, len(realData), nil)
}
dataPosMap[dataPk.CK] += int64(len(realData))
}
if ack != nil {
msh.ServerProc.Input.SendPacket(ack)
}
// fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPk.CK, dataPk.FdNum, len(realData), dataPk.Eof, dataPk.Error)
continue continue
} }
if pk.GetType() == packet.DataAckPacketStr { if pk.GetType() == packet.DataAckPacketStr {

66
pkg/remote/updatequeue.go Normal file
View File

@ -0,0 +1,66 @@
package remote
import (
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
)
func pushCmdWaitIfRequired(ck base.CommandKey, update sstore.UpdatePacket) bool {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
updates, ok := GlobalStore.CmdWaitMap[ck]
if !ok {
return false
}
updates = append(updates, update)
GlobalStore.CmdWaitMap[ck] = updates
return true
}
func sendCmdUpdate(ck base.CommandKey, update sstore.UpdatePacket) {
pushed := pushCmdWaitIfRequired(ck, update)
if pushed {
return
}
sstore.MainBus.SendUpdate(ck.GetSessionId(), update)
}
func runCmdWaitUpdates(ck base.CommandKey) {
for {
update := removeFirstCmdWaitUpdate(ck)
if update == nil {
break
}
sstore.MainBus.SendUpdate(ck.GetSessionId(), update)
}
}
func removeFirstCmdWaitUpdate(ck base.CommandKey) sstore.UpdatePacket {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
updates := GlobalStore.CmdWaitMap[ck]
if len(updates) == 0 {
delete(GlobalStore.CmdWaitMap, ck)
return nil
}
if len(updates) == 1 {
delete(GlobalStore.CmdWaitMap, ck)
return updates[0]
}
update := updates[0]
GlobalStore.CmdWaitMap[ck] = updates[1:]
return update
}
func removeCmdWait(ck base.CommandKey) {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
updates := GlobalStore.CmdWaitMap[ck]
if len(updates) == 0 {
delete(GlobalStore.CmdWaitMap, ck)
return
}
go runCmdWaitUpdates(ck)
}

View File

@ -21,22 +21,22 @@ func CreateCmdPtyFile(ctx context.Context, sessionId string, cmdId string, maxSi
return f.Close() return f.Close()
} }
func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte, pos int64) error { func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte, pos int64) (*PtyDataUpdate, error) {
if pos < 0 { if pos < 0 {
return fmt.Errorf("invalid seek pos '%d' in AppendToCmdPtyBlob", pos) return nil, fmt.Errorf("invalid seek pos '%d' in AppendToCmdPtyBlob", pos)
} }
ptyOutFileName, err := scbase.PtyOutFile(sessionId, cmdId) ptyOutFileName, err := scbase.PtyOutFile(sessionId, cmdId)
if err != nil { if err != nil {
return err return nil, err
} }
f, err := cirfile.OpenCirFile(ptyOutFileName) f, err := cirfile.OpenCirFile(ptyOutFileName)
if err != nil { if err != nil {
return err return nil, err
} }
defer f.Close() defer f.Close()
err = f.WriteAt(ctx, data, pos) err = f.WriteAt(ctx, data, pos)
if err != nil { if err != nil {
return err return nil, err
} }
data64 := base64.StdEncoding.EncodeToString(data) data64 := base64.StdEncoding.EncodeToString(data)
update := &PtyDataUpdate{ update := &PtyDataUpdate{
@ -46,8 +46,7 @@ func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, dat
PtyData64: data64, PtyData64: data64,
PtyDataLen: int64(len(data)), PtyDataLen: int64(len(data)),
} }
MainBus.SendUpdate(sessionId, update) return update, nil
return nil
} }
// returns (offset, data, err) // returns (offset, data, err)