diff --git a/cmd/main-server.go b/cmd/main-server.go index 5d2316223..1656ce01f 100644 --- a/cmd/main-server.go +++ b/cmd/main-server.go @@ -572,6 +572,12 @@ func runWebSocketServer() { } func main() { + scLock, err := scbase.AcquireSCLock() + if err != nil || scLock == nil { + fmt.Printf("[error] cannot acquire sh2 lock: %v\n", err) + return + } + if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") { err := sstore.MigrateCommandOpts(os.Args[1:]) if err != nil { @@ -579,7 +585,7 @@ func main() { } return } - err := sstore.TryMigrateUp() + err = sstore.TryMigrateUp() if err != nil { fmt.Printf("[error] %v\n", err) return @@ -601,7 +607,10 @@ func main() { return } - sstore.AppendToCmdPtyBlob(context.Background(), "", "", nil) + err = sstore.HangupAllRunningCmds(context.Background()) + if err != nil { + fmt.Printf("[error] calling HUP on all running commands\n") + } go runWebSocketServer() gr := mux.NewRouter() diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index b954b01f9..bb7726cdc 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -152,6 +152,7 @@ func (msh *MShellProc) Launch() { msh.Status = StatusDisconnected } }) + fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode) }() go msh.ProcessPackets() @@ -206,6 +207,10 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId str } return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) } + status := sstore.CmdStatusRunning + if runPacket.Detached { + status = sstore.CmdStatusDetached + } cmd := &sstore.CmdType{ SessionId: pk.SessionId, CmdId: startPk.CK.GetCmdId(), @@ -213,7 +218,7 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId str RemoteId: msh.Remote.RemoteId, RemoteState: convertRemoteState(pk.RemoteState), TermOpts: makeTermOpts(), - Status: "running", + Status: status, StartPk: startPk, DonePk: nil, RunOut: nil, @@ -262,11 +267,33 @@ func makeDataAckPacket(ck base.CommandKey, fdNum int, ackLen int, err error) *pa return ack } +func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) { + err := sstore.UpdateCmdDonePk(context.Background(), donePk) + if err != nil { + fmt.Printf("[error] updating cmddone: %v\n", err) + return + } + return +} + +func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) { + err := sstore.AppendCmdErrorPk(context.Background(), errPk) + if err != nil { + fmt.Printf("[error] adding cmderr: %v\n", err) + return + } + return +} + func (runner *MShellProc) ProcessPackets() { defer runner.WithLock(func() { if runner.Status == StatusConnected { runner.Status = StatusDisconnected } + err := sstore.HangupRunningCmdsByRemoteId(context.Background(), runner.Remote.RemoteId) + if err != nil { + fmt.Printf("[error] calling HUP on remoteid=%d cmds\n", runner.Remote.RemoteId) + } }) for pk := range runner.ServerProc.Output.MainCh { if pk.GetType() == packet.DataPacketStr { @@ -298,8 +325,11 @@ func (runner *MShellProc) ProcessPackets() { continue } if pk.GetType() == packet.CmdDonePacketStr { - donePacket := pk.(*packet.CmdDonePacketType) - fmt.Printf("cmd-done %s\n", donePacket.CK) + runner.handleCmdDonePacket(pk.(*packet.CmdDonePacketType)) + continue + } + if pk.GetType() == packet.CmdErrorPacketStr { + runner.handleCmdErrorPacket(pk.(*packet.CmdErrorPacketType)) continue } if pk.GetType() == packet.MessagePacketStr { diff --git a/pkg/scbase/scbase.go b/pkg/scbase/scbase.go index 27e147062..8578ed328 100644 --- a/pkg/scbase/scbase.go +++ b/pkg/scbase/scbase.go @@ -7,12 +7,15 @@ import ( "os" "path" "sync" + + "golang.org/x/sys/unix" ) const HomeVarName = "HOME" const ScHomeVarName = "SCRIPTHAUS_HOME" const SessionsDirBaseName = "sessions" const RemotesDirBaseName = "remotes" +const SCLockFile = "sh2.lock" var SessionDirCache = make(map[string]string) var BaseLock = &sync.Mutex{} @@ -29,6 +32,21 @@ func GetScHomeDir() string { return scHome } +func AcquireSCLock() (*os.File, error) { + homeDir := GetScHomeDir() + lockFileName := path.Join(homeDir, SCLockFile) + fd, err := os.Create(lockFileName) + if err != nil { + return nil, err + } + err = unix.Flock(int(fd.Fd()), unix.LOCK_EX|unix.LOCK_NB) + if err != nil { + fd.Close() + return nil, err + } + return fd, nil +} + func EnsureSessionDir(sessionId string) (string, error) { BaseLock.Lock() sdir, ok := SessionDirCache[sessionId] diff --git a/pkg/sstore/dbops.go b/pkg/sstore/dbops.go index 108cf39d1..f957a3cd4 100644 --- a/pkg/sstore/dbops.go +++ b/pkg/sstore/dbops.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/google/uuid" + "github.com/scripthaus-dev/mshell/pkg/packet" ) func NumSessions(ctx context.Context) (int, error) { @@ -225,3 +226,41 @@ func GetCmdById(ctx context.Context, sessionId string, cmdId string) (*CmdType, } return cmd, nil } + +func UpdateCmdDonePk(ctx context.Context, donePk *packet.CmdDonePacketType) error { + if donePk == nil || donePk.CK.IsEmpty() { + return fmt.Errorf("invalid cmddone packet (no ck)") + } + return WithTx(ctx, func(tx *TxWrap) error { + query := `UPDATE cmd SET status = ?, donepk = ? WHERE sessionid = ? AND cmdid = ?` + tx.ExecWrap(query, CmdStatusDone, quickJson(donePk), donePk.CK.GetSessionId(), donePk.CK.GetCmdId()) + return nil + }) +} + +func AppendCmdErrorPk(ctx context.Context, errPk *packet.CmdErrorPacketType) error { + if errPk == nil || errPk.CK.IsEmpty() { + return fmt.Errorf("invalid cmderror packet (no ck)") + } + return WithTx(ctx, func(tx *TxWrap) error { + query := `UPDATE cmd SET runout = json_insert(runout, '$[#]', ?) WHERE sessionid = ? AND cmdid = ?` + tx.ExecWrap(query, quickJson(errPk), errPk.CK.GetSessionId(), errPk.CK.GetCmdId()) + return nil + }) +} + +func HangupAllRunningCmds(ctx context.Context) error { + return WithTx(ctx, func(tx *TxWrap) error { + query := `UPDATE cmd SET status = ? WHERE status = ?` + tx.ExecWrap(query, CmdStatusHangup, CmdStatusRunning) + return nil + }) +} + +func HangupRunningCmdsByRemoteId(ctx context.Context, remoteId string) error { + return WithTx(ctx, func(tx *TxWrap) error { + query := `UPDATE cmd SET status = ? WHERE status = ? AND remoteid = ?` + tx.ExecWrap(query, CmdStatusHangup, CmdStatusRunning, remoteId) + return nil + }) +} diff --git a/pkg/sstore/sstore.go b/pkg/sstore/sstore.go index bca56f4f5..c867c8e4d 100644 --- a/pkg/sstore/sstore.go +++ b/pkg/sstore/sstore.go @@ -33,6 +33,12 @@ const LocalRemoteName = "local" const DefaultCwd = "~" +const CmdStatusRunning = "running" +const CmdStatusDetached = "detached" +const CmdStatusError = "error" +const CmdStatusDone = "done" +const CmdStatusHangup = "hangup" + var globalDBLock = &sync.Mutex{} var globalDB *sqlx.DB var globalDBErr error