SSH Wsh Install (#225)

This change adds the wsh installation to remote shells, so they have
access to its commands.
This commit is contained in:
Sylvie Crowe 2024-08-15 21:32:08 -07:00 committed by GitHub
parent 65e8d4e3fd
commit 6bc3054733
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 637 additions and 69 deletions

View File

@ -8,7 +8,7 @@ vars:
BIN_DIR: "bin"
VERSION:
sh: node version.cjs
RM: '{{if eq OS "windows"}}powershell Remove-Item{{else}}rm{{end}}'
RM: '{{if eq OS "windows"}}cmd --% /c del /S{{else}}rm {{end}}'
RMRF: '{{if eq OS "windows"}}powershell Remove-Item -Force -Recurse{{else}}rm -rf{{end}}'
DATE: '{{if eq OS "windows"}}powershell date -UFormat{{else}}date{{end}}'
@ -53,7 +53,7 @@ tasks:
status:
- exit {{if eq OS "darwin"}}1{{else}}0{{end}}
cmds:
- cmd: '{{.RM}} "dist/bin/wavesrv*"'
- cmd: "{{.RM}} dist/bin/wavesrv*"
ignore_error: true
- task: build:server:internal
vars:
@ -67,7 +67,7 @@ tasks:
status:
- exit {{if eq OS "darwin"}}0{{else}}1{{end}}
cmds:
- cmd: '{{.RM}} "dist/bin/wavesrv*"'
- cmd: "{{.RM}} dist/bin/wavesrv*"
ignore_error: true
- task: build:server:internal
vars:
@ -94,7 +94,7 @@ tasks:
build:wsh:
desc: Build the wsh component for all possible targets.
cmds:
- cmd: '{{.RM}} "dist/bin/wsh*"'
- cmd: "{{.RM}} dist/bin/wsh*"
ignore_error: true
- task: build:wsh:internal
vars:
@ -148,7 +148,7 @@ tasks:
generates:
- dist/bin/wsh-{{.VERSION}}-{{.GOOS}}.{{.GOARCH}}{{.EXT}}
cmds:
- (CGO_ENABLED=0 GOOS={{.GOOS}} GOARCH={{.GOARCH}} go build -ldflags="-s -w -X main.BuildTime=$({{.DATE}} +'%Y%m%d%H%M')" -o dist/bin/wsh-{{.VERSION}}-{{.GOOS}}.{{.GOARCH}}{{.EXT}} cmd/wsh/main-wsh.go)
- (CGO_ENABLED=0 GOOS={{.GOOS}} GOARCH={{.GOARCH}} go build -ldflags="-s -w -X main.BuildTime=$({{.DATE}} +'%Y%m%d%H%M') -X main.WaveVersion={{.VERSION}}" -o dist/bin/wsh-{{.VERSION}}-{{.GOOS}}.{{.GOARCH}}{{.EXT}} cmd/wsh/main-wsh.go)
deps:
- generate
- go:mod:tidy

View File

@ -0,0 +1,33 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"path/filepath"
"github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
)
var WshBinDir = ".waveterm/bin"
func init() {
rootCmd.AddCommand(rcfilesCmd)
}
var rcfilesCmd = &cobra.Command{
Use: "rcfiles",
Short: "Generate the rc files needed for various shells",
Run: func(cmd *cobra.Command, args []string) {
home := wavebase.GetHomeDir()
waveDir := filepath.Join(home, ".waveterm")
winBinDir := filepath.Join(waveDir, "bin")
err := shellutil.InitRcFiles(waveDir, winBinDir)
if err != nil {
WriteStderr(err.Error())
return
}
},
}

View File

