diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 6583710a7..eb151f8a8 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log" "os" "os/exec" "path" @@ -1401,10 +1402,38 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) { } } } + if donePk.FinalStateDiff != nil { + fmt.Printf("** final state diff! %v\n", donePk.FinalStateDiff) + } sstore.MainBus.SendUpdate(donePk.CK.GetSessionId(), update) return } +func (msh *MShellProc) handleCmdFinalPacket(finalPk *packet.CmdFinalPacketType) { + defer msh.RemoveRunningCmd(finalPk.CK) + rtnCmd, err := sstore.GetCmdById(context.Background(), finalPk.CK.GetSessionId(), finalPk.CK.GetCmdId()) + if err != nil { + log.Printf("error calling GetCmdById in handleCmdFinalPacket: %v\n", err) + return + } + if rtnCmd == nil || rtnCmd.DonePk != nil { + return + } + log.Printf("finalpk %s (hangup): %s\n", finalPk.CK, finalPk.Error) + sstore.HangupCmd(context.Background(), finalPk.CK) + rtnCmd, err = sstore.GetCmdById(context.Background(), finalPk.CK.GetSessionId(), finalPk.CK.GetCmdId()) + if err != nil { + log.Printf("error getting cmd(2) in handleCmdFinalPacket: %v\n", err) + return + } + if rtnCmd == nil { + log.Printf("error getting cmd(2) in handleCmdFinalPacket (not found)\n") + return + } + update := &sstore.ModelUpdate{Cmd: rtnCmd} + sstore.MainBus.SendUpdate(finalPk.CK.GetSessionId(), update) +} + // TODO notify FE about cmd errors func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) { err := sstore.AppendCmdErrorPk(context.Background(), errPk) @@ -1454,6 +1483,12 @@ func (msh *MShellProc) makeHandleCmdDonePacketClosure(donePk *packet.CmdDonePack } } +func (msh *MShellProc) makeHandleCmdFinalPacketClosure(finalPk *packet.CmdFinalPacketType) func() { + return func() { + msh.handleCmdFinalPacket(finalPk) + } +} + func (msh *MShellProc) ProcessPackets() { defer msh.WithLock(func() { if msh.Status == StatusConnected { @@ -1488,6 +1523,11 @@ func (msh *MShellProc) ProcessPackets() { runCmdUpdateFn(donePk.CK, msh.makeHandleCmdDonePacketClosure(donePk)) continue } + if pk.GetType() == packet.CmdFinalPacketStr { + finalPk := pk.(*packet.CmdFinalPacketType) + runCmdUpdateFn(finalPk.CK, msh.makeHandleCmdFinalPacketClosure(finalPk)) + continue + } if pk.GetType() == packet.CmdErrorPacketStr { msh.handleCmdErrorPacket(pk.(*packet.CmdErrorPacketType)) continue diff --git a/pkg/sstore/dbops.go b/pkg/sstore/dbops.go index 2958d9156..1544af636 100644 --- a/pkg/sstore/dbops.go +++ b/pkg/sstore/dbops.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/sh2-server/pkg/scbase" ) @@ -681,6 +682,18 @@ func GetCmdById(ctx context.Context, sessionId string, cmdId string) (*CmdType, return cmd, nil } +func HasDonePk(ctx context.Context, ck base.CommandKey) (bool, error) { + var found bool + txErr := WithTx(ctx, func(tx *TxWrap) error { + found = tx.Exists(`SELECT sessionid FROM cmd WHERE sessionid = ? AND cmdid = ? AND donepk is NOT NULL`, ck.GetSessionId(), ck.GetCmdId()) + return nil + }) + if txErr != nil { + return false, txErr + } + return found, nil +} + func UpdateCmdDonePk(ctx context.Context, donePk *packet.CmdDonePacketType) (*ModelUpdate, error) { if donePk == nil || donePk.CK.IsEmpty() { return nil, fmt.Errorf("invalid cmddone packet (no ck)") @@ -732,6 +745,14 @@ func HangupRunningCmdsByRemoteId(ctx context.Context, remoteId string) error { }) } +func HangupCmd(ctx context.Context, ck base.CommandKey) error { + return WithTx(ctx, func(tx *TxWrap) error { + query := `UPDATE cmd SET status = ? WHERE sessionid = ? AND cmdid = ?` + tx.ExecWrap(query, CmdStatusHangup, ck.GetSessionId(), ck.GetCmdId()) + return nil + }) +} + func getNextId(ids []string, delId string) string { if len(ids) == 0 { return "" diff --git a/pkg/utilfn/binpack.go b/pkg/utilfn/binpack.go deleted file mode 100644 index 94cc12a85..000000000 --- a/pkg/utilfn/binpack.go +++ /dev/null @@ -1,34 +0,0 @@ -package utilfn - -import ( - "bufio" - "encoding/binary" - "io" -) - -func PackValue(w io.Writer, barr []byte) error { - viBuf := make([]byte, binary.MaxVarintLen64) - viLen := binary.PutUvarint(viBuf, uint64(len(barr))) - _, err := w.Write(viBuf[0:viLen]) - if err != nil { - return err - } - _, err = w.Write(barr) - if err != nil { - return err - } - return nil -} - -func UnpackValue(r *bufio.Reader) ([]byte, error) { - lenVal, err := binary.ReadUvarint(r) - if err != nil { - return nil, err - } - rtnBuf := make([]byte, int(lenVal)) - _, err = io.ReadFull(r, rtnBuf) - if err != nil { - return nil, err - } - return rtnBuf, nil -}