lots of changes related to ephemeral commands (for sync), checkpoint

This commit is contained in:
sawka 2024-03-12 16:36:10 -07:00
parent fde141fce3
commit a69ac24659
3 changed files with 116 additions and 87 deletions

View File

@ -510,22 +510,21 @@ func SyncCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.U
SessionId: ids.SessionId,
ScreenId: ids.ScreenId,
RemotePtr: ids.Remote.RemotePtr,
Ephemeral: true,
}
cmd, callback, err := remote.RunCommand(ctx, rcOpts, runPacket)
_, callback, err := remote.RunCommand(ctx, rcOpts, runPacket)
if callback != nil {
defer callback()
}
if err != nil {
return nil, err
}
cmd.RawCmdStr = pk.GetRawStr()
update, err := addLineForCmd(ctx, "/sync", true, ids, cmd, "terminal", nil)
if err != nil {
return nil, err
}
update.AddUpdate(sstore.InteractiveUpdate(pk.Interactive))
scbus.MainUpdateBus.DoScreenUpdate(ids.ScreenId, update)
return nil, nil
update := scbus.MakeUpdatePacket()
update.AddUpdate(sstore.InfoMsgType{
InfoMsg: "syncing state",
TimeoutMs: 2000,
})
return update, nil
}
func getRendererArg(pk *scpacket.FeCommandPacketType) (string, error) {
@ -1175,7 +1174,8 @@ func deferWriteCmdStatus(ctx context.Context, cmd *sstore.CmdType, startTime tim
donePk.Ts = time.Now().UnixMilli()
donePk.ExitCode = exitCode
donePk.DurationMs = duration.Milliseconds()
update, err := sstore.UpdateCmdDoneInfo(context.Background(), ck, donePk, cmdStatus)
update := scbus.MakeUpdatePacket()
err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus)
if err != nil {
// nothing to do
log.Printf("error updating cmddoneinfo (in openai): %v\n", err)
@ -2551,7 +2551,8 @@ func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt
donePk.Ts = time.Now().UnixMilli()
donePk.ExitCode = exitCode
donePk.DurationMs = duration.Milliseconds()
update, err := sstore.UpdateCmdDoneInfo(context.Background(), ck, donePk, cmdStatus)
update := scbus.MakeUpdatePacket()
err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus)
if err != nil {
// nothing to do
log.Printf("error updating cmddoneinfo (in openai): %v\n", err)
@ -2710,7 +2711,8 @@ func doOpenAIStreamCompletion(cmd *sstore.CmdType, clientId string, opts *sstore
donePk.Ts = time.Now().UnixMilli()
donePk.ExitCode = exitCode
donePk.DurationMs = duration.Milliseconds()
update, err := sstore.UpdateCmdDoneInfo(context.Background(), ck, donePk, cmdStatus)
update := scbus.MakeUpdatePacket()
err := sstore.UpdateCmdDoneInfo(context.Background(), update, ck, donePk, cmdStatus)
if err != nil {
// nothing to do
log.Printf("error updating cmddoneinfo (in openai): %v\n", err)

View File

@ -18,6 +18,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@ -167,12 +168,13 @@ type MShellProc struct {
}
type RunCmdType struct {
CK base.CommandKey
SessionId string
ScreenId string
RemotePtr sstore.RemotePtrType
RunPacket *packet.RunPacketType
Ephemeral bool
CK base.CommandKey
SessionId string
ScreenId string
RemotePtr sstore.RemotePtrType
RunPacket *packet.RunPacketType
Ephemeral bool
EphCancled atomic.Bool // only for Ephemeral commands, if true, then the command result should be discarded
}
type RemoteRuntimeState = sstore.RemoteRuntimeState
@ -1930,6 +1932,8 @@ func makeTermOpts(runPk *packet.RunPacketType) sstore.TermOpts {
}
// returns (ok, currentPSC)
// if ok is true, currentPSC will be nil
// if ok is false, currentPSC will be the existing pending state command (not nil)
func (msh *MShellProc) testAndSetPendingStateCmd(screenId string, rptr sstore.RemotePtrType, newCK *base.CommandKey) (bool, *base.CommandKey) {
key := pendingStateKey{ScreenId: screenId, RemotePtr: rptr}
msh.Lock.Lock()
@ -2028,14 +2032,14 @@ func RunCommand(ctx context.Context, rcOpts RunCommandOpts, runPacket *packet.Ru
}
ok, existingPSC := msh.testAndSetPendingStateCmd(screenId, remotePtr, newPSC)
if !ok {
line, _, err := sstore.GetLineCmdByLineId(ctx, screenId, existingPSC.GetCmdId())
if err != nil {
return nil, nil, fmt.Errorf("cannot run command while a stateful command is still running: %v", err)
rct := msh.GetRunningCmd(*existingPSC)
if rct.Ephemeral {
// if the existing command is ephemeral, we cancel it and continue
rct.EphCancled.Store(true)
} else {
line, _, err := sstore.GetLineCmdByLineId(ctx, screenId, existingPSC.GetCmdId())
return nil, nil, makePSCLineError(*existingPSC, line, err)
}
if line == nil {
return nil, nil, fmt.Errorf("cannot run command while a stateful command is still running %s", *existingPSC)
}
return nil, nil, fmt.Errorf("cannot run command while a stateful command (linenum=%d) is still running", line.LineNum)
}
if newPSC != nil {
defer func() {
@ -2127,7 +2131,7 @@ func RunCommand(ctx context.Context, rcOpts RunCommandOpts, runPacket *packet.Ru
RunOut: nil,
RtnState: runPacket.ReturnState,
}
if !rcOpts.NoCreateCmdPtyFile {
if !rcOpts.NoCreateCmdPtyFile && !rcOpts.Ephemeral {
err = sstore.CreateCmdPtyFile(ctx, cmd.ScreenId, cmd.LineId, cmd.TermOpts.MaxPtySize)
if err != nil {
// TODO the cmd is running, so this is a tricky error to handle
@ -2146,6 +2150,17 @@ func RunCommand(ctx context.Context, rcOpts RunCommandOpts, runPacket *packet.Ru
return cmd, func() { removeCmdWait(runPacket.CK) }, nil
}
// helper func to construct the proper error given what information we have
func makePSCLineError(existingPSC base.CommandKey, line *sstore.LineType, lineErr error) error {
if lineErr != nil {
return fmt.Errorf("cannot run command while a stateful command is still running: %v", lineErr)
}
if line == nil {
return fmt.Errorf("cannot run command while a stateful command is still running %s", existingPSC)
}
return fmt.Errorf("cannot run command while a stateful command (linenum=%d) is still running", line.LineNum)
}
func (msh *MShellProc) AddRunningCmd(rct *RunCmdType) {
msh.Lock.Lock()
defer msh.Lock.Unlock()
@ -2249,38 +2264,70 @@ func (msh *MShellProc) notifyHangups_nolock() {
msh.PendingStateCmds = make(map[pendingStateKey]base.CommandKey)
}
// either fullstate or statediff will be set (not both) <- this is so the result is compatible with the sstore.UpdateRemoteState function
// note that this function *does* touch the DB, if FinalStateDiff is set, will ensure that StateBase is written to DB
func (msh *MShellProc) makeStatePtrFromFinalState(ctx context.Context, donePk *packet.CmdDonePacketType) (*sstore.ShellStatePtr, map[string]string, *packet.ShellState, *packet.ShellStateDiff, error) {
if donePk.FinalState != nil {
finalState := stripScVarsFromState(donePk.FinalState)
feState := sstore.FeStateFromShellState(finalState)
statePtr := &sstore.ShellStatePtr{BaseHash: finalState.GetHashVal(false)}
return statePtr, feState, finalState, nil, nil
}
if donePk.FinalStateDiff != nil {
stateDiff := stripScVarsFromStateDiff(donePk.FinalStateDiff)
feState, err := msh.getFeStateFromDiff(stateDiff)
if err != nil {
return nil, nil, nil, nil, err
}
fullState := msh.StateMap.GetStateByHash(stateDiff.GetShellType(), stateDiff.BaseHash)
if fullState != nil {
sstore.StoreStateBase(ctx, fullState)
}
diffHashArr := append(([]string)(nil), donePk.FinalStateDiff.DiffHashArr...)
diffHashArr = append(diffHashArr, donePk.FinalStateDiff.GetHashVal(false))
statePtr := &sstore.ShellStatePtr{BaseHash: donePk.FinalStateDiff.BaseHash, DiffHashArr: diffHashArr}
return statePtr, feState, nil, stateDiff, nil
}
return nil, nil, nil, nil, nil
}
func (msh *MShellProc) handleCmdDonePacket(rct *RunCmdType, donePk *packet.CmdDonePacketType) {
if rct == nil {
log.Printf("cmddone packet received, but no running command found for it %q\n", donePk.CK)
return
}
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
// this will remove from RunningCmds and from PendingStateCmds
defer msh.RemoveRunningCmd(donePk.CK)
if donePk.FinalState != nil {
donePk.FinalState = stripScVarsFromState(donePk.FinalState)
}
if donePk.FinalStateDiff != nil {
donePk.FinalStateDiff = stripScVarsFromStateDiff(donePk.FinalStateDiff)
}
update, err := sstore.UpdateCmdDoneInfo(ctx, donePk.CK, donePk, sstore.CmdStatusDone)
if err != nil {
msh.WriteToPtyBuffer("*error updating cmddone: %v\n", err)
if rct.Ephemeral && rct.EphCancled.Load() {
// do nothing when an ephemeral command is canceled
return
}
screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, donePk.CK.GetGroupId(), donePk.CK.GetCmdId())
if err != nil {
msh.WriteToPtyBuffer("*error trying to update screen focus type: %v\n", err)
// fall-through (nothing to do)
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
update := scbus.MakeUpdatePacket()
if !rct.Ephemeral {
// only update DB for non-ephemeral commands
err := sstore.UpdateCmdDoneInfo(ctx, update, donePk.CK, donePk, sstore.CmdStatusDone)
if err != nil {
msh.WriteToPtyBuffer("*error updating cmddone: %v\n", err)
return
}
screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, donePk.CK.GetGroupId(), donePk.CK.GetCmdId())
if err != nil {
msh.WriteToPtyBuffer("*error trying to update screen focus type: %v\n", err)
// fall-through (nothing to do)
}
if screen != nil {
update.AddUpdate(*screen)
}
}
if screen != nil {
update.AddUpdate(*screen)
}
var statePtr *sstore.ShellStatePtr
if donePk.FinalState != nil {
feState := sstore.FeStateFromShellState(donePk.FinalState)
remoteInst, err := sstore.UpdateRemoteState(ctx, rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, donePk.FinalState, nil)
// ephemeral commands *do* update the remote state
if donePk.FinalState != nil || donePk.FinalStateDiff != nil {
statePtr, feState, finalState, finalStateDiff, err := msh.makeStatePtrFromFinalState(ctx, donePk)
if err != nil {
msh.WriteToPtyBuffer("*error trying to read final command state: %v\n", err)
}
remoteInst, err := sstore.UpdateRemoteState(ctx, rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, finalState, finalStateDiff)
if err != nil {
msh.WriteToPtyBuffer("*error trying to update remotestate: %v\n", err)
// fall-through (nothing to do)
@ -2288,36 +2335,13 @@ func (msh *MShellProc) handleCmdDonePacket(rct *RunCmdType, donePk *packet.CmdDo
if remoteInst != nil {
update.AddUpdate(sstore.MakeSessionUpdateForRemote(rct.SessionId, remoteInst))
}
statePtr = &sstore.ShellStatePtr{BaseHash: donePk.FinalState.GetHashVal(false)}
} else if donePk.FinalStateDiff != nil {
feState, err := msh.getFeStateFromDiff(donePk.FinalStateDiff)
if err != nil {
msh.WriteToPtyBuffer("*error trying to update remotestate: %v\n", err)
// fall-through (nothing to do)
} else {
stateDiff := donePk.FinalStateDiff
fullState := msh.StateMap.GetStateByHash(stateDiff.GetShellType(), stateDiff.BaseHash)
if fullState != nil {
sstore.StoreStateBase(ctx, fullState)
}
remoteInst, err := sstore.UpdateRemoteState(ctx, rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, nil, stateDiff)
// ephemeral commands *do not* update cmd state (there is no command)
if statePtr != nil && !rct.Ephemeral {
err = sstore.UpdateCmdRtnState(ctx, donePk.CK, *statePtr)
if err != nil {
msh.WriteToPtyBuffer("*error trying to update remotestate: %v\n", err)
msh.WriteToPtyBuffer("*error trying to update cmd rtnstate: %v\n", err)
// fall-through (nothing to do)
}
if remoteInst != nil {
update.AddUpdate(sstore.MakeSessionUpdateForRemote(rct.SessionId, remoteInst))
}
diffHashArr := append(([]string)(nil), donePk.FinalStateDiff.DiffHashArr...)
diffHashArr = append(diffHashArr, donePk.FinalStateDiff.GetHashVal(false))
statePtr = &sstore.ShellStatePtr{BaseHash: donePk.FinalStateDiff.BaseHash, DiffHashArr: diffHashArr}
}
}
if statePtr != nil {
err = sstore.UpdateCmdRtnState(ctx, donePk.CK, *statePtr)
if err != nil {
msh.WriteToPtyBuffer("*error trying to update cmd rtnstate: %v\n", err)
// fall-through (nothing to do)
}
}
scbus.MainUpdateBus.DoUpdate(update)
@ -2329,6 +2353,10 @@ func (msh *MShellProc) handleCmdFinalPacket(rct *RunCmdType, finalPk *packet.Cmd
return
}
defer msh.RemoveRunningCmd(finalPk.CK)
if rct.Ephemeral {
// just remove the running command, but there is no DB state to update in this case
return
}
rtnCmd, err := sstore.GetCmdByScreenId(context.Background(), finalPk.CK.GetGroupId(), finalPk.CK.GetCmdId())
if err != nil {
log.Printf("error calling GetCmdById in handleCmdFinalPacket: %v\n", err)
@ -2377,6 +2405,11 @@ func (msh *MShellProc) handleDataPacket(rct *RunCmdType, dataPk *packet.DataPack
msh.ServerProc.Input.SendPacket(ack)
return
}
if rct.Ephemeral {
ack := makeDataAckPacket(dataPk.CK, dataPk.FdNum, len(realData), nil)
msh.ServerProc.Input.SendPacket(ack)
return
}
var ack *packet.DataAckPacketType
if len(realData) > 0 {
dataPos := dataPosMap.Get(dataPk.CK)
@ -2394,7 +2427,6 @@ func (msh *MShellProc) handleDataPacket(rct *RunCmdType, dataPk *packet.DataPack
if ack != nil {
msh.ServerProc.Input.SendPacket(ack)
}
// log.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPk.CK, dataPk.FdNum, len(realData), dataPk.Eof, dataPk.Error)
}
func sendScreenUpdates(screens []*sstore.ScreenType) {

View File

@ -916,12 +916,12 @@ func UpdateCmdForRestart(ctx context.Context, ck base.CommandKey, ts int64, cmdP
})
}
func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, donePk *packet.CmdDonePacketType, status string) (*scbus.ModelUpdatePacketType, error) {
func UpdateCmdDoneInfo(ctx context.Context, update *scbus.ModelUpdatePacketType, ck base.CommandKey, donePk *packet.CmdDonePacketType, status string) error {
if donePk == nil {
return nil, fmt.Errorf("invalid cmddone packet")
return fmt.Errorf("invalid cmddone packet")
}
if ck.IsEmpty() {
return nil, fmt.Errorf("cannot update cmddoneinfo, empty ck")
return fmt.Errorf("cannot update cmddoneinfo, empty ck")
}
screenId := ck.GetGroupId()
var rtnCmd *CmdType
@ -944,15 +944,12 @@ func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, donePk *packet.C
return nil
})
if txErr != nil {
return nil, txErr
return txErr
}
if rtnCmd == nil {
return nil, fmt.Errorf("cmd data not found for ck[%s]", ck)
return fmt.Errorf("cmd data not found for ck[%s]", ck)
}
update := scbus.MakeUpdatePacket()
update.AddUpdate(*rtnCmd)
// Update in-memory screen indicator status
var indicator StatusIndicatorLevel
if rtnCmd.ExitCode == 0 {
@ -960,15 +957,13 @@ func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, donePk *packet.C
} else {
indicator = StatusIndicatorLevel_Error
}
err := SetStatusIndicatorLevel_Update(ctx, update, screenId, indicator, false)
if err != nil {
// This is not a fatal error, so just log it
log.Printf("error setting status indicator level after done packet: %v\n", err)
}
IncrementNumRunningCmds_Update(update, screenId, -1)
return update, nil
return nil
}
func UpdateCmdRtnState(ctx context.Context, ck base.CommandKey, statePtr ShellStatePtr) error {