checkpoint, getting closer to running a command via mshell server

This commit is contained in:
sawka 2022-07-06 19:01:00 -07:00
parent 2755be315d
commit 98e46399be
7 changed files with 182 additions and 124 deletions

View File

@ -429,16 +429,25 @@ func ProcessFeCommandPacket(ctx context.Context, pk *scpacket.FeCommandPacketTyp
cdPacket.ReqId = uuid.New().String()
cdPacket.Dir = newDir
localRemote := remote.GetRemoteById(pk.RemoteState.RemoteId)
if localRemote != nil {
localRemote.Input.SendPacket(cdPacket)
if localRemote == nil {
return nil, fmt.Errorf("invalid remote, cannot execute command")
}
resp, err := localRemote.PacketRpc(ctx, cdPacket)
if err != nil {
return nil, err
}
fmt.Printf("GOT cd RESP: %v\n", resp)
return nil, nil
}
rtnLine, err := sstore.AddCmdLine(ctx, pk.SessionId, pk.WindowId, pk.UserId)
if err != nil {
return nil, err
}
err = remote.RunCommand(pk, rtnLine.CmdId)
startPk, err := remote.RunCommand(ctx, pk, rtnLine.CmdId)
if err != nil {
return nil, err
}
fmt.Printf("START CMD: %s\n", packet.AsString(startPk))
return &runCommandResponse{Line: rtnLine}, nil
}
@ -572,6 +581,9 @@ func main() {
fmt.Printf("[error] loading remotes: %v\n", err)
return
}
sstore.AppendToCmdPtyBlob(context.Background(), "", "", nil)
go runWebSocketServer()
gr := mux.NewRouter()
gr.HandleFunc("/api/ptyout", HandleGetPtyOut)

View File

