mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-20 21:21:44 +01:00
SSH Wsh Install (#225)
This change adds the wsh installation to remote shells, so they have access to its commands.
This commit is contained in:
parent
65e8d4e3fd
commit
6bc3054733
10
Taskfile.yml
10
Taskfile.yml
@ -8,7 +8,7 @@ vars:
|
||||
BIN_DIR: "bin"
|
||||
VERSION:
|
||||
sh: node version.cjs
|
||||
RM: '{{if eq OS "windows"}}powershell Remove-Item{{else}}rm{{end}}'
|
||||
RM: '{{if eq OS "windows"}}cmd --% /c del /S{{else}}rm {{end}}'
|
||||
RMRF: '{{if eq OS "windows"}}powershell Remove-Item -Force -Recurse{{else}}rm -rf{{end}}'
|
||||
DATE: '{{if eq OS "windows"}}powershell date -UFormat{{else}}date{{end}}'
|
||||
|
||||
@ -53,7 +53,7 @@ tasks:
|
||||
status:
|
||||
- exit {{if eq OS "darwin"}}1{{else}}0{{end}}
|
||||
cmds:
|
||||
- cmd: '{{.RM}} "dist/bin/wavesrv*"'
|
||||
- cmd: "{{.RM}} dist/bin/wavesrv*"
|
||||
ignore_error: true
|
||||
- task: build:server:internal
|
||||
vars:
|
||||
@ -67,7 +67,7 @@ tasks:
|
||||
status:
|
||||
- exit {{if eq OS "darwin"}}0{{else}}1{{end}}
|
||||
cmds:
|
||||
- cmd: '{{.RM}} "dist/bin/wavesrv*"'
|
||||
- cmd: "{{.RM}} dist/bin/wavesrv*"
|
||||
ignore_error: true
|
||||
- task: build:server:internal
|
||||
vars:
|
||||
@ -94,7 +94,7 @@ tasks:
|
||||
build:wsh:
|
||||
desc: Build the wsh component for all possible targets.
|
||||
cmds:
|
||||
- cmd: '{{.RM}} "dist/bin/wsh*"'
|
||||
- cmd: "{{.RM}} dist/bin/wsh*"
|
||||
ignore_error: true
|
||||
- task: build:wsh:internal
|
||||
vars:
|
||||
@ -148,7 +148,7 @@ tasks:
|
||||
generates:
|
||||
- dist/bin/wsh-{{.VERSION}}-{{.GOOS}}.{{.GOARCH}}{{.EXT}}
|
||||
cmds:
|
||||
- (CGO_ENABLED=0 GOOS={{.GOOS}} GOARCH={{.GOARCH}} go build -ldflags="-s -w -X main.BuildTime=$({{.DATE}} +'%Y%m%d%H%M')" -o dist/bin/wsh-{{.VERSION}}-{{.GOOS}}.{{.GOARCH}}{{.EXT}} cmd/wsh/main-wsh.go)
|
||||
- (CGO_ENABLED=0 GOOS={{.GOOS}} GOARCH={{.GOARCH}} go build -ldflags="-s -w -X main.BuildTime=$({{.DATE}} +'%Y%m%d%H%M') -X main.WaveVersion={{.VERSION}}" -o dist/bin/wsh-{{.VERSION}}-{{.GOOS}}.{{.GOARCH}}{{.EXT}} cmd/wsh/main-wsh.go)
|
||||
deps:
|
||||
- generate
|
||||
- go:mod:tidy
|
||||
|
33
cmd/wsh/cmd/wshcmd-rcfiles.go
Normal file
33
cmd/wsh/cmd/wshcmd-rcfiles.go
Normal file
@ -0,0 +1,33 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
)
|
||||
|
||||
var WshBinDir = ".waveterm/bin"
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(rcfilesCmd)
|
||||
}
|
||||
|
||||
var rcfilesCmd = &cobra.Command{
|
||||
Use: "rcfiles",
|
||||
Short: "Generate the rc files needed for various shells",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
home := wavebase.GetHomeDir()
|
||||
waveDir := filepath.Join(home, ".waveterm")
|
||||
winBinDir := filepath.Join(waveDir, "bin")
|
||||
err := shellutil.InitRcFiles(waveDir, winBinDir)
|
||||
if err != nil {
|
||||
WriteStderr(err.Error())
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
61
cmd/wsh/cmd/wshcmd-shell-unix.go
Normal file
61
cmd/wsh/cmd/wshcmd-shell-unix.go
Normal file
@ -0,0 +1,61 @@
|
||||
//go:build !windows
|
||||
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(shellCmd)
|
||||
}
|
||||
|
||||
var shellCmd = &cobra.Command{
|
||||
Use: "shell",
|
||||
Short: "Print the login shell of this user",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
WriteStdout(shellCmdInner())
|
||||
},
|
||||
}
|
||||
|
||||
func shellCmdInner() string {
|
||||
if runtime.GOOS == "darwin" {
|
||||
return shellutil.GetMacUserShell() + "\n"
|
||||
}
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
return "/bin/bash\n"
|
||||
}
|
||||
|
||||
passwd, err := os.Open("/etc/passwd")
|
||||
if err != nil {
|
||||
return "/bin/bash\n"
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(passwd)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
parts := strings.Split(line, ":")
|
||||
|
||||
if len(parts) != 7 {
|
||||
continue
|
||||
}
|
||||
|
||||
if parts[0] == user.Username {
|
||||
return parts[6] + "\n"
|
||||
}
|
||||
}
|
||||
// none found
|
||||
return "bin/bash\n"
|
||||
}
|
26
cmd/wsh/cmd/wshcmd-shell-win.go
Normal file
26
cmd/wsh/cmd/wshcmd-shell-win.go
Normal file
@ -0,0 +1,26 @@
|
||||
//go:build windows
|
||||
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(shellCmd)
|
||||
}
|
||||
|
||||
var shellCmd = &cobra.Command{
|
||||
Use: "shell",
|
||||
Short: "Print the login shell of this user",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shellCmdInner()
|
||||
},
|
||||
}
|
||||
|
||||
func shellCmdInner() {
|
||||
WriteStderr("not implemented/n")
|
||||
}
|
@ -4,7 +4,10 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -15,6 +18,6 @@ var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print the version number of wsh",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
WriteStdout("wsh v0.1.0\n")
|
||||
WriteStdout(fmt.Sprintf("wsh v%s\n", wavebase.WaveVersion))
|
||||
},
|
||||
}
|
||||
|
@ -5,8 +5,15 @@ package main
|
||||
|
||||
import (
|
||||
"github.com/wavetermdev/thenextwave/cmd/wsh/cmd"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
)
|
||||
|
||||
// set by main-server.go
|
||||
var WaveVersion = "0.0.0"
|
||||
var BuildTime = "0"
|
||||
|
||||
func main() {
|
||||
wavebase.WaveVersion = WaveVersion
|
||||
wavebase.BuildTime = BuildTime
|
||||
cmd.Execute()
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/wavetermdev/thenextwave/pkg/eventbus"
|
||||
"github.com/wavetermdev/thenextwave/pkg/filestore"
|
||||
"github.com/wavetermdev/thenextwave/pkg/remote"
|
||||
"github.com/wavetermdev/thenextwave/pkg/shellexec"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/waveobj"
|
||||
@ -307,7 +308,19 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
||||
}
|
||||
var shellProc *shellexec.ShellProc
|
||||
if remoteName != "" {
|
||||
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, remoteName)
|
||||
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
opts, err := remote.ParseOpts(remoteName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := remote.GetClient(credentialCtx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
396
pkg/remote/conncontroller.go
Normal file
396
pkg/remote/conncontroller.go
Normal file
@ -0,0 +1,396 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/wavetermdev/thenextwave/pkg/userinput"
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`)
|
||||
var globalLock = &sync.Mutex{}
|
||||
var clientControllerMap = make(map[SSHOpts]*ssh.Client)
|
||||
|
||||
func GetClient(ctx context.Context, opts *SSHOpts) (*ssh.Client, error) {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
// attempt to retrieve if already opened
|
||||
client, ok := clientControllerMap[*opts]
|
||||
if ok {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
client, err := ConnectToClient(ctx, opts) //todo specify or remove opts
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check that correct wsh extensions are installed
|
||||
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
||||
clientVersion, err := getWshVersion(client)
|
||||
if err == nil && clientVersion == expectedVersion {
|
||||
// save successful connection to map
|
||||
clientControllerMap[*opts] = client
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
var queryText string
|
||||
var title string
|
||||
if err != nil {
|
||||
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
|
||||
title = "Install Wsh Shell Extensions"
|
||||
} else {
|
||||
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
|
||||
title = "Update Wsh Shell Extensions"
|
||||
|
||||
}
|
||||
|
||||
request := &userinput.UserInputRequest{
|
||||
ResponseType: "confirm",
|
||||
QueryText: queryText,
|
||||
Title: title,
|
||||
CheckBoxMsg: "Don't show me this again",
|
||||
}
|
||||
response, err := userinput.GetUserInput(ctx, request)
|
||||
if err != nil || !response.Confirm {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
|
||||
|
||||
clientOs, err := getClientOs(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientArch, err := getClientArch(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// attempt to install extension
|
||||
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
||||
err = cpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("successful install")
|
||||
|
||||
// save successful connection to map
|
||||
clientControllerMap[*opts] = client
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func DisconnectClient(opts *SSHOpts) error {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
client, ok := clientControllerMap[*opts]
|
||||
if ok {
|
||||
return client.Close()
|
||||
}
|
||||
return fmt.Errorf("client %v not found", opts)
|
||||
}
|
||||
|
||||
func ParseOpts(input string) (*SSHOpts, error) {
|
||||
m := userHostRe.FindStringSubmatch(input)
|
||||
if m == nil {
|
||||
return nil, fmt.Errorf("invalid format of user@host argument")
|
||||
}
|
||||
remoteUser, remoteHost, remotePortStr := m[1], m[2], m[3]
|
||||
remoteUser = strings.Trim(remoteUser, "@")
|
||||
var remotePort int
|
||||
if remotePortStr != "" {
|
||||
var err error
|
||||
remotePort, err = strconv.Atoi(remotePortStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port specified on user@host argument")
|
||||
}
|
||||
}
|
||||
|
||||
return &SSHOpts{SSHHost: remoteHost, SSHUser: remoteUser, SSHPort: remotePort}, nil
|
||||
}
|
||||
|
||||
func DetectShell(client *ssh.Client) (string, error) {
|
||||
wshPath := getWshPath(client)
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
log.Printf("shell detecting using command: %s shell", wshPath)
|
||||
out, err := session.Output(wshPath + " shell")
|
||||
if err != nil {
|
||||
log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
|
||||
return "/bin/bash", nil
|
||||
}
|
||||
log.Printf("detecting shell: %s", out)
|
||||
|
||||
return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil
|
||||
}
|
||||
|
||||
func getWshVersion(client *ssh.Client) (string, error) {
|
||||
wshPath := getWshPath(client)
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, err := session.Output(wshPath + " version")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
func getWshPath(client *ssh.Client) string {
|
||||
defaultPath := filepath.Join("~", ".waveterm", "bin", "wsh")
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
log.Printf("unable to detect client's wsh path. using default. error: %v", err)
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
out, whichErr := session.Output("which wsh")
|
||||
if whichErr == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
session, err = client.NewSession()
|
||||
if err != nil {
|
||||
log.Printf("unable to detect client's wsh path. using default. error: %v", err)
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
out, whereErr := session.Output("where.exe wsh")
|
||||
if whereErr == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// no custom install, use default path
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
func hasBashInstalled(client *ssh.Client) (bool, error) {
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
// this is a true error that should stop further progress
|
||||
return false, err
|
||||
}
|
||||
|
||||
out, whichErr := session.Output("which bash")
|
||||
if whichErr == nil && len(out) != 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
session, err = client.NewSession()
|
||||
if err != nil {
|
||||
// this is a true error that should stop further progress
|
||||
return false, err
|
||||
}
|
||||
|
||||
out, whereErr := session.Output("where.exe bash")
|
||||
if whereErr == nil && len(out) != 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// note: we could also check in /bin/bash explicitly
|
||||
// just in case that wasn't added to the path. but if
|
||||
// that's true, we will most likely have worse
|
||||
// problems going forward
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func getClientOs(client *ssh.Client) (string, error) {
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, unixErr := session.Output("uname -s")
|
||||
if unixErr == nil {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
return formatted, nil
|
||||
}
|
||||
|
||||
session, err = client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, cmdErr := session.Output("echo %OS%")
|
||||
if cmdErr == nil {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
return strings.Split(formatted, "_")[0], nil
|
||||
}
|
||||
|
||||
session, err = client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, psErr := session.Output("echo $env:OS")
|
||||
if psErr == nil {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
return strings.Split(formatted, "_")[0], nil
|
||||
}
|
||||
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||
}
|
||||
|
||||
func getClientArch(client *ssh.Client) (string, error) {
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, unixErr := session.Output("uname -m")
|
||||
if unixErr == nil {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
if formatted == "x86_64" {
|
||||
return "amd64", nil
|
||||
}
|
||||
return formatted, nil
|
||||
}
|
||||
|
||||
session, err = client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, cmdErr := session.Output("echo %PROCESSOR_ARCHITECTURE%")
|
||||
if cmdErr == nil {
|
||||
formatted := strings.ToLower(string(out))
|
||||
return strings.TrimSpace(formatted), nil
|
||||
}
|
||||
|
||||
session, err = client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, psErr := session.Output("echo $env:PROCESSOR_ARCHITECTURE")
|
||||
if psErr == nil {
|
||||
formatted := strings.ToLower(string(out))
|
||||
return strings.TrimSpace(formatted), nil
|
||||
}
|
||||
return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||
}
|
||||
|
||||
var installTemplateRawBash = `bash -c ' \
|
||||
mkdir -p {{.installDir}}; \
|
||||
cat > {{.tempPath}}; \
|
||||
mv {{.tempPath}} {{.installPath}}; \
|
||||
chmod a+x {{.installPath}};' \
|
||||
`
|
||||
|
||||
var installTemplateRawDefault = ` \
|
||||
mkdir -p {{.installDir}}; \
|
||||
cat > {{.tempPath}}; \
|
||||
mv {{.tempPath}} {{.installPath}}; \
|
||||
chmod a+x {{.installPath}}; \
|
||||
`
|
||||
|
||||
func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error {
|
||||
// warning: does not work on windows remote yet
|
||||
bashInstalled, err := hasBashInstalled(client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var selectedTemplateRaw string
|
||||
if bashInstalled {
|
||||
selectedTemplateRaw = installTemplateRawBash
|
||||
} else {
|
||||
log.Printf("bash is not installed on remote. attempting with default shell")
|
||||
selectedTemplateRaw = installTemplateRawDefault
|
||||
}
|
||||
|
||||
var installWords = map[string]string{
|
||||
"installDir": filepath.Dir(destPath),
|
||||
"tempPath": destPath + ".temp",
|
||||
"installPath": destPath,
|
||||
}
|
||||
|
||||
installCmd := &bytes.Buffer{}
|
||||
installTemplate := template.Must(template.New("").Parse(selectedTemplateRaw))
|
||||
installTemplate.Execute(installCmd, installWords)
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
installStdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = session.Start(installCmd.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
input, err := os.Open(sourcePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
io.Copy(installStdin, input)
|
||||
session.Close() // this allows the command to complete for reasons i don't fully understand
|
||||
}()
|
||||
|
||||
return session.Wait()
|
||||
}
|
||||
|
||||
func InstallClientRcFiles(client *ssh.Client) error {
|
||||
path := getWshPath(client)
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
// this is a true error that should stop further progress
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = session.Output(path + " rcfiles")
|
||||
return err
|
||||
}
|
||||
|
||||
func GetHomeDir(client *ssh.Client) string {
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return "~"
|
||||
}
|
||||
|
||||
out, err := session.Output("pwd")
|
||||
if err != nil {
|
||||
return "~"
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
|
||||
}
|
@ -5,7 +5,6 @@ package shellexec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@ -15,16 +14,15 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/wavetermdev/thenextwave/pkg/remote"
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type TermSize struct {
|
||||
@ -155,45 +153,53 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
|
||||
return pp.Write([]byte(s))
|
||||
}
|
||||
|
||||
func StartRemoteShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType, remoteName string) (*ShellProc, error) {
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
var shellPath string
|
||||
if cmdStr == "" {
|
||||
shellPath = "/bin/bash"
|
||||
} else {
|
||||
shellPath = cmdStr
|
||||
}
|
||||
|
||||
var shellOpts []string
|
||||
if cmdOpts.Login {
|
||||
shellOpts = append(shellOpts, "-l")
|
||||
}
|
||||
if cmdOpts.Interactive {
|
||||
shellOpts = append(shellOpts, "-i")
|
||||
}
|
||||
cmdCombined := fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
||||
log.Print(cmdCombined)
|
||||
m := userHostRe.FindStringSubmatch(remoteName)
|
||||
if m == nil {
|
||||
return nil, fmt.Errorf("invalid format of user@host argument")
|
||||
}
|
||||
remoteUser, remoteHost, remotePortStr := m[1], m[2], m[3]
|
||||
remoteUser = strings.Trim(remoteUser, "@")
|
||||
var remotePort int
|
||||
if remotePortStr != "" {
|
||||
var err error
|
||||
remotePort, err = strconv.Atoi(remotePortStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port specified on user@host argument")
|
||||
}
|
||||
}
|
||||
|
||||
client, err := remote.ConnectToClient(ctx, &remote.SSHOpts{SSHHost: remoteHost, SSHUser: remoteUser, SSHPort: remotePort}) //todo specify or remove opts
|
||||
func StartRemoteShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType, client *ssh.Client) (*ShellProc, error) {
|
||||
shellPath, err := remote.DetectShell(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var shellOpts []string
|
||||
var cmdCombined string
|
||||
log.Printf("detected shell: %s", shellPath)
|
||||
|
||||
err = remote.InstallClientRcFiles(client)
|
||||
if err != nil {
|
||||
log.Printf("error installing rc files: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
homeDir := remote.GetHomeDir(client)
|
||||
|
||||
if cmdStr == "" {
|
||||
/* transform command in order to inject environment vars */
|
||||
if isBashShell(shellPath) {
|
||||
log.Printf("recognized as bash shell")
|
||||
// add --rcfile
|
||||
// cant set -l or -i with --rcfile
|
||||
shellOpts = append(shellOpts, "--rcfile", fmt.Sprintf(`"%s"/.waveterm/bash-integration/.bashrc`, homeDir))
|
||||
} else {
|
||||
if cmdOpts.Login {
|
||||
shellOpts = append(shellOpts, "-l")
|
||||
}
|
||||
if cmdOpts.Interactive {
|
||||
shellOpts = append(shellOpts, "-i")
|
||||
}
|
||||
// zdotdir setting moved to after session is created
|
||||
}
|
||||
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
||||
log.Printf("combined command is: %s", cmdCombined)
|
||||
} else {
|
||||
shellPath = cmdStr
|
||||
if cmdOpts.Login {
|
||||
shellOpts = append(shellOpts, "-l")
|
||||
}
|
||||
if cmdOpts.Interactive {
|
||||
shellOpts = append(shellOpts, "-i")
|
||||
}
|
||||
shellOpts = append(shellOpts, "-c", cmdStr)
|
||||
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
||||
}
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -223,11 +229,16 @@ func StartRemoteShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsT
|
||||
session.Stdin = remoteStdinRead
|
||||
session.Stdout = remoteStdoutWrite
|
||||
session.Stderr = remoteStdoutWrite
|
||||
|
||||
for envKey, envVal := range cmdOpts.Env {
|
||||
// note these might fail depending on server settings, but we still try
|
||||
session.Setenv(envKey, envVal)
|
||||
}
|
||||
|
||||
if isZshShell(shellPath) {
|
||||
cmdCombined = fmt.Sprintf(`ZDOTDIR="%s/.waveterm/zsh-integration" %s`, homeDir, cmdCombined)
|
||||
}
|
||||
|
||||
session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil)
|
||||
|
||||
sessionWrap := SessionWrap{session, cmdCombined, pipePty, pipePty}
|
||||
|
@ -48,7 +48,7 @@ const (
|
||||
# Source the original zshrc
|
||||
[ -f ~/.zshrc ] && source ~/.zshrc
|
||||
|
||||
export PATH=$WAVETERM_WSHBINDIR:$PATH
|
||||
export PATH={{.WSHBINDIR}}:$PATH
|
||||
`
|
||||
|
||||
ZshStartup_Zlogin = `
|
||||
@ -75,8 +75,7 @@ elif [ -f ~/.profile ]; then
|
||||
. ~/.profile
|
||||
fi
|
||||
|
||||
set -i
|
||||
export PATH=$WAVETERM_WSHBINDIR:$PATH
|
||||
export PATH={{.WSHBINDIR}}:$PATH
|
||||
`
|
||||
)
|
||||
|
||||
@ -194,9 +193,16 @@ func GetZshZDotDir() string {
|
||||
return filepath.Join(wavebase.GetWaveHomeDir(), ZshIntegrationDir)
|
||||
}
|
||||
|
||||
func initCustomShellStartupFilesInternal() error {
|
||||
log.Printf("initializing wsh and shell startup files\n")
|
||||
waveHome := wavebase.GetWaveHomeDir()
|
||||
func GetWshBinaryPath(version string, goos string, goarch string) string {
|
||||
ext := ""
|
||||
if goos == "windows" {
|
||||
ext = ".exe"
|
||||
}
|
||||
return filepath.Join(os.Getenv(WaveAppPathVarName), AppPathBinDir, fmt.Sprintf("wsh-%s-%s.%s%s", version, goos, goarch, ext))
|
||||
}
|
||||
|
||||
func InitRcFiles(waveHome string, wshBinDir string) error {
|
||||
// ensure directiries exist
|
||||
zshDir := filepath.Join(waveHome, ZshIntegrationDir)
|
||||
err := wavebase.CacheEnsureDir(zshDir, ZshIntegrationDir, 0755, ZshIntegrationDir)
|
||||
if err != nil {
|
||||
@ -207,19 +213,14 @@ func initCustomShellStartupFilesInternal() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
binDir := filepath.Join(waveHome, WaveHomeBinDir)
|
||||
err = wavebase.CacheEnsureDir(binDir, WaveHomeBinDir, 0755, WaveHomeBinDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write files to directory
|
||||
zprofilePath := filepath.Join(zshDir, ".zprofile")
|
||||
err = os.WriteFile(zprofilePath, []byte(ZshStartup_Zprofile), 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing zsh-integration .zprofile: %v", err)
|
||||
}
|
||||
zshrcPath := filepath.Join(zshDir, ".zshrc")
|
||||
err = os.WriteFile(zshrcPath, []byte(ZshStartup_Zshrc), 0644)
|
||||
err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zshrc"), ZshStartup_Zshrc, map[string]string{"WSHBINDIR": fmt.Sprintf(`"%s"`, wshBinDir)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing zsh-integration .zshrc: %v", err)
|
||||
}
|
||||
@ -233,20 +234,30 @@ func initCustomShellStartupFilesInternal() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing zsh-integration .zshenv: %v", err)
|
||||
}
|
||||
bashrcPath := filepath.Join(bashDir, ".bashrc")
|
||||
err = os.WriteFile(bashrcPath, []byte(BashStartup_Bashrc), 0644)
|
||||
err = utilfn.WriteTemplateToFile(filepath.Join(bashDir, ".bashrc"), BashStartup_Bashrc, map[string]string{"WSHBINDIR": fmt.Sprintf(`"%s"`, wshBinDir)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing bash-integration .bashrc: %v", err)
|
||||
}
|
||||
|
||||
// copy the correct binary to bin
|
||||
appPath := os.Getenv(WaveAppPathVarName)
|
||||
if appPath == "" {
|
||||
return fmt.Errorf("no app path set")
|
||||
return nil
|
||||
}
|
||||
|
||||
func initCustomShellStartupFilesInternal() error {
|
||||
log.Printf("initializing wsh and shell startup files\n")
|
||||
waveHome := wavebase.GetWaveHomeDir()
|
||||
binDir := filepath.Join(waveHome, WaveHomeBinDir)
|
||||
err := InitRcFiles(waveHome, `$WAVETERM_WSHBINDIR`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
appBinPath := filepath.Join(appPath, AppPathBinDir)
|
||||
wshBaseName := computeWshBaseName()
|
||||
wshFullPath := filepath.Join(appBinPath, wshBaseName)
|
||||
|
||||
err = wavebase.CacheEnsureDir(binDir, WaveHomeBinDir, 0755, WaveHomeBinDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// copy the correct binary to bin
|
||||
wshFullPath := GetWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH)
|
||||
if _, err := os.Stat(wshFullPath); err != nil {
|
||||
log.Printf("error (non-fatal), could not resolve wsh binary %q: %v\n", wshFullPath, err)
|
||||
return nil
|
||||
@ -256,7 +267,7 @@ func initCustomShellStartupFilesInternal() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("error copying wsh binary to bin: %v", err)
|
||||
}
|
||||
log.Printf("wsh binary successfully %q copied to %q\n", wshBaseName, wshDstPath)
|
||||
log.Printf("wsh binary successfully %q copied to %q\n", computeWshBaseName(), wshDstPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"text/template"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
@ -878,3 +879,9 @@ func AtoiNoErr(str string) int {
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func WriteTemplateToFile(fileName string, templateText string, vars map[string]string) error {
|
||||
outBuffer := &bytes.Buffer{}
|
||||
template.Must(template.New("").Parse(templateText)).Execute(outBuffer, vars)
|
||||
return os.WriteFile(fileName, outBuffer.Bytes(), 0644)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user