From f7666fe480fd2b8855952d357c3662761f47473b Mon Sep 17 00:00:00 2001 From: sawka Date: Thu, 7 Jul 2022 00:10:37 -0700 Subject: [PATCH] checkpoint on storing cmd in db --- cmd/main-server.go | 9 ++- db/migrations/000001_init.down.sql | 2 +- db/migrations/000001_init.up.sql | 13 ++-- pkg/remote/remote.go | 34 +++++++-- pkg/sstore/dbops.go | 43 +++++++++++ pkg/sstore/sstore.go | 110 +++++++++++++++++++++++++---- pkg/sstore/txwrap.go | 22 ++++++ 7 files changed, 204 insertions(+), 29 deletions(-) diff --git a/cmd/main-server.go b/cmd/main-server.go index 14f77cfdc..5d2316223 100644 --- a/cmd/main-server.go +++ b/cmd/main-server.go @@ -459,12 +459,15 @@ func ProcessFeCommandPacket(ctx context.Context, pk *scpacket.FeCommandPacketTyp if err != nil { return nil, err } - startPk, err := remote.RunCommand(ctx, pk, rtnLine.CmdId) + cmd, err := remote.RunCommand(ctx, pk, rtnLine.CmdId) if err != nil { return nil, err } - fmt.Printf("START CMD: %s\n", packet.AsString(startPk)) - return &runCommandResponse{Line: rtnLine}, nil + err = sstore.InsertCmd(ctx, cmd) + if err != nil { + return nil, err + } + return &runCommandResponse{Line: rtnLine, Cmd: cmd}, nil } // /api/start-session diff --git a/db/migrations/000001_init.down.sql b/db/migrations/000001_init.down.sql index 6887fccdc..f523b435e 100644 --- a/db/migrations/000001_init.down.sql +++ b/db/migrations/000001_init.down.sql @@ -3,5 +3,5 @@ DROP TABLE window; DROP TABLE remote_instance; DROP TABLE line; DROP TABLE remote; -DROP TABLE session_cmd; +DROP TABLE cmd; DROP TABLE history; diff --git a/db/migrations/000001_init.up.sql b/db/migrations/000001_init.up.sql index ea0438427..6d54ab367 100644 --- a/db/migrations/000001_init.up.sql +++ b/db/migrations/000001_init.up.sql @@ -52,18 +52,17 @@ CREATE TABLE remote ( lastconnectts bigint NOT NULL ); -CREATE TABLE session_cmd ( +CREATE TABLE cmd ( sessionid varchar(36) NOT NULL, cmdid varchar(36) NOT NULL, - rsid varchar(36) NOT NULL, remoteid varchar(36) NOT NULL, + cmdstr text NOT NULL, remotestate json NOT NULL, + termopts json NOT NULL, status varchar(10) NOT NULL, - startts bigint NOT NULL, - pid int NOT NULL, - runnerpid int NOT NULL, - donets bigint NOT NULL, - exitcode int NOT NULL, + startpk json NOT NULL, + donepk json NOT NULL, + runout json NOT NULL, PRIMARY KEY (sessionid, cmdid) ); diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 884669ee8..d47fc0e2e 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -154,7 +154,15 @@ func (msh *MShellProc) IsConnected() bool { return msh.Status == StatusConnected } -func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId string) (*packet.CmdStartPacketType, error) { +func convertRemoteState(rs scpacket.RemoteState) sstore.RemoteState { + return sstore.RemoteState{Cwd: rs.Cwd} +} + +func makeTermOpts() sstore.TermOpts { + return sstore.TermOpts{Rows: DefaultTermRows, Cols: DefaultTermCols, FlexRows: true} +} + +func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId string) (*sstore.CmdType, error) { msh := GetRemoteById(pk.RemoteState.RemoteId) if msh == nil { return nil, fmt.Errorf("no remote id=%s found", pk.RemoteState.RemoteId) @@ -177,15 +185,29 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId str return nil, fmt.Errorf("sending run packet to remote: %w", err) } rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId) - if startPk, ok := rtnPk.(*packet.CmdStartPacketType); ok { - return startPk, nil - } - if respPk, ok := rtnPk.(*packet.ResponsePacketType); ok { + startPk, ok := rtnPk.(*packet.CmdStartPacketType) + if !ok { + respPk, ok := rtnPk.(*packet.ResponsePacketType) + if !ok { + return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) + } if respPk.Error != "" { return nil, errors.New(respPk.Error) } + return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) } - return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk)) + cmd := &sstore.CmdType{ + SessionId: pk.SessionId, + CmdId: startPk.CK.GetCmdId(), + RemoteId: msh.Remote.RemoteId, + RemoteState: convertRemoteState(pk.RemoteState), + TermOpts: makeTermOpts(), + Status: "running", + StartPk: startPk, + DonePk: nil, + RunOut: nil, + } + return cmd, nil } func (msh *MShellProc) PacketRpc(ctx context.Context, pk packet.RpcPacketType) (*packet.ResponsePacketType, error) { diff --git a/pkg/sstore/dbops.go b/pkg/sstore/dbops.go index 54904d403..d6b07262b 100644 --- a/pkg/sstore/dbops.go +++ b/pkg/sstore/dbops.go @@ -103,6 +103,11 @@ func GetSessionById(ctx context.Context, id string) (*SessionType, error) { tx.SelectWrap(&session.Windows, query, session.SessionId) query = `SELECT * FROM remote_instance WHERE sessionid = ?` tx.SelectWrap(&session.Remotes, query, session.SessionId) + query = `SELECT * FROM cmd WHERE sessionid = ?` + marr := tx.SelectMaps(query, session.SessionId) + for _, m := range marr { + session.Cmds = append(session.Cmds, CmdFromMap(m)) + } return nil }) if err != nil { @@ -191,3 +196,41 @@ func InsertLine(ctx context.Context, line *LineType) error { return nil }) } + +func InsertCmd(ctx context.Context, cmd *CmdType) error { + if cmd == nil { + return fmt.Errorf("cmd cannot be nil") + } + return WithTx(ctx, func(tx *TxWrap) error { + var sessionId string + query := `SELECT sessionid FROM session WHERE sessionid = ?` + hasSession := tx.GetWrap(&sessionId, query, cmd.SessionId) + if !hasSession { + return fmt.Errorf("session not found, cannot insert cmd") + } + cmdMap := cmd.ToMap() + query = ` +INSERT INTO cmd ( sessionid, cmdid, remoteid, cmdstr, remotestate, termopts, status, startpk, donepk, runout) + VALUES (:sessionid,:cmdid,:remoteid,:cmdstr,:remotestate,:termopts,:status,:startpk,:donepk,:runout) +` + tx.NamedExecWrap(query, cmdMap) + return nil + }) +} + +func GetCmd(ctx context.Context, sessionId string, cmdId string) (*CmdType, error) { + db, err := GetDB() + if err != nil { + return nil, err + } + var m map[string]interface{} + query := `SELECT * FROM cmd WHERE sessionid = ? AND cmdid = ?` + err = db.GetContext(ctx, &m, query, sessionId, cmdId) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return CmdFromMap(m), nil +} diff --git a/pkg/sstore/sstore.go b/pkg/sstore/sstore.go index 4346244e6..63f632524 100644 --- a/pkg/sstore/sstore.go +++ b/pkg/sstore/sstore.go @@ -89,6 +89,30 @@ func (s *RemoteState) Value() (driver.Value, error) { return json.Marshal(s) } +type TermOpts struct { + Rows int `json:"rows"` + Cols int `json:"cols"` + FlexRows bool `json:"flexrows,omitempty"` +} + +func (opts *TermOpts) Scan(val interface{}) error { + if strVal, ok := val.(string); ok { + if strVal == "" { + return nil + } + err := json.Unmarshal([]byte(strVal), opts) + if err != nil { + return err + } + return nil + } + return fmt.Errorf("cannot scan '%T' into TermOpts", val) +} + +func (opts *TermOpts) Value() (driver.Value, error) { + return json.Marshal(opts) +} + type RemoteInstance struct { RIId string `json:"riid"` Name string `json:"name"` @@ -127,19 +151,81 @@ type RemoteType struct { } type CmdType struct { - SessionId string `json:"sessionid"` - CmdId string `json:"cmdid"` - RSId string `json:"rsid"` - RemoteId string `json:"remoteid"` - RemoteState string `json:"remotestate"` - Status string `json:"status"` - StartTs int64 `json:"startts"` - DoneTs int64 `json:"donets"` - Pid int `json:"pid"` - RunnerPid int `json:"runnerpid"` - ExitCode int `json:"exitcode"` + SessionId string `json:"sessionid"` + CmdId string `json:"cmdid"` + RemoteId string `json:"remoteid"` + CmdStr string `json:"cmdstr"` + RemoteState RemoteState `json:"remotestate"` + TermOpts TermOpts `json:"termopts"` + Status string `json:"status"` + StartPk *packet.CmdStartPacketType `json:"startpk"` + DonePk *packet.CmdDonePacketType `json:"donepk"` + RunOut []packet.PacketType `json:"runout"` +} - RunOut packet.PacketType `json:"runout"` +func quickJson(v interface{}) string { + if v == nil { + return "" + } + barr, _ := json.Marshal(v) + return string(barr) +} + +func (cmd *CmdType) ToMap() map[string]interface{} { + rtn := make(map[string]interface{}) + rtn["sessionid"] = cmd.SessionId + rtn["cmdid"] = cmd.CmdId + rtn["remoteid"] = cmd.RemoteId + rtn["cmdstr"] = cmd.CmdStr + rtn["remotestate"] = quickJson(cmd.RemoteState) + rtn["termopts"] = quickJson(cmd.TermOpts) + rtn["status"] = cmd.Status + rtn["startpk"] = quickJson(cmd.StartPk) + rtn["donepk"] = quickJson(cmd.DonePk) + rtn["runout"] = quickJson(cmd.RunOut) + return rtn +} + +func quickSetStr(strVal *string, m map[string]interface{}, name string) { + v, ok := m[name] + if !ok { + return + } + str, ok := v.(string) + if !ok { + return + } + *strVal = str +} + +func quickSetJson(ptr interface{}, m map[string]interface{}, name string) { + v, ok := m[name] + if !ok { + return + } + str, ok := v.(string) + if !ok { + return + } + if str == "" { + return + } + json.Unmarshal([]byte(str), ptr) +} + +func CmdFromMap(m map[string]interface{}) *CmdType { + var cmd CmdType + quickSetStr(&cmd.SessionId, m, "sessionid") + quickSetStr(&cmd.CmdId, m, "cmdid") + quickSetStr(&cmd.RemoteId, m, "remoteid") + quickSetStr(&cmd.CmdStr, m, "cmdstr") + quickSetJson(&cmd.RemoteState, m, "remotestate") + quickSetJson(&cmd.TermOpts, m, "termopts") + quickSetStr(&cmd.Status, m, "status") + quickSetJson(&cmd.StartPk, m, "startpk") + quickSetJson(&cmd.DonePk, m, "donepk") + quickSetJson(&cmd.RunOut, m, "runout") + return &cmd } func makeNewLineCmd(sessionId string, windowId string, userId string) *LineType { diff --git a/pkg/sstore/txwrap.go b/pkg/sstore/txwrap.go index b58b8442d..2bac4030f 100644 --- a/pkg/sstore/txwrap.go +++ b/pkg/sstore/txwrap.go @@ -90,3 +90,25 @@ func (tx *TxWrap) SelectWrap(dest interface{}, query string, args ...interface{} } return } + +func (tx *TxWrap) SelectMaps(query string, args ...interface{}) []map[string]interface{} { + if tx.Err != nil { + return nil + } + rows, err := tx.Txx.Queryx(query, args...) + if err != nil { + tx.Err = err + return nil + } + var rtn []map[string]interface{} + for rows.Next() { + m := make(map[string]interface{}) + err = rows.MapScan(m) + if err != nil { + tx.Err = err + return nil + } + rtn = append(rtn, m) + } + return rtn +}