mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-02-22 02:41:23 +01:00
tightening up server mode, fix bugs, refactor, etc.
This commit is contained in:
parent
b6711e7428
commit
0a828b7184
@ -10,9 +10,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
|
||||
@ -22,43 +20,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func doSingle(ck base.CommandKey) {
|
||||
packetParser := packet.MakePacketParser(os.Stdin)
|
||||
sender := packet.MakePacketSender(os.Stdout)
|
||||
var runPacket *packet.RunPacketType
|
||||
for pk := range packetParser.MainCh {
|
||||
if pk.GetType() == packet.RunPacketStr {
|
||||
runPacket, _ = pk.(*packet.RunPacketType)
|
||||
break
|
||||
}
|
||||
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
|
||||
return
|
||||
}
|
||||
if runPacket == nil {
|
||||
sender.SendErrorPacket("did not receive a 'run' packet")
|
||||
return
|
||||
}
|
||||
if runPacket.CK.IsEmpty() {
|
||||
runPacket.CK = ck
|
||||
}
|
||||
if runPacket.CK != ck {
|
||||
sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CK, ck))
|
||||
return
|
||||
}
|
||||
cmd, err := shexec.RunCommandDetached(runPacket, sender)
|
||||
if err != nil {
|
||||
sender.SendErrorPacket(fmt.Sprintf("error running command: %v", err))
|
||||
return
|
||||
}
|
||||
shexec.SetupSignalsForDetach()
|
||||
startPacket := cmd.MakeCmdStartPacket()
|
||||
sender.SendPacket(startPacket)
|
||||
donePacket := cmd.WaitForCommand()
|
||||
sender.SendPacket(donePacket)
|
||||
sender.Close()
|
||||
sender.WaitForDone()
|
||||
}
|
||||
|
||||
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
|
||||
err := shexec.ValidateRunPacket(pk)
|
||||
if err != nil {
|
||||
@ -122,13 +83,8 @@ func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packe
|
||||
}
|
||||
|
||||
func doMain() {
|
||||
scHomeDir, err := base.GetScHomeDir()
|
||||
if err != nil {
|
||||
packet.SendErrorPacket(os.Stdout, err.Error())
|
||||
return
|
||||
}
|
||||
homeDir := base.GetHomeDir()
|
||||
err = os.Chdir(homeDir)
|
||||
err := os.Chdir(homeDir)
|
||||
if err != nil {
|
||||
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to $HOME '%s': %v", homeDir, err))
|
||||
return
|
||||
@ -146,13 +102,7 @@ func doMain() {
|
||||
return
|
||||
}
|
||||
go tailer.Run()
|
||||
initPacket := packet.MakeInitPacket()
|
||||
initPacket.Env = os.Environ()
|
||||
initPacket.HomeDir = homeDir
|
||||
initPacket.ScHomeDir = scHomeDir
|
||||
if user, _ := user.Current(); user != nil {
|
||||
initPacket.User = user.Username
|
||||
}
|
||||
initPacket := shexec.MakeInitPacket()
|
||||
sender.SendPacket(initPacket)
|
||||
for pk := range packetParser.MainCh {
|
||||
if pk.GetType() == packet.RunPacketStr {
|
||||
@ -208,19 +158,14 @@ func handleSingle() {
|
||||
packetParser := packet.MakePacketParser(os.Stdin)
|
||||
sender := packet.MakePacketSender(os.Stdout)
|
||||
defer func() {
|
||||
// wait for sender to complete
|
||||
sender.Close()
|
||||
sender.WaitForDone()
|
||||
}()
|
||||
initPacket := shexec.MakeInitPacket()
|
||||
sender.SendPacket(initPacket)
|
||||
if len(os.Args) >= 3 && os.Args[2] == "--version" {
|
||||
initPacket := packet.MakeInitPacket()
|
||||
initPacket.Version = base.MShellVersion
|
||||
sender.SendPacket(initPacket)
|
||||
return
|
||||
}
|
||||
initPacket := packet.MakeInitPacket()
|
||||
initPacket.Version = base.MShellVersion
|
||||
sender.SendPacket(initPacket)
|
||||
runPacket, err := readFullRunPacket(packetParser)
|
||||
if err != nil {
|
||||
ck := base.CommandKey("")
|
||||
@ -236,12 +181,11 @@ func handleSingle() {
|
||||
return
|
||||
}
|
||||
if runPacket.Detached {
|
||||
cmd, err := shexec.RunCommandDetached(runPacket, sender)
|
||||
err := shexec.RunCommandDetached(runPacket, sender)
|
||||
if err != nil {
|
||||
sender.SendCKErrorPacket(runPacket.CK, err.Error())
|
||||
return
|
||||
}
|
||||
cmd.WaitForCommand()
|
||||
} else {
|
||||
cmd, err := shexec.RunCommandSimple(runPacket, sender)
|
||||
if err != nil {
|
||||
@ -560,17 +504,4 @@ func main() {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(os.Args) >= 2 {
|
||||
ck := base.CommandKey(os.Args[1])
|
||||
if err := ck.Validate("mshell arg"); err != nil {
|
||||
packet.SendErrorPacket(os.Stdout, err.Error())
|
||||
return
|
||||
}
|
||||
doSingle(ck)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return
|
||||
} else {
|
||||
doMain()
|
||||
}
|
||||
}
|
||||
|
119
pkg/base/base.go
119
pkg/base/base.go
@ -9,6 +9,7 @@ package base
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
@ -19,20 +20,15 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const DefaultMShellPath = "mshell"
|
||||
const DefaultUserMShellPath = ".mshell/mshell"
|
||||
const MShellPathVarName = "MSHELL_PATH"
|
||||
const SSHCommandVarName = "SSH_COMMAND"
|
||||
const ScHomeVarName = "SCRIPTHAUS_HOME"
|
||||
const HomeVarName = "HOME"
|
||||
const ScShell = "bash"
|
||||
const SessionsDirBaseName = ".sessions"
|
||||
const RunnerBaseName = "runner"
|
||||
const SessionDBName = "session.db"
|
||||
const ScReadyString = "scripthaus runner ready"
|
||||
const DefaultMShellHome = "~/.mshell"
|
||||
const DefaultMShellName = "mshell"
|
||||
const MShellPathVarName = "MSHELL_PATH"
|
||||
const MShellHomeVarName = "MSHELL_HOME"
|
||||
const SSHCommandVarName = "SSH_COMMAND"
|
||||
const SessionsDirBaseName = "sessions"
|
||||
const MShellVersion = "0.1.0"
|
||||
|
||||
const OSCEscError = "error"
|
||||
const RemoteIdFile = "remoteid"
|
||||
|
||||
type CommandFileNames struct {
|
||||
PtyOutFile string
|
||||
@ -110,16 +106,12 @@ func GetHomeDir() string {
|
||||
return homeVar
|
||||
}
|
||||
|
||||
func GetScHomeDir() (string, error) {
|
||||
scHome := os.Getenv(ScHomeVarName)
|
||||
if scHome == "" {
|
||||
homeVar := os.Getenv(HomeVarName)
|
||||
if homeVar == "" {
|
||||
return "", fmt.Errorf("Cannot resolve scripthaus home directory (SCRIPTHAUS_HOME and HOME not set)")
|
||||
}
|
||||
scHome = path.Join(homeVar, "scripthaus")
|
||||
func GetMShellHomeDir() string {
|
||||
homeVar := os.Getenv(MShellHomeVarName)
|
||||
if homeVar != "" {
|
||||
return homeVar
|
||||
}
|
||||
return scHome, nil
|
||||
return ExpandHomeDir(DefaultMShellHome)
|
||||
}
|
||||
|
||||
func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
|
||||
@ -139,8 +131,8 @@ func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func MakeCommandFileNamesWithHome(scHome string, ck CommandKey) *CommandFileNames {
|
||||
base := path.Join(scHome, SessionsDirBaseName, ck.GetSessionId(), ck.GetCmdId())
|
||||
func MakeCommandFileNamesWithHome(mhome string, ck CommandKey) *CommandFileNames {
|
||||
base := path.Join(mhome, SessionsDirBaseName, ck.GetSessionId(), ck.GetCmdId())
|
||||
return &CommandFileNames{
|
||||
PtyOutFile: base + ".ptyout",
|
||||
StdinFifo: base + ".stdin",
|
||||
@ -174,11 +166,8 @@ func EnsureSessionDir(sessionId string) (string, error) {
|
||||
if sessionId == "" {
|
||||
return "", fmt.Errorf("Bad sessionid, cannot be empty")
|
||||
}
|
||||
shhome, err := GetScHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdir := path.Join(shhome, SessionsDirBaseName, sessionId)
|
||||
mhome := GetMShellHomeDir()
|
||||
sdir := path.Join(mhome, SessionsDirBaseName, sessionId)
|
||||
info, err := os.Stat(sdir)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
err = os.MkdirAll(sdir, 0777)
|
||||
@ -197,51 +186,22 @@ func EnsureSessionDir(sessionId string) (string, error) {
|
||||
}
|
||||
|
||||
func GetMShellPath() (string, error) {
|
||||
msPath := os.Getenv(MShellPathVarName)
|
||||
msPath := os.Getenv(MShellPathVarName) // use MSHELL_PATH
|
||||
if msPath != "" {
|
||||
return exec.LookPath(msPath)
|
||||
}
|
||||
userMShellPath := path.Join(GetHomeDir(), DefaultUserMShellPath)
|
||||
mhome := GetMShellHomeDir()
|
||||
userMShellPath := path.Join(mhome, DefaultMShellName) // look in ~/.mshell
|
||||
msPath, err := exec.LookPath(userMShellPath)
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
return msPath, nil
|
||||
}
|
||||
return exec.LookPath(DefaultMShellPath)
|
||||
return exec.LookPath(DefaultMShellName) // standard path lookup for 'mshell'
|
||||
}
|
||||
|
||||
func GetScSessionsDir() (string, error) {
|
||||
scHome, err := GetScHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path.Join(scHome, SessionsDirBaseName), nil
|
||||
}
|
||||
|
||||
func GetSessionDBName(sessionId string) (string, error) {
|
||||
scHome, err := GetScHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path.Join(scHome, SessionDBName), nil
|
||||
}
|
||||
|
||||
// SH OSC Escapes (code 198, S=19, H=8)
|
||||
// \e]198;cmdid;(cmd-id)BEL - return command-id to server
|
||||
// \e]198;remote;0BEL - runner program not available
|
||||
// \e]198;remote;1BEL - runner program is available
|
||||
// \e]198;error;(error-str)BEL - communicate an internal error
|
||||
func MakeSHOSCEsc(escName string, data string) string {
|
||||
return fmt.Sprintf("\033]198;%s;%s\007", escName, data)
|
||||
}
|
||||
|
||||
func WriteErrorMsg(fileName string, errVal string) error {
|
||||
fd, err := os.OpenFile(fileName, os.O_APPEND|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oscEsc := MakeSHOSCEsc(OSCEscError, errVal)
|
||||
_, writeErr := fd.Write([]byte(oscEsc))
|
||||
return writeErr
|
||||
func GetMShellSessionsDir() (string, error) {
|
||||
mhome := GetMShellHomeDir()
|
||||
return path.Join(mhome, SessionsDirBaseName), nil
|
||||
}
|
||||
|
||||
func ExpandHomeDir(pathStr string) string {
|
||||
@ -262,3 +222,32 @@ func ValidGoArch(goos string, goarch string) bool {
|
||||
func GoArchOptFile(goos string, goarch string) string {
|
||||
return fmt.Sprintf("/opt/mshell/bin/mshell.%s.%s", goos, goarch)
|
||||
}
|
||||
|
||||
func GetRemoteId() (string, error) {
|
||||
mhome := GetMShellHomeDir()
|
||||
remoteIdFile := path.Join(mhome, RemoteIdFile)
|
||||
fd, err := os.Open(remoteIdFile)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// write the file
|
||||
remoteId := uuid.New().String()
|
||||
err = os.WriteFile(remoteIdFile, []byte(remoteId), 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot write remoteid to '%s': %w", remoteIdFile, err)
|
||||
}
|
||||
return remoteId, nil
|
||||
} else if err != nil {
|
||||
return "", fmt.Errorf("cannot read remoteid file '%s': %w", remoteIdFile, err)
|
||||
} else {
|
||||
defer fd.Close()
|
||||
contents, err := io.ReadAll(fd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot read remoteid file '%s': %w", remoteIdFile, err)
|
||||
}
|
||||
uuidStr := string(contents)
|
||||
_, err = uuid.Parse(uuidStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid uuid read from '%s': %w", remoteIdFile, err)
|
||||
}
|
||||
return uuidStr, nil
|
||||
}
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ func (pos TailPos) IsCurrent(entry CmdWatchEntry) bool {
|
||||
type Tailer struct {
|
||||
Lock *sync.Mutex
|
||||
WatchList map[base.CommandKey]CmdWatchEntry
|
||||
ScHomeDir string
|
||||
MHomeDir string
|
||||
Watcher *fsnotify.Watcher
|
||||
Sender *packet.PacketSender
|
||||
}
|
||||
@ -101,7 +101,7 @@ func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) {
|
||||
}
|
||||
|
||||
// delete from watchlist, remove watches
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey)
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.MHomeDir, cmdKey)
|
||||
delete(t.WatchList, cmdKey)
|
||||
t.Watcher.Remove(fileNames.PtyOutFile)
|
||||
t.Watcher.Remove(fileNames.RunnerOutFile)
|
||||
@ -130,16 +130,14 @@ func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (Cm
|
||||
}
|
||||
|
||||
func MakeTailer(sender *packet.PacketSender) (*Tailer, error) {
|
||||
scHomeDir, err := base.GetScHomeDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mhomeDir := base.GetMShellHomeDir()
|
||||
rtn := &Tailer{
|
||||
Lock: &sync.Mutex{},
|
||||
WatchList: make(map[base.CommandKey]CmdWatchEntry),
|
||||
ScHomeDir: scHomeDir,
|
||||
MHomeDir: mhomeDir,
|
||||
Sender: sender,
|
||||
}
|
||||
var err error
|
||||
rtn.Watcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -196,7 +194,7 @@ func (t *Tailer) runSingleDataTransfer(key base.CommandKey, reqId string) (*pack
|
||||
if !foundPos {
|
||||
return nil, false
|
||||
}
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key)
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.MHomeDir, key)
|
||||
dataPacket := t.makeCmdDataPacket(fileNames, entry, pos)
|
||||
|
||||
t.Lock.Lock()
|
||||
@ -353,7 +351,7 @@ func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error {
|
||||
if getPacket.ReqId == "" {
|
||||
return fmt.Errorf("getcmd, no reqid specified")
|
||||
}
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.CK)
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.MHomeDir, getPacket.CK)
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
key := getPacket.CK
|
||||
@ -370,7 +368,7 @@ func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error {
|
||||
return err
|
||||
}
|
||||
entry = CmdWatchEntry{CmdKey: key}
|
||||
entry.fillFilePos(t.ScHomeDir)
|
||||
entry.fillFilePos(t.MHomeDir)
|
||||
}
|
||||
pos, foundPos := entry.getTailPos(getPacket.ReqId)
|
||||
if !foundPos {
|
||||
|
@ -200,7 +200,7 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
||||
defer m.HandleInputDone()
|
||||
for pk := range m.Input.MainCh {
|
||||
if m.Debug {
|
||||
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
||||
fmt.Printf("PK-M> %s\n", packet.AsString(pk))
|
||||
}
|
||||
if pk.GetType() == packet.DataPacketStr {
|
||||
dataPacket := pk.(*packet.DataPacketType)
|
||||
@ -220,10 +220,6 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
||||
donePacket := pk.(*packet.CmdDonePacketType)
|
||||
return donePacket
|
||||
}
|
||||
if pk.GetType() == packet.CmdStartPacketStr {
|
||||
// nothing
|
||||
continue
|
||||
}
|
||||
m.UPR.UnknownPacket(pk)
|
||||
}
|
||||
return nil
|
||||
|
@ -73,6 +73,10 @@ func init() {
|
||||
TypeStrToFactory[DataEndPacketStr] = reflect.TypeOf(DataEndPacketType{})
|
||||
}
|
||||
|
||||
func RegisterPacketType(typeStr string, rtype reflect.Type) {
|
||||
TypeStrToFactory[typeStr] = rtype
|
||||
}
|
||||
|
||||
func MakePacket(packetType string) (PacketType, error) {
|
||||
rtype := TypeStrToFactory[packetType]
|
||||
if rtype == nil {
|
||||
@ -355,14 +359,15 @@ func FmtMessagePacket(fmtStr string, args ...interface{}) *MessagePacketType {
|
||||
}
|
||||
|
||||
type InitPacketType struct {
|
||||
Type string `json:"type"`
|
||||
Version string `json:"version"`
|
||||
ScHomeDir string `json:"schomedir,omitempty"`
|
||||
HomeDir string `json:"homedir,omitempty"`
|
||||
Env []string `json:"env,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
NotFound bool `json:"notfound,omitempty"`
|
||||
UName string `json:"uname,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Version string `json:"version"`
|
||||
MShellHomeDir string `json:"mshellhomedir,omitempty"`
|
||||
HomeDir string `json:"homedir,omitempty"`
|
||||
Env []string `json:"env,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
NotFound bool `json:"notfound,omitempty"`
|
||||
UName string `json:"uname,omitempty"`
|
||||
RemoteId string `json:"remoteid,omitempty"`
|
||||
}
|
||||
|
||||
func (*InitPacketType) GetType() string {
|
||||
@ -615,6 +620,22 @@ func MakePacketSender(output io.Writer) *PacketSender {
|
||||
return sender
|
||||
}
|
||||
|
||||
func MakeChannelPacketSender(packetCh chan PacketType) *PacketSender {
|
||||
sender := &PacketSender{
|
||||
Lock: &sync.Mutex{},
|
||||
SendCh: make(chan PacketType, PacketSenderQueueSize),
|
||||
DoneCh: make(chan bool),
|
||||
}
|
||||
go func() {
|
||||
defer close(sender.DoneCh)
|
||||
defer sender.Close()
|
||||
for pk := range sender.SendCh {
|
||||
packetCh <- pk
|
||||
}
|
||||
}()
|
||||
return sender
|
||||
}
|
||||
|
||||
func (sender *PacketSender) Close() {
|
||||
sender.Lock.Lock()
|
||||
defer sender.Lock.Unlock()
|
||||
@ -676,6 +697,8 @@ func (DefaultUPR) UnknownPacket(pk PacketType) {
|
||||
} else if pk.GetType() == RawPacketStr {
|
||||
rawPacket := pk.(*RawPacketType)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
|
||||
} else if pk.GetType() == CmdStartPacketStr {
|
||||
return // do nothing
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "[error] invalid packet received '%s'", AsExtType(pk))
|
||||
}
|
||||
|
@ -150,8 +150,11 @@ func RunServer() (int, error) {
|
||||
server.MainInput = packet.MakePacketParser(os.Stdin)
|
||||
server.Sender = packet.MakePacketSender(os.Stdout)
|
||||
defer server.Close()
|
||||
initPacket := packet.MakeInitPacket()
|
||||
initPacket.Version = base.MShellVersion
|
||||
var err error
|
||||
initPacket, err := shexec.MakeServerInitPacket()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
server.Sender.SendPacket(initPacket)
|
||||
builder := packet.MakeRunPacketBuilder()
|
||||
for pk := range server.MainInput.MainCh {
|
||||
@ -159,6 +162,9 @@ func RunServer() (int, error) {
|
||||
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
||||
}
|
||||
ok, runPacket := builder.ProcessPacket(pk)
|
||||
if server.Debug {
|
||||
fmt.Printf("PP> %s | %v\n", pk.GetType(), ok)
|
||||
}
|
||||
if ok {
|
||||
if runPacket != nil {
|
||||
server.runCommand(runPacket)
|
||||
@ -166,6 +172,13 @@ func RunServer() (int, error) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if startPk, ok := pk.(*packet.CmdStartPacketType); ok {
|
||||
if server.Debug {
|
||||
fmt.Printf("START> %v", startPk)
|
||||
}
|
||||
server.Sender.SendPacket(startPk)
|
||||
continue
|
||||
}
|
||||
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
|
||||
server.ProcessCommandPacket(cmdPk)
|
||||
continue
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
@ -23,6 +24,7 @@ import (
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/mpio"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const DefaultRows = 25
|
||||
@ -57,13 +59,16 @@ const RunSudoCommandFmt = `sudo -n -C %d bash /dev/fd/%d`
|
||||
const RunSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d bash -c "echo '[from-mshell]'; exec %d>&-; bash /dev/fd/%d < /dev/fd/%d"`
|
||||
|
||||
type ShExecType struct {
|
||||
Lock *sync.Mutex
|
||||
StartTs time.Time
|
||||
CK base.CommandKey
|
||||
FileNames *base.CommandFileNames
|
||||
Cmd *exec.Cmd
|
||||
CmdPty *os.File
|
||||
Multiplexer *mpio.Multiplexer
|
||||
Lock *sync.Mutex
|
||||
StartTs time.Time
|
||||
CK base.CommandKey
|
||||
FileNames *base.CommandFileNames
|
||||
Cmd *exec.Cmd
|
||||
CmdPty *os.File
|
||||
Multiplexer *mpio.Multiplexer
|
||||
Detached bool
|
||||
DetachedOutput *packet.PacketSender
|
||||
RunnerOutFd *os.File
|
||||
}
|
||||
|
||||
type StdContext struct{}
|
||||
@ -115,6 +120,13 @@ func (c *ShExecType) Close() {
|
||||
c.CmdPty.Close()
|
||||
}
|
||||
c.Multiplexer.Close()
|
||||
if c.DetachedOutput != nil {
|
||||
c.DetachedOutput.Close()
|
||||
c.DetachedOutput.WaitForDone()
|
||||
}
|
||||
if c.RunnerOutFd != nil {
|
||||
c.RunnerOutFd.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType {
|
||||
@ -300,11 +312,13 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
|
||||
func GetWinsize(p *packet.RunPacketType) *pty.Winsize {
|
||||
rows := DefaultRows
|
||||
cols := DefaultCols
|
||||
if p.TermSize.Rows > 0 && p.TermSize.Rows <= MaxRows {
|
||||
rows = p.TermSize.Rows
|
||||
}
|
||||
if p.TermSize.Cols > 0 && p.TermSize.Cols <= MaxCols {
|
||||
cols = p.TermSize.Cols
|
||||
if p.TermSize != nil {
|
||||
if p.TermSize.Rows > 0 && p.TermSize.Rows <= MaxRows {
|
||||
rows = p.TermSize.Rows
|
||||
}
|
||||
if p.TermSize.Cols > 0 && p.TermSize.Cols <= MaxCols {
|
||||
cols = p.TermSize.Cols
|
||||
}
|
||||
}
|
||||
return &pty.Winsize{Rows: uint16(rows), Cols: uint16(cols)}
|
||||
}
|
||||
@ -834,63 +848,98 @@ func SetupSignalsForDetach() {
|
||||
}()
|
||||
}
|
||||
|
||||
func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
|
||||
func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) error {
|
||||
fileNames, err := base.GetCommandFileNames(pk.CK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
ptyOutInfo, err := os.Stat(fileNames.PtyOutFile)
|
||||
if err == nil { // non-nil error will be caught by regular OpenFile below
|
||||
// must have size 0
|
||||
if ptyOutInfo.Size() != 0 {
|
||||
return nil, fmt.Errorf("cmdkey '%s' was already used (ptyout len=%d)", pk.CK, ptyOutInfo.Size())
|
||||
return fmt.Errorf("cmdkey '%s' was already used (ptyout len=%d)", pk.CK, ptyOutInfo.Size())
|
||||
}
|
||||
}
|
||||
cmdPty, cmdTty, err := pty.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening new pty: %w", err)
|
||||
return fmt.Errorf("opening new pty: %w", err)
|
||||
}
|
||||
pty.Setsize(cmdPty, GetWinsize(pk))
|
||||
defer func() {
|
||||
cmdTty.Close()
|
||||
}()
|
||||
rtn := MakeShExec(pk.CK, nil)
|
||||
cmd := MakeShExec(pk.CK, nil)
|
||||
cmd.FileNames = fileNames
|
||||
cmd.CmdPty = cmdPty
|
||||
cmd.Detached = true
|
||||
cmd.RunnerOutFd, err = os.OpenFile(fileNames.RunnerOutFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open runout file '%s': %w", fileNames.RunnerOutFile, err)
|
||||
}
|
||||
nullFd, err := os.OpenFile("/dev/null", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open /dev/null: %w", err)
|
||||
}
|
||||
cmd.DetachedOutput = packet.MakePacketSender(cmd.RunnerOutFd)
|
||||
ecmd, err := MakeDetachedExecCmd(pk, cmdTty)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
cmd.Cmd = ecmd
|
||||
SetupSignalsForDetach()
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("starting command: %w", err)
|
||||
return fmt.Errorf("starting command: %w", err)
|
||||
}
|
||||
for _, fd := range ecmd.ExtraFiles {
|
||||
if fd != cmdTty {
|
||||
fd.Close()
|
||||
}
|
||||
}
|
||||
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
// after Start(), any errors must go to DetachedOutput
|
||||
// close stdin/stdout/stderr, but wait for cmdstart packet to get sent
|
||||
startPacket := cmd.MakeCmdStartPacket()
|
||||
go func() {
|
||||
sender.SendPacket(startPacket)
|
||||
sender.Close()
|
||||
sender.WaitForDone()
|
||||
fmt.Printf("sender done! start: %v\n", startPacket)
|
||||
err = unix.Dup2(int(nullFd.Fd()), int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot dup2 stdin to /dev/null: %w", err))
|
||||
}
|
||||
err = unix.Dup2(int(nullFd.Fd()), int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot dup2 stdin to /dev/null: %w", err))
|
||||
}
|
||||
err = unix.Dup2(int(nullFd.Fd()), int(os.Stderr.Fd()))
|
||||
if err != nil {
|
||||
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot dup2 stdin to /dev/null: %w", err))
|
||||
}
|
||||
cmd.DetachedOutput.SendPacket(startPacket)
|
||||
}()
|
||||
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot open ptyout file '%s': %w", fileNames.PtyOutFile, err)
|
||||
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open ptyout file '%s': %v", fileNames.PtyOutFile, err))
|
||||
// don't return (command is already running)
|
||||
}
|
||||
go func() {
|
||||
// copy pty output to .ptyout file
|
||||
_, copyErr := io.Copy(ptyOutFd, cmdPty)
|
||||
if copyErr != nil {
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
|
||||
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
// copy .stdin fifo contents to pty input
|
||||
copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo)
|
||||
if copyFifoErr != nil {
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
|
||||
cmd.DetachedOutput.SendCKErrorPacket(pk.CK, fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
|
||||
}
|
||||
}()
|
||||
rtn.FileNames = fileNames
|
||||
rtn.Cmd = ecmd
|
||||
rtn.CmdPty = cmdPty
|
||||
return rtn, nil
|
||||
donePacket := cmd.WaitForCommand()
|
||||
cmd.DetachedOutput.SendPacket(donePacket)
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetExitCode(err error) int {
|
||||
@ -919,3 +968,25 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
|
||||
}
|
||||
return donePacket
|
||||
}
|
||||
|
||||
func MakeInitPacket() *packet.InitPacketType {
|
||||
initPacket := packet.MakeInitPacket()
|
||||
initPacket.Version = base.MShellVersion
|
||||
initPacket.HomeDir = base.GetHomeDir()
|
||||
initPacket.MShellHomeDir = base.GetMShellHomeDir()
|
||||
if user, _ := user.Current(); user != nil {
|
||||
initPacket.User = user.Username
|
||||
}
|
||||
return initPacket
|
||||
}
|
||||
|
||||
func MakeServerInitPacket() (*packet.InitPacketType, error) {
|
||||
var err error
|
||||
initPacket := MakeInitPacket()
|
||||
initPacket.Env = os.Environ()
|
||||
initPacket.RemoteId, err = base.GetRemoteId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return initPacket, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user