only allow one instance of sh2 to run at a time (flock). HUP all running processes when sh2 starts or remote connection ends

This commit is contained in:
sawka 2022-07-07 16:29:14 -07:00
parent e4bf4b4ef8
commit 4cc55c46ca
5 changed files with 107 additions and 5 deletions

View File

@ -572,6 +572,12 @@ func runWebSocketServer() {
}
func main() {
scLock, err := scbase.AcquireSCLock()
if err != nil || scLock == nil {
fmt.Printf("[error] cannot acquire sh2 lock: %v\n", err)
return
}
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
err := sstore.MigrateCommandOpts(os.Args[1:])
if err != nil {
@ -579,7 +585,7 @@ func main() {
}
return
}
err := sstore.TryMigrateUp()
err = sstore.TryMigrateUp()
if err != nil {
fmt.Printf("[error] %v\n", err)
return
@ -601,7 +607,10 @@ func main() {
return
}
sstore.AppendToCmdPtyBlob(context.Background(), "", "", nil)
err = sstore.HangupAllRunningCmds(context.Background())
if err != nil {
fmt.Printf("[error] calling HUP on all running commands\n")
}
go runWebSocketServer()
gr := mux.NewRouter()

View File

@ -152,6 +152,7 @@ func (msh *MShellProc) Launch() {
msh.Status = StatusDisconnected
}
})
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
}()
go msh.ProcessPackets()
@ -206,6 +207,10 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId str
}
return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
}
status := sstore.CmdStatusRunning
if runPacket.Detached {
status = sstore.CmdStatusDetached
}
cmd := &sstore.CmdType{
SessionId: pk.SessionId,
CmdId: startPk.CK.GetCmdId(),
@ -213,7 +218,7 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId str
RemoteId: msh.Remote.RemoteId,
RemoteState: convertRemoteState(pk.RemoteState),
TermOpts: makeTermOpts(),
Status: "running",
Status: status,
StartPk: startPk,
DonePk: nil,
RunOut: nil,
@ -262,11 +267,33 @@ func makeDataAckPacket(ck base.CommandKey, fdNum int, ackLen int, err error) *pa
return ack
}
func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
err := sstore.UpdateCmdDonePk(context.Background(), donePk)
if err != nil {
fmt.Printf("[error] updating cmddone: %v\n", err)
return
}
return
}
func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) {
err := sstore.AppendCmdErrorPk(context.Background(), errPk)
if err != nil {
fmt.Printf("[error] adding cmderr: %v\n", err)
return
}
return
}
func (runner *MShellProc) ProcessPackets() {
defer runner.WithLock(func() {
if runner.Status == StatusConnected {
runner.Status = StatusDisconnected
}
err := sstore.HangupRunningCmdsByRemoteId(context.Background(), runner.Remote.RemoteId)
if err != nil {
fmt.Printf("[error] calling HUP on remoteid=%d cmds\n", runner.Remote.RemoteId)
}
})
for pk := range runner.ServerProc.Output.MainCh {
if pk.GetType() == packet.DataPacketStr {
@ -298,8 +325,11 @@ func (runner *MShellProc) ProcessPackets() {
continue
}
if pk.GetType() == packet.CmdDonePacketStr {
donePacket := pk.(*packet.CmdDonePacketType)
fmt.Printf("cmd-done %s\n", donePacket.CK)
runner.handleCmdDonePacket(pk.(*packet.CmdDonePacketType))
continue
}
if pk.GetType() == packet.CmdErrorPacketStr {
runner.handleCmdErrorPacket(pk.(*packet.CmdErrorPacketType))
continue
}
if pk.GetType() == packet.MessagePacketStr {

View File

@ -7,12 +7,15 @@ import (
"os"
"path"
"sync"
"golang.org/x/sys/unix"
)
const HomeVarName = "HOME"
const ScHomeVarName = "SCRIPTHAUS_HOME"
const SessionsDirBaseName = "sessions"
const RemotesDirBaseName = "remotes"
const SCLockFile = "sh2.lock"
var SessionDirCache = make(map[string]string)
var BaseLock = &sync.Mutex{}
@ -29,6 +32,21 @@ func GetScHomeDir() string {
return scHome
}
func AcquireSCLock() (*os.File, error) {
homeDir := GetScHomeDir()
lockFileName := path.Join(homeDir, SCLockFile)
fd, err := os.Create(lockFileName)
if err != nil {
return nil, err
}
err = unix.Flock(int(fd.Fd()), unix.LOCK_EX|unix.LOCK_NB)
if err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
func EnsureSessionDir(sessionId string) (string, error) {
BaseLock.Lock()
sdir, ok := SessionDirCache[sessionId]

View File

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
func NumSessions(ctx context.Context) (int, error) {
@ -225,3 +226,41 @@ func GetCmdById(ctx context.Context, sessionId string, cmdId string) (*CmdType,
}
return cmd, nil
}
func UpdateCmdDonePk(ctx context.Context, donePk *packet.CmdDonePacketType) error {
if donePk == nil || donePk.CK.IsEmpty() {
return fmt.Errorf("invalid cmddone packet (no ck)")
}
return WithTx(ctx, func(tx *TxWrap) error {
query := `UPDATE cmd SET status = ?, donepk = ? WHERE sessionid = ? AND cmdid = ?`
tx.ExecWrap(query, CmdStatusDone, quickJson(donePk), donePk.CK.GetSessionId(), donePk.CK.GetCmdId())
return nil
})
}
func AppendCmdErrorPk(ctx context.Context, errPk *packet.CmdErrorPacketType) error {
if errPk == nil || errPk.CK.IsEmpty() {
return fmt.Errorf("invalid cmderror packet (no ck)")
}
return WithTx(ctx, func(tx *TxWrap) error {
query := `UPDATE cmd SET runout = json_insert(runout, '$[#]', ?) WHERE sessionid = ? AND cmdid = ?`
tx.ExecWrap(query, quickJson(errPk), errPk.CK.GetSessionId(), errPk.CK.GetCmdId())
return nil
})
}
func HangupAllRunningCmds(ctx context.Context) error {
return WithTx(ctx, func(tx *TxWrap) error {
query := `UPDATE cmd SET status = ? WHERE status = ?`
tx.ExecWrap(query, CmdStatusHangup, CmdStatusRunning)
return nil
})
}
func HangupRunningCmdsByRemoteId(ctx context.Context, remoteId string) error {
return WithTx(ctx, func(tx *TxWrap) error {
query := `UPDATE cmd SET status = ? WHERE status = ? AND remoteid = ?`
tx.ExecWrap(query, CmdStatusHangup, CmdStatusRunning, remoteId)
return nil
})
}

View File

@ -33,6 +33,12 @@ const LocalRemoteName = "local"
const DefaultCwd = "~"
const CmdStatusRunning = "running"
const CmdStatusDetached = "detached"
const CmdStatusError = "error"
const CmdStatusDone = "done"
const CmdStatusHangup = "hangup"
var globalDBLock = &sync.Mutex{}
var globalDB *sqlx.DB
var globalDBErr error