mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-01 18:28:59 +01:00
161 lines
3.8 KiB
Go
161 lines
3.8 KiB
Go
|
package remote
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"os/exec"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/scripthaus-dev/mshell/pkg/base"
|
||
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||
|
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
||
|
)
|
||
|
|
||
|
const RemoteTypeMShell = "mshell"
|
||
|
|
||
|
type MShellProc struct {
|
||
|
Lock *sync.Mutex
|
||
|
RemoteId string
|
||
|
WindowId string
|
||
|
RemoteName string
|
||
|
Cmd *exec.Cmd
|
||
|
Input *packet.PacketSender
|
||
|
Output *packet.PacketParser
|
||
|
Local bool
|
||
|
DoneCh chan bool
|
||
|
CurDir string
|
||
|
HomeDir string
|
||
|
User string
|
||
|
Host string
|
||
|
Env []string
|
||
|
Connected bool
|
||
|
RpcMap map[string]*RpcEntry
|
||
|
}
|
||
|
|
||
|
type RpcEntry struct {
|
||
|
PacketId string
|
||
|
RespCh chan packet.RpcPacketType
|
||
|
}
|
||
|
|
||
|
func LaunchMShell() (*MShellProc, error) {
|
||
|
msPath, err := base.GetMShellPath()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
ecmd := exec.Command(msPath)
|
||
|
inputWriter, err := ecmd.StdinPipe()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
outputReader, err := ecmd.StdoutPipe()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
ecmd.Stderr = ecmd.Stdout
|
||
|
err = ecmd.Start()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
rtn := &MShellProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd}
|
||
|
rtn.Output = packet.MakePacketParser(outputReader)
|
||
|
rtn.Input = packet.MakePacketSender(inputWriter)
|
||
|
rtn.RpcMap = make(map[string]*RpcEntry)
|
||
|
rtn.DoneCh = make(chan bool)
|
||
|
go func() {
|
||
|
exitErr := ecmd.Wait()
|
||
|
exitCode := shexec.GetExitCode(exitErr)
|
||
|
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
|
||
|
close(rtn.DoneCh)
|
||
|
}()
|
||
|
return rtn, nil
|
||
|
}
|
||
|
|
||
|
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcPacketType, error) {
|
||
|
if pk == nil {
|
||
|
return nil, fmt.Errorf("PacketRpc passed nil packet")
|
||
|
}
|
||
|
id := pk.GetPacketId()
|
||
|
respCh := make(chan packet.RpcPacketType)
|
||
|
runner.Lock.Lock()
|
||
|
runner.RpcMap[id] = &RpcEntry{PacketId: id, RespCh: respCh}
|
||
|
runner.Lock.Unlock()
|
||
|
defer func() {
|
||
|
runner.Lock.Lock()
|
||
|
delete(runner.RpcMap, id)
|
||
|
runner.Lock.Unlock()
|
||
|
}()
|
||
|
runner.Input.SendPacket(pk)
|
||
|
timer := time.NewTimer(timeout)
|
||
|
defer timer.Stop()
|
||
|
select {
|
||
|
case rtnPk := <-respCh:
|
||
|
return rtnPk, nil
|
||
|
|
||
|
case <-timer.C:
|
||
|
return nil, fmt.Errorf("PacketRpc timeout")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (runner *MShellProc) ProcessPackets() {
|
||
|
for pk := range runner.Output.MainCh {
|
||
|
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
|
||
|
rpcId := rpcPk.GetPacketId()
|
||
|
runner.Lock.Lock()
|
||
|
entry := runner.RpcMap[rpcId]
|
||
|
if entry != nil {
|
||
|
delete(runner.RpcMap, rpcId)
|
||
|
go func() {
|
||
|
entry.RespCh <- rpcPk
|
||
|
close(entry.RespCh)
|
||
|
}()
|
||
|
}
|
||
|
runner.Lock.Unlock()
|
||
|
|
||
|
}
|
||
|
if pk.GetType() == packet.CmdDataPacketStr {
|
||
|
dataPacket := pk.(*packet.CmdDataPacketType)
|
||
|
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
|
||
|
continue
|
||
|
}
|
||
|
if pk.GetType() == packet.InitPacketStr {
|
||
|
initPacket := pk.(*packet.InitPacketType)
|
||
|
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
|
||
|
runner.Lock.Lock()
|
||
|
runner.Connected = true
|
||
|
runner.User = initPacket.User
|
||
|
runner.CurDir = initPacket.HomeDir
|
||
|
runner.HomeDir = initPacket.HomeDir
|
||
|
runner.Env = initPacket.Env
|
||
|
if runner.Local {
|
||
|
runner.Host = "local"
|
||
|
}
|
||
|
runner.Lock.Unlock()
|
||
|
continue
|
||
|
}
|
||
|
if pk.GetType() == packet.MessagePacketStr {
|
||
|
msgPacket := pk.(*packet.MessagePacketType)
|
||
|
fmt.Printf("# %s\n", msgPacket.Message)
|
||
|
continue
|
||
|
}
|
||
|
if pk.GetType() == packet.RawPacketStr {
|
||
|
rawPacket := pk.(*packet.RawPacketType)
|
||
|
fmt.Printf("stderr> %s\n", rawPacket.Data)
|
||
|
continue
|
||
|
}
|
||
|
fmt.Printf("runner-packet: %v\n", pk)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *MShellProc) GetPrompt() string {
|
||
|
r.Lock.Lock()
|
||
|
defer r.Lock.Unlock()
|
||
|
var curDir = r.CurDir
|
||
|
if r.CurDir == r.HomeDir {
|
||
|
curDir = "~"
|
||
|
} else if strings.HasPrefix(r.CurDir, r.HomeDir+"/") {
|
||
|
curDir = "~/" + r.CurDir[0:len(r.HomeDir)+1]
|
||
|
}
|
||
|
return fmt.Sprintf("[%s@%s %s]", r.User, r.Host, curDir)
|
||
|
}
|