waveterm/pkg/remote/remote.go

295 lines
7.1 KiB
Go
Raw Normal View History

2022-07-01 21:17:19 +02:00
package remote
import (
"context"
2022-07-01 21:17:19 +02:00
"fmt"
"io"
"os"
2022-07-01 21:17:19 +02:00
"os/exec"
"strings"
2022-07-01 21:17:19 +02:00
"sync"
"time"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
"github.com/scripthaus-dev/sh2-server/pkg/scpacket"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
2022-07-01 21:17:19 +02:00
)
const RemoteTypeMShell = "mshell"
const (
StatusInit = "init"
StatusConnecting = "connecting"
StatusConnected = "connected"
StatusDisconnected = "disconnected"
StatusError = "error"
)
var GlobalStore *Store
2022-07-01 23:57:42 +02:00
type Store struct {
Lock *sync.Mutex
Map map[string]*MShellProc // key=remoteid
2022-07-01 23:57:42 +02:00
}
2022-07-05 07:18:01 +02:00
type RemoteState struct {
RemoteType string `json:"remotetype"`
RemoteId string `json:"remoteid"`
RemoteName string `json:"remotename"`
Status string `json:"status"`
DefaultState *sstore.RemoteState `json:"defaultstate"`
2022-07-05 07:18:01 +02:00
}
2022-07-01 21:17:19 +02:00
type MShellProc struct {
2022-07-01 23:57:42 +02:00
Lock *sync.Mutex
Remote *sstore.RemoteType
2022-07-01 23:57:42 +02:00
// runtime
Status string
InitPk *packet.InitPacketType
Cmd *exec.Cmd
Input *packet.PacketSender
Output *packet.PacketParser
DoneCh chan bool
RpcMap map[string]*RpcEntry
Err error
2022-07-01 21:17:19 +02:00
}
type RpcEntry struct {
ReqId string
RespCh chan packet.RpcResponsePacketType
2022-07-01 21:17:19 +02:00
}
func LoadRemotes(ctx context.Context) error {
GlobalStore = &Store{
Lock: &sync.Mutex{},
Map: make(map[string]*MShellProc),
}
allRemotes, err := sstore.GetAllRemotes(ctx)
if err != nil {
return err
}
for _, remote := range allRemotes {
msh := MakeMShell(remote)
GlobalStore.Map[remote.RemoteId] = msh
if remote.AutoConnect {
go msh.Launch()
}
}
return nil
}
func GetRemoteByName(name string) *MShellProc {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
for _, msh := range GlobalStore.Map {
if msh.Remote.RemoteName == name {
return msh
}
}
return nil
}
func GetRemoteById(remoteId string) *MShellProc {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
return GlobalStore.Map[remoteId]
}
2022-07-01 23:57:42 +02:00
2022-07-05 07:18:01 +02:00
func GetAllRemoteState() []RemoteState {
GlobalStore.Lock.Lock()
defer GlobalStore.Lock.Unlock()
var rtn []RemoteState
for _, proc := range GlobalStore.Map {
state := RemoteState{
RemoteType: proc.Remote.RemoteType,
RemoteId: proc.Remote.RemoteId,
RemoteName: proc.Remote.RemoteName,
Status: proc.Status,
}
if proc.InitPk != nil {
state.DefaultState = &sstore.RemoteState{Cwd: proc.InitPk.HomeDir}
2022-07-05 07:18:01 +02:00
}
rtn = append(rtn, state)
}
return rtn
}
func MakeMShell(r *sstore.RemoteType) *MShellProc {
rtn := &MShellProc{Lock: &sync.Mutex{}, Remote: r, Status: StatusInit}
return rtn
2022-07-01 23:57:42 +02:00
}
func (msh *MShellProc) Launch() {
msh.Lock.Lock()
defer msh.Lock.Unlock()
2022-07-01 21:17:19 +02:00
msPath, err := base.GetMShellPath()
if err != nil {
msh.Status = StatusError
msh.Err = err
return
2022-07-01 21:17:19 +02:00
}
ecmd := exec.Command(msPath, "--server")
msh.Cmd = ecmd
2022-07-01 21:17:19 +02:00
inputWriter, err := ecmd.StdinPipe()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stdin pipe: %w", err)
return
2022-07-01 21:17:19 +02:00
}
stdoutReader, err := ecmd.StdoutPipe()
2022-07-01 21:17:19 +02:00
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stdout pipe: %w", err)
return
2022-07-01 21:17:19 +02:00
}
stderrReader, err := ecmd.StderrPipe()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("create stderr pipe: %w", err)
return
}
go func() {
io.Copy(os.Stderr, stderrReader)
}()
2022-07-01 21:17:19 +02:00
err = ecmd.Start()
if err != nil {
msh.Status = StatusError
msh.Err = fmt.Errorf("starting mshell server: %w", err)
return
2022-07-01 21:17:19 +02:00
}
fmt.Printf("Started remote '%s' pid=%d\n", msh.Remote.RemoteName, msh.Cmd.Process.Pid)
msh.Status = StatusConnecting
msh.Output = packet.MakePacketParser(stdoutReader)
msh.Input = packet.MakePacketSender(inputWriter)
msh.RpcMap = make(map[string]*RpcEntry)
msh.DoneCh = make(chan bool)
2022-07-01 21:17:19 +02:00
go func() {
exitErr := ecmd.Wait()
exitCode := shexec.GetExitCode(exitErr)
msh.WithLock(func() {
if msh.Status == StatusConnected || msh.Status == StatusConnecting {
msh.Status = StatusDisconnected
}
})
2022-07-01 21:17:19 +02:00
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
close(msh.DoneCh)
2022-07-01 21:17:19 +02:00
}()
go msh.ProcessPackets()
return
}
func (msh *MShellProc) IsConnected() bool {
msh.Lock.Lock()
defer msh.Lock.Unlock()
return msh.Status == StatusConnected
2022-07-01 21:17:19 +02:00
}
func RunCommand(pk *scpacket.FeCommandPacketType, cmdId string) error {
msh := GetRemoteById(pk.RemoteState.RemoteId)
if msh == nil {
return fmt.Errorf("no remote id=%s found", pk.RemoteState.RemoteId)
}
if !msh.IsConnected() {
return fmt.Errorf("remote '%s' is not connected", msh.Remote.RemoteName)
}
runPacket := packet.MakeRunPacket()
runPacket.CK = base.MakeCommandKey(pk.SessionId, cmdId)
runPacket.Cwd = pk.RemoteState.Cwd
runPacket.Env = nil
runPacket.Command = strings.TrimSpace(pk.CmdStr)
fmt.Printf("run-packet %v\n", runPacket)
go func() {
msh.Input.SendPacket(runPacket)
}()
return nil
}
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcResponsePacketType, error) {
if !runner.IsConnected() {
return nil, fmt.Errorf("runner is not connected")
}
2022-07-01 21:17:19 +02:00
if pk == nil {
return nil, fmt.Errorf("PacketRpc passed nil packet")
}
id := pk.GetReqId()
respCh := make(chan packet.RpcResponsePacketType)
runner.WithLock(func() {
runner.RpcMap[id] = &RpcEntry{ReqId: id, RespCh: respCh}
})
defer runner.WithLock(func() {
2022-07-01 21:17:19 +02:00
delete(runner.RpcMap, id)
})
2022-07-01 21:17:19 +02:00
runner.Input.SendPacket(pk)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case rtnPk := <-respCh:
return rtnPk, nil
case <-timer.C:
return nil, fmt.Errorf("PacketRpc timeout")
}
}
func (runner *MShellProc) WithLock(fn func()) {
runner.Lock.Lock()
defer runner.Lock.Unlock()
fn()
}
2022-07-01 21:17:19 +02:00
func (runner *MShellProc) ProcessPackets() {
defer runner.WithLock(func() {
if runner.Status == StatusConnected || runner.Status == StatusConnecting {
runner.Status = StatusDisconnected
}
})
2022-07-01 21:17:19 +02:00
for pk := range runner.Output.MainCh {
fmt.Printf("MSH> %s\n", packet.AsString(pk))
if rpcPk, ok := pk.(packet.RpcResponsePacketType); ok {
rpcId := rpcPk.GetResponseId()
runner.WithLock(func() {
entry := runner.RpcMap[rpcId]
if entry == nil {
return
}
2022-07-01 21:17:19 +02:00
delete(runner.RpcMap, rpcId)
go func() {
entry.RespCh <- rpcPk
close(entry.RespCh)
}()
})
2022-07-01 21:17:19 +02:00
}
if pk.GetType() == packet.CmdDataPacketStr {
dataPacket := pk.(*packet.CmdDataPacketType)
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
continue
}
if pk.GetType() == packet.InitPacketStr {
initPacket := pk.(*packet.InitPacketType)
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
runner.WithLock(func() {
runner.InitPk = initPacket
runner.Status = StatusConnected
})
2022-07-01 21:17:19 +02:00
continue
}
if pk.GetType() == packet.MessagePacketStr {
msgPacket := pk.(*packet.MessagePacketType)
fmt.Printf("# %s\n", msgPacket.Message)
continue
}
if pk.GetType() == packet.RawPacketStr {
rawPacket := pk.(*packet.RawPacketType)
fmt.Printf("stderr> %s\n", rawPacket.Data)
continue
}
fmt.Printf("runner-packet: %v\n", pk)
}
}