@ -49,8 +49,7 @@ CREATE TABLE remote (
sshuser varchar(100) NOT NULL,
-- runtime data
lastconnectts bigint NOT NULL,
ptyout BLOB NOT NULL
lastconnectts bigint NOT NULL
);
CREATE TABLE session_cmd (
@ -65,8 +64,6 @@ CREATE TABLE session_cmd (
runnerpid int NOT NULL,
donets bigint NOT NULL,
exitcode int NOT NULL,
ptyout BLOB NOT NULL,
runout BLOB NOT NULL,
PRIMARY KEY (sessionid, cmdid)
);

View File

@ -2,14 +2,13 @@ package remote
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
@ -18,10 +17,12 @@ import (
)
const RemoteTypeMShell = "mshell"
const DefaultTermRows = 25
const DefaultTermCols = 80
const DefaultTerm = "xterm-256color"
const (
StatusInit = "init"
StatusConnecting = "connecting"
StatusConnected = "connected"
StatusDisconnected = "disconnected"
StatusError = "error"
@ -47,20 +48,9 @@ type MShellProc struct {
Remote *sstore.RemoteType
// runtime
Status string
InitPk *packet.InitPacketType
Cmd *exec.Cmd
Input *packet.PacketSender
Output *packet.PacketParser
DoneCh chan bool
RpcMap map[string]*RpcEntry
Err error
}
type RpcEntry struct {
ReqId string
RespCh chan packet.RpcResponsePacketType
Status string
ServerProc *shexec.ClientProc
Err error
}
func LoadRemotes(ctx context.Context) error {
@ -111,8 +101,8 @@ func GetAllRemoteState() []RemoteState {
RemoteName: proc.Remote.RemoteName,
Status: proc.Status,
}
if proc.InitPk != nil {
state.DefaultState = &sstore.RemoteState{Cwd: proc.InitPk.HomeDir}
if proc.ServerProc != nil && proc.ServerProc.InitPk != nil {
state.DefaultState = &sstore.RemoteState{Cwd: proc.ServerProc.InitPk.HomeDir}
}
rtn = append(rtn, state)
}
@ -135,50 +125,24 @@ func (msh *MShellProc) Launch() {
return
}
ecmd := exec.Command(msPath, "--server")
msh.Cmd = ecmd
inputWriter, err := ecmd.StdinPipe()
cproc, err := shexec.MakeClientProc(ecmd)
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stdin pipe: %w", err)
return
}
stdoutReader, err := ecmd.StdoutPipe()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stdout pipe: %w", err)
return
}
stderrReader, err := ecmd.StderrPipe()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stderr pipe: %w", err)
msh.Err = err
return
}
msh.ServerProc = cproc
fmt.Printf("START MAKECLIENTPROC: %#v\n", msh.ServerProc.InitPk)
msh.Status = StatusConnected
go func() {
io.Copy(os.Stderr, stderrReader)
}()
err = ecmd.Start()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("starting mshell server: %w", err)
return
}
fmt.Printf("Started remote '%s' pid=%d\n", msh.Remote.RemoteName, msh.Cmd.Process.Pid)
msh.Status = StatusConnecting
msh.Output = packet.MakePacketParser(stdoutReader)
msh.Input = packet.MakePacketSender(inputWriter)
msh.RpcMap = make(map[string]*RpcEntry)
msh.DoneCh = make(chan bool)
go func() {
exitErr := ecmd.Wait()
exitErr := cproc.Cmd.Wait()
exitCode := shexec.GetExitCode(exitErr)
msh.WithLock(func() {
if msh.Status == StatusConnected || msh.Status == StatusConnecting {
if msh.Status == StatusConnected {
msh.Status = StatusDisconnected
}
})
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
close(msh.DoneCh)
}()
go msh.ProcessPackets()
return
@ -190,51 +154,62 @@ func (msh *MShellProc) IsConnected() bool {
return msh.Status == StatusConnected
}
func RunCommand(pk *scpacket.FeCommandPacketType, cmdId string) error {
func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId string) (*packet.CmdStartPacketType, error) {
msh := GetRemoteById(pk.RemoteState.RemoteId)
if msh == nil {
return fmt.Errorf("no remote id=%s found", pk.RemoteState.RemoteId)
return nil, fmt.Errorf("no remote id=%s found", pk.RemoteState.RemoteId)
}
if !msh.IsConnected() {
return fmt.Errorf("remote '%s' is not connected", msh.Remote.RemoteName)
return nil, fmt.Errorf("remote '%s' is not connected", msh.Remote.RemoteName)
}
runPacket := packet.MakeRunPacket()
runPacket.ReqId = uuid.New().String()
runPacket.CK = base.MakeCommandKey(pk.SessionId, cmdId)
runPacket.Cwd = pk.RemoteState.Cwd
runPacket.Env = nil
runPacket.UsePty = true
runPacket.TermOpts = &packet.TermOpts{Rows: DefaultTermRows, Cols: DefaultTermCols, Term: DefaultTerm}
runPacket.Command = strings.TrimSpace(pk.CmdStr)
fmt.Printf("run-packet %v\n", runPacket)
go func() {
msh.Input.SendPacket(runPacket)
}()
return nil
fmt.Printf("RUN-CMD> %s\n", runPacket.CK)
msh.ServerProc.Output.RegisterRpc(runPacket.ReqId)
err := shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket)
if err != nil {
return nil, fmt.Errorf("sending run packet to remote: %w", err)
}
rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId)
if startPk, ok := rtnPk.(*packet.CmdStartPacketType); ok {
return startPk, nil
}
if respPk, ok := rtnPk.(*packet.ResponsePacketType); ok {
if respPk.Error != "" {
return nil, errors.New(respPk.Error)
}
}
return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
}
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcResponsePacketType, error) {
if !runner.IsConnected() {
func (msh *MShellProc) PacketRpc(ctx context.Context, pk packet.RpcPacketType) (*packet.ResponsePacketType, error) {
if !msh.IsConnected() {
return nil, fmt.Errorf("runner is not connected")
}
if pk == nil {
return nil, fmt.Errorf("PacketRpc passed nil packet")
}
id := pk.GetReqId()
respCh := make(chan packet.RpcResponsePacketType)
runner.WithLock(func() {
runner.RpcMap[id] = &RpcEntry{ReqId: id, RespCh: respCh}
})
defer runner.WithLock(func() {
delete(runner.RpcMap, id)
})
runner.Input.SendPacket(pk)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case rtnPk := <-respCh:
return rtnPk, nil
case <-timer.C:
return nil, fmt.Errorf("PacketRpc timeout")
reqId := pk.GetReqId()
msh.ServerProc.Output.RegisterRpc(reqId)
defer msh.ServerProc.Output.UnRegisterRpc(reqId)
err := msh.ServerProc.Input.SendPacketCtx(ctx, pk)
if err != nil {
return nil, err
}
rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, reqId)
if rtnPk == nil {
return nil, ctx.Err()
}
if respPk, ok := rtnPk.(*packet.ResponsePacketType); ok {
return respPk, nil
}
return nil, fmt.Errorf("invalid response packet received: %s", packet.AsString(rtnPk))
}
func (runner *MShellProc) WithLock(fn func()) {
@ -245,38 +220,25 @@ func (runner *MShellProc) WithLock(fn func()) {
func (runner *MShellProc) ProcessPackets() {
defer runner.WithLock(func() {
if runner.Status == StatusConnected || runner.Status == StatusConnecting {
if runner.Status == StatusConnected {
runner.Status = StatusDisconnected
}
})
for pk := range runner.Output.MainCh {
fmt.Printf("MSH> %s\n", packet.AsString(pk))
if rpcPk, ok := pk.(packet.RpcResponsePacketType); ok {
rpcId := rpcPk.GetResponseId()
runner.WithLock(func() {
entry := runner.RpcMap[rpcId]
if entry == nil {
return
}
delete(runner.RpcMap, rpcId)
go func() {
entry.RespCh <- rpcPk
close(entry.RespCh)
}()
})
for pk := range runner.ServerProc.Output.MainCh {
fmt.Printf("MSH> %s | %#v\n", packet.AsString(pk), pk)
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPacket.CK, dataPacket.FdNum, packet.B64DecodedLen(dataPacket.Data64), dataPacket.Eof, dataPacket.Error)
continue
}
if pk.GetType() == packet.CmdDataPacketStr {
dataPacket := pk.(*packet.CmdDataPacketType)
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, dataPacket.PtyDataLen, dataPacket.RunDataLen)
continue
}
if pk.GetType() == packet.InitPacketStr {
initPacket := pk.(*packet.InitPacketType)
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
runner.WithLock(func() {
runner.InitPk = initPacket
runner.Status = StatusConnected
})
if pk.GetType() == packet.CmdDonePacketStr {
donePacket := pk.(*packet.CmdDonePacketType)
fmt.Printf("cmd-done %s\n", donePacket.CK)
continue
}
if pk.GetType() == packet.MessagePacketStr {

View File

@ -1,12 +1,21 @@
package scbase
import (
"errors"
"fmt"
"io/fs"
"os"
"path"
"sync"
)
const HomeVarName = "HOME"
const ScHomeVarName = "SCRIPTHAUS_HOME"
const SessionsDirBaseName = "sessions"
const RemotesDirBaseName = "remotes"
var SessionDirCache = make(map[string]string)
var BaseLock = &sync.Mutex{}
func GetScHomeDir() string {
scHome := os.Getenv(ScHomeVarName)
@ -19,3 +28,66 @@ func GetScHomeDir() string {
}
return scHome
}
func EnsureSessionDir(sessionId string) (string, error) {
BaseLock.Lock()
sdir, ok := SessionDirCache[sessionId]
BaseLock.Unlock()
if ok {
return sdir, nil
}
scHome := GetScHomeDir()
sdir = path.Join(scHome, SessionsDirBaseName, sessionId)
err := ensureDir(sdir)
if err != nil {
return "", err
}
BaseLock.Lock()
SessionDirCache[sessionId] = sdir
BaseLock.Unlock()
return sdir, nil
}
func ensureDir(dirName string) error {
info, err := os.Stat(dirName)
if errors.Is(err, fs.ErrNotExist) {
err = os.MkdirAll(dirName, 0700)
if err != nil {
return err
}
info, err = os.Stat(dirName)
}
if err != nil {
return err
}
if !info.IsDir() {
return fmt.Errorf("'%s' must be a directory", dirName)
}
return nil
}
func PtyOutFile(sessionId string, cmdId string) (string, error) {
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s.ptyout", sdir, cmdId), nil
}
func RunOutFile(sessionId string, cmdId string) (string, error) {
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s.runout", sdir, cmdId), nil
}
func RemotePtyOut(remoteId string) (string, error) {
scHome := GetScHomeDir()
rdir := path.Join(scHome, RemotesDirBaseName)
err := ensureDir(rdir)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s.ptyout", rdir, remoteId), nil
}

View File

@ -22,7 +22,7 @@ func NumSessions(ctx context.Context) (int, error) {
return count, nil
}
const remoteSelectCols = "rowid, remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts"
const remoteSelectCols = "remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts"
func GetAllRemotes(ctx context.Context) ([]*RemoteType, error) {
db, err := GetDB()
@ -76,23 +76,16 @@ func InsertRemote(ctx context.Context, remote *RemoteType) error {
if remote == nil {
return fmt.Errorf("cannot insert nil remote")
}
if remote.RowId != 0 {
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
}
db, err := GetDB()
if err != nil {
return err
}
query := `INSERT INTO remote ( remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts, ptyout) VALUES
(:remoteid,:remotetype,:remotename,:autoconnect,:sshhost,:sshopts,:sshidentity,:sshuser, 0 , '')`
result, err := db.NamedExec(query, remote)
query := `INSERT INTO remote ( remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts) VALUES
(:remoteid,:remotetype,:remotename,:autoconnect,:sshhost,:sshopts,:sshidentity,:sshuser, 0 )`
_, err = db.NamedExec(query, remote)
if err != nil {
return err
}
remote.RowId, err = result.LastInsertId()
if err != nil {
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
}
return nil
}

24
pkg/sstore/fileops.go Normal file
View File

@ -0,0 +1,24 @@
package sstore
import (
"context"
"os"
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
)
func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte) error {
ptyOutFileName, err := scbase.PtyOutFile(sessionId, cmdId)
if err != nil {
return err
}
fd, err := os.OpenFile(ptyOutFileName, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return err
}
_, err = fd.Write(data)
if err != nil {
return err
}
return nil
}

View File

@ -111,7 +111,6 @@ type LineType struct {
}
type RemoteType struct {
RowId int64 `json:"rowid"`
RemoteId string `json:"remoteid"`
RemoteType string `json:"remotetype"`
RemoteName string `json:"remotename"`
@ -128,7 +127,6 @@ type RemoteType struct {
}
type CmdType struct {
RowId int64 `json:"rowid"`
SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"`
RSId string `json:"rsid"`