implement cmdfinal (hangup) from server

This commit is contained in:
sawka 2022-11-27 14:12:15 -08:00
parent 301bfaa0be
commit d5ea9e0221
3 changed files with 61 additions and 34 deletions

View File

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log"
"os"
"os/exec"
"path"
@ -1401,10 +1402,38 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
}
}
}
if donePk.FinalStateDiff != nil {
fmt.Printf("** final state diff! %v\n", donePk.FinalStateDiff)
}
sstore.MainBus.SendUpdate(donePk.CK.GetSessionId(), update)
return
}
func (msh *MShellProc) handleCmdFinalPacket(finalPk *packet.CmdFinalPacketType) {
defer msh.RemoveRunningCmd(finalPk.CK)
rtnCmd, err := sstore.GetCmdById(context.Background(), finalPk.CK.GetSessionId(), finalPk.CK.GetCmdId())
if err != nil {
log.Printf("error calling GetCmdById in handleCmdFinalPacket: %v\n", err)
return
}
if rtnCmd == nil || rtnCmd.DonePk != nil {
return
}
log.Printf("finalpk %s (hangup): %s\n", finalPk.CK, finalPk.Error)
sstore.HangupCmd(context.Background(), finalPk.CK)
rtnCmd, err = sstore.GetCmdById(context.Background(), finalPk.CK.GetSessionId(), finalPk.CK.GetCmdId())
if err != nil {
log.Printf("error getting cmd(2) in handleCmdFinalPacket: %v\n", err)
return
}
if rtnCmd == nil {
log.Printf("error getting cmd(2) in handleCmdFinalPacket (not found)\n")
return
}
update := &sstore.ModelUpdate{Cmd: rtnCmd}
sstore.MainBus.SendUpdate(finalPk.CK.GetSessionId(), update)
}
// TODO notify FE about cmd errors
func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) {
err := sstore.AppendCmdErrorPk(context.Background(), errPk)
@ -1454,6 +1483,12 @@ func (msh *MShellProc) makeHandleCmdDonePacketClosure(donePk *packet.CmdDonePack
}
}
func (msh *MShellProc) makeHandleCmdFinalPacketClosure(finalPk *packet.CmdFinalPacketType) func() {
return func() {
msh.handleCmdFinalPacket(finalPk)
}
}
func (msh *MShellProc) ProcessPackets() {
defer msh.WithLock(func() {
if msh.Status == StatusConnected {
@ -1488,6 +1523,11 @@ func (msh *MShellProc) ProcessPackets() {
runCmdUpdateFn(donePk.CK, msh.makeHandleCmdDonePacketClosure(donePk))
continue
}
if pk.GetType() == packet.CmdFinalPacketStr {
finalPk := pk.(*packet.CmdFinalPacketType)
runCmdUpdateFn(finalPk.CK, msh.makeHandleCmdFinalPacketClosure(finalPk))
continue
}
if pk.GetType() == packet.CmdErrorPacketStr {
msh.handleCmdErrorPacket(pk.(*packet.CmdErrorPacketType))
continue

View File

@ -7,6 +7,7 @@ import (
"strings"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
)
@ -681,6 +682,18 @@ func GetCmdById(ctx context.Context, sessionId string, cmdId string) (*CmdType,
return cmd, nil
}
func HasDonePk(ctx context.Context, ck base.CommandKey) (bool, error) {
var found bool
txErr := WithTx(ctx, func(tx *TxWrap) error {
found = tx.Exists(`SELECT sessionid FROM cmd WHERE sessionid = ? AND cmdid = ? AND donepk is NOT NULL`, ck.GetSessionId(), ck.GetCmdId())
return nil
})
if txErr != nil {
return false, txErr
}
return found, nil
}
func UpdateCmdDonePk(ctx context.Context, donePk *packet.CmdDonePacketType) (*ModelUpdate, error) {
if donePk == nil || donePk.CK.IsEmpty() {
return nil, fmt.Errorf("invalid cmddone packet (no ck)")
@ -732,6 +745,14 @@ func HangupRunningCmdsByRemoteId(ctx context.Context, remoteId string) error {
})
}
func HangupCmd(ctx context.Context, ck base.CommandKey) error {
return WithTx(ctx, func(tx *TxWrap) error {
query := `UPDATE cmd SET status = ? WHERE sessionid = ? AND cmdid = ?`
tx.ExecWrap(query, CmdStatusHangup, ck.GetSessionId(), ck.GetCmdId())
return nil
})
}
func getNextId(ids []string, delId string) string {
if len(ids) == 0 {
return ""

View File

@ -1,34 +0,0 @@
package utilfn
import (
"bufio"
"encoding/binary"
"io"
)
func PackValue(w io.Writer, barr []byte) error {
viBuf := make([]byte, binary.MaxVarintLen64)
viLen := binary.PutUvarint(viBuf, uint64(len(barr)))
_, err := w.Write(viBuf[0:viLen])
if err != nil {
return err
}
_, err = w.Write(barr)
if err != nil {
return err
}
return nil
}
func UnpackValue(r *bufio.Reader) ([]byte, error) {
lenVal, err := binary.ReadUvarint(r)
if err != nil {
return nil, err
}
rtnBuf := make([]byte, int(lenVal))
_, err = io.ReadFull(r, rtnBuf)
if err != nil {
return nil, err
}
return rtnBuf, nil
}