write auto-detect logic for arch from uname

This commit is contained in:
sawka 2022-06-27 23:14:53 -07:00
parent afd3bdb315
commit 9377619e4c
2 changed files with 68 additions and 26 deletions

View File

@ -296,6 +296,10 @@ func parseInstallOpts() (*shexec.InstallOpts, error) {
if found { if found {
continue continue
} }
if argStr == "--detect" {
opts.Detect = true
continue
}
if base.IsOption(argStr) { if base.IsOption(argStr) {
return nil, fmt.Errorf("invalid option '%s' passed to mshell --install", argStr) return nil, fmt.Errorf("invalid option '%s' passed to mshell --install", argStr)
} }
@ -402,6 +406,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
opts.Command = strings.Join(iter.Rest(), " ") opts.Command = strings.Join(iter.Rest(), " ")
break break
} }
return nil, fmt.Errorf("invalid option '%s' passed to mshell", argStr)
} }
return opts, nil return opts, nil
} }
@ -437,21 +442,29 @@ func handleInstall() (int, error) {
if opts.SSHOpts.SSHHost == "" { if opts.SSHOpts.SSHHost == "" {
return 1, fmt.Errorf("cannot install without '--ssh user@host' option") return 1, fmt.Errorf("cannot install without '--ssh user@host' option")
} }
fullArch := opts.ArchStr if opts.Detect && opts.ArchStr != "" {
fields := strings.SplitN(fullArch, ".", 2) return 1, fmt.Errorf("cannot supply both --detect and arch '%s'", opts.ArchStr)
if len(fields) != 2 {
return 1, fmt.Errorf("invalid arch format '%s' passed to mshell --install", fullArch)
} }
goos, goarch := fields[0], fields[1] if opts.ArchStr == "" && !opts.Detect {
if !base.ValidGoArch(goos, goarch) { return 1, fmt.Errorf("must supply an arch string or '--detect' to auto detect")
return 1, fmt.Errorf("invalid arch '%s' passed to mshell --install", fullArch)
} }
optName := base.GoArchOptFile(goos, goarch) if opts.ArchStr != "" {
_, err = os.Stat(optName) fullArch := opts.ArchStr
if err != nil { fields := strings.SplitN(fullArch, ".", 2)
return 1, fmt.Errorf("cannot install mshell to remote host, cannot read '%s': %w", optName, err) if len(fields) != 2 {
return 1, fmt.Errorf("invalid arch format '%s' passed to mshell --install", fullArch)
}
goos, goarch := fields[0], fields[1]
if !base.ValidGoArch(goos, goarch) {
return 1, fmt.Errorf("invalid arch '%s' passed to mshell --install", fullArch)
}
optName := base.GoArchOptFile(goos, goarch)
_, err = os.Stat(optName)
if err != nil {
return 1, fmt.Errorf("cannot install mshell to remote host, cannot read '%s': %w", optName, err)
}
opts.OptName = optName
} }
opts.OptName = optName
err = shexec.RunInstallSSHCommand(opts) err = shexec.RunInstallSSHCommand(opts)
if err != nil { if err != nil {
return 1, err return 1, err

View File

@ -42,6 +42,7 @@ fi
` `
const InstallCommand = ` const InstallCommand = `
printf "\n##N{\"type\": \"init\", \"notfound\": true, \"uname\": \"%s | %s\"}\n" "$(uname -s)" "$(uname -m)";
mkdir -p ~/.mshell/; mkdir -p ~/.mshell/;
cat > ~/.mshell/mshell.temp; cat > ~/.mshell/mshell.temp;
mv ~/.mshell/mshell.temp ~/.mshell/mshell; mv ~/.mshell/mshell.temp ~/.mshell/mshell;
@ -246,6 +247,7 @@ type InstallOpts struct {
SSHOpts SharedSSHOpts SSHOpts SharedSSHOpts
ArchStr string ArchStr string
OptName string OptName string
Detect bool
} }
type ClientOpts struct { type ClientOpts struct {
@ -294,8 +296,8 @@ func (opts *InstallOpts) MakeExecCmd() *exec.Cmd {
moreSSHOpts = append(moreSSHOpts, userOpt) moreSSHOpts = append(moreSSHOpts, userOpt)
} }
// note that SSHOptsStr is *not* escaped // note that SSHOptsStr is *not* escaped
installCommand := strings.TrimSpace(InstallCommand) command := strings.TrimSpace(InstallCommand)
sshCmd := fmt.Sprintf("ssh %s %s %s %s", strings.Join(moreSSHOpts, " "), opts.SSHOpts.SSHOptsStr, shellescape.Quote(opts.SSHOpts.SSHHost), shellescape.Quote(installCommand)) sshCmd := fmt.Sprintf("ssh %s %s %s %s", strings.Join(moreSSHOpts, " "), opts.SSHOpts.SSHOptsStr, shellescape.Quote(opts.SSHOpts.SSHHost), shellescape.Quote(command))
ecmd := exec.Command("bash", "-c", sshCmd) ecmd := exec.Command("bash", "-c", sshCmd)
return ecmd return ecmd
} }
@ -418,7 +420,20 @@ func ValidateRemoteFds(rfds []packet.RemoteFd) error {
return nil return nil
} }
func sendOptFile(input io.WriteCloser, optName string) error {
fd, err := os.Open(optName)
if err != nil {
return fmt.Errorf("cannot open '%s': %w", optName, err)
}
go func() {
defer input.Close()
io.Copy(input, fd)
}()
return nil
}
func RunInstallSSHCommand(opts *InstallOpts) error { func RunInstallSSHCommand(opts *InstallOpts) error {
tryDetect := opts.Detect
ecmd := opts.MakeExecCmd() ecmd := opts.MakeExecCmd()
inputWriter, err := ecmd.StdinPipe() inputWriter, err := ecmd.StdinPipe()
if err != nil { if err != nil {
@ -435,21 +450,36 @@ func RunInstallSSHCommand(opts *InstallOpts) error {
go func() { go func() {
io.Copy(os.Stderr, stderrReader) io.Copy(os.Stderr, stderrReader)
}() }()
fd, err := os.Open(opts.OptName) if opts.OptName != "" {
if err != nil { sendOptFile(inputWriter, opts.OptName)
return fmt.Errorf("cannot open '%s': %w", opts.OptName, err)
} }
go func() {
defer inputWriter.Close()
io.Copy(inputWriter, fd)
}()
packetParser := packet.MakePacketParser(stdoutReader) packetParser := packet.MakePacketParser(stdoutReader)
err = ecmd.Start() err = ecmd.Start()
if err != nil { if err != nil {
return fmt.Errorf("running ssh command: %w", err) return fmt.Errorf("running ssh command: %w", err)
} }
firstInit := true
for pk := range packetParser.MainCh { for pk := range packetParser.MainCh {
if pk.GetType() == packet.InitPacketStr { if pk.GetType() == packet.InitPacketStr && firstInit {
firstInit = false
initPacket := pk.(*packet.InitPacketType)
if !tryDetect {
continue // ignore
}
tryDetect = false
if initPacket.UName == "" {
return fmt.Errorf("cannot detect arch, no uname received from remote server")
}
goos, goarch, err := DetectGoArch(initPacket.UName)
if err != nil {
return fmt.Errorf("arch cannot be detected (might be incompatible with mshell): %w", err)
}
fmt.Printf("mshell detected remote architecture as '%s.%s'\n", goos, goarch)
optName := base.GoArchOptFile(goos, goarch)
sendOptFile(inputWriter, optName)
continue
}
if pk.GetType() == packet.InitPacketStr && !firstInit {
initPacket := pk.(*packet.InitPacketType) initPacket := pk.(*packet.InitPacketType)
if initPacket.Version == base.MShellVersion { if initPacket.Version == base.MShellVersion {
fmt.Printf("mshell %s, installed successfully at %s:~/.mshell/mshell\n", initPacket.Version, opts.SSHOpts.SSHHost) fmt.Printf("mshell %s, installed successfully at %s:~/.mshell/mshell\n", initPacket.Version, opts.SSHOpts.SSHHost)
@ -537,16 +567,15 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
if pk.GetType() == packet.InitPacketStr { if pk.GetType() == packet.InitPacketStr {
initPk := pk.(*packet.InitPacketType) initPk := pk.(*packet.InitPacketType)
if initPk.NotFound { if initPk.NotFound {
fmt.Printf("UNAME> %s\n", initPk.UName)
if initPk.UName == "" { if initPk.UName == "" {
return nil, fmt.Errorf("mshell command not found on remote server, no uname detected") return nil, fmt.Errorf("mshell command not found on remote server, no uname detected")
} }
goos, goarch, err := UNameStringToGoArch(initPk.UName) goos, goarch, err := DetectGoArch(initPk.UName)
if err != nil { if err != nil {
return nil, fmt.Errorf("mshell command not found on remote server, architecture cannot be detected (might be incompatible with mshell): %w", err) return nil, fmt.Errorf("mshell command not found on remote server, architecture cannot be detected (might be incompatible with mshell): %w", err)
} }
installCmd := opts.MakeInstallCommandString(goos, goarch) installCmd := opts.MakeInstallCommandString(goos, goarch)
return nil, fmt.Errorf("mshell command not found on remote server, can install with '%s'", installCmd) return nil, fmt.Errorf("mshell command not found on remote server, can install with '%s' (or --auto-install)", installCmd)
} }
if initPk.Version != base.MShellVersion { if initPk.Version != base.MShellVersion {
return nil, fmt.Errorf("invalid remote mshell version 'v%s', must be v%s", initPk.Version, base.MShellVersion) return nil, fmt.Errorf("invalid remote mshell version 'v%s', must be v%s", initPk.Version, base.MShellVersion)
@ -573,7 +602,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
return donePacket, nil return donePacket, nil
} }
func UNameStringToGoArch(uname string) (string, string, error) { func DetectGoArch(uname string) (string, string, error) {
fields := strings.SplitN(uname, "|", 2) fields := strings.SplitN(uname, "|", 2)
if len(fields) != 2 { if len(fields) != 2 {
return "", "", fmt.Errorf("invalid uname string returned") return "", "", fmt.Errorf("invalid uname string returned")