From 45dfeb69f6c8bfc1750878642148263493ef6700 Mon Sep 17 00:00:00 2001 From: sawka Date: Thu, 7 Jul 2022 21:39:25 -0700 Subject: [PATCH] updates to allow cmd tailing to work with mshell --- cmd/main-server.go | 26 ++++++++++++++++---------- pkg/remote/remote.go | 6 +++++- pkg/scbase/scbase.go | 13 +++++++++++++ pkg/sstore/dbops.go | 25 ++++++------------------- pkg/sstore/fileops.go | 3 +++ pkg/sstore/sstore.go | 13 ++++++------- 6 files changed, 49 insertions(+), 37 deletions(-) diff --git a/cmd/main-server.go b/cmd/main-server.go index 1656ce01f..21cdc332a 100644 --- a/cmd/main-server.go +++ b/cmd/main-server.go @@ -87,7 +87,8 @@ func MakeWSState(clientId string) (*WSState, error) { rtn.ConnectTime = time.Now() rtn.PacketCh = make(chan packet.PacketType, WSStatePacketChSize) chSender := packet.MakeChannelPacketSender(rtn.PacketCh) - rtn.Tailer, err = cmdtail.MakeTailer(chSender) + gen := scbase.ScFileNameGenerator{ScHome: scbase.GetScHomeDir()} + rtn.Tailer, err = cmdtail.MakeTailer(chSender, gen) if err != nil { return nil, err } @@ -198,9 +199,18 @@ func HandleWs(w http.ResponseWriter, r *http.Request) { continue } if pk.GetType() == "getcmd" { - err = state.Tailer.AddWatch(pk.(*packet.GetCmdPacketType)) + getPk := pk.(*packet.GetCmdPacketType) + done, err := state.Tailer.AddWatch(getPk) if err != nil { + // TODO: send responseerror + respPk := packet.MakeErrorResponsePacket(getPk.ReqId, err) fmt.Printf("[error] adding watch to tailer: %v\n", err) + fmt.Printf("%v\n", respPk) + } + if done { + respPk := packet.MakeResponsePacket(getPk.ReqId, true) + fmt.Printf("%v\n", respPk) + // TODO: send response } continue } @@ -455,15 +465,12 @@ func ProcessFeCommandPacket(ctx context.Context, pk *scpacket.FeCommandPacketTyp fmt.Printf("GOT cd RESP: %v\n", resp) return nil, nil } - rtnLine, err := sstore.AddCmdLine(ctx, pk.SessionId, pk.WindowId, pk.UserId) + cmdId := uuid.New().String() + cmd, err := remote.RunCommand(ctx, pk, cmdId) if err != nil { return nil, err } - cmd, err := remote.RunCommand(ctx, pk, rtnLine.CmdId) - if err != nil { - return nil, err - } - err = sstore.InsertCmd(ctx, cmd) + rtnLine, err := sstore.AddCmdLine(ctx, pk.SessionId, pk.WindowId, pk.UserId, cmd) if err != nil { return nil, err } @@ -595,12 +602,11 @@ func main() { fmt.Printf("[error] ensuring local remote: %v\n", err) return } - defaultSession, err := sstore.EnsureDefaultSession(context.Background()) + _, err = sstore.EnsureDefaultSession(context.Background()) if err != nil { fmt.Printf("[error] ensuring default session: %v\n", err) return } - fmt.Printf("session: %v\n", defaultSession) err = remote.LoadRemotes(context.Background()) if err != nil { fmt.Printf("[error] loading remotes: %v\n", err) diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index bb7726cdc..5941233ab 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -223,6 +223,10 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId str DonePk: nil, RunOut: nil, } + err = sstore.AppendToCmdPtyBlob(ctx, cmd.SessionId, cmd.CmdId, nil) + if err != nil { + return nil, err + } return cmd, nil } @@ -316,7 +320,7 @@ func (runner *MShellProc) ProcessPackets() { if ack != nil { runner.ServerProc.Input.SendPacket(ack) } - fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPk.CK, dataPk.FdNum, len(realData), dataPk.Eof, dataPk.Error) + // fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPk.CK, dataPk.FdNum, len(realData), dataPk.Eof, dataPk.Error) continue } if pk.GetType() == packet.CmdDataPacketStr { diff --git a/pkg/scbase/scbase.go b/pkg/scbase/scbase.go index 8578ed328..f30b53092 100644 --- a/pkg/scbase/scbase.go +++ b/pkg/scbase/scbase.go @@ -8,6 +8,7 @@ import ( "path" "sync" + "github.com/scripthaus-dev/mshell/pkg/base" "golang.org/x/sys/unix" ) @@ -109,3 +110,15 @@ func RemotePtyOut(remoteId string) (string, error) { } return fmt.Sprintf("%s/%s.ptyout", rdir, remoteId), nil } + +type ScFileNameGenerator struct { + ScHome string +} + +func (g ScFileNameGenerator) PtyOutFile(ck base.CommandKey) string { + return path.Join(g.ScHome, SessionsDirBaseName, ck.GetSessionId(), ck.GetCmdId()+".ptyout") +} + +func (g ScFileNameGenerator) RunOutFile(ck base.CommandKey) string { + return path.Join(g.ScHome, SessionsDirBaseName, ck.GetSessionId(), ck.GetCmdId()+".runout") +} diff --git a/pkg/sstore/dbops.go b/pkg/sstore/dbops.go index f957a3cd4..23d2b5af6 100644 --- a/pkg/sstore/dbops.go +++ b/pkg/sstore/dbops.go @@ -167,7 +167,7 @@ func InsertSessionWithName(ctx context.Context, sessionName string) error { }) } -func InsertLine(ctx context.Context, line *LineType) error { +func InsertLine(ctx context.Context, line *LineType, cmd *CmdType) error { if line == nil { return fmt.Errorf("line cannot be nil") } @@ -188,27 +188,14 @@ func InsertLine(ctx context.Context, line *LineType) error { query = `INSERT INTO line ( sessionid, windowid, lineid, ts, userid, linetype, text, cmdid) VALUES (:sessionid,:windowid,:lineid,:ts,:userid,:linetype,:text,:cmdid)` tx.NamedExecWrap(query, line) - return nil - }) -} - -func InsertCmd(ctx context.Context, cmd *CmdType) error { - if cmd == nil { - return fmt.Errorf("cmd cannot be nil") - } - return WithTx(ctx, func(tx *TxWrap) error { - var sessionId string - query := `SELECT sessionid FROM session WHERE sessionid = ?` - hasSession := tx.GetWrap(&sessionId, query, cmd.SessionId) - if !hasSession { - return fmt.Errorf("session not found, cannot insert cmd") - } - cmdMap := cmd.ToMap() - query = ` + if cmd != nil { + cmdMap := cmd.ToMap() + query = ` INSERT INTO cmd ( sessionid, cmdid, remoteid, cmdstr, remotestate, termopts, status, startpk, donepk, runout) VALUES (:sessionid,:cmdid,:remoteid,:cmdstr,:remotestate,:termopts,:status,:startpk,:donepk,:runout) ` - tx.NamedExecWrap(query, cmdMap) + tx.NamedExecWrap(query, cmdMap) + } return nil }) } diff --git a/pkg/sstore/fileops.go b/pkg/sstore/fileops.go index 5971a8996..67026ba49 100644 --- a/pkg/sstore/fileops.go +++ b/pkg/sstore/fileops.go @@ -16,6 +16,9 @@ func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, dat if err != nil { return err } + if len(data) == 0 { + return nil + } _, err = fd.Write(data) if err != nil { return err diff --git a/pkg/sstore/sstore.go b/pkg/sstore/sstore.go index c867c8e4d..3f50ee08f 100644 --- a/pkg/sstore/sstore.go +++ b/pkg/sstore/sstore.go @@ -11,7 +11,6 @@ import ( "sync" "time" - "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/packet" @@ -245,14 +244,14 @@ func CmdFromMap(m map[string]interface{}) *CmdType { return &cmd } -func makeNewLineCmd(sessionId string, windowId string, userId string) *LineType { +func makeNewLineCmd(sessionId string, windowId string, userId string, cmdId string) *LineType { rtn := &LineType{} rtn.SessionId = sessionId rtn.WindowId = windowId rtn.Ts = time.Now().UnixMilli() rtn.UserId = userId rtn.LineType = LineTypeCmd - rtn.CmdId = uuid.New().String() + rtn.CmdId = cmdId return rtn } @@ -269,16 +268,16 @@ func makeNewLineText(sessionId string, windowId string, userId string, text stri func AddCommentLine(ctx context.Context, sessionId string, windowId string, userId string, commentText string) (*LineType, error) { rtnLine := makeNewLineText(sessionId, windowId, userId, commentText) - err := InsertLine(ctx, rtnLine) + err := InsertLine(ctx, rtnLine, nil) if err != nil { return nil, err } return rtnLine, nil } -func AddCmdLine(ctx context.Context, sessionId string, windowId string, userId string) (*LineType, error) { - rtnLine := makeNewLineCmd(sessionId, windowId, userId) - err := InsertLine(ctx, rtnLine) +func AddCmdLine(ctx context.Context, sessionId string, windowId string, userId string, cmd *CmdType) (*LineType, error) { + rtnLine := makeNewLineCmd(sessionId, windowId, userId, cmd.CmdId) + err := InsertLine(ctx, rtnLine, cmd) if err != nil { return nil, err }