mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-31 18:18:02 +01:00
only allow one instance of sh2 to run at a time (flock). HUP all running processes when sh2 starts or remote connection ends
This commit is contained in:
parent
e4bf4b4ef8
commit
4cc55c46ca
@ -572,6 +572,12 @@ func runWebSocketServer() {
|
||||
}
|
||||
|
||||
func main() {
|
||||
scLock, err := scbase.AcquireSCLock()
|
||||
if err != nil || scLock == nil {
|
||||
fmt.Printf("[error] cannot acquire sh2 lock: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
|
||||
err := sstore.MigrateCommandOpts(os.Args[1:])
|
||||
if err != nil {
|
||||
@ -579,7 +585,7 @@ func main() {
|
||||
}
|
||||
return
|
||||
}
|
||||
err := sstore.TryMigrateUp()
|
||||
err = sstore.TryMigrateUp()
|
||||
if err != nil {
|
||||
fmt.Printf("[error] %v\n", err)
|
||||
return
|
||||
@ -601,7 +607,10 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
sstore.AppendToCmdPtyBlob(context.Background(), "", "", nil)
|
||||
err = sstore.HangupAllRunningCmds(context.Background())
|
||||
if err != nil {
|
||||
fmt.Printf("[error] calling HUP on all running commands\n")
|
||||
}
|
||||
|
||||
go runWebSocketServer()
|
||||
gr := mux.NewRouter()
|
||||
|
@ -152,6 +152,7 @@ func (msh *MShellProc) Launch() {
|
||||
msh.Status = StatusDisconnected
|
||||
}
|
||||
})
|
||||
|
||||
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
|
||||
}()
|
||||
go msh.ProcessPackets()
|
||||
@ -206,6 +207,10 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId str
|
||||
}
|
||||
return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
|
||||
}
|
||||
status := sstore.CmdStatusRunning
|
||||
if runPacket.Detached {
|
||||
status = sstore.CmdStatusDetached
|
||||
}
|
||||
cmd := &sstore.CmdType{
|
||||
SessionId: pk.SessionId,
|
||||
CmdId: startPk.CK.GetCmdId(),
|
||||
@ -213,7 +218,7 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId str
|
||||
RemoteId: msh.Remote.RemoteId,
|
||||
RemoteState: convertRemoteState(pk.RemoteState),
|
||||
TermOpts: makeTermOpts(),
|
||||
Status: "running",
|
||||
Status: status,
|
||||
StartPk: startPk,
|
||||
DonePk: nil,
|
||||
RunOut: nil,
|
||||
@ -262,11 +267,33 @@ func makeDataAckPacket(ck base.CommandKey, fdNum int, ackLen int, err error) *pa
|
||||
return ack
|
||||
}
|
||||
|
||||
func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
|
||||
err := sstore.UpdateCmdDonePk(context.Background(), donePk)
|
||||
if err != nil {
|
||||
fmt.Printf("[error] updating cmddone: %v\n", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) {
|
||||
err := sstore.AppendCmdErrorPk(context.Background(), errPk)
|
||||
if err != nil {
|
||||
fmt.Printf("[error] adding cmderr: %v\n", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (runner *MShellProc) ProcessPackets() {
|
||||
defer runner.WithLock(func() {
|
||||
if runner.Status == StatusConnected {
|
||||
runner.Status = StatusDisconnected
|
||||
}
|
||||
err := sstore.HangupRunningCmdsByRemoteId(context.Background(), runner.Remote.RemoteId)
|
||||
if err != nil {
|
||||
fmt.Printf("[error] calling HUP on remoteid=%d cmds\n", runner.Remote.RemoteId)
|
||||
}
|
||||
})
|
||||
for pk := range runner.ServerProc.Output.MainCh {
|
||||
if pk.GetType() == packet.DataPacketStr {
|
||||
@ -298,8 +325,11 @@ func (runner *MShellProc) ProcessPackets() {
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.CmdDonePacketStr {
|
||||
donePacket := pk.(*packet.CmdDonePacketType)
|
||||
fmt.Printf("cmd-done %s\n", donePacket.CK)
|
||||
runner.handleCmdDonePacket(pk.(*packet.CmdDonePacketType))
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.CmdErrorPacketStr {
|
||||
runner.handleCmdErrorPacket(pk.(*packet.CmdErrorPacketType))
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.MessagePacketStr {
|
||||
|
@ -7,12 +7,15 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const HomeVarName = "HOME"
|
||||
const ScHomeVarName = "SCRIPTHAUS_HOME"
|
||||
const SessionsDirBaseName = "sessions"
|
||||
const RemotesDirBaseName = "remotes"
|
||||
const SCLockFile = "sh2.lock"
|
||||
|
||||
var SessionDirCache = make(map[string]string)
|
||||
var BaseLock = &sync.Mutex{}
|
||||
@ -29,6 +32,21 @@ func GetScHomeDir() string {
|
||||
return scHome
|
||||
}
|
||||
|
||||
func AcquireSCLock() (*os.File, error) {
|
||||
homeDir := GetScHomeDir()
|
||||
lockFileName := path.Join(homeDir, SCLockFile)
|
||||
fd, err := os.Create(lockFileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = unix.Flock(int(fd.Fd()), unix.LOCK_EX|unix.LOCK_NB)
|
||||
if err != nil {
|
||||
fd.Close()
|
||||
return nil, err
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
func EnsureSessionDir(sessionId string) (string, error) {
|
||||
BaseLock.Lock()
|
||||
sdir, ok := SessionDirCache[sessionId]
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
)
|
||||
|
||||
func NumSessions(ctx context.Context) (int, error) {
|
||||
@ -225,3 +226,41 @@ func GetCmdById(ctx context.Context, sessionId string, cmdId string) (*CmdType,
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func UpdateCmdDonePk(ctx context.Context, donePk *packet.CmdDonePacketType) error {
|
||||
if donePk == nil || donePk.CK.IsEmpty() {
|
||||
return fmt.Errorf("invalid cmddone packet (no ck)")
|
||||
}
|
||||
return WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `UPDATE cmd SET status = ?, donepk = ? WHERE sessionid = ? AND cmdid = ?`
|
||||
tx.ExecWrap(query, CmdStatusDone, quickJson(donePk), donePk.CK.GetSessionId(), donePk.CK.GetCmdId())
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func AppendCmdErrorPk(ctx context.Context, errPk *packet.CmdErrorPacketType) error {
|
||||
if errPk == nil || errPk.CK.IsEmpty() {
|
||||
return fmt.Errorf("invalid cmderror packet (no ck)")
|
||||
}
|
||||
return WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `UPDATE cmd SET runout = json_insert(runout, '$[#]', ?) WHERE sessionid = ? AND cmdid = ?`
|
||||
tx.ExecWrap(query, quickJson(errPk), errPk.CK.GetSessionId(), errPk.CK.GetCmdId())
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func HangupAllRunningCmds(ctx context.Context) error {
|
||||
return WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `UPDATE cmd SET status = ? WHERE status = ?`
|
||||
tx.ExecWrap(query, CmdStatusHangup, CmdStatusRunning)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func HangupRunningCmdsByRemoteId(ctx context.Context, remoteId string) error {
|
||||
return WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `UPDATE cmd SET status = ? WHERE status = ? AND remoteid = ?`
|
||||
tx.ExecWrap(query, CmdStatusHangup, CmdStatusRunning, remoteId)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
@ -33,6 +33,12 @@ const LocalRemoteName = "local"
|
||||
|
||||
const DefaultCwd = "~"
|
||||
|
||||
const CmdStatusRunning = "running"
|
||||
const CmdStatusDetached = "detached"
|
||||
const CmdStatusError = "error"
|
||||
const CmdStatusDone = "done"
|
||||
const CmdStatusHangup = "hangup"
|
||||
|
||||
var globalDBLock = &sync.Mutex{}
|
||||
var globalDB *sqlx.DB
|
||||
var globalDBErr error
|
||||
|
Loading…
Reference in New Issue
Block a user