stateful commands block other commands while they are running, introduce waiting state

This commit is contained in:
sawka 2022-10-27 22:00:10 -07:00
parent 56259e3f05
commit 2df33621fd
4 changed files with 102 additions and 16 deletions

View File

@ -240,19 +240,18 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.U
if err != nil {
return nil, fmt.Errorf("/run error: %w", err)
}
cmdId := scbase.GenSCUUID()
cmdStr := firstArg(pk)
isRtnStateCmd := IsReturnStateCommand(cmdStr)
runPacket := packet.MakeRunPacket()
runPacket.ReqId = uuid.New().String()
runPacket.CK = base.MakeCommandKey(ids.SessionId, cmdId)
runPacket.CK = base.MakeCommandKey(ids.SessionId, scbase.GenSCUUID())
runPacket.State = ids.Remote.RemoteState
runPacket.StateComplete = true
runPacket.UsePty = true
runPacket.TermOpts = getUITermOpts(pk.UIContext)
runPacket.Command = strings.TrimSpace(cmdStr)
runPacket.ReturnState = resolveBool(pk.Kwargs["rtnstate"], isRtnStateCmd)
cmd, callback, err := remote.RunCommand(ctx, cmdId, ids.Remote.RemotePtr, ids.Remote.RemoteState, runPacket)
cmd, callback, err := remote.RunCommand(ctx, ids.SessionId, ids.WindowId, ids.Remote.RemotePtr, runPacket)
if callback != nil {
defer callback()
}
@ -1619,7 +1618,7 @@ func LineShowCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sst
if lineId == "" {
return nil, fmt.Errorf("line %q not found", lineArg)
}
line, cmd, err := sstore.GetLineCmd(ctx, ids.SessionId, ids.WindowId, lineId)
line, cmd, err := sstore.GetLineCmdByLineId(ctx, ids.SessionId, ids.WindowId, lineId)
if err != nil {
return nil, fmt.Errorf("error getting line: %v", err)
}

View File

