diff --git a/main-mshell.go b/main-mshell.go index db6e34e58..103bf4511 100644 --- a/main-mshell.go +++ b/main-mshell.go @@ -296,6 +296,10 @@ func parseInstallOpts() (*shexec.InstallOpts, error) { if found { continue } + if argStr == "--detect" { + opts.Detect = true + continue + } if base.IsOption(argStr) { return nil, fmt.Errorf("invalid option '%s' passed to mshell --install", argStr) } @@ -402,6 +406,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) { opts.Command = strings.Join(iter.Rest(), " ") break } + return nil, fmt.Errorf("invalid option '%s' passed to mshell", argStr) } return opts, nil } @@ -437,21 +442,29 @@ func handleInstall() (int, error) { if opts.SSHOpts.SSHHost == "" { return 1, fmt.Errorf("cannot install without '--ssh user@host' option") } - fullArch := opts.ArchStr - fields := strings.SplitN(fullArch, ".", 2) - if len(fields) != 2 { - return 1, fmt.Errorf("invalid arch format '%s' passed to mshell --install", fullArch) + if opts.Detect && opts.ArchStr != "" { + return 1, fmt.Errorf("cannot supply both --detect and arch '%s'", opts.ArchStr) } - goos, goarch := fields[0], fields[1] - if !base.ValidGoArch(goos, goarch) { - return 1, fmt.Errorf("invalid arch '%s' passed to mshell --install", fullArch) + if opts.ArchStr == "" && !opts.Detect { + return 1, fmt.Errorf("must supply an arch string or '--detect' to auto detect") } - optName := base.GoArchOptFile(goos, goarch) - _, err = os.Stat(optName) - if err != nil { - return 1, fmt.Errorf("cannot install mshell to remote host, cannot read '%s': %w", optName, err) + if opts.ArchStr != "" { + fullArch := opts.ArchStr + fields := strings.SplitN(fullArch, ".", 2) + if len(fields) != 2 { + return 1, fmt.Errorf("invalid arch format '%s' passed to mshell --install", fullArch) + } + goos, goarch := fields[0], fields[1] + if !base.ValidGoArch(goos, goarch) { + return 1, fmt.Errorf("invalid arch '%s' passed to mshell --install", fullArch) + } + optName := base.GoArchOptFile(goos, goarch) + _, err = os.Stat(optName) + if err != nil { + return 1, fmt.Errorf("cannot install mshell to remote host, cannot read '%s': %w", optName, err) + } + opts.OptName = optName } - opts.OptName = optName err = shexec.RunInstallSSHCommand(opts) if err != nil { return 1, err diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 814cfac49..9a45077bc 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -42,6 +42,7 @@ fi ` const InstallCommand = ` +printf "\n##N{\"type\": \"init\", \"notfound\": true, \"uname\": \"%s | %s\"}\n" "$(uname -s)" "$(uname -m)"; mkdir -p ~/.mshell/; cat > ~/.mshell/mshell.temp; mv ~/.mshell/mshell.temp ~/.mshell/mshell; @@ -246,6 +247,7 @@ type InstallOpts struct { SSHOpts SharedSSHOpts ArchStr string OptName string + Detect bool } type ClientOpts struct { @@ -294,8 +296,8 @@ func (opts *InstallOpts) MakeExecCmd() *exec.Cmd { moreSSHOpts = append(moreSSHOpts, userOpt) } // note that SSHOptsStr is *not* escaped - installCommand := strings.TrimSpace(InstallCommand) - sshCmd := fmt.Sprintf("ssh %s %s %s %s", strings.Join(moreSSHOpts, " "), opts.SSHOpts.SSHOptsStr, shellescape.Quote(opts.SSHOpts.SSHHost), shellescape.Quote(installCommand)) + command := strings.TrimSpace(InstallCommand) + sshCmd := fmt.Sprintf("ssh %s %s %s %s", strings.Join(moreSSHOpts, " "), opts.SSHOpts.SSHOptsStr, shellescape.Quote(opts.SSHOpts.SSHHost), shellescape.Quote(command)) ecmd := exec.Command("bash", "-c", sshCmd) return ecmd } @@ -418,7 +420,20 @@ func ValidateRemoteFds(rfds []packet.RemoteFd) error { return nil } +func sendOptFile(input io.WriteCloser, optName string) error { + fd, err := os.Open(optName) + if err != nil { + return fmt.Errorf("cannot open '%s': %w", optName, err) + } + go func() { + defer input.Close() + io.Copy(input, fd) + }() + return nil +} + func RunInstallSSHCommand(opts *InstallOpts) error { + tryDetect := opts.Detect ecmd := opts.MakeExecCmd() inputWriter, err := ecmd.StdinPipe() if err != nil { @@ -435,21 +450,36 @@ func RunInstallSSHCommand(opts *InstallOpts) error { go func() { io.Copy(os.Stderr, stderrReader) }() - fd, err := os.Open(opts.OptName) - if err != nil { - return fmt.Errorf("cannot open '%s': %w", opts.OptName, err) + if opts.OptName != "" { + sendOptFile(inputWriter, opts.OptName) } - go func() { - defer inputWriter.Close() - io.Copy(inputWriter, fd) - }() packetParser := packet.MakePacketParser(stdoutReader) err = ecmd.Start() if err != nil { return fmt.Errorf("running ssh command: %w", err) } + firstInit := true for pk := range packetParser.MainCh { - if pk.GetType() == packet.InitPacketStr { + if pk.GetType() == packet.InitPacketStr && firstInit { + firstInit = false + initPacket := pk.(*packet.InitPacketType) + if !tryDetect { + continue // ignore + } + tryDetect = false + if initPacket.UName == "" { + return fmt.Errorf("cannot detect arch, no uname received from remote server") + } + goos, goarch, err := DetectGoArch(initPacket.UName) + if err != nil { + return fmt.Errorf("arch cannot be detected (might be incompatible with mshell): %w", err) + } + fmt.Printf("mshell detected remote architecture as '%s.%s'\n", goos, goarch) + optName := base.GoArchOptFile(goos, goarch) + sendOptFile(inputWriter, optName) + continue + } + if pk.GetType() == packet.InitPacketStr && !firstInit { initPacket := pk.(*packet.InitPacketType) if initPacket.Version == base.MShellVersion { fmt.Printf("mshell %s, installed successfully at %s:~/.mshell/mshell\n", initPacket.Version, opts.SSHOpts.SSHHost) @@ -537,16 +567,15 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er if pk.GetType() == packet.InitPacketStr { initPk := pk.(*packet.InitPacketType) if initPk.NotFound { - fmt.Printf("UNAME> %s\n", initPk.UName) if initPk.UName == "" { return nil, fmt.Errorf("mshell command not found on remote server, no uname detected") } - goos, goarch, err := UNameStringToGoArch(initPk.UName) + goos, goarch, err := DetectGoArch(initPk.UName) if err != nil { return nil, fmt.Errorf("mshell command not found on remote server, architecture cannot be detected (might be incompatible with mshell): %w", err) } installCmd := opts.MakeInstallCommandString(goos, goarch) - return nil, fmt.Errorf("mshell command not found on remote server, can install with '%s'", installCmd) + return nil, fmt.Errorf("mshell command not found on remote server, can install with '%s' (or --auto-install)", installCmd) } if initPk.Version != base.MShellVersion { return nil, fmt.Errorf("invalid remote mshell version 'v%s', must be v%s", initPk.Version, base.MShellVersion) @@ -573,7 +602,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er return donePacket, nil } -func UNameStringToGoArch(uname string) (string, string, error) { +func DetectGoArch(uname string) (string, string, error) { fields := strings.SplitN(uname, "|", 2) if len(fields) != 2 { return "", "", fmt.Errorf("invalid uname string returned")