checkpoint, getting closer to running a command via mshell server

This commit is contained in:
sawka 2022-07-06 19:01:00 -07:00
parent 2755be315d
commit 98e46399be
7 changed files with 182 additions and 124 deletions

View File

@ -429,16 +429,25 @@ func ProcessFeCommandPacket(ctx context.Context, pk *scpacket.FeCommandPacketTyp
cdPacket.ReqId = uuid.New().String() cdPacket.ReqId = uuid.New().String()
cdPacket.Dir = newDir cdPacket.Dir = newDir
localRemote := remote.GetRemoteById(pk.RemoteState.RemoteId) localRemote := remote.GetRemoteById(pk.RemoteState.RemoteId)
if localRemote != nil { if localRemote == nil {
localRemote.Input.SendPacket(cdPacket) return nil, fmt.Errorf("invalid remote, cannot execute command")
} }
resp, err := localRemote.PacketRpc(ctx, cdPacket)
if err != nil {
return nil, err
}
fmt.Printf("GOT cd RESP: %v\n", resp)
return nil, nil return nil, nil
} }
rtnLine, err := sstore.AddCmdLine(ctx, pk.SessionId, pk.WindowId, pk.UserId) rtnLine, err := sstore.AddCmdLine(ctx, pk.SessionId, pk.WindowId, pk.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = remote.RunCommand(pk, rtnLine.CmdId) startPk, err := remote.RunCommand(ctx, pk, rtnLine.CmdId)
if err != nil {
return nil, err
}
fmt.Printf("START CMD: %s\n", packet.AsString(startPk))
return &runCommandResponse{Line: rtnLine}, nil return &runCommandResponse{Line: rtnLine}, nil
} }
@ -572,6 +581,9 @@ func main() {
fmt.Printf("[error] loading remotes: %v\n", err) fmt.Printf("[error] loading remotes: %v\n", err)
return return
} }
sstore.AppendToCmdPtyBlob(context.Background(), "", "", nil)
go runWebSocketServer() go runWebSocketServer()
gr := mux.NewRouter() gr := mux.NewRouter()
gr.HandleFunc("/api/ptyout", HandleGetPtyOut) gr.HandleFunc("/api/ptyout", HandleGetPtyOut)

View File