@ -0,0 +1,61 @@
//go:build !windows
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"bufio"
"os"
"os/user"
"runtime"
"strings"
"github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
)
func init() {
rootCmd.AddCommand(shellCmd)
}
var shellCmd = &cobra.Command{
Use: "shell",
Short: "Print the login shell of this user",
Run: func(cmd *cobra.Command, args []string) {
WriteStdout(shellCmdInner())
},
}
func shellCmdInner() string {
if runtime.GOOS == "darwin" {
return shellutil.GetMacUserShell() + "\n"
}
user, err := user.Current()
if err != nil {
return "/bin/bash\n"
}
passwd, err := os.Open("/etc/passwd")
if err != nil {
return "/bin/bash\n"
}
scanner := bufio.NewScanner(passwd)
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
parts := strings.Split(line, ":")
if len(parts) != 7 {
continue
}
if parts[0] == user.Username {
return parts[6] + "\n"
}
}
// none found
return "bin/bash\n"
}

View File

@ -0,0 +1,26 @@
//go:build windows
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"github.com/spf13/cobra"
)
func init() {
rootCmd.AddCommand(shellCmd)
}
var shellCmd = &cobra.Command{
Use: "shell",
Short: "Print the login shell of this user",
Run: func(cmd *cobra.Command, args []string) {
shellCmdInner()
},
}
func shellCmdInner() {
WriteStderr("not implemented/n")
}

View File

@ -4,7 +4,10 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
)
func init() {
@ -15,6 +18,6 @@ var versionCmd = &cobra.Command{
Use: "version",
Short: "Print the version number of wsh",
Run: func(cmd *cobra.Command, args []string) {
WriteStdout("wsh v0.1.0\n")
WriteStdout(fmt.Sprintf("wsh v%s\n", wavebase.WaveVersion))
},
}

View File

@ -5,8 +5,15 @@ package main
import (
"github.com/wavetermdev/thenextwave/cmd/wsh/cmd"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
)
// set by main-server.go
var WaveVersion = "0.0.0"
var BuildTime = "0"
func main() {
wavebase.WaveVersion = WaveVersion
wavebase.BuildTime = BuildTime
cmd.Execute()
}

View File

@ -19,6 +19,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/eventbus"
"github.com/wavetermdev/thenextwave/pkg/filestore"
"github.com/wavetermdev/thenextwave/pkg/remote"
"github.com/wavetermdev/thenextwave/pkg/shellexec"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/waveobj"
@ -307,7 +308,19 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
}
var shellProc *shellexec.ShellProc
if remoteName != "" {
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, remoteName)
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFunc()
opts, err := remote.ParseOpts(remoteName)
if err != nil {
return err
}
client, err := remote.GetClient(credentialCtx, opts)
if err != nil {
return err
}
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, client)
if err != nil {
return err
}

View File

