diff --git a/main-mshell.go b/main-mshell.go index 0cc8c493d..65d7a8c9f 100644 --- a/main-mshell.go +++ b/main-mshell.go @@ -425,6 +425,19 @@ func handleInstall() (int, error) { return 0, nil } +func handleEnv() (int, error) { + cwd, err := os.Getwd() + if err != nil { + return 1, err + } + fmt.Printf("%s\x00", cwd) + fullEnv := os.Environ() + for _, envLine := range fullEnv { + fmt.Printf("%s\x00", envLine) + } + return 0, nil +} + func handleUsage() { usage := ` Client Usage: mshell [opts] --ssh user@host -- [command] @@ -482,6 +495,14 @@ func main() { } else if firstArg == "--version" { fmt.Printf("mshell v%s\n", base.MShellVersion) return + } else if firstArg == "--env" { + rtnCode, err := handleEnv() + if err != nil { + fmt.Fprintf(os.Stderr, "[error] %v\n", err) + } + if rtnCode != 0 { + os.Exit(rtnCode) + } } else if firstArg == "--single" { handleSingle() return diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index 35a8d86d1..380fce05b 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -433,16 +433,17 @@ func FmtMessagePacket(fmtStr string, args ...interface{}) *MessagePacketType { } type InitPacketType struct { - Type string `json:"type"` - Version string `json:"version"` - MShellHomeDir string `json:"mshellhomedir,omitempty"` - HomeDir string `json:"homedir,omitempty"` - Env []string `json:"env,omitempty"` - User string `json:"user,omitempty"` - HostName string `json:"hostname,omitempty"` - NotFound bool `json:"notfound,omitempty"` - UName string `json:"uname,omitempty"` - RemoteId string `json:"remoteid,omitempty"` + Type string `json:"type"` + Version string `json:"version"` + MShellHomeDir string `json:"mshellhomedir,omitempty"` + HomeDir string `json:"homedir,omitempty"` + Cwd string `json:"cwd,omitempty"` + Env []byte `json:"env,omitempty"` // "env -0" format + User string `json:"user,omitempty"` + HostName string `json:"hostname,omitempty"` + NotFound bool `json:"notfound,omitempty"` + UName string `json:"uname,omitempty"` + RemoteId string `json:"remoteid,omitempty"` } func (*InitPacketType) GetType() string { diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 5dae3e403..5cb7fb688 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -7,8 +7,10 @@ package shexec import ( + "bytes" "context" "encoding/base64" + "errors" "fmt" "io" "os" @@ -38,6 +40,8 @@ const FirstExtraFilesFdNum = 3 const DefaultTermType = "xterm-256color" const DefaultMaxPtySize = 1024 * 1024 +const GetStateTimeout = 5 * time.Second + const ClientCommand = ` PATH=$PATH:~/.mshell; which mshell > /dev/null; @@ -1095,10 +1099,70 @@ func MakeInitPacket() *packet.InitPacketType { func MakeServerInitPacket() (*packet.InitPacketType, error) { var err error initPacket := MakeInitPacket() - initPacket.Env = os.Environ() + cwd, env, err := GetCurrentState() + if err != nil { + return nil, err + } + initPacket.Cwd = cwd + initPacket.Env = env initPacket.RemoteId, err = base.GetRemoteId() if err != nil { return nil, err } return initPacket, nil } + +func parseEnv(env []byte) map[string]string { + envLines := bytes.Split(env, []byte{0}) + rtn := make(map[string]string) + for _, envLine := range envLines { + if len(envLine) == 0 { + continue + } + eqIdx := bytes.Index(envLine, []byte{'='}) + if eqIdx == -1 { + continue + } + varName := string(envLine[0:eqIdx]) + varVal := string(envLine[eqIdx+1:]) + rtn[varName] = varVal + } + return rtn +} + +func getStderr(err error) string { + exitErr, ok := err.(*exec.ExitError) + if !ok { + return "" + } + if len(exitErr.Stderr) == 0 { + return "" + } + lines := strings.SplitN(string(exitErr.Stderr), "\n", 2) + if len(lines[0]) > 100 { + return lines[0][0:100] + } + return lines[0] +} + +func GetCurrentState() (string, []byte, error) { + execFile, err := os.Executable() + if err != nil { + return "", nil, fmt.Errorf("cannot find local mshell executable: %w", err) + } + ctx, _ := context.WithTimeout(context.Background(), GetStateTimeout) + ecmd := exec.CommandContext(ctx, "bash", "-l", "-c", fmt.Sprintf("%s --env", shellescape.Quote(execFile))) + outputBytes, err := ecmd.Output() + if err != nil { + errMsg := getStderr(err) + if errMsg != "" { + return "", nil, errors.New(errMsg) + } + return "", nil, err + } + idx := bytes.Index(outputBytes, []byte{0}) + if idx == -1 { + return "", nil, fmt.Errorf("invalid current state output no NUL byte separator") + } + return string(outputBytes[0:idx]), outputBytes[idx+1:], nil +}