waveterm/pkg/remote/remote.go
2022-07-04 22:18:01 -07:00

261 lines
6.0 KiB
Go

package remote
import (
"context"
"fmt"
"io"
"os"
"os/exec"
"sync"
"time"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
)
const RemoteTypeMShell = "mshell"
const (
StatusInit = "init"
StatusConnecting = "connecting"
StatusConnected = "connected"
StatusDisconnected = "disconnected"
StatusError = "error"
)
var GlobalStore *Store
type Store struct {
Lock *sync.Mutex
Map map[string]*MShellProc
}
type RemoteState struct {
RemoteType string `json:"remotetype"`
RemoteId string `json:"remoteid"`
RemoteName string `json:"remotename"`
Status string `json:"status"`
Cwd string `json:"cwd"`
}
type MShellProc struct {
Lock *sync.Mutex
Remote *sstore.RemoteType
// runtime
Status string
InitPk *packet.InitPacketType
Cmd *exec.Cmd
Input *packet.PacketSender
Output *packet.PacketParser
DoneCh chan bool
RpcMap map[string]*RpcEntry
Err error
}
type RpcEntry struct {
PacketId string
RespCh chan packet.RpcPacketType
}
func LoadRemotes(ctx context.Context) error {
GlobalStore = &Store{
Lock: &sync.Mutex{},
Map: make(map[string]*MShellProc),
}
allRemotes, err := sstore.GetAllRemotes(ctx)
if err != nil {
return err
}
for _, remote := range allRemotes {
msh := MakeMShell(remote)
GlobalStore.Map[remote.RemoteName] = msh
if remote.AutoConnect {
go msh.Launch()
}
}
return nil
}
func GetRemote(name string) *MShellProc {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
return GlobalStore.Map[name]
}
func GetAllRemoteState() []RemoteState {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
var rtn []RemoteState
for _, proc := range GlobalStore.Map {
state := RemoteState{
RemoteType: proc.Remote.RemoteType,
RemoteId: proc.Remote.RemoteId,
RemoteName: proc.Remote.RemoteName,
Status: proc.Status,
}
if proc.InitPk != nil {
state.Cwd = proc.InitPk.HomeDir
}
rtn = append(rtn, state)
}
return rtn
}
func MakeMShell(r *sstore.RemoteType) *MShellProc {
rtn := &MShellProc{Lock: &sync.Mutex{}, Remote: r, Status: StatusInit}
return rtn
}
func (msh *MShellProc) Launch() {
msh.Lock.Lock()
defer msh.Lock.Unlock()
msPath, err := base.GetMShellPath()
if err != nil {
msh.Status = StatusError
msh.Err = err
return
}
ecmd := exec.Command(msPath, "--server")
msh.Cmd = ecmd
inputWriter, err := ecmd.StdinPipe()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stdin pipe: %w", err)
return
}
stdoutReader, err := ecmd.StdoutPipe()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stdout pipe: %w", err)
return
}
stderrReader, err := ecmd.StderrPipe()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stderr pipe: %w", err)
return
}
go func() {
io.Copy(os.Stderr, stderrReader)
}()
err = ecmd.Start()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("starting mshell server: %w", err)
return
}
fmt.Printf("Started remote '%s' pid=%d\n", msh.Remote.RemoteName, msh.Cmd.Process.Pid)
msh.Status = StatusConnecting
msh.Output = packet.MakePacketParser(stdoutReader)
msh.Input = packet.MakePacketSender(inputWriter)
msh.RpcMap = make(map[string]*RpcEntry)
msh.DoneCh = make(chan bool)
go func() {
exitErr := ecmd.Wait()
exitCode := shexec.GetExitCode(exitErr)
msh.WithLock(func() {
if msh.Status == StatusConnected || msh.Status == StatusConnecting {
msh.Status = StatusDisconnected
}
})
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
close(msh.DoneCh)
}()
go msh.ProcessPackets()
return
}
func (msh *MShellProc) IsConnected() bool {
msh.Lock.Lock()
defer msh.Lock.Unlock()
return msh.Status == StatusConnected
}
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcPacketType, error) {
if !runner.IsConnected() {
return nil, fmt.Errorf("runner is not connected")
}
if pk == nil {
return nil, fmt.Errorf("PacketRpc passed nil packet")
}
id := pk.GetPacketId()
respCh := make(chan packet.RpcPacketType)
runner.WithLock(func() {
runner.RpcMap[id] = &RpcEntry{PacketId: id, RespCh: respCh}
})
defer runner.WithLock(func() {
delete(runner.RpcMap, id)
})
runner.Input.SendPacket(pk)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case rtnPk := <-respCh:
return rtnPk, nil
case <-timer.C:
return nil, fmt.Errorf("PacketRpc timeout")
}
}
func (runner *MShellProc) WithLock(fn func()) {
runner.Lock.Lock()
defer runner.Lock.Unlock()
fn()
}
func (runner *MShellProc) ProcessPackets() {
defer runner.WithLock(func() {
if runner.Status == StatusConnected || runner.Status == StatusConnecting {
runner.Status = StatusDisconnected
}
})
for pk := range runner.Output.MainCh {
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
rpcId := rpcPk.GetPacketId()
runner.WithLock(func() {
entry := runner.RpcMap[rpcId]
if entry == nil {
return
}
delete(runner.RpcMap, rpcId)
go func() {
entry.RespCh <- rpcPk
close(entry.RespCh)
}()
})
}
if pk.GetType() == packet.CmdDataPacketStr {
dataPacket := pk.(*packet.CmdDataPacketType)
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
continue
}
if pk.GetType() == packet.InitPacketStr {
initPacket := pk.(*packet.InitPacketType)
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
runner.WithLock(func() {
runner.InitPk = initPacket
runner.Status = StatusConnected
})
continue
}
if pk.GetType() == packet.MessagePacketStr {
msgPacket := pk.(*packet.MessagePacketType)
fmt.Printf("# %s\n", msgPacket.Message)
continue
}
if pk.GetType() == packet.RawPacketStr {
rawPacket := pk.(*packet.RawPacketType)
fmt.Printf("stderr> %s\n", rawPacket.Data)
continue
}
fmt.Printf("runner-packet: %v\n", pk)
}
}