@ -94,7 +94,14 @@ type MShellProc struct {
InstallCancelFn context.CancelFunc
InstallErr error
RunningCmds map[base.CommandKey]bool
RunningCmds map[base.CommandKey]bool
WaitingCmds []RunCmdType
PendingStateCmds map[string]base.CommandKey // key=[remoteinstance name]
}
type RunCmdType struct {
RemotePtr sstore.RemotePtrType
RunPacket *packet.RunPacketType
}
type RemoteRuntimeState struct {
@ -530,12 +537,13 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
panic(err) // this should never happen (NewBuffer only returns an error if CirBufSize <= 0)
}
rtn := &MShellProc{
Lock: &sync.Mutex{},
Remote: r,
Status: StatusDisconnected,
PtyBuffer: buf,
InstallStatus: StatusDisconnected,
RunningCmds: make(map[base.CommandKey]bool),
Lock: &sync.Mutex{},
Remote: r,
Status: StatusDisconnected,
PtyBuffer: buf,
InstallStatus: StatusDisconnected,
RunningCmds: make(map[base.CommandKey]bool),
PendingStateCmds: make(map[string]base.CommandKey),
}
rtn.WriteToPtyBuffer("console for remote [%s]\n", r.GetName())
return rtn
@ -1076,11 +1084,40 @@ func makeTermOpts(runPk *packet.RunPacketType) sstore.TermOpts {
return sstore.TermOpts{Rows: int64(runPk.TermOpts.Rows), Cols: int64(runPk.TermOpts.Cols), FlexRows: true, MaxPtySize: DefaultMaxPtySize}
}
// returns (ok, currentPSC)
func (msh *MShellProc) testAndSetPendingStateCmd(name string, newCK *base.CommandKey) (bool, *base.CommandKey) {
msh.Lock.Lock()
defer msh.Lock.Unlock()
ck, found := msh.PendingStateCmds[name]
if found {
return false, &ck
}
if newCK != nil {
msh.PendingStateCmds[name] = *newCK
}
return true, nil
}
func (msh *MShellProc) removePendingStateCmd(name string, ck base.CommandKey) {
msh.Lock.Lock()
defer msh.Lock.Unlock()
existingCK, found := msh.PendingStateCmds[name]
if !found {
return
}
if existingCK == ck {
delete(msh.PendingStateCmds, name)
}
}
// returns (cmdtype, allow-updates-callback, err)
func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrType, remoteState *packet.ShellState, runPacket *packet.RunPacketType) (rtnCmd *sstore.CmdType, rtnCallback func(), rtnErr error) {
func RunCommand(ctx context.Context, sessionId string, windowId string, remotePtr sstore.RemotePtrType, runPacket *packet.RunPacketType) (rtnCmd *sstore.CmdType, rtnCallback func(), rtnErr error) {
if remotePtr.OwnerId != "" {
return nil, nil, fmt.Errorf("cannot run command against another user's remote '%s'", remotePtr.MakeFullRemoteRef())
}
if sessionId != runPacket.CK.GetSessionId() {
return nil, nil, fmt.Errorf("run commands sessionids do not match")
}
msh := GetRemoteById(remotePtr.RemoteId)
if msh == nil {
return nil, nil, fmt.Errorf("no remote id=%s found", remotePtr.RemoteId)
@ -1088,9 +1125,24 @@ func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrTyp
if !msh.IsConnected() {
return nil, nil, fmt.Errorf("remote '%s' is not connected", remotePtr.RemoteId)
}
if remoteState == nil {
if runPacket.State == nil {
return nil, nil, fmt.Errorf("no remote state passed to RunCommand")
}
var newPSC *base.CommandKey
if runPacket.ReturnState {
newPSC = &runPacket.CK
}
ok, existingPSC := msh.testAndSetPendingStateCmd(remotePtr.Name, newPSC)
if !ok {
line, _, err := sstore.GetLineCmdByCmdId(ctx, sessionId, windowId, existingPSC.GetCmdId())
if err != nil {
return nil, nil, fmt.Errorf("cannot run command while a stateful command is still running: %v", err)
}
if line == nil {
return nil, nil, fmt.Errorf("cannot run command while a stateful command is still running %s", *existingPSC, windowId)
}
return nil, nil, fmt.Errorf("cannot run command while a stateful command (linenum=%d) is still running", line.LineNum)
}
callbackFn := func() {
removeCmdWait(runPacket.CK)
}
@ -1098,6 +1150,9 @@ func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrTyp
defer func() {
if rtnErr != nil {
callbackFn()
if newPSC != nil {
msh.removePendingStateCmd(remotePtr.Name, *newPSC)
}
}
}()
msh.ServerProc.Output.RegisterRpc(runPacket.ReqId)
@ -1126,10 +1181,10 @@ func RunCommand(ctx context.Context, cmdId string, remotePtr sstore.RemotePtrTyp
}
cmd := &sstore.CmdType{
SessionId: runPacket.CK.GetSessionId(),
CmdId: startPk.CK.GetCmdId(),
CmdId: runPacket.CK.GetCmdId(),
CmdStr: runPacket.Command,
Remote: remotePtr,
RemoteState: *remoteState,
RemoteState: *runPacket.State,
TermOpts: makeTermOpts(runPacket),
Status: status,
StartPk: startPk,
@ -1156,6 +1211,11 @@ func (msh *MShellProc) RemoveRunningCmd(ck base.CommandKey) {
msh.Lock.Lock()
defer msh.Lock.Unlock()
delete(msh.RunningCmds, ck)
for name, pendingCk := range msh.PendingStateCmds {
if pendingCk == ck {
delete(msh.PendingStateCmds, name)
}
}
}
func (msh *MShellProc) PacketRpcRaw(ctx context.Context, pk packet.RpcPacketType) (packet.RpcResponsePacketType, error) {

View File

@ -577,7 +577,7 @@ func FindLineIdByArg(ctx context.Context, sessionId string, windowId string, lin
return lineId, nil
}
func GetLineCmd(ctx context.Context, sessionId string, windowId string, lineId string) (*LineType, *CmdType, error) {
func GetLineCmdByLineId(ctx context.Context, sessionId string, windowId string, lineId string) (*LineType, *CmdType, error) {
var lineRtn *LineType
var cmdRtn *CmdType
txErr := WithTx(ctx, func(tx *TxWrap) error {
@ -605,6 +605,32 @@ func GetLineCmd(ctx context.Context, sessionId string, windowId string, lineId s
return lineRtn, cmdRtn, nil
}
func GetLineCmdByCmdId(ctx context.Context, sessionId string, windowId string, cmdId string) (*LineType, *CmdType, error) {
var lineRtn *LineType
var cmdRtn *CmdType
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT windowid FROM window WHERE sessionid = ? AND windowid = ?`
if !tx.Exists(query, sessionId, windowId) {
return fmt.Errorf("window not found")
}
var lineVal LineType
query = `SELECT * FROM line WHERE sessionid = ? AND windowid = ? AND cmdid = ?`
found := tx.GetWrap(&lineVal, query, sessionId, windowId, cmdId)
if !found {
return nil
}
lineRtn = &lineVal
query = `SELECT * FROM cmd WHERE sessionid = ? AND cmdid = ?`
m := tx.GetMap(query, sessionId, cmdId)
cmdRtn = CmdFromMap(m)
return nil
})
if txErr != nil {
return nil, nil, txErr
}
return lineRtn, cmdRtn, nil
}
func InsertLine(ctx context.Context, line *LineType, cmd *CmdType) error {
if line == nil {
return fmt.Errorf("line cannot be nil")

View File

@ -44,6 +44,7 @@ const (
CmdStatusError = "error"
CmdStatusDone = "done"
CmdStatusHangup = "hangup"
CmdStatusWaiting = "waiting"
)
const (