@ -0,0 +1,396 @@
package remote
import (
"bytes"
"context"
"fmt"
"html/template"
"io"
"log"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"github.com/wavetermdev/thenextwave/pkg/userinput"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"golang.org/x/crypto/ssh"
)
var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`)
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[SSHOpts]*ssh.Client)
func GetClient(ctx context.Context, opts *SSHOpts) (*ssh.Client, error) {
globalLock.Lock()
defer globalLock.Unlock()
// attempt to retrieve if already opened
client, ok := clientControllerMap[*opts]
if ok {
return client, nil
}
client, err := ConnectToClient(ctx, opts) //todo specify or remove opts
if err != nil {
return nil, err
}
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := getWshVersion(client)
if err == nil && clientVersion == expectedVersion {
// save successful connection to map
clientControllerMap[*opts] = client
return client, nil
}
var queryText string
var title string
if err != nil {
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
title = "Install Wsh Shell Extensions"
} else {
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
title = "Update Wsh Shell Extensions"
}
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
CheckBoxMsg: "Don't show me this again",
}
response, err := userinput.GetUserInput(ctx, request)
if err != nil || !response.Confirm {
return nil, err
}
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
clientOs, err := getClientOs(client)
if err != nil {
return nil, err
}
clientArch, err := getClientArch(client)
if err != nil {
return nil, err
}
// attempt to install extension
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
err = cpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
if err != nil {
return nil, err
}
log.Printf("successful install")
// save successful connection to map
clientControllerMap[*opts] = client
return client, nil
}
func DisconnectClient(opts *SSHOpts) error {
globalLock.Lock()
defer globalLock.Unlock()
client, ok := clientControllerMap[*opts]
if ok {
return client.Close()
}
return fmt.Errorf("client %v not found", opts)
}
func ParseOpts(input string) (*SSHOpts, error) {
m := userHostRe.FindStringSubmatch(input)
if m == nil {
return nil, fmt.Errorf("invalid format of user@host argument")
}
remoteUser, remoteHost, remotePortStr := m[1], m[2], m[3]
remoteUser = strings.Trim(remoteUser, "@")
var remotePort int
if remotePortStr != "" {
var err error
remotePort, err = strconv.Atoi(remotePortStr)
if err != nil {
return nil, fmt.Errorf("invalid port specified on user@host argument")
}
}
return &SSHOpts{SSHHost: remoteHost, SSHUser: remoteUser, SSHPort: remotePort}, nil
}
func DetectShell(client *ssh.Client) (string, error) {
wshPath := getWshPath(client)
session, err := client.NewSession()
if err != nil {
return "", err
}
log.Printf("shell detecting using command: %s shell", wshPath)
out, err := session.Output(wshPath + " shell")
if err != nil {
log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
return "/bin/bash", nil
}
log.Printf("detecting shell: %s", out)
return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil
}
func getWshVersion(client *ssh.Client) (string, error) {
wshPath := getWshPath(client)
session, err := client.NewSession()
if err != nil {
return "", err
}
out, err := session.Output(wshPath + " version")
if err != nil {
return "", err
}
return strings.TrimSpace(string(out)), nil
}
func getWshPath(client *ssh.Client) string {
defaultPath := filepath.Join("~", ".waveterm", "bin", "wsh")
session, err := client.NewSession()
if err != nil {
log.Printf("unable to detect client's wsh path. using default. error: %v", err)
return defaultPath
}
out, whichErr := session.Output("which wsh")
if whichErr == nil {
return strings.TrimSpace(string(out))
}
session, err = client.NewSession()
if err != nil {
log.Printf("unable to detect client's wsh path. using default. error: %v", err)
return defaultPath
}
out, whereErr := session.Output("where.exe wsh")
if whereErr == nil {
return strings.TrimSpace(string(out))
}
// no custom install, use default path
return defaultPath
}
func hasBashInstalled(client *ssh.Client) (bool, error) {
session, err := client.NewSession()
if err != nil {
// this is a true error that should stop further progress
return false, err
}
out, whichErr := session.Output("which bash")
if whichErr == nil && len(out) != 0 {
return true, nil
}
session, err = client.NewSession()
if err != nil {
// this is a true error that should stop further progress
return false, err
}
out, whereErr := session.Output("where.exe bash")
if whereErr == nil && len(out) != 0 {
return true, nil
}
// note: we could also check in /bin/bash explicitly
// just in case that wasn't added to the path. but if
// that's true, we will most likely have worse
// problems going forward
return false, nil
}
func getClientOs(client *ssh.Client) (string, error) {
session, err := client.NewSession()
if err != nil {
return "", err
}
out, unixErr := session.Output("uname -s")
if unixErr == nil {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return formatted, nil
}
session, err = client.NewSession()
if err != nil {
return "", err
}
out, cmdErr := session.Output("echo %OS%")
if cmdErr == nil {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return strings.Split(formatted, "_")[0], nil
}
session, err = client.NewSession()
if err != nil {
return "", err
}
out, psErr := session.Output("echo $env:OS")
if psErr == nil {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return strings.Split(formatted, "_")[0], nil
}
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
}
func getClientArch(client *ssh.Client) (string, error) {
session, err := client.NewSession()
if err != nil {
return "", err
}
out, unixErr := session.Output("uname -m")
if unixErr == nil {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
if formatted == "x86_64" {
return "amd64", nil
}
return formatted, nil
}
session, err = client.NewSession()
if err != nil {
return "", err
}
out, cmdErr := session.Output("echo %PROCESSOR_ARCHITECTURE%")
if cmdErr == nil {
formatted := strings.ToLower(string(out))
return strings.TrimSpace(formatted), nil
}
session, err = client.NewSession()
if err != nil {
return "", err
}
out, psErr := session.Output("echo $env:PROCESSOR_ARCHITECTURE")
if psErr == nil {
formatted := strings.ToLower(string(out))
return strings.TrimSpace(formatted), nil
}
return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
}
var installTemplateRawBash = `bash -c ' \
mkdir -p {{.installDir}}; \
cat > {{.tempPath}}; \
mv {{.tempPath}} {{.installPath}}; \
chmod a+x {{.installPath}};' \
`
var installTemplateRawDefault = ` \
mkdir -p {{.installDir}}; \
cat > {{.tempPath}}; \
mv {{.tempPath}} {{.installPath}}; \
chmod a+x {{.installPath}}; \
`
func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error {
// warning: does not work on windows remote yet
bashInstalled, err := hasBashInstalled(client)
if err != nil {
return err
}
var selectedTemplateRaw string
if bashInstalled {
selectedTemplateRaw = installTemplateRawBash
} else {
log.Printf("bash is not installed on remote. attempting with default shell")
selectedTemplateRaw = installTemplateRawDefault
}
var installWords = map[string]string{
"installDir": filepath.Dir(destPath),
"tempPath": destPath + ".temp",
"installPath": destPath,
}
installCmd := &bytes.Buffer{}
installTemplate := template.Must(template.New("").Parse(selectedTemplateRaw))
installTemplate.Execute(installCmd, installWords)
session, err := client.NewSession()
if err != nil {
return err
}
installStdin, err := session.StdinPipe()
if err != nil {
return err
}
err = session.Start(installCmd.String())
if err != nil {
return err
}
input, err := os.Open(sourcePath)
if err != nil {
return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
}
go func() {
io.Copy(installStdin, input)
session.Close() // this allows the command to complete for reasons i don't fully understand
}()
return session.Wait()
}
func InstallClientRcFiles(client *ssh.Client) error {
path := getWshPath(client)
session, err := client.NewSession()
if err != nil {
// this is a true error that should stop further progress
return err
}
_, err = session.Output(path + " rcfiles")
return err
}
func GetHomeDir(client *ssh.Client) string {
session, err := client.NewSession()
if err != nil {
return "~"
}
out, err := session.Output("pwd")
if err != nil {
return "~"
}
return strings.TrimSpace(string(out))
}

View File

@ -5,7 +5,6 @@ package shellexec
import (
"bytes"
"context"
"fmt"
"io"
"log"
@ -15,16 +14,15 @@ import (
"reflect"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/creack/pty"
"github.com/wavetermdev/thenextwave/pkg/remote"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"golang.org/x/crypto/ssh"
)
type TermSize struct {
@ -155,45 +153,53 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s))
}
func StartRemoteShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType, remoteName string) (*ShellProc, error) {
ctx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFunc()
var shellPath string
if cmdStr == "" {
shellPath = "/bin/bash"
} else {
shellPath = cmdStr
}
var shellOpts []string
if cmdOpts.Login {
shellOpts = append(shellOpts, "-l")
}
if cmdOpts.Interactive {
shellOpts = append(shellOpts, "-i")
}
cmdCombined := fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
log.Print(cmdCombined)
m := userHostRe.FindStringSubmatch(remoteName)
if m == nil {
return nil, fmt.Errorf("invalid format of user@host argument")
}
remoteUser, remoteHost, remotePortStr := m[1], m[2], m[3]
remoteUser = strings.Trim(remoteUser, "@")
var remotePort int
if remotePortStr != "" {
var err error
remotePort, err = strconv.Atoi(remotePortStr)
if err != nil {
return nil, fmt.Errorf("invalid port specified on user@host argument")
}
}
client, err := remote.ConnectToClient(ctx, &remote.SSHOpts{SSHHost: remoteHost, SSHUser: remoteUser, SSHPort: remotePort}) //todo specify or remove opts
func StartRemoteShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType, client *ssh.Client) (*ShellProc, error) {
shellPath, err := remote.DetectShell(client)
if err != nil {
return nil, err
}
var shellOpts []string
var cmdCombined string
log.Printf("detected shell: %s", shellPath)
err = remote.InstallClientRcFiles(client)
if err != nil {
log.Printf("error installing rc files: %v", err)
return nil, err
}
homeDir := remote.GetHomeDir(client)
if cmdStr == "" {
/* transform command in order to inject environment vars */
if isBashShell(shellPath) {
log.Printf("recognized as bash shell")
// add --rcfile
// cant set -l or -i with --rcfile
shellOpts = append(shellOpts, "--rcfile", fmt.Sprintf(`"%s"/.waveterm/bash-integration/.bashrc`, homeDir))
} else {
if cmdOpts.Login {
shellOpts = append(shellOpts, "-l")
}
if cmdOpts.Interactive {
shellOpts = append(shellOpts, "-i")
}
// zdotdir setting moved to after session is created
}
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
log.Printf("combined command is: %s", cmdCombined)
} else {
shellPath = cmdStr
if cmdOpts.Login {
shellOpts = append(shellOpts, "-l")
}
if cmdOpts.Interactive {
shellOpts = append(shellOpts, "-i")
}
shellOpts = append(shellOpts, "-c", cmdStr)
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
}
session, err := client.NewSession()
if err != nil {
return nil, err
@ -223,11 +229,16 @@ func StartRemoteShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsT
session.Stdin = remoteStdinRead
session.Stdout = remoteStdoutWrite
session.Stderr = remoteStdoutWrite
for envKey, envVal := range cmdOpts.Env {
// note these might fail depending on server settings, but we still try
session.Setenv(envKey, envVal)
}
if isZshShell(shellPath) {
cmdCombined = fmt.Sprintf(`ZDOTDIR="%s/.waveterm/zsh-integration" %s`, homeDir, cmdCombined)
}
session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil)
sessionWrap := SessionWrap{session, cmdCombined, pipePty, pipePty}

View File

@ -48,7 +48,7 @@ const (
# Source the original zshrc
[ -f ~/.zshrc ] && source ~/.zshrc
export PATH=$WAVETERM_WSHBINDIR:$PATH
export PATH={{.WSHBINDIR}}:$PATH
`
ZshStartup_Zlogin = `
@ -75,8 +75,7 @@ elif [ -f ~/.profile ]; then
. ~/.profile
fi
set -i
export PATH=$WAVETERM_WSHBINDIR:$PATH
export PATH={{.WSHBINDIR}}:$PATH
`
)
@ -194,9 +193,16 @@ func GetZshZDotDir() string {
return filepath.Join(wavebase.GetWaveHomeDir(), ZshIntegrationDir)
}
func initCustomShellStartupFilesInternal() error {
log.Printf("initializing wsh and shell startup files\n")
waveHome := wavebase.GetWaveHomeDir()
func GetWshBinaryPath(version string, goos string, goarch string) string {
ext := ""
if goos == "windows" {
ext = ".exe"
}
return filepath.Join(os.Getenv(WaveAppPathVarName), AppPathBinDir, fmt.Sprintf("wsh-%s-%s.%s%s", version, goos, goarch, ext))
}
func InitRcFiles(waveHome string, wshBinDir string) error {
// ensure directiries exist
zshDir := filepath.Join(waveHome, ZshIntegrationDir)
err := wavebase.CacheEnsureDir(zshDir, ZshIntegrationDir, 0755, ZshIntegrationDir)
if err != nil {
@ -207,19 +213,14 @@ func initCustomShellStartupFilesInternal() error {
if err != nil {
return err
}
binDir := filepath.Join(waveHome, WaveHomeBinDir)
err = wavebase.CacheEnsureDir(binDir, WaveHomeBinDir, 0755, WaveHomeBinDir)
if err != nil {
return err
}
// write files to directory
zprofilePath := filepath.Join(zshDir, ".zprofile")
err = os.WriteFile(zprofilePath, []byte(ZshStartup_Zprofile), 0644)
if err != nil {
return fmt.Errorf("error writing zsh-integration .zprofile: %v", err)
}
zshrcPath := filepath.Join(zshDir, ".zshrc")
err = os.WriteFile(zshrcPath, []byte(ZshStartup_Zshrc), 0644)
err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zshrc"), ZshStartup_Zshrc, map[string]string{"WSHBINDIR": fmt.Sprintf(`"%s"`, wshBinDir)})
if err != nil {
return fmt.Errorf("error writing zsh-integration .zshrc: %v", err)
}
@ -233,20 +234,30 @@ func initCustomShellStartupFilesInternal() error {
if err != nil {
return fmt.Errorf("error writing zsh-integration .zshenv: %v", err)
}
bashrcPath := filepath.Join(bashDir, ".bashrc")
err = os.WriteFile(bashrcPath, []byte(BashStartup_Bashrc), 0644)
err = utilfn.WriteTemplateToFile(filepath.Join(bashDir, ".bashrc"), BashStartup_Bashrc, map[string]string{"WSHBINDIR": fmt.Sprintf(`"%s"`, wshBinDir)})
if err != nil {
return fmt.Errorf("error writing bash-integration .bashrc: %v", err)
}
// copy the correct binary to bin
appPath := os.Getenv(WaveAppPathVarName)
if appPath == "" {
return fmt.Errorf("no app path set")
return nil
}
func initCustomShellStartupFilesInternal() error {
log.Printf("initializing wsh and shell startup files\n")
waveHome := wavebase.GetWaveHomeDir()
binDir := filepath.Join(waveHome, WaveHomeBinDir)
err := InitRcFiles(waveHome, `$WAVETERM_WSHBINDIR`)
if err != nil {
return err
}
appBinPath := filepath.Join(appPath, AppPathBinDir)
wshBaseName := computeWshBaseName()
wshFullPath := filepath.Join(appBinPath, wshBaseName)
err = wavebase.CacheEnsureDir(binDir, WaveHomeBinDir, 0755, WaveHomeBinDir)
if err != nil {
return err
}
// copy the correct binary to bin
wshFullPath := GetWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH)
if _, err := os.Stat(wshFullPath); err != nil {
log.Printf("error (non-fatal), could not resolve wsh binary %q: %v\n", wshFullPath, err)
return nil
@ -256,7 +267,7 @@ func initCustomShellStartupFilesInternal() error {
if err != nil {
return fmt.Errorf("error copying wsh binary to bin: %v", err)
}
log.Printf("wsh binary successfully %q copied to %q\n", wshBaseName, wshDstPath)
log.Printf("wsh binary successfully %q copied to %q\n", computeWshBaseName(), wshDstPath)
return nil
}

View File

@ -23,6 +23,7 @@ import (
"strconv"
"strings"
"syscall"
"text/template"
"unicode/utf8"
"github.com/mitchellh/mapstructure"
@ -878,3 +879,9 @@ func AtoiNoErr(str string) int {
}
return val
}
func WriteTemplateToFile(fileName string, templateText string, vars map[string]string) error {
outBuffer := &bytes.Buffer{}
template.Must(template.New("").Parse(templateText)).Execute(outBuffer, vars)
return os.WriteFile(fileName, outBuffer.Bytes(), 0644)
}