create a remote update queue to ensure that we send the line update before we send cmd updates

This commit is contained in:
sawka 2022-09-05 14:49:23 -07:00
parent b980fd6b74
commit 54e0ecffe1
4 changed files with 157 additions and 77 deletions

View File

@ -180,7 +180,10 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.U
}
}
runPacket.Command = strings.TrimSpace(cmdStr)
cmd, err := remote.RunCommand(ctx, cmdId, ids.Remote.RemotePtr, ids.Remote.RemoteState, runPacket)
cmd, callback, err := remote.RunCommand(ctx, cmdId, ids.Remote.RemotePtr, ids.Remote.RemoteState, runPacket)
if callback != nil {
defer callback()
}
if err != nil {
return nil, err
}

View File

@ -49,9 +49,10 @@ const (
var GlobalStore *Store
type Store struct {
Lock *sync.Mutex
Map map[string]*MShellProc // key=remoteid
Log *CircleLog
Lock *sync.Mutex
Map map[string]*MShellProc // key=remoteid
Log *CircleLog
CmdWaitMap map[base.CommandKey][]sstore.UpdatePacket
}
type MShellProc struct {
@ -123,9 +124,10 @@ func (state RemoteRuntimeState) GetDisplayName(rptr *sstore.RemotePtrType) strin
func LoadRemotes(ctx context.Context) error {
GlobalStore = &Store{
Lock: &sync.Mutex{},
Map: make(map[string]*MShellProc),
Log: MakeCircleLog(100),
Lock: &sync.Mutex{},
Map: make(map[string]*MShellProc),
Log: MakeCircleLog(100),
CmdWaitMap: make(map[base.CommandKey][]sstore.UpdatePacket),
}
allRemotes, err := sstore.GetAllRemotes(ctx)
if err != nil {
@ -308,7 +310,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
return state
}
func (msh *MShellProc) NotifyUpdate() {
func (msh *MShellProc) NotifyRemoteUpdate() {
rstate := msh.GetRemoteRuntimeState()
update := &sstore.ModelUpdate{Remotes: []interface{}{rstate}}
sstore.MainBus.SendUpdate("", update)
@ -382,7 +384,7 @@ func (msh *MShellProc) setErrorStatus(err error) {
defer msh.Lock.Unlock()
msh.Status = StatusError
msh.Err = err
go msh.NotifyUpdate()
go msh.NotifyRemoteUpdate()
}
func (msh *MShellProc) getRemoteCopy() sstore.RemoteType {
@ -458,7 +460,7 @@ func (msh *MShellProc) Launch() {
msh.WithLock(func() {
msh.ServerProc = cproc
msh.Status = StatusConnected
go msh.NotifyUpdate()
go msh.NotifyRemoteUpdate()
})
go func() {
exitErr := cproc.Cmd.Wait()
@ -466,7 +468,7 @@ func (msh *MShellProc) Launch() {
msh.WithLock(func() {
if msh.Status == StatusConnected {
msh.Status = StatusDisconnected
go msh.NotifyUpdate()
go msh.NotifyRemoteUpdate()
}
})
logf(&remoteCopy, "remote disconnected exitcode=%d", exitCode)
@ -546,39 +548,40 @@ func makeTermOpts(runPk *packet.RunPacketType) sstore.TermOpts {
return sstore.TermOpts{Rows: int64(runPk.TermOpts.Rows), Cols: int64(runPk.TermOpts.Cols), FlexRows: true, MaxPtySize: DefaultMaxPtySize}
}
func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrType, remoteState *sstore.RemoteState, runPacket *packet.RunPacketType) (*sstore.CmdType, error) {
// returns (cmdtype, allow-updates-callback, err)
func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrType, remoteState *sstore.RemoteState, runPacket *packet.RunPacketType) (*sstore.CmdType, func(), error) {
if remotePtr.OwnerId != "" {
return nil, fmt.Errorf("cannot run command against another user's remote '%s'", remotePtr.MakeFullRemoteRef())
return nil, nil, fmt.Errorf("cannot run command against another user's remote '%s'", remotePtr.MakeFullRemoteRef())
}
msh := GetRemoteById(remotePtr.RemoteId)
if msh == nil {
return nil, fmt.Errorf("no remote id=%s found", remotePtr.RemoteId)
return nil, nil, fmt.Errorf("no remote id=%s found", remotePtr.RemoteId)
}
if !msh.IsConnected() {
return nil, fmt.Errorf("remote '%s' is not connected", remotePtr.RemoteId)
return nil, nil, fmt.Errorf("remote '%s' is not connected", remotePtr.RemoteId)
}
if remoteState == nil {
return nil, fmt.Errorf("no remote state passed to RunCommand")
return nil, nil, fmt.Errorf("no remote state passed to RunCommand")
}
msh.ServerProc.Output.RegisterRpc(runPacket.ReqId)
err := shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket)
if err != nil {
return nil, fmt.Errorf("sending run packet to remote: %w", err)
return nil, nil, fmt.Errorf("sending run packet to remote: %w", err)
}
rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId)
if rtnPk == nil {
return nil, ctx.Err()
return nil, nil, ctx.Err()
}
startPk, ok := rtnPk.(*packet.CmdStartPacketType)
if !ok {
respPk, ok := rtnPk.(*packet.ResponsePacketType)
if !ok {
return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
return nil, nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
}
if respPk.Error != "" {
return nil, errors.New(respPk.Error)
return nil, nil, errors.New(respPk.Error)
}
return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
return nil, nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
}
status := sstore.CmdStatusRunning
if runPacket.Detached {
@ -598,10 +601,11 @@ func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrTyp
}
err = sstore.CreateCmdPtyFile(ctx, cmd.SessionId, cmd.CmdId, cmd.TermOpts.MaxPtySize)
if err != nil {
return nil, err
// TODO the cmd is running, so this is a tricky error to handle
return nil, nil, fmt.Errorf("cannot create local ptyout file for running command: %v", err)
}
msh.AddRunningCmd(startPk.CK)
return cmd, nil
return cmd, func() { removeCmdWait(startPk.CK) }, nil
}
func (msh *MShellProc) AddRunningCmd(ck base.CommandKey) {
@ -651,31 +655,6 @@ func makeDataAckPacket(ck base.CommandKey, fdNum int, ackLen int, err error) *pa
return ack
}
func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
update, err := sstore.UpdateCmdDonePk(context.Background(), donePk)
if err != nil {
fmt.Printf("[error] updating cmddone: %v\n", err)
return
}
if update != nil {
// TODO fix timing issue (this update gets to the FE before run-command returns for short lived commands)
go func() {
time.Sleep(10 * time.Millisecond)
sstore.MainBus.SendUpdate(donePk.CK.GetSessionId(), update)
}()
}
return
}
func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) {
err := sstore.AppendCmdErrorPk(context.Background(), errPk)
if err != nil {
fmt.Printf("[error] adding cmderr: %v\n", err)
return
}
return
}
func (msh *MShellProc) notifyHangups_nolock() {
for _, ck := range msh.RunningCmds {
cmd, err := sstore.GetCmdById(context.Background(), ck.GetSessionId(), ck.GetCmdId())
@ -688,6 +667,59 @@ func (msh *MShellProc) notifyHangups_nolock() {
msh.RunningCmds = nil
}
func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
update, err := sstore.UpdateCmdDonePk(context.Background(), donePk)
if err != nil {
fmt.Printf("[error] updating cmddone: %v\n", err)
return
}
if update != nil {
// TODO fix timing issue (this update gets to the FE before run-command returns for short lived commands)
go func() {
time.Sleep(10 * time.Millisecond)
sendCmdUpdate(donePk.CK, update)
}()
}
return
}
// TODO notify FE about cmd errors
func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) {
err := sstore.AppendCmdErrorPk(context.Background(), errPk)
if err != nil {
fmt.Printf("[error] adding cmderr: %v\n", err)
return
}
return
}
func (msh *MShellProc) handleDataPacket(dataPk *packet.DataPacketType, dataPosMap map[base.CommandKey]int64) {
realData, err := base64.StdEncoding.DecodeString(dataPk.Data64)
if err != nil {
ack := makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err)
msh.ServerProc.Input.SendPacket(ack)
return
}
var ack *packet.DataAckPacketType
if len(realData) > 0 {
dataPos := dataPosMap[dataPk.CK]
update, err := sstore.AppendToCmdPtyBlob(context.Background(), dataPk.CK.GetSessionId(), dataPk.CK.GetCmdId(), realData, dataPos)
if err != nil {
ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err)
} else {
ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, len(realData), nil)
}
dataPosMap[dataPk.CK] += int64(len(realData))
if update != nil {
sendCmdUpdate(dataPk.CK, update)
}
}
if ack != nil {
msh.ServerProc.Input.SendPacket(ack)
}
// fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPk.CK, dataPk.FdNum, len(realData), dataPk.Eof, dataPk.Error)
}
func (msh *MShellProc) ProcessPackets() {
defer msh.WithLock(func() {
if msh.Status == StatusConnected {
@ -698,33 +730,13 @@ func (msh *MShellProc) ProcessPackets() {
logf(msh.Remote, "calling HUP on cmds %v", err)
}
msh.notifyHangups_nolock()
go msh.NotifyUpdate()
go msh.NotifyRemoteUpdate()
})
dataPosMap := make(map[base.CommandKey]int64)
for pk := range msh.ServerProc.Output.MainCh {
if pk.GetType() == packet.DataPacketStr {
dataPk := pk.(*packet.DataPacketType)
realData, err := base64.StdEncoding.DecodeString(dataPk.Data64)
if err != nil {
ack := makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err)
msh.ServerProc.Input.SendPacket(ack)
continue
}
var ack *packet.DataAckPacketType
if len(realData) > 0 {
dataPos := dataPosMap[dataPk.CK]
err = sstore.AppendToCmdPtyBlob(context.Background(), dataPk.CK.GetSessionId(), dataPk.CK.GetCmdId(), realData, dataPos)
if err != nil {
ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err)
} else {
ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, len(realData), nil)
}
dataPosMap[dataPk.CK] += int64(len(realData))
}
if ack != nil {
msh.ServerProc.Input.SendPacket(ack)
}
// fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPk.CK, dataPk.FdNum, len(realData), dataPk.Eof, dataPk.Error)
msh.handleDataPacket(dataPk, dataPosMap)
continue
}
if pk.GetType() == packet.DataAckPacketStr {

66
pkg/remote/updatequeue.go Normal file
View File

@ -0,0 +1,66 @@
package remote
import (
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
)
func pushCmdWaitIfRequired(ck base.CommandKey, update sstore.UpdatePacket) bool {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
updates, ok := GlobalStore.CmdWaitMap[ck]
if !ok {
return false
}
updates = append(updates, update)
GlobalStore.CmdWaitMap[ck] = updates
return true
}
func sendCmdUpdate(ck base.CommandKey, update sstore.UpdatePacket) {
pushed := pushCmdWaitIfRequired(ck, update)
if pushed {
return
}
sstore.MainBus.SendUpdate(ck.GetSessionId(), update)
}
func runCmdWaitUpdates(ck base.CommandKey) {
for {
update := removeFirstCmdWaitUpdate(ck)
if update == nil {
break
}
sstore.MainBus.SendUpdate(ck.GetSessionId(), update)
}
}
func removeFirstCmdWaitUpdate(ck base.CommandKey) sstore.UpdatePacket {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
updates := GlobalStore.CmdWaitMap[ck]
if len(updates) == 0 {
delete(GlobalStore.CmdWaitMap, ck)
return nil
}
if len(updates) == 1 {
delete(GlobalStore.CmdWaitMap, ck)
return updates[0]
}
update := updates[0]
GlobalStore.CmdWaitMap[ck] = updates[1:]
return update
}
func removeCmdWait(ck base.CommandKey) {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
updates := GlobalStore.CmdWaitMap[ck]
if len(updates) == 0 {
delete(GlobalStore.CmdWaitMap, ck)
return
}
go runCmdWaitUpdates(ck)
}

View File

@ -21,22 +21,22 @@ func CreateCmdPtyFile(ctx context.Context, sessionId string, cmdId string, maxSi
return f.Close()
}
func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte, pos int64) error {
func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte, pos int64) (*PtyDataUpdate, error) {
if pos < 0 {
return fmt.Errorf("invalid seek pos '%d' in AppendToCmdPtyBlob", pos)
return nil, fmt.Errorf("invalid seek pos '%d' in AppendToCmdPtyBlob", pos)
}
ptyOutFileName, err := scbase.PtyOutFile(sessionId, cmdId)
if err != nil {
return err
return nil, err
}
f, err := cirfile.OpenCirFile(ptyOutFileName)
if err != nil {
return err
return nil, err
}
defer f.Close()
err = f.WriteAt(ctx, data, pos)
if err != nil {
return err
return nil, err
}
data64 := base64.StdEncoding.EncodeToString(data)
update := &PtyDataUpdate{
@ -46,8 +46,7 @@ func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, dat
PtyData64: data64,
PtyDataLen: int64(len(data)),
}
MainBus.SendUpdate(sessionId, update)
return nil
return update, nil
}
// returns (offset, data, err)