mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-19 21:11:32 +01:00
261 lines
6.0 KiB
Go
261 lines
6.0 KiB
Go
package remote
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/exec"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/scripthaus-dev/mshell/pkg/base"
|
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
|
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
|
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
|
|
)
|
|
|
|
const RemoteTypeMShell = "mshell"
|
|
|
|
const (
|
|
StatusInit = "init"
|
|
StatusConnecting = "connecting"
|
|
StatusConnected = "connected"
|
|
StatusDisconnected = "disconnected"
|
|
StatusError = "error"
|
|
)
|
|
|
|
var GlobalStore *Store
|
|
|
|
type Store struct {
|
|
Lock *sync.Mutex
|
|
Map map[string]*MShellProc
|
|
}
|
|
|
|
type RemoteState struct {
|
|
RemoteType string `json:"remotetype"`
|
|
RemoteId string `json:"remoteid"`
|
|
RemoteName string `json:"remotename"`
|
|
Status string `json:"status"`
|
|
Cwd string `json:"cwd"`
|
|
}
|
|
|
|
type MShellProc struct {
|
|
Lock *sync.Mutex
|
|
Remote *sstore.RemoteType
|
|
|
|
// runtime
|
|
Status string
|
|
InitPk *packet.InitPacketType
|
|
Cmd *exec.Cmd
|
|
Input *packet.PacketSender
|
|
Output *packet.PacketParser
|
|
DoneCh chan bool
|
|
RpcMap map[string]*RpcEntry
|
|
|
|
Err error
|
|
}
|
|
|
|
type RpcEntry struct {
|
|
PacketId string
|
|
RespCh chan packet.RpcPacketType
|
|
}
|
|
|
|
func LoadRemotes(ctx context.Context) error {
|
|
GlobalStore = &Store{
|
|
Lock: &sync.Mutex{},
|
|
Map: make(map[string]*MShellProc),
|
|
}
|
|
allRemotes, err := sstore.GetAllRemotes(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, remote := range allRemotes {
|
|
msh := MakeMShell(remote)
|
|
GlobalStore.Map[remote.RemoteName] = msh
|
|
if remote.AutoConnect {
|
|
go msh.Launch()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GetRemote(name string) *MShellProc {
|
|
GlobalStore.Lock.Lock()
|
|
defer GlobalStore.Lock.Unlock()
|
|
return GlobalStore.Map[name]
|
|
}
|
|
|
|
func GetAllRemoteState() []RemoteState {
|
|
GlobalStore.Lock.Lock()
|
|
defer GlobalStore.Lock.Unlock()
|
|
|
|
var rtn []RemoteState
|
|
for _, proc := range GlobalStore.Map {
|
|
state := RemoteState{
|
|
RemoteType: proc.Remote.RemoteType,
|
|
RemoteId: proc.Remote.RemoteId,
|
|
RemoteName: proc.Remote.RemoteName,
|
|
Status: proc.Status,
|
|
}
|
|
if proc.InitPk != nil {
|
|
state.Cwd = proc.InitPk.HomeDir
|
|
}
|
|
rtn = append(rtn, state)
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func MakeMShell(r *sstore.RemoteType) *MShellProc {
|
|
rtn := &MShellProc{Lock: &sync.Mutex{}, Remote: r, Status: StatusInit}
|
|
return rtn
|
|
}
|
|
|
|
func (msh *MShellProc) Launch() {
|
|
msh.Lock.Lock()
|
|
defer msh.Lock.Unlock()
|
|
|
|
msPath, err := base.GetMShellPath()
|
|
if err != nil {
|
|
msh.Status = StatusError
|
|
msh.Err = err
|
|
return
|
|
}
|
|
ecmd := exec.Command(msPath, "--server")
|
|
msh.Cmd = ecmd
|
|
inputWriter, err := ecmd.StdinPipe()
|
|
if err != nil {
|
|
msh.Status = StatusError
|
|
msh.Err = fmt.Errorf("create stdin pipe: %w", err)
|
|
return
|
|
}
|
|
stdoutReader, err := ecmd.StdoutPipe()
|
|
if err != nil {
|
|
msh.Status = StatusError
|
|
msh.Err = fmt.Errorf("create stdout pipe: %w", err)
|
|
return
|
|
}
|
|
stderrReader, err := ecmd.StderrPipe()
|
|
if err != nil {
|
|
msh.Status = StatusError
|
|
msh.Err = fmt.Errorf("create stderr pipe: %w", err)
|
|
return
|
|
}
|
|
go func() {
|
|
io.Copy(os.Stderr, stderrReader)
|
|
}()
|
|
err = ecmd.Start()
|
|
if err != nil {
|
|
msh.Status = StatusError
|
|
msh.Err = fmt.Errorf("starting mshell server: %w", err)
|
|
return
|
|
}
|
|
fmt.Printf("Started remote '%s' pid=%d\n", msh.Remote.RemoteName, msh.Cmd.Process.Pid)
|
|
msh.Status = StatusConnecting
|
|
msh.Output = packet.MakePacketParser(stdoutReader)
|
|
msh.Input = packet.MakePacketSender(inputWriter)
|
|
msh.RpcMap = make(map[string]*RpcEntry)
|
|
msh.DoneCh = make(chan bool)
|
|
go func() {
|
|
exitErr := ecmd.Wait()
|
|
exitCode := shexec.GetExitCode(exitErr)
|
|
msh.WithLock(func() {
|
|
if msh.Status == StatusConnected || msh.Status == StatusConnecting {
|
|
msh.Status = StatusDisconnected
|
|
}
|
|
})
|
|
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
|
|
close(msh.DoneCh)
|
|
}()
|
|
go msh.ProcessPackets()
|
|
return
|
|
}
|
|
|
|
func (msh *MShellProc) IsConnected() bool {
|
|
msh.Lock.Lock()
|
|
defer msh.Lock.Unlock()
|
|
return msh.Status == StatusConnected
|
|
}
|
|
|
|
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcPacketType, error) {
|
|
if !runner.IsConnected() {
|
|
return nil, fmt.Errorf("runner is not connected")
|
|
}
|
|
if pk == nil {
|
|
return nil, fmt.Errorf("PacketRpc passed nil packet")
|
|
}
|
|
id := pk.GetPacketId()
|
|
respCh := make(chan packet.RpcPacketType)
|
|
runner.WithLock(func() {
|
|
runner.RpcMap[id] = &RpcEntry{PacketId: id, RespCh: respCh}
|
|
})
|
|
defer runner.WithLock(func() {
|
|
delete(runner.RpcMap, id)
|
|
})
|
|
runner.Input.SendPacket(pk)
|
|
timer := time.NewTimer(timeout)
|
|
defer timer.Stop()
|
|
select {
|
|
case rtnPk := <-respCh:
|
|
return rtnPk, nil
|
|
|
|
case <-timer.C:
|
|
return nil, fmt.Errorf("PacketRpc timeout")
|
|
}
|
|
}
|
|
|
|
func (runner *MShellProc) WithLock(fn func()) {
|
|
runner.Lock.Lock()
|
|
defer runner.Lock.Unlock()
|
|
fn()
|
|
}
|
|
|
|
func (runner *MShellProc) ProcessPackets() {
|
|
defer runner.WithLock(func() {
|
|
if runner.Status == StatusConnected || runner.Status == StatusConnecting {
|
|
runner.Status = StatusDisconnected
|
|
}
|
|
})
|
|
for pk := range runner.Output.MainCh {
|
|
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
|
|
rpcId := rpcPk.GetPacketId()
|
|
runner.WithLock(func() {
|
|
entry := runner.RpcMap[rpcId]
|
|
if entry == nil {
|
|
return
|
|
}
|
|
delete(runner.RpcMap, rpcId)
|
|
go func() {
|
|
entry.RespCh <- rpcPk
|
|
close(entry.RespCh)
|
|
}()
|
|
})
|
|
}
|
|
if pk.GetType() == packet.CmdDataPacketStr {
|
|
dataPacket := pk.(*packet.CmdDataPacketType)
|
|
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
|
|
continue
|
|
}
|
|
if pk.GetType() == packet.InitPacketStr {
|
|
initPacket := pk.(*packet.InitPacketType)
|
|
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
|
|
runner.WithLock(func() {
|
|
runner.InitPk = initPacket
|
|
runner.Status = StatusConnected
|
|
})
|
|
continue
|
|
}
|
|
if pk.GetType() == packet.MessagePacketStr {
|
|
msgPacket := pk.(*packet.MessagePacketType)
|
|
fmt.Printf("# %s\n", msgPacket.Message)
|
|
continue
|
|
}
|
|
if pk.GetType() == packet.RawPacketStr {
|
|
rawPacket := pk.(*packet.RawPacketType)
|
|
fmt.Printf("stderr> %s\n", rawPacket.Data)
|
|
continue
|
|
}
|
|
fmt.Printf("runner-packet: %v\n", pk)
|
|
}
|
|
}
|