mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-02-22 02:41:23 +01:00
checkpoint, getting closer to running a command via mshell server
This commit is contained in:
parent
2755be315d
commit
98e46399be
@ -429,16 +429,25 @@ func ProcessFeCommandPacket(ctx context.Context, pk *scpacket.FeCommandPacketTyp
|
||||
cdPacket.ReqId = uuid.New().String()
|
||||
cdPacket.Dir = newDir
|
||||
localRemote := remote.GetRemoteById(pk.RemoteState.RemoteId)
|
||||
if localRemote != nil {
|
||||
localRemote.Input.SendPacket(cdPacket)
|
||||
if localRemote == nil {
|
||||
return nil, fmt.Errorf("invalid remote, cannot execute command")
|
||||
}
|
||||
resp, err := localRemote.PacketRpc(ctx, cdPacket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("GOT cd RESP: %v\n", resp)
|
||||
return nil, nil
|
||||
}
|
||||
rtnLine, err := sstore.AddCmdLine(ctx, pk.SessionId, pk.WindowId, pk.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = remote.RunCommand(pk, rtnLine.CmdId)
|
||||
startPk, err := remote.RunCommand(ctx, pk, rtnLine.CmdId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("START CMD: %s\n", packet.AsString(startPk))
|
||||
return &runCommandResponse{Line: rtnLine}, nil
|
||||
}
|
||||
|
||||
@ -572,6 +581,9 @@ func main() {
|
||||
fmt.Printf("[error] loading remotes: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
sstore.AppendToCmdPtyBlob(context.Background(), "", "", nil)
|
||||
|
||||
go runWebSocketServer()
|
||||
gr := mux.NewRouter()
|
||||
gr.HandleFunc("/api/ptyout", HandleGetPtyOut)
|
||||
|
@ -49,8 +49,7 @@ CREATE TABLE remote (
|
||||
sshuser varchar(100) NOT NULL,
|
||||
|
||||
-- runtime data
|
||||
lastconnectts bigint NOT NULL,
|
||||
ptyout BLOB NOT NULL
|
||||
lastconnectts bigint NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE session_cmd (
|
||||
@ -65,8 +64,6 @@ CREATE TABLE session_cmd (
|
||||
runnerpid int NOT NULL,
|
||||
donets bigint NOT NULL,
|
||||
exitcode int NOT NULL,
|
||||
ptyout BLOB NOT NULL,
|
||||
runout BLOB NOT NULL,
|
||||
PRIMARY KEY (sessionid, cmdid)
|
||||
);
|
||||
|
||||
|
@ -2,14 +2,13 @@ package remote
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
||||
@ -18,10 +17,12 @@ import (
|
||||
)
|
||||
|
||||
const RemoteTypeMShell = "mshell"
|
||||
const DefaultTermRows = 25
|
||||
const DefaultTermCols = 80
|
||||
const DefaultTerm = "xterm-256color"
|
||||
|
||||
const (
|
||||
StatusInit = "init"
|
||||
StatusConnecting = "connecting"
|
||||
StatusConnected = "connected"
|
||||
StatusDisconnected = "disconnected"
|
||||
StatusError = "error"
|
||||
@ -47,20 +48,9 @@ type MShellProc struct {
|
||||
Remote *sstore.RemoteType
|
||||
|
||||
// runtime
|
||||
Status string
|
||||
InitPk *packet.InitPacketType
|
||||
Cmd *exec.Cmd
|
||||
Input *packet.PacketSender
|
||||
Output *packet.PacketParser
|
||||
DoneCh chan bool
|
||||
RpcMap map[string]*RpcEntry
|
||||
|
||||
Err error
|
||||
}
|
||||
|
||||
type RpcEntry struct {
|
||||
ReqId string
|
||||
RespCh chan packet.RpcResponsePacketType
|
||||
Status string
|
||||
ServerProc *shexec.ClientProc
|
||||
Err error
|
||||
}
|
||||
|
||||
func LoadRemotes(ctx context.Context) error {
|
||||
@ -111,8 +101,8 @@ func GetAllRemoteState() []RemoteState {
|
||||
RemoteName: proc.Remote.RemoteName,
|
||||
Status: proc.Status,
|
||||
}
|
||||
if proc.InitPk != nil {
|
||||
state.DefaultState = &sstore.RemoteState{Cwd: proc.InitPk.HomeDir}
|
||||
if proc.ServerProc != nil && proc.ServerProc.InitPk != nil {
|
||||
state.DefaultState = &sstore.RemoteState{Cwd: proc.ServerProc.InitPk.HomeDir}
|
||||
}
|
||||
rtn = append(rtn, state)
|
||||
}
|
||||
@ -135,50 +125,24 @@ func (msh *MShellProc) Launch() {
|
||||
return
|
||||
}
|
||||
ecmd := exec.Command(msPath, "--server")
|
||||
msh.Cmd = ecmd
|
||||
inputWriter, err := ecmd.StdinPipe()
|
||||
cproc, err := shexec.MakeClientProc(ecmd)
|
||||
if err != nil {
|
||||
msh.Status = StatusError
|
||||
msh.Err = fmt.Errorf("create stdin pipe: %w", err)
|
||||
return
|
||||
}
|
||||
stdoutReader, err := ecmd.StdoutPipe()
|
||||
if err != nil {
|
||||
msh.Status = StatusError
|
||||
msh.Err = fmt.Errorf("create stdout pipe: %w", err)
|
||||
return
|
||||
}
|
||||
stderrReader, err := ecmd.StderrPipe()
|
||||
if err != nil {
|
||||
msh.Status = StatusError
|
||||
msh.Err = fmt.Errorf("create stderr pipe: %w", err)
|
||||
msh.Err = err
|
||||
return
|
||||
}
|
||||
msh.ServerProc = cproc
|
||||
fmt.Printf("START MAKECLIENTPROC: %#v\n", msh.ServerProc.InitPk)
|
||||
msh.Status = StatusConnected
|
||||
go func() {
|
||||
io.Copy(os.Stderr, stderrReader)
|
||||
}()
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
msh.Status = StatusError
|
||||
msh.Err = fmt.Errorf("starting mshell server: %w", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("Started remote '%s' pid=%d\n", msh.Remote.RemoteName, msh.Cmd.Process.Pid)
|
||||
msh.Status = StatusConnecting
|
||||
msh.Output = packet.MakePacketParser(stdoutReader)
|
||||
msh.Input = packet.MakePacketSender(inputWriter)
|
||||
msh.RpcMap = make(map[string]*RpcEntry)
|
||||
msh.DoneCh = make(chan bool)
|
||||
go func() {
|
||||
exitErr := ecmd.Wait()
|
||||
exitErr := cproc.Cmd.Wait()
|
||||
exitCode := shexec.GetExitCode(exitErr)
|
||||
msh.WithLock(func() {
|
||||
if msh.Status == StatusConnected || msh.Status == StatusConnecting {
|
||||
if msh.Status == StatusConnected {
|
||||
msh.Status = StatusDisconnected
|
||||
}
|
||||
})
|
||||
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
|
||||
close(msh.DoneCh)
|
||||
}()
|
||||
go msh.ProcessPackets()
|
||||
return
|
||||
@ -190,51 +154,62 @@ func (msh *MShellProc) IsConnected() bool {
|
||||
return msh.Status == StatusConnected
|
||||
}
|
||||
|
||||
func RunCommand(pk *scpacket.FeCommandPacketType, cmdId string) error {
|
||||
func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId string) (*packet.CmdStartPacketType, error) {
|
||||
msh := GetRemoteById(pk.RemoteState.RemoteId)
|
||||
if msh == nil {
|
||||
return fmt.Errorf("no remote id=%s found", pk.RemoteState.RemoteId)
|
||||
return nil, fmt.Errorf("no remote id=%s found", pk.RemoteState.RemoteId)
|
||||
}
|
||||
if !msh.IsConnected() {
|
||||
return fmt.Errorf("remote '%s' is not connected", msh.Remote.RemoteName)
|
||||
return nil, fmt.Errorf("remote '%s' is not connected", msh.Remote.RemoteName)
|
||||
}
|
||||
runPacket := packet.MakeRunPacket()
|
||||
runPacket.ReqId = uuid.New().String()
|
||||
runPacket.CK = base.MakeCommandKey(pk.SessionId, cmdId)
|
||||
runPacket.Cwd = pk.RemoteState.Cwd
|
||||
runPacket.Env = nil
|
||||
runPacket.UsePty = true
|
||||
runPacket.TermOpts = &packet.TermOpts{Rows: DefaultTermRows, Cols: DefaultTermCols, Term: DefaultTerm}
|
||||
runPacket.Command = strings.TrimSpace(pk.CmdStr)
|
||||
fmt.Printf("run-packet %v\n", runPacket)
|
||||
go func() {
|
||||
msh.Input.SendPacket(runPacket)
|
||||
}()
|
||||
return nil
|
||||
fmt.Printf("RUN-CMD> %s\n", runPacket.CK)
|
||||
msh.ServerProc.Output.RegisterRpc(runPacket.ReqId)
|
||||
err := shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sending run packet to remote: %w", err)
|
||||
}
|
||||
rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, runPacket.ReqId)
|
||||
if startPk, ok := rtnPk.(*packet.CmdStartPacketType); ok {
|
||||
return startPk, nil
|
||||
}
|
||||
if respPk, ok := rtnPk.(*packet.ResponsePacketType); ok {
|
||||
if respPk.Error != "" {
|
||||
return nil, errors.New(respPk.Error)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("invalid response received from server for run packet: %s", packet.AsString(rtnPk))
|
||||
}
|
||||
|
||||
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcResponsePacketType, error) {
|
||||
if !runner.IsConnected() {
|
||||
func (msh *MShellProc) PacketRpc(ctx context.Context, pk packet.RpcPacketType) (*packet.ResponsePacketType, error) {
|
||||
if !msh.IsConnected() {
|
||||
return nil, fmt.Errorf("runner is not connected")
|
||||
}
|
||||
if pk == nil {
|
||||
return nil, fmt.Errorf("PacketRpc passed nil packet")
|
||||
}
|
||||
id := pk.GetReqId()
|
||||
respCh := make(chan packet.RpcResponsePacketType)
|
||||
runner.WithLock(func() {
|
||||
runner.RpcMap[id] = &RpcEntry{ReqId: id, RespCh: respCh}
|
||||
})
|
||||
defer runner.WithLock(func() {
|
||||
delete(runner.RpcMap, id)
|
||||
})
|
||||
runner.Input.SendPacket(pk)
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case rtnPk := <-respCh:
|
||||
return rtnPk, nil
|
||||
|
||||
case <-timer.C:
|
||||
return nil, fmt.Errorf("PacketRpc timeout")
|
||||
reqId := pk.GetReqId()
|
||||
msh.ServerProc.Output.RegisterRpc(reqId)
|
||||
defer msh.ServerProc.Output.UnRegisterRpc(reqId)
|
||||
err := msh.ServerProc.Input.SendPacketCtx(ctx, pk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtnPk := msh.ServerProc.Output.WaitForResponse(ctx, reqId)
|
||||
if rtnPk == nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if respPk, ok := rtnPk.(*packet.ResponsePacketType); ok {
|
||||
return respPk, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid response packet received: %s", packet.AsString(rtnPk))
|
||||
}
|
||||
|
||||
func (runner *MShellProc) WithLock(fn func()) {
|
||||
@ -245,38 +220,25 @@ func (runner *MShellProc) WithLock(fn func()) {
|
||||
|
||||
func (runner *MShellProc) ProcessPackets() {
|
||||
defer runner.WithLock(func() {
|
||||
if runner.Status == StatusConnected || runner.Status == StatusConnecting {
|
||||
if runner.Status == StatusConnected {
|
||||
runner.Status = StatusDisconnected
|
||||
}
|
||||
})
|
||||
for pk := range runner.Output.MainCh {
|
||||
fmt.Printf("MSH> %s\n", packet.AsString(pk))
|
||||
if rpcPk, ok := pk.(packet.RpcResponsePacketType); ok {
|
||||
rpcId := rpcPk.GetResponseId()
|
||||
runner.WithLock(func() {
|
||||
entry := runner.RpcMap[rpcId]
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
delete(runner.RpcMap, rpcId)
|
||||
go func() {
|
||||
entry.RespCh <- rpcPk
|
||||
close(entry.RespCh)
|
||||
}()
|
||||
})
|
||||
for pk := range runner.ServerProc.Output.MainCh {
|
||||
fmt.Printf("MSH> %s | %#v\n", packet.AsString(pk), pk)
|
||||
if pk.GetType() == packet.DataPacketStr {
|
||||
dataPacket := pk.(*packet.DataPacketType)
|
||||
fmt.Printf("data %s fd=%d len=%d eof=%v err=%v\n", dataPacket.CK, dataPacket.FdNum, packet.B64DecodedLen(dataPacket.Data64), dataPacket.Eof, dataPacket.Error)
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.CmdDataPacketStr {
|
||||
dataPacket := pk.(*packet.CmdDataPacketType)
|
||||
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
|
||||
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, dataPacket.PtyDataLen, dataPacket.RunDataLen)
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.InitPacketStr {
|
||||
initPacket := pk.(*packet.InitPacketType)
|
||||
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
|
||||
runner.WithLock(func() {
|
||||
runner.InitPk = initPacket
|
||||
runner.Status = StatusConnected
|
||||
})
|
||||
if pk.GetType() == packet.CmdDonePacketStr {
|
||||
donePacket := pk.(*packet.CmdDonePacketType)
|
||||
fmt.Printf("cmd-done %s\n", donePacket.CK)
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.MessagePacketStr {
|
||||
|
@ -1,12 +1,21 @@
|
||||
package scbase
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const HomeVarName = "HOME"
|
||||
const ScHomeVarName = "SCRIPTHAUS_HOME"
|
||||
const SessionsDirBaseName = "sessions"
|
||||
const RemotesDirBaseName = "remotes"
|
||||
|
||||
var SessionDirCache = make(map[string]string)
|
||||
var BaseLock = &sync.Mutex{}
|
||||
|
||||
func GetScHomeDir() string {
|
||||
scHome := os.Getenv(ScHomeVarName)
|
||||
@ -19,3 +28,66 @@ func GetScHomeDir() string {
|
||||
}
|
||||
return scHome
|
||||
}
|
||||
|
||||
func EnsureSessionDir(sessionId string) (string, error) {
|
||||
BaseLock.Lock()
|
||||
sdir, ok := SessionDirCache[sessionId]
|
||||
BaseLock.Unlock()
|
||||
if ok {
|
||||
return sdir, nil
|
||||
}
|
||||
scHome := GetScHomeDir()
|
||||
sdir = path.Join(scHome, SessionsDirBaseName, sessionId)
|
||||
err := ensureDir(sdir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
BaseLock.Lock()
|
||||
SessionDirCache[sessionId] = sdir
|
||||
BaseLock.Unlock()
|
||||
return sdir, nil
|
||||
}
|
||||
|
||||
func ensureDir(dirName string) error {
|
||||
info, err := os.Stat(dirName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
err = os.MkdirAll(dirName, 0700)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, err = os.Stat(dirName)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("'%s' must be a directory", dirName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func PtyOutFile(sessionId string, cmdId string) (string, error) {
|
||||
sdir, err := EnsureSessionDir(sessionId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%s/%s.ptyout", sdir, cmdId), nil
|
||||
}
|
||||
|
||||
func RunOutFile(sessionId string, cmdId string) (string, error) {
|
||||
sdir, err := EnsureSessionDir(sessionId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%s/%s.runout", sdir, cmdId), nil
|
||||
}
|
||||
|
||||
func RemotePtyOut(remoteId string) (string, error) {
|
||||
scHome := GetScHomeDir()
|
||||
rdir := path.Join(scHome, RemotesDirBaseName)
|
||||
err := ensureDir(rdir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%s/%s.ptyout", rdir, remoteId), nil
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ func NumSessions(ctx context.Context) (int, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
const remoteSelectCols = "rowid, remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts"
|
||||
const remoteSelectCols = "remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts"
|
||||
|
||||
func GetAllRemotes(ctx context.Context) ([]*RemoteType, error) {
|
||||
db, err := GetDB()
|
||||
@ -76,23 +76,16 @@ func InsertRemote(ctx context.Context, remote *RemoteType) error {
|
||||
if remote == nil {
|
||||
return fmt.Errorf("cannot insert nil remote")
|
||||
}
|
||||
if remote.RowId != 0 {
|
||||
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
|
||||
}
|
||||
db, err := GetDB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query := `INSERT INTO remote ( remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts, ptyout) VALUES
|
||||
(:remoteid,:remotetype,:remotename,:autoconnect,:sshhost,:sshopts,:sshidentity,:sshuser, 0 , '')`
|
||||
result, err := db.NamedExec(query, remote)
|
||||
query := `INSERT INTO remote ( remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts) VALUES
|
||||
(:remoteid,:remotetype,:remotename,:autoconnect,:sshhost,:sshopts,:sshidentity,:sshuser, 0 )`
|
||||
_, err = db.NamedExec(query, remote)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remote.RowId, err = result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
24
pkg/sstore/fileops.go
Normal file
24
pkg/sstore/fileops.go
Normal file
@ -0,0 +1,24 @@
|
||||
package sstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
|
||||
)
|
||||
|
||||
func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte) error {
|
||||
ptyOutFileName, err := scbase.PtyOutFile(sessionId, cmdId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fd, err := os.OpenFile(ptyOutFileName, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fd.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@ -111,7 +111,6 @@ type LineType struct {
|
||||
}
|
||||
|
||||
type RemoteType struct {
|
||||
RowId int64 `json:"rowid"`
|
||||
RemoteId string `json:"remoteid"`
|
||||
RemoteType string `json:"remotetype"`
|
||||
RemoteName string `json:"remotename"`
|
||||
@ -128,7 +127,6 @@ type RemoteType struct {
|
||||
}
|
||||
|
||||
type CmdType struct {
|
||||
RowId int64 `json:"rowid"`
|
||||
SessionId string `json:"sessionid"`
|
||||
CmdId string `json:"cmdid"`
|
||||
RSId string `json:"rsid"`
|
||||
|
Loading…
Reference in New Issue
Block a user