@ -49,8 +49,7 @@ CREATE TABLE remote (
sshuser varchar(100) NOT NULL, sshuser varchar(100) NOT NULL,
-- runtime data -- runtime data
lastconnectts bigint NOT NULL, lastconnectts bigint NOT NULL
ptyout BLOB NOT NULL
); );
CREATE TABLE session_cmd ( CREATE TABLE session_cmd (
@ -65,8 +64,6 @@ CREATE TABLE session_cmd (
runnerpid int NOT NULL, runnerpid int NOT NULL,
donets bigint NOT NULL, donets bigint NOT NULL,
exitcode int NOT NULL, exitcode int NOT NULL,
ptyout BLOB NOT NULL,
runout BLOB NOT NULL,
PRIMARY KEY (sessionid, cmdid) PRIMARY KEY (sessionid, cmdid)
); );

View File

@ -2,14 +2,13 @@ package remote
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"os"
"os/exec" "os/exec"
"strings" "strings"
"sync" "sync"
"time"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec" "github.com/scripthaus-dev/mshell/pkg/shexec"
@ -18,10 +17,12 @@ import (
) )
const RemoteTypeMShell = "mshell" const RemoteTypeMShell = "mshell"
const DefaultTermRows = 25
const DefaultTermCols = 80
const DefaultTerm = "xterm-256color"
const ( const (
StatusInit = "init" StatusInit = "init"
StatusConnecting = "connecting"
StatusConnected = "connected" StatusConnected = "connected"
StatusDisconnected = "disconnected" StatusDisconnected = "disconnected"
StatusError = "error" StatusError = "error"
@ -47,20 +48,9 @@ type MShellProc struct {
Remote *sstore.RemoteType Remote *sstore.RemoteType
// runtime // runtime
Status string Status string
InitPk *packet.InitPacketType ServerProc *shexec.ClientProc
Cmd *exec.Cmd Err error
Input *packet.PacketSender
Output *packet.PacketParser
DoneCh chan bool
RpcMap map[string]*RpcEntry
Err error
}
type RpcEntry struct {
ReqId string
RespCh chan packet.RpcResponsePacketType
} }
func LoadRemotes(ctx context.Context) error { func LoadRemotes(ctx context.Context) error {
@ -111,8 +101,8 @@ func GetAllRemoteState() []RemoteState {
RemoteName: proc.Remote.RemoteName, RemoteName: proc.Remote.RemoteName,
Status: proc.Status, Status: proc.Status,
} }
if proc.InitPk != nil { if proc.ServerProc != nil && proc.ServerProc.InitPk != nil {
state.DefaultState = &sstore.RemoteState{Cwd: proc.InitPk.HomeDir} state.DefaultState = &sstore.RemoteState{Cwd: proc.ServerProc.InitPk.HomeDir}
} }
rtn = append(rtn, state) rtn = append(rtn, state)
} }
@ -135,50 +125,24 @@ func (msh *MShellProc) Launch() {
return return
} }
ecmd := exec.Command(msPath, "--server") ecmd := exec.Command(msPath, "--server")
msh.Cmd = ecmd cproc, err := shexec.MakeClientProc(ecmd)
inputWriter, err := ecmd.StdinPipe()
if err != nil { if err != nil {
msh.Status = StatusError msh.Status = StatusError
msh.Err = fmt.Errorf("create stdin pipe: %w", err) msh.Err = err
return
}
stdoutReader, err := ecmd.StdoutPipe()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stdout pipe: %w", err)
return
}
stderrReader, err := ecmd.StderrPipe()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stderr pipe: %w", err)
return return
} }
msh.ServerProc = cproc
fmt.Printf("START MAKECLIENTPROC: %#v\n", msh.ServerProc.InitPk)
msh.Status = StatusConnected
go func() { go func() {
io.Copy(os.Stderr, stderrReader) exitErr := cproc.Cmd.Wait()
}()
err = ecmd.Start()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("starting mshell server: %w", err)
return
}
fmt.Printf("Started remote '%s' pid=%d\n", msh.Remote.RemoteName, msh.Cmd.Process.Pid)
msh.Status = StatusConnecting
msh.Output = packet.MakePacketParser(stdoutReader)
msh.Input = packet.MakePacketSender(inputWriter)
msh.RpcMap = make(map[string]*RpcEntry)
msh.DoneCh = make(chan bool)
go func() {
exitErr := ecmd.Wait()
exitCode := shexec.GetExitCode(exitErr) exitCode := shexec.GetExitCode(exitErr)
msh.WithLock(func() { msh.WithLock(func() {
if msh.Status == StatusConnected || msh.Status == StatusConnecting { if msh.Status == StatusConnected {
msh.Status = StatusDisconnected msh.Status = StatusDisconnected
} }
}) })
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode) fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
close(msh.DoneCh)
}() }()
go msh.ProcessPackets() go msh.ProcessPackets()
return return
@ -190,51 +154,62 @@ func (msh *MShellProc) IsConnected() bool {
return msh.Status == StatusConnected return msh.Status == StatusConnected
} }
func RunCommand(pk *scpacket.FeCommandPacketType, cmdId string) error { func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId string) (*packet.CmdStartPacketType, error) {
msh := GetRemoteById(pk.RemoteState.RemoteId) msh := GetRemoteById(pk.RemoteState.RemoteId)
if msh == nil { if msh == nil {
return fmt.Errorf("no remote id=%s found", pk.RemoteState.RemoteId) return nil, fmt.Errorf("no remote id=%s found", pk.RemoteState.RemoteId)
} }
if !msh.IsConnected() { if !msh.IsConnected() {
return fmt.Errorf("remote '%s' is not connected", msh.Remote.RemoteName) return nil, fmt.Errorf("remote '%s' is not connected", msh.Remote.RemoteName)
} }
runPacket := packet.MakeRunPacket() runPacket := packet.MakeRunPacket()
runPacket.ReqId = uuid.New().String()
runPacket.CK = base.MakeCommandKey(pk.SessionId, cmdId) runPacket.CK = base.MakeCommandKey(pk.SessionId, cmdId)
runPacket.Cwd = pk.RemoteState.Cwd runPacket.Cwd = pk.RemoteState.Cwd
runPacket.Env = nil runPacket.Env = nil
runPacket.UsePty = true
runPacket.TermOpts = &packet.TermOpts{Rows: DefaultTermRows, Cols: DefaultTermCols, Term: DefaultTerm}
runPacket.Command = strings.TrimSpace(pk.CmdStr) runPacket.Command = strings.TrimSpace(pk.CmdStr)
fmt.Printf("run-packet %v\n", runPacket) fmt.Printf("RUN-CMD> %s\n", runPacket.CK)
go func() { msh.ServerProc.Output.RegisterRpc(runPacket.ReqId)
msh.Input.SendPacket(runPacket) err := shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket)
}() if err != nil {
return nil return nil, fmt.Errorf("sending run packet to remote: %w", err)
}
rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId)
if startPk, ok := rtnPk.(*packet.CmdStartPacketType); ok {
return startPk, nil
}
if respPk, ok := rtnPk.(*packet.ResponsePacketType); ok {
if respPk.Error != "" {
return nil, errors.New(respPk.Error)
}
}
return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
} }
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcResponsePacketType, error) { func (msh *MShellProc) PacketRpc(ctx context.Context, pk packet.RpcPacketType) (*packet.ResponsePacketType, error) {
if !runner.IsConnected() { if !msh.IsConnected() {
return nil, fmt.Errorf("runner is not connected") return nil, fmt.Errorf("runner is not connected")
} }
if pk == nil { if pk == nil {
return nil, fmt.Errorf("PacketRpc passed nil packet") return nil, fmt.Errorf("PacketRpc passed nil packet")
} }
id := pk.GetReqId() reqId := pk.GetReqId()
respCh := make(chan packet.RpcResponsePacketType) msh.ServerProc.Output.RegisterRpc(reqId)
runner.WithLock(func() { defer msh.ServerProc.Output.UnRegisterRpc(reqId)
runner.RpcMap[id] = &RpcEntry{ReqId: id, RespCh: respCh} err := msh.ServerProc.Input.SendPacketCtx(ctx, pk)
}) if err != nil {
defer runner.WithLock(func() { return nil, err
delete(runner.RpcMap, id)
})
runner.Input.SendPacket(pk)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case rtnPk := <-respCh:
return rtnPk, nil
case <-timer.C:
return nil, fmt.Errorf("PacketRpc timeout")
} }
rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, reqId)
if rtnPk == nil {
return nil, ctx.Err()
}
if respPk, ok := rtnPk.(*packet.ResponsePacketType); ok {
return respPk, nil
}
return nil, fmt.Errorf("invalid response packet received: %s", packet.AsString(rtnPk))
} }
func (runner *MShellProc) WithLock(fn func()) { func (runner *MShellProc) WithLock(fn func()) {
@ -245,38 +220,25 @@ func (runner *MShellProc) WithLock(fn func()) {
func (runner *MShellProc) ProcessPackets() { func (runner *MShellProc) ProcessPackets() {
defer runner.WithLock(func() { defer runner.WithLock(func() {
if runner.Status == StatusConnected || runner.Status == StatusConnecting { if runner.Status == StatusConnected {
runner.Status = StatusDisconnected runner.Status = StatusDisconnected
} }
}) })
for pk := range runner.Output.MainCh { for pk := range runner.ServerProc.Output.MainCh {
fmt.Printf("MSH> %s\n", packet.AsString(pk)) fmt.Printf("MSH> %s | %#v\n", packet.AsString(pk), pk)
if rpcPk, ok := pk.(packet.RpcResponsePacketType); ok { if pk.GetType() == packet.DataPacketStr {
rpcId := rpcPk.GetResponseId() dataPacket := pk.(*packet.DataPacketType)
runner.WithLock(func() { fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPacket.CK, dataPacket.FdNum, packet.B64DecodedLen(dataPacket.Data64), dataPacket.Eof, dataPacket.Error)
entry := runner.RpcMap[rpcId] continue
if entry == nil {
return
}
delete(runner.RpcMap, rpcId)
go func() {
entry.RespCh <- rpcPk
close(entry.RespCh)
}()
})
} }
if pk.GetType() == packet.CmdDataPacketStr { if pk.GetType() == packet.CmdDataPacketStr {
dataPacket := pk.(*packet.CmdDataPacketType) dataPacket := pk.(*packet.CmdDataPacketType)
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData)) fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, dataPacket.PtyDataLen, dataPacket.RunDataLen)
continue continue
} }
if pk.GetType() == packet.InitPacketStr { if pk.GetType() == packet.CmdDonePacketStr {
initPacket := pk.(*packet.InitPacketType) donePacket := pk.(*packet.CmdDonePacketType)
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir) fmt.Printf("cmd-done %s\n", donePacket.CK)
runner.WithLock(func() {
runner.InitPk = initPacket
runner.Status = StatusConnected
})
continue continue
} }
if pk.GetType() == packet.MessagePacketStr { if pk.GetType() == packet.MessagePacketStr {

View File

@ -1,12 +1,21 @@
package scbase package scbase
import ( import (
"errors"
"fmt"
"io/fs"
"os" "os"
"path" "path"
"sync"
) )
const HomeVarName = "HOME" const HomeVarName = "HOME"
const ScHomeVarName = "SCRIPTHAUS_HOME" const ScHomeVarName = "SCRIPTHAUS_HOME"
const SessionsDirBaseName = "sessions"
const RemotesDirBaseName = "remotes"
var SessionDirCache = make(map[string]string)
var BaseLock = &sync.Mutex{}
func GetScHomeDir() string { func GetScHomeDir() string {
scHome := os.Getenv(ScHomeVarName) scHome := os.Getenv(ScHomeVarName)
@ -19,3 +28,66 @@ func GetScHomeDir() string {
} }
return scHome return scHome
} }
func EnsureSessionDir(sessionId string) (string, error) {
BaseLock.Lock()
sdir, ok := SessionDirCache[sessionId]
BaseLock.Unlock()
if ok {
return sdir, nil
}
scHome := GetScHomeDir()
sdir = path.Join(scHome, SessionsDirBaseName, sessionId)
err := ensureDir(sdir)
if err != nil {
return "", err
}
BaseLock.Lock()
SessionDirCache[sessionId] = sdir
BaseLock.Unlock()
return sdir, nil
}
func ensureDir(dirName string) error {
info, err := os.Stat(dirName)
if errors.Is(err, fs.ErrNotExist) {
err = os.MkdirAll(dirName, 0700)
if err != nil {
return err
}
info, err = os.Stat(dirName)
}
if err != nil {
return err
}
if !info.IsDir() {
return fmt.Errorf("'%s' must be a directory", dirName)
}
return nil
}
func PtyOutFile(sessionId string, cmdId string) (string, error) {
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s.ptyout", sdir, cmdId), nil
}
func RunOutFile(sessionId string, cmdId string) (string, error) {
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s.runout", sdir, cmdId), nil
}
func RemotePtyOut(remoteId string) (string, error) {
scHome := GetScHomeDir()
rdir := path.Join(scHome, RemotesDirBaseName)
err := ensureDir(rdir)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s.ptyout", rdir, remoteId), nil
}

View File

@ -22,7 +22,7 @@ func NumSessions(ctx context.Context) (int, error) {
return count, nil return count, nil
} }
const remoteSelectCols = "rowid, remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts" const remoteSelectCols = "remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts"
func GetAllRemotes(ctx context.Context) ([]*RemoteType, error) { func GetAllRemotes(ctx context.Context) ([]*RemoteType, error) {
db, err := GetDB() db, err := GetDB()
@ -76,23 +76,16 @@ func InsertRemote(ctx context.Context, remote *RemoteType) error {
if remote == nil { if remote == nil {
return fmt.Errorf("cannot insert nil remote") return fmt.Errorf("cannot insert nil remote")
} }
if remote.RowId != 0 {
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
}
db, err := GetDB() db, err := GetDB()
if err != nil { if err != nil {
return err return err
} }
query := `INSERT INTO remote ( remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts, ptyout) VALUES query := `INSERT INTO remote ( remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts) VALUES
(:remoteid,:remotetype,:remotename,:autoconnect,:sshhost,:sshopts,:sshidentity,:sshuser, 0 , '')` (:remoteid,:remotetype,:remotename,:autoconnect,:sshhost,:sshopts,:sshidentity,:sshuser, 0 )`
result, err := db.NamedExec(query, remote) _, err = db.NamedExec(query, remote)
if err != nil { if err != nil {
return err return err
} }
remote.RowId, err = result.LastInsertId()
if err != nil {
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
}
return nil return nil
} }

24
pkg/sstore/fileops.go Normal file
View File

@ -0,0 +1,24 @@
package sstore
import (
"context"
"os"
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
)
func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte) error {
ptyOutFileName, err := scbase.PtyOutFile(sessionId, cmdId)
if err != nil {
return err
}
fd, err := os.OpenFile(ptyOutFileName, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return err
}
_, err = fd.Write(data)
if err != nil {
return err
}
return nil
}

View File

@ -111,7 +111,6 @@ type LineType struct {
} }
type RemoteType struct { type RemoteType struct {
RowId int64 `json:"rowid"`
RemoteId string `json:"remoteid"` RemoteId string `json:"remoteid"`
RemoteType string `json:"remotetype"` RemoteType string `json:"remotetype"`
RemoteName string `json:"remotename"` RemoteName string `json:"remotename"`
@ -128,7 +127,6 @@ type RemoteType struct {
} }
type CmdType struct { type CmdType struct {
RowId int64 `json:"rowid"`
SessionId string `json:"sessionid"` SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"` CmdId string `json:"cmdid"`
RSId string `json:"rsid"` RSId string `json:"rsid"`