diff --git a/cmd/main-server.go b/cmd/main-server.go index 0e484c545..b2b2cdef1 100644 --- a/cmd/main-server.go +++ b/cmd/main-server.go @@ -36,7 +36,6 @@ const WSStatePacketChSize = 20 const MaxInputDataSize = 1000 -var GlobalMShellProc *remote.MShellProc var GlobalLock = &sync.Mutex{} var WSStateMap = make(map[string]*WSState) // clientid -> WsState @@ -364,7 +363,11 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) { cdPacket := packet.MakeCdPacket() cdPacket.PacketId = uuid.New().String() cdPacket.Dir = newDir - GlobalMShellProc.Input.SendPacket(cdPacket) + localRemote := remote.GetRemote("local") + if localRemote != nil { + localRemote.Input.SendPacket(cdPacket) + return + } return } rtnLine := sstore.MakeNewLineCmd(params.SessionId, params.WindowId) @@ -377,12 +380,9 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) { fmt.Printf("run-packet %v\n", runPacket) WriteJsonSuccess(w, &runCommandResponse{Line: rtnLine}) go func() { - GlobalMShellProc.Input.SendPacket(runPacket) - if !GlobalMShellProc.Local { - getPacket := packet.MakeGetCmdPacket() - getPacket.CK = runPacket.CK - getPacket.Tail = true - GlobalMShellProc.Input.SendPacket(getPacket) + localRemote := remote.GetRemote("local") + if localRemote != nil { + localRemote.Input.SendPacket(runPacket) } }() return @@ -506,17 +506,12 @@ func main() { fmt.Printf("[error] ensuring default session: %v\n", err) return } - fmt.Printf("session: %#v\n", defaultSession) - return - - runnerProc, err := remote.LaunchMShell() + fmt.Printf("session: %v\n", defaultSession) + err = remote.LoadRemotes(context.Background()) if err != nil { - fmt.Printf("error launching runner-proc: %v\n", err) + fmt.Printf("[error] loading remotes: %v\n", err) return } - GlobalMShellProc = runnerProc - go runnerProc.ProcessPackets() - fmt.Printf("Started local runner pid[%d]\n", runnerProc.Cmd.Process.Pid) go runWebSocketServer() gr := mux.NewRouter() gr.HandleFunc("/api/ptyout", GetPtyOut) diff --git a/db/migrations/000001_init.up.sql b/db/migrations/000001_init.up.sql index 942d3bbfc..622a3b6f6 100644 --- a/db/migrations/000001_init.up.sql +++ b/db/migrations/000001_init.up.sql @@ -39,9 +39,16 @@ CREATE TABLE remote ( remoteid varchar(36) PRIMARY KEY, remotetype varchar(10) NOT NULL, remotename varchar(50) NOT NULL, - hostname varchar(200) NOT NULL, + autoconnect boolean NOT NULL, + + -- ssh specific opts + sshhost varchar(300) NOT NULL, + sshopts varchar(300) NOT NULL, + sshidentity varchar(300) NOT NULL, + sshuser varchar(100) NOT NULL, + + -- runtime data lastconnectts bigint NOT NULL, - connectopts varchar(300) NOT NULL, ptyout BLOB NOT NULL ); diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index f023190a6..5123e5ff0 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -1,19 +1,32 @@ package remote import ( + "context" "fmt" + "io" + "os" "os/exec" - "strings" "sync" "time" "github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/shexec" + "github.com/scripthaus-dev/sh2-server/pkg/sstore" ) const RemoteTypeMShell = "mshell" +const ( + StatusInit = "init" + StatusConnecting = "connecting" + StatusConnected = "connected" + StatusDisconnected = "disconnected" + StatusError = "error" +) + +var GlobalStore *Store + type Store struct { Lock *sync.Mutex Map map[string]*MShellProc @@ -21,17 +34,18 @@ type Store struct { type MShellProc struct { Lock *sync.Mutex - Remote *RemoteType + Remote *sstore.RemoteType // runtime - Connected bool - InitPk *packet.InitPacketType - Cmd *exec.Cmd - Input *packet.PacketSender - Output *packet.PacketParser - Local bool - DoneCh chan bool - RpcMap map[string]*RpcEntry + Status string + InitPk *packet.InitPacketType + Cmd *exec.Cmd + Input *packet.PacketSender + Output *packet.PacketParser + DoneCh chan bool + RpcMap map[string]*RpcEntry + + Err error } type RpcEntry struct { @@ -39,57 +53,117 @@ type RpcEntry struct { RespCh chan packet.RpcPacketType } -func LoadRemotes() { - +func LoadRemotes(ctx context.Context) error { + GlobalStore = &Store{ + Lock: &sync.Mutex{}, + Map: make(map[string]*MShellProc), + } + allRemotes, err := sstore.GetAllRemotes(ctx) + if err != nil { + return err + } + for _, remote := range allRemotes { + msh := MakeMShell(remote) + GlobalStore.Map[remote.RemoteName] = msh + if remote.AutoConnect { + go msh.Launch() + } + } + return nil } -func LaunchMShell() (*MShellProc, error) { +func GetRemote(name string) *MShellProc { + GlobalStore.Lock.Lock() + defer GlobalStore.Lock.Unlock() + return GlobalStore.Map[name] +} + +func MakeMShell(r *sstore.RemoteType) *MShellProc { + rtn := &MShellProc{Lock: &sync.Mutex{}, Remote: r, Status: StatusInit} + return rtn +} + +func (msh *MShellProc) Launch() { + msh.Lock.Lock() + defer msh.Lock.Unlock() + msPath, err := base.GetMShellPath() if err != nil { - return nil, err + msh.Status = StatusError + msh.Err = err + return } - ecmd := exec.Command(msPath) + ecmd := exec.Command(msPath, "--server") + msh.Cmd = ecmd inputWriter, err := ecmd.StdinPipe() if err != nil { - return nil, err + msh.Status = StatusError + msh.Err = fmt.Errorf("create stdin pipe: %w", err) + return } - outputReader, err := ecmd.StdoutPipe() + stdoutReader, err := ecmd.StdoutPipe() if err != nil { - return nil, err + msh.Status = StatusError + msh.Err = fmt.Errorf("create stdout pipe: %w", err) + return } - ecmd.Stderr = ecmd.Stdout + stderrReader, err := ecmd.StderrPipe() + if err != nil { + msh.Status = StatusError + msh.Err = fmt.Errorf("create stderr pipe: %w", err) + return + } + go func() { + io.Copy(os.Stderr, stderrReader) + }() err = ecmd.Start() if err != nil { - return nil, err + msh.Status = StatusError + msh.Err = fmt.Errorf("starting mshell server: %w", err) + return } - rtn := &MShellProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd} - rtn.Output = packet.MakePacketParser(outputReader) - rtn.Input = packet.MakePacketSender(inputWriter) - rtn.RpcMap = make(map[string]*RpcEntry) - rtn.DoneCh = make(chan bool) + fmt.Printf("Started remote '%s' pid=%d\n", msh.Remote.RemoteName, msh.Cmd.Process.Pid) + msh.Status = StatusConnecting + msh.Output = packet.MakePacketParser(stdoutReader) + msh.Input = packet.MakePacketSender(inputWriter) + msh.RpcMap = make(map[string]*RpcEntry) + msh.DoneCh = make(chan bool) go func() { exitErr := ecmd.Wait() exitCode := shexec.GetExitCode(exitErr) + msh.WithLock(func() { + if msh.Status == StatusConnected || msh.Status == StatusConnecting { + msh.Status = StatusDisconnected + } + }) fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode) - close(rtn.DoneCh) + close(msh.DoneCh) }() - return rtn, nil + go msh.ProcessPackets() + return +} + +func (msh *MShellProc) IsConnected() bool { + msh.Lock.Lock() + defer msh.Lock.Unlock() + return msh.Status == StatusConnected } func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcPacketType, error) { + if !runner.IsConnected() { + return nil, fmt.Errorf("runner is not connected") + } if pk == nil { return nil, fmt.Errorf("PacketRpc passed nil packet") } id := pk.GetPacketId() respCh := make(chan packet.RpcPacketType) - runner.Lock.Lock() - runner.RpcMap[id] = &RpcEntry{PacketId: id, RespCh: respCh} - runner.Lock.Unlock() - defer func() { - runner.Lock.Lock() + runner.WithLock(func() { + runner.RpcMap[id] = &RpcEntry{PacketId: id, RespCh: respCh} + }) + defer runner.WithLock(func() { delete(runner.RpcMap, id) - runner.Lock.Unlock() - }() + }) runner.Input.SendPacket(pk) timer := time.NewTimer(timeout) defer timer.Stop() @@ -102,21 +176,32 @@ func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Durati } } +func (runner *MShellProc) WithLock(fn func()) { + runner.Lock.Lock() + defer runner.Lock.Unlock() + fn() +} + func (runner *MShellProc) ProcessPackets() { + defer runner.WithLock(func() { + if runner.Status == StatusConnected || runner.Status == StatusConnecting { + runner.Status = StatusDisconnected + } + }) for pk := range runner.Output.MainCh { if rpcPk, ok := pk.(packet.RpcPacketType); ok { rpcId := rpcPk.GetPacketId() - runner.Lock.Lock() - entry := runner.RpcMap[rpcId] - if entry != nil { + runner.WithLock(func() { + entry := runner.RpcMap[rpcId] + if entry == nil { + return + } delete(runner.RpcMap, rpcId) go func() { entry.RespCh <- rpcPk close(entry.RespCh) }() - } - runner.Lock.Unlock() - + }) } if pk.GetType() == packet.CmdDataPacketStr { dataPacket := pk.(*packet.CmdDataPacketType) @@ -126,16 +211,10 @@ func (runner *MShellProc) ProcessPackets() { if pk.GetType() == packet.InitPacketStr { initPacket := pk.(*packet.InitPacketType) fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir) - runner.Lock.Lock() - runner.Connected = true - runner.User = initPacket.User - runner.CurDir = initPacket.HomeDir - runner.HomeDir = initPacket.HomeDir - runner.Env = initPacket.Env - if runner.Local { - runner.Host = "local" - } - runner.Lock.Unlock() + runner.WithLock(func() { + runner.InitPk = initPacket + runner.Status = StatusConnected + }) continue } if pk.GetType() == packet.MessagePacketStr { @@ -151,15 +230,3 @@ func (runner *MShellProc) ProcessPackets() { fmt.Printf("runner-packet: %v\n", pk) } } - -func (r *MShellProc) GetPrompt() string { - r.Lock.Lock() - defer r.Lock.Unlock() - var curDir = r.CurDir - if r.CurDir == r.HomeDir { - curDir = "~" - } else if strings.HasPrefix(r.CurDir, r.HomeDir+"/") { - curDir = "~/" + r.CurDir[0:len(r.HomeDir)+1] - } - return fmt.Sprintf("[%s@%s %s]", r.User, r.Host, curDir) -} diff --git a/pkg/sstore/dbops.go b/pkg/sstore/dbops.go index cae208c62..1bb1767e2 100644 --- a/pkg/sstore/dbops.go +++ b/pkg/sstore/dbops.go @@ -22,12 +22,28 @@ func NumSessions(ctx context.Context) (int, error) { return count, nil } +const remoteSelectCols = "rowid, remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts" + +func GetAllRemotes(ctx context.Context) ([]*RemoteType, error) { + db, err := GetDB() + if err != nil { + return nil, err + } + query := fmt.Sprintf(`SELECT %s FROM remote`, remoteSelectCols) + var remoteArr []*RemoteType + err = db.SelectContext(ctx, &remoteArr, query) + if err != nil { + return nil, err + } + return remoteArr, nil +} + func GetRemoteByName(ctx context.Context, remoteName string) (*RemoteType, error) { db, err := GetDB() if err != nil { return nil, err } - query := `SELECT rowid, remoteid, remotetype, remotename, hostname, connectopts, lastconnectts FROM remote WHERE remotename = ?` + query := fmt.Sprintf(`SELECT %s FROM remote WHERE remotename = ?`, remoteSelectCols) var remote RemoteType err = db.GetContext(ctx, &remote, query, remoteName) if err == sql.ErrNoRows { @@ -44,7 +60,7 @@ func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) { if err != nil { return nil, err } - query := `SELECT rowid, remoteid, remotetype, remotename, hostname, connectopts, lastconnectts FROM remote WHERE remoteid = ?` + query := fmt.Sprintf(`SELECT %s FROM remote WHERE remoteid = ?`, remoteSelectCols) var remote RemoteType err = db.GetContext(ctx, &remote, query, remoteId) if err == sql.ErrNoRows { @@ -67,7 +83,8 @@ func InsertRemote(ctx context.Context, remote *RemoteType) error { if err != nil { return err } - query := `INSERT INTO remote (remoteid, remotetype, remotename, hostname, connectopts, lastconnectts, ptyout) VALUES (:remoteid, :remotetype, :remotename, :hostname, :connectopts, 0, '')` + query := `INSERT INTO remote ( remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts, ptyout) VALUES + (:remoteid,:remotetype,:remotename,:autoconnect,:sshhost,:sshopts,:sshidentity,:sshuser, 0 , '')` result, err := db.NamedExec(query, remote) if err != nil { return err diff --git a/pkg/sstore/sstore.go b/pkg/sstore/sstore.go index fa6c30643..19aa0c326 100644 --- a/pkg/sstore/sstore.go +++ b/pkg/sstore/sstore.go @@ -2,9 +2,7 @@ package sstore import ( "context" - "fmt" "log" - "os" "path" "sync" "time" @@ -86,13 +84,20 @@ type LineType struct { } type RemoteType struct { - RowId int64 `json:"rowid"` - RemoteId string `json:"remoteid"` - RemoteType string `json:"remotetype"` - RemoteName string `json:"remotename"` - HostName string `json:"hostname"` - LastConnectTs int64 `json:"lastconnectts"` - ConnectOpts string `json:"connectopts"` + RowId int64 `json:"rowid"` + RemoteId string `json:"remoteid"` + RemoteType string `json:"remotetype"` + RemoteName string `json:"remotename"` + AutoConnect bool `json:"autoconnect"` + + // type=ssh options + SSHHost string `json:"sshhost"` + SSHOpts string `json:"sshopts"` + SSHIdentity string `json:"sshidentity"` + SSHUser string `json:"sshuser"` + + // runtime data + LastConnectTs int64 `json:"lastconnectts"` } type CmdType struct { @@ -154,16 +159,12 @@ func EnsureLocalRemote(ctx context.Context) error { if remote != nil { return nil } - hostName, err := os.Hostname() - if err != nil { - return fmt.Errorf("cannot get hostname: %w", err) - } // create the local remote localRemote := &RemoteType{ - RemoteId: remoteId, - RemoteType: "ssh", - RemoteName: LocalRemoteName, - HostName: hostName, + RemoteId: remoteId, + RemoteType: "ssh", + RemoteName: LocalRemoteName, + AutoConnect: true, } err = InsertRemote(ctx, localRemote) if err != nil { @@ -187,38 +188,3 @@ func EnsureDefaultSession(ctx context.Context) (*SessionType, error) { } return GetSessionByName(ctx, DefaultSessionName) } - -func CreateInitialSession(ctx context.Context) error { - db, err := GetDB() - if err != nil { - return err - } - session := &SessionType{ - SessionId: uuid.New().String(), - Name: DefaultSessionName, - } - window := &WindowType{ - SessionId: session.SessionId, - WindowId: uuid.New().String(), - Name: DefaultWindowName, - CurRemote: LocalRemoteName, - } - remoteId, err := base.GetRemoteId() - if err != nil { - return err - } - localRemote := &RemoteType{ - RemoteId: remoteId, - RemoteType: "ssh", - RemoteName: LocalRemoteName, - } - sessRemote := &SessionRemote{ - SessionId: session.SessionId, - WindowId: window.WindowId, - RemoteId: remoteId, - RemoteName: localRemote.RemoteName, - Cwd: base.GetHomeDir(), - } - fmt.Printf("db=%v s=%v w=%v r=%v sr=%v\n", db, session, window, localRemote, sessRemote) - return nil -} diff --git a/scripthaus.md b/scripthaus.md index 08a6c52db..6ae418626 100644 --- a/scripthaus.md +++ b/scripthaus.md @@ -10,3 +10,7 @@ sqlite3 /Users/mike/scripthaus/sh2.db .schema > db/schema.sql sqlite3 /Users/mike/scripthaus/sh2.db ``` +```bash +# @scripthaus command build +go build -o server cmd/main-server.go +```