conn updates 4 (#1726)

This commit is contained in:
Mike Sawka 2025-01-14 14:09:26 -08:00 committed by GitHub
parent 1ded7bdd74
commit a24fe750c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 341 additions and 219 deletions

View File

@ -142,8 +142,8 @@ func aiRun(cmd *cobra.Command, args []string) (rtnErr error) {
if message.Len() == 0 { if message.Len() == 0 {
return fmt.Errorf("message is empty") return fmt.Errorf("message is empty")
} }
if message.Len() > 10*1024 { if message.Len() > 50*1024 {
return fmt.Errorf("current max message size is 10k") return fmt.Errorf("current max message size is 50k")
} }
messageData := wshrpc.AiMessageData{ messageData := wshrpc.AiMessageData{

View File

@ -19,7 +19,7 @@ var rcfilesCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
err := wshutil.InstallRcFiles() err := wshutil.InstallRcFiles()
if err != nil { if err != nil {
WriteStderr(err.Error()) WriteStderr("%s\n", err.Error())
return return
} }
}, },

View File

@ -162,7 +162,7 @@ func setBgRun(cmd *cobra.Command, args []string) (rtnErr error) {
if err != nil { if err != nil {
return fmt.Errorf("error formatting metadata: %v", err) return fmt.Errorf("error formatting metadata: %v", err)
} }
WriteStdout(string(jsonBytes) + "\n") WriteStdout("%s\n", string(jsonBytes))
return nil return nil
} }

View File

@ -38,6 +38,15 @@ import "./preview.scss";
const MaxFileSize = 1024 * 1024 * 10; // 10MB const MaxFileSize = 1024 * 1024 * 10; // 10MB
const MaxCSVSize = 1024 * 1024 * 1; // 1MB const MaxCSVSize = 1024 * 1024 * 1; // 1MB
// TODO drive this using config
const BOOKMARKS: { label: string; path: string }[] = [
{ label: "Home", path: "~" },
{ label: "Desktop", path: "~/Desktop" },
{ label: "Downloads", path: "~/Downloads" },
{ label: "Documents", path: "~/Documents" },
{ label: "Root", path: "/" },
];
type SpecializedViewProps = { type SpecializedViewProps = {
model: PreviewModel; model: PreviewModel;
parentRef: React.RefObject<HTMLDivElement>; parentRef: React.RefObject<HTMLDivElement>;
@ -185,27 +194,10 @@ export class PreviewModel implements ViewModel {
elemtype: "iconbutton", elemtype: "iconbutton",
icon: "folder-open", icon: "folder-open",
longClick: (e: React.MouseEvent<any>) => { longClick: (e: React.MouseEvent<any>) => {
const menuItems: ContextMenuItem[] = []; const menuItems: ContextMenuItem[] = BOOKMARKS.map((bookmark) => ({
menuItems.push({ label: `Go to ${bookmark.label} (${bookmark.path})`,
label: "Go to Home", click: () => this.goHistory(bookmark.path),
click: () => this.goHistory("~"), }));
});
menuItems.push({
label: "Go to Desktop",
click: () => this.goHistory("~/Desktop"),
});
menuItems.push({
label: "Go to Downloads",
click: () => this.goHistory("~/Downloads"),
});
menuItems.push({
label: "Go to Documents",
click: () => this.goHistory("~/Documents"),
});
menuItems.push({
label: "Go to Root",
click: () => this.goHistory("/"),
});
ContextMenuModel.showContextMenu(menuItems, e); ContextMenuModel.showContextMenu(menuItems, e);
}, },
}; };

View File

@ -523,7 +523,6 @@ const ChatWindow = memo(
const handleNewMessage = useCallback( const handleNewMessage = useCallback(
throttle(100, (messagesLen: number) => { throttle(100, (messagesLen: number) => {
if (osRef.current?.osInstance()) { if (osRef.current?.osInstance()) {
console.log("handleNewMessage", messagesLen, isUserScrolling.current);
const { viewport } = osRef.current.osInstance().elements(); const { viewport } = osRef.current.osInstance().elements();
if (prevMessagesLenRef.current !== messagesLen || !isUserScrolling.current) { if (prevMessagesLenRef.current !== messagesLen || !isUserScrolling.current) {
viewport.scrollTo({ viewport.scrollTo({

View File

@ -304,6 +304,7 @@ declare global {
"conn:askbeforewshinstall"?: boolean; "conn:askbeforewshinstall"?: boolean;
"conn:overrideconfig"?: boolean; "conn:overrideconfig"?: boolean;
"conn:wshpath"?: string; "conn:wshpath"?: string;
"conn:shellpath"?: string;
"display:hidden"?: boolean; "display:hidden"?: boolean;
"display:order"?: number; "display:order"?: number;
"term:*"?: boolean; "term:*"?: boolean;

View File

@ -369,18 +369,18 @@ func (bc *BlockController) setupAndStartShellProcess(rc *RunShellOpts, blockMeta
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
} }
if !conn.WshEnabled.Load() { if !conn.WshEnabled.Load() {
shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn) shellProc, err = shellexec.StartRemoteShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn) shellProc, err = shellexec.StartRemoteShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, conn)
if err != nil { if err != nil {
conn.SetWshError(err) conn.SetWshError(err)
conn.WshEnabled.Store(false) conn.WshEnabled.Store(false)
log.Printf("error starting remote shell proc with wsh: %v", err) log.Printf("error starting remote shell proc with wsh: %v", err)
log.Print("attempting install without wsh") log.Print("attempting install without wsh")
shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn) shellProc, err = shellexec.StartRemoteShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -408,7 +408,7 @@ func (bc *BlockController) setupAndStartShellProcess(rc *RunShellOpts, blockMeta
if len(blockMeta.GetStringList(waveobj.MetaKey_TermLocalShellOpts)) > 0 { if len(blockMeta.GetStringList(waveobj.MetaKey_TermLocalShellOpts)) > 0 {
cmdOpts.ShellOpts = append([]string{}, blockMeta.GetStringList(waveobj.MetaKey_TermLocalShellOpts)...) cmdOpts.ShellOpts = append([]string{}, blockMeta.GetStringList(waveobj.MetaKey_TermLocalShellOpts)...)
} }
shellProc, err = shellexec.StartShellProc(rc.TermSize, cmdStr, cmdOpts) shellProc, err = shellexec.StartLocalShellProc(rc.TermSize, cmdStr, cmdOpts)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -7,15 +7,12 @@ import "regexp"
var ( var (
safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`) safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`)
psSafePattern = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
needsEscape = map[byte]bool{
'"': true,
'\\': true,
'$': true,
'`': true,
}
) )
// TODO: fish quoting is slightly different
// specifically \` will cause an inconsistency between fish and bash/zsh :/
// might need a specific fish quoting function, and an explicit fish shell detection
func HardQuote(s string) string { func HardQuote(s string) string {
if s == "" { if s == "" {
return "\"\"" return "\"\""
@ -29,11 +26,43 @@ func HardQuote(s string) string {
buf = append(buf, '"') buf = append(buf, '"')
for i := 0; i < len(s); i++ { for i := 0; i < len(s); i++ {
if needsEscape[s[i]] { switch s[i] {
buf = append(buf, '\\') case '"', '\\', '$', '`':
} buf = append(buf, '\\', s[i])
case '\n':
buf = append(buf, '\\', '\n')
default:
buf = append(buf, s[i]) buf = append(buf, s[i])
} }
}
buf = append(buf, '"')
return string(buf)
}
func HardQuotePowerShell(s string) string {
if s == "" {
return "\"\""
}
if psSafePattern.MatchString(s) {
return s
}
buf := make([]byte, 0, len(s)+5)
buf = append(buf, '"')
for i := 0; i < len(s); i++ {
c := s[i]
// In PowerShell, backtick (`) is the escape character
switch c {
case '"', '`', '$':
buf = append(buf, '`')
case '\n':
buf = append(buf, '`', 'n') // PowerShell uses `n for newline
}
buf = append(buf, c)
}
buf = append(buf, '"') buf = append(buf, '"')
return string(buf) return string(buf)

View File

@ -6,6 +6,7 @@ package genconn
import ( import (
"fmt" "fmt"
"io" "io"
"log"
"sync" "sync"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -41,6 +42,7 @@ type SSHProcessController struct {
// MakeSSHCmdClient creates a new instance of SSHCmdClient // MakeSSHCmdClient creates a new instance of SSHCmdClient
func MakeSSHCmdClient(client *ssh.Client, cmdSpec CommandSpec) (*SSHProcessController, error) { func MakeSSHCmdClient(client *ssh.Client, cmdSpec CommandSpec) (*SSHProcessController, error) {
log.Printf("SSH-NEWSESSION (cmdclient)\n")
session, err := client.NewSession() session, err := client.NewSession()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create SSH session: %w", err) return nil, fmt.Errorf("failed to create SSH session: %w", err)

View File

@ -69,11 +69,12 @@ type SSHConn struct {
ActiveConnNum int ActiveConnNum int
} }
var ConnServerCmdTemplate = strings.TrimSpace(` var ConnServerCmdTemplate = strings.TrimSpace(
%s version || echo "not-installed" strings.Join([]string{
read jwt_token "%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm);",
WAVETERM_JWT="$jwt_token" %s connserver "read jwt_token;",
`) "WAVETERM_JWT=\"$jwt_token\" %s connserver",
}, "\n"))
func GetAllConnStatus() []wshrpc.ConnStatus { func GetAllConnStatus() []wshrpc.ConnStatus {
globalLock.Lock() globalLock.Lock()
@ -225,37 +226,55 @@ func (conn *SSHConn) OpenDomainSocketListener(ctx context.Context) error {
return nil return nil
} }
// expects the output of `wsh version` which looks like `wsh v0.10.4` or "not-installed" // expects the output of `wsh version` which looks like `wsh v0.10.4` or "not-installed [os] [arch]"
// returns (up-to-date, semver, error) // returns (up-to-date, semver, osArchStr, error)
// if not up to date, or error, version might be "" // if not up to date, or error, version might be ""
func IsWshVersionUpToDate(wshVersionLine string) (bool, string, error) { func IsWshVersionUpToDate(wshVersionLine string) (bool, string, string, error) {
wshVersionLine = strings.TrimSpace(wshVersionLine) wshVersionLine = strings.TrimSpace(wshVersionLine)
if wshVersionLine == "not-installed" { if strings.HasPrefix(wshVersionLine, "not-installed") {
return false, "", nil return false, "not-installed", strings.TrimSpace(strings.TrimPrefix(wshVersionLine, "not-installed")), nil
} }
parts := strings.Fields(wshVersionLine) parts := strings.Fields(wshVersionLine)
if len(parts) != 2 { if len(parts) != 2 {
return false, "", fmt.Errorf("unexpected version format: %s", wshVersionLine) return false, "", "", fmt.Errorf("unexpected version format: %s", wshVersionLine)
} }
clientVersion := parts[1] clientVersion := parts[1]
expectedVersion := fmt.Sprintf("v%s", wavebase.WaveVersion) expectedVersion := fmt.Sprintf("v%s", wavebase.WaveVersion)
if semver.Compare(clientVersion, expectedVersion) < 0 { if semver.Compare(clientVersion, expectedVersion) < 0 {
return false, clientVersion, nil return false, clientVersion, "", nil
} }
return true, clientVersion, nil return true, clientVersion, "", nil
} }
// returns (needsInstall, clientVersion, error) func (conn *SSHConn) getWshPath() string {
func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, error) { config, ok := conn.getConnectionConfig()
if ok && config.ConnWshPath != "" {
return config.ConnWshPath
}
return wavebase.RemoteFullWshBinPath
}
func (conn *SSHConn) GetConfigShellPath() string {
config, ok := conn.getConnectionConfig()
if !ok {
return ""
}
return config.ConnShellPath
}
// returns (needsInstall, clientVersion, osArchStr, error)
// if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr
// if clientVersion is set, then no osArchStr will be returned
func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, string, error) {
conn.Infof(ctx, "running StartConnServer...\n") conn.Infof(ctx, "running StartConnServer...\n")
allowed := WithLockRtn(conn, func() bool { allowed := WithLockRtn(conn, func() bool {
return conn.Status == Status_Connecting return conn.Status == Status_Connecting
}) })
if !allowed { if !allowed {
return false, "", fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus()) return false, "", "", fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
} }
client := conn.GetClient() client := conn.GetClient()
wshPath := remote.GetWshPath(client) wshPath := conn.getWshPath()
rpcCtx := wshrpc.RpcContext{ rpcCtx := wshrpc.RpcContext{
ClientType: wshrpc.ClientType_ConnServer, ClientType: wshrpc.ClientType_ConnServer,
Conn: conn.GetName(), Conn: conn.GetName(),
@ -263,49 +282,51 @@ func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, error)
sockName := conn.GetDomainSocketName() sockName := conn.GetDomainSocketName()
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName) jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
if err != nil { if err != nil {
return false, "", fmt.Errorf("unable to create jwt token for conn controller: %w", err) return false, "", "", fmt.Errorf("unable to create jwt token for conn controller: %w", err)
} }
conn.Infof(ctx, "SSH-NEWSESSION (StartConnServer)\n")
sshSession, err := client.NewSession() sshSession, err := client.NewSession()
if err != nil { if err != nil {
return false, "", fmt.Errorf("unable to create ssh session for conn controller: %w", err) return false, "", "", fmt.Errorf("unable to create ssh session for conn controller: %w", err)
} }
pipeRead, pipeWrite := io.Pipe() pipeRead, pipeWrite := io.Pipe()
sshSession.Stdout = pipeWrite sshSession.Stdout = pipeWrite
sshSession.Stderr = pipeWrite sshSession.Stderr = pipeWrite
stdinPipe, err := sshSession.StdinPipe() stdinPipe, err := sshSession.StdinPipe()
if err != nil { if err != nil {
return false, "", fmt.Errorf("unable to get stdin pipe: %w", err) return false, "", "", fmt.Errorf("unable to get stdin pipe: %w", err)
} }
cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath) cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath)
log.Printf("starting conn controller: %s\n", cmdStr) log.Printf("starting conn controller: %q\n", cmdStr)
shWrappedCmdStr := fmt.Sprintf("sh -c %s", genconn.HardQuote(cmdStr)) shWrappedCmdStr := fmt.Sprintf("sh -c %s", genconn.HardQuote(cmdStr))
blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr)
err = sshSession.Start(shWrappedCmdStr) err = sshSession.Start(shWrappedCmdStr)
if err != nil { if err != nil {
return false, "", fmt.Errorf("unable to start conn controller command: %w", err) return false, "", "", fmt.Errorf("unable to start conn controller command: %w", err)
} }
linesChan := wshutil.StreamToLinesChan(pipeRead) linesChan := wshutil.StreamToLinesChan(pipeRead)
versionLine, err := wshutil.ReadLineWithTimeout(linesChan, 2*time.Second) versionLine, err := wshutil.ReadLineWithTimeout(linesChan, 2*time.Second)
if err != nil { if err != nil {
sshSession.Close() sshSession.Close()
return false, "", fmt.Errorf("error reading wsh version: %w", err) return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
} }
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine)) conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
isUpToDate, clientVersion, err := IsWshVersionUpToDate(versionLine) isUpToDate, clientVersion, osArchStr, err := IsWshVersionUpToDate(versionLine)
if err != nil { if err != nil {
sshSession.Close() sshSession.Close()
return false, "", fmt.Errorf("error checking wsh version: %w", err) return false, "", "", fmt.Errorf("error checking wsh version: %w", err)
} }
conn.Infof(ctx, "connserver update to date: %v\n", isUpToDate) conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
if !isUpToDate { if !isUpToDate {
sshSession.Close() sshSession.Close()
return true, clientVersion, nil return true, clientVersion, osArchStr, nil
} }
// write the jwt // write the jwt
conn.Infof(ctx, "writing jwt token to connserver\n") conn.Infof(ctx, "writing jwt token to connserver\n")
_, err = fmt.Fprintf(stdinPipe, "%s\n", jwtToken) _, err = fmt.Fprintf(stdinPipe, "%s\n", jwtToken)
if err != nil { if err != nil {
sshSession.Close() sshSession.Close()
return false, clientVersion, fmt.Errorf("failed to write JWT token: %w", err) return false, clientVersion, "", fmt.Errorf("failed to write JWT token: %w", err)
} }
conn.WithLock(func() { conn.WithLock(func() {
conn.ConnController = sshSession conn.ConnController = sshSession
@ -351,11 +372,11 @@ func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, error)
defer cancelFn() defer cancelFn()
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn)) err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
if err != nil { if err != nil {
return false, clientVersion, fmt.Errorf("timeout waiting for connserver to register") return false, clientVersion, "", fmt.Errorf("timeout waiting for connserver to register")
} }
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready") time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
conn.Infof(ctx, "connserver is registered and ready\n") conn.Infof(ctx, "connserver is registered and ready\n")
return false, clientVersion, nil return false, clientVersion, "", nil
} }
type WshInstallOpts struct { type WshInstallOpts struct {
@ -438,17 +459,22 @@ func (conn *SSHConn) getPermissionToInstallWsh(ctx context.Context, clientDispla
return true, nil return true, nil
} }
func (conn *SSHConn) InstallWsh(ctx context.Context) error { func (conn *SSHConn) InstallWsh(ctx context.Context, osArchStr string) error {
conn.Infof(ctx, "running installWsh...\n") conn.Infof(ctx, "running installWsh...\n")
client := conn.GetClient() client := conn.GetClient()
if client == nil { if client == nil {
conn.Infof(ctx, "ERROR ssh client is not connected, cannot install\n") conn.Infof(ctx, "ERROR ssh client is not connected, cannot install\n")
return fmt.Errorf("ssh client is not connected, cannot install") return fmt.Errorf("ssh client is not connected, cannot install")
} }
clientOs, clientArch, err := remote.GetClientPlatform(ctx, genconn.MakeSSHShellClient(client)) var clientOs, clientArch string
var err error
if osArchStr != "" {
clientOs, clientArch, err = remote.GetClientPlatformFromOsArchStr(ctx, osArchStr)
} else {
clientOs, clientArch, err = remote.GetClientPlatform(ctx, genconn.MakeSSHShellClient(client))
}
if err != nil { if err != nil {
conn.Infof(ctx, "ERROR detecting client platform: %v\n", err) conn.Infof(ctx, "ERROR detecting client platform: %v\n", err)
return err
} }
conn.Infof(ctx, "detected remote platform os:%s arch:%s\n", clientOs, clientArch) conn.Infof(ctx, "detected remote platform os:%s arch:%s\n", clientOs, clientArch)
err = remote.CpWshToRemote(ctx, client, clientOs, clientArch) err = remote.CpWshToRemote(ctx, client, clientOs, clientArch)
@ -547,8 +573,7 @@ func (conn *SSHConn) Connect(ctx context.Context, connFlags *wshrpc.ConnKeywords
// logic for saving connection and potential flags (we only save once a connection has been made successfully) // logic for saving connection and potential flags (we only save once a connection has been made successfully)
// at the moment, identity files is the only saved flag // at the moment, identity files is the only saved flag
var identityFiles []string var identityFiles []string
existingConfig := wconfig.GetWatcher().GetFullConfig() existingConnection, ok := conn.getConnectionConfig()
existingConnection, ok := existingConfig.Connections[conn.GetName()]
if ok { if ok {
identityFiles = existingConnection.SshIdentityFile identityFiles = existingConnection.SshIdentityFile
} }
@ -592,7 +617,7 @@ func (conn *SSHConn) getConnWshSettings() (bool, bool) {
config := wconfig.GetWatcher().GetFullConfig() config := wconfig.GetWatcher().GetFullConfig()
enableWsh := config.Settings.ConnWshEnabled enableWsh := config.Settings.ConnWshEnabled
askBeforeInstall := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true) askBeforeInstall := wconfig.DefaultBoolPtr(config.Settings.ConnAskBeforeWshInstall, true)
connSettings, ok := config.Connections[conn.GetName()] connSettings, ok := conn.getConnectionConfig()
if ok { if ok {
if connSettings.ConnWshEnabled != nil { if connSettings.ConnWshEnabled != nil {
enableWsh = *connSettings.ConnWshEnabled enableWsh = *connSettings.ConnWshEnabled
@ -639,7 +664,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string)
err = fmt.Errorf("error opening domain socket listener: %w", err) err = fmt.Errorf("error opening domain socket listener: %w", err)
return WshCheckResult{NoWshReason: "error opening domain socket", WshError: err} return WshCheckResult{NoWshReason: "error opening domain socket", WshError: err}
} }
needsInstall, clientVersion, err := conn.StartConnServer(ctx) needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx)
if err != nil { if err != nil {
conn.Infof(ctx, "ERROR starting conn server: %v\n", err) conn.Infof(ctx, "ERROR starting conn server: %v\n", err)
err = fmt.Errorf("error starting conn server: %w", err) err = fmt.Errorf("error starting conn server: %w", err)
@ -647,13 +672,13 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string)
} }
if needsInstall { if needsInstall {
conn.Infof(ctx, "connserver needs to be (re)installed\n") conn.Infof(ctx, "connserver needs to be (re)installed\n")
err = conn.InstallWsh(ctx) err = conn.InstallWsh(ctx, osArchStr)
if err != nil { if err != nil {
conn.Infof(ctx, "ERROR installing wsh: %v\n", err) conn.Infof(ctx, "ERROR installing wsh: %v\n", err)
err = fmt.Errorf("error installing wsh: %w", err) err = fmt.Errorf("error installing wsh: %w", err)
return WshCheckResult{NoWshReason: "error installing wsh/connserver", WshError: err} return WshCheckResult{NoWshReason: "error installing wsh/connserver", WshError: err}
} }
needsInstall, clientVersion, err = conn.StartConnServer(ctx) needsInstall, clientVersion, _, err = conn.StartConnServer(ctx)
if err != nil { if err != nil {
conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err) conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err)
err = fmt.Errorf("error starting conn server (after install): %w", err) err = fmt.Errorf("error starting conn server (after install): %w", err)
@ -670,6 +695,15 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string)
} }
} }
func (conn *SSHConn) getConnectionConfig() (wshrpc.ConnKeywords, bool) {
config := wconfig.GetWatcher().GetFullConfig()
connSettings, ok := config.Connections[conn.GetName()]
if !ok {
return wshrpc.ConnKeywords{}, false
}
return connSettings, true
}
func (conn *SSHConn) persistWshInstalled(ctx context.Context, result WshCheckResult) { func (conn *SSHConn) persistWshInstalled(ctx context.Context, result WshCheckResult) {
conn.WshEnabled.Store(result.WshEnabled) conn.WshEnabled.Store(result.WshEnabled)
conn.SetWshError(result.WshError) conn.SetWshError(result.WshError)
@ -677,9 +711,8 @@ func (conn *SSHConn) persistWshInstalled(ctx context.Context, result WshCheckRes
conn.NoWshReason = result.NoWshReason conn.NoWshReason = result.NoWshReason
conn.WshVersion = result.ClientVersion conn.WshVersion = result.ClientVersion
}) })
config := wconfig.GetWatcher().GetFullConfig() connConfig, ok := conn.getConnectionConfig()
connSettings, ok := config.Connections[conn.GetName()] if ok && connConfig.ConnWshEnabled != nil {
if ok && connSettings.ConnWshEnabled != nil {
return return
} }
meta := make(map[string]any) meta := make(map[string]any)

View File

@ -16,6 +16,7 @@ import (
"strings" "strings"
"text/template" "text/template"
"github.com/wavetermdev/waveterm/pkg/blocklogger"
"github.com/wavetermdev/waveterm/pkg/genconn" "github.com/wavetermdev/waveterm/pkg/genconn"
"github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wavebase"
@ -35,46 +36,6 @@ func ParseOpts(input string) (*SSHOpts, error) {
return &SSHOpts{SSHHost: remoteHost, SSHUser: remoteUser, SSHPort: remotePort}, nil return &SSHOpts{SSHHost: remoteHost, SSHUser: remoteUser, SSHPort: remotePort}, nil
} }
func GetWshPath(client *ssh.Client) string {
defaultPath := wavebase.RemoteFullWshBinPath
session, err := client.NewSession()
if err != nil {
log.Printf("unable to detect client's wsh path. using default. error: %v", err)
return defaultPath
}
out, whichErr := session.Output("which wsh")
if whichErr == nil {
return strings.TrimSpace(string(out))
}
session, err = client.NewSession()
if err != nil {
log.Printf("unable to detect client's wsh path. using default. error: %v", err)
return defaultPath
}
out, whereErr := session.Output("where.exe wsh")
if whereErr == nil {
return strings.TrimSpace(string(out))
}
// check cmd on windows since it requires an absolute path with backslashes
session, err = client.NewSession()
if err != nil {
log.Printf("unable to detect client's wsh path. using default. error: %v", err)
return defaultPath
}
out, cmdErr := session.Output("(dir 2>&1 *``|echo %userprofile%\\.waveterm%\\.waveterm\\bin\\wsh.exe);&<# rem #>echo none") //todo
if cmdErr == nil && strings.TrimSpace(string(out)) != "none" {
return strings.TrimSpace(string(out))
}
// no custom install, use default path
return defaultPath
}
func normalizeOs(os string) string { func normalizeOs(os string) string {
os = strings.ToLower(strings.TrimSpace(os)) os = strings.ToLower(strings.TrimSpace(os))
return os return os
@ -94,6 +55,7 @@ func normalizeArch(arch string) string {
// returns (os, arch, error) // returns (os, arch, error)
// guaranteed to return a supported platform // guaranteed to return a supported platform
func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string, string, error) { func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string, string, error) {
blocklogger.Infof(ctx, "[conndebug] running `uname -sm` to detect client platform\n")
stdout, stderr, err := genconn.RunSimpleCommand(ctx, shell, genconn.CommandSpec{ stdout, stderr, err := genconn.RunSimpleCommand(ctx, shell, genconn.CommandSpec{
Cmd: "uname -sm", Cmd: "uname -sm",
}) })
@ -112,16 +74,28 @@ func GetClientPlatform(ctx context.Context, shell genconn.ShellClient) (string,
return os, arch, nil return os, arch, nil
} }
func GetClientPlatformFromOsArchStr(ctx context.Context, osArchStr string) (string, string, error) {
parts := strings.Fields(strings.TrimSpace(osArchStr))
if len(parts) != 2 {
return "", "", fmt.Errorf("unexpected output from uname: %s", osArchStr)
}
os, arch := normalizeOs(parts[0]), normalizeArch(parts[1])
if err := wavebase.ValidateWshSupportedArch(os, arch); err != nil {
return "", "", err
}
return os, arch, nil
}
var installTemplateRawDefault = strings.TrimSpace(` var installTemplateRawDefault = strings.TrimSpace(`
mkdir -p {{.installDir}} || exit 1 mkdir -p {{.installDir}} || exit 1;
cat > {{.tempPath}} || exit 1 cat > {{.tempPath}} || exit 1;
mv {{.tempPath}} {{.installPath}} || exit 1 mv {{.tempPath}} {{.installPath}} || exit 1;
chmod a+x {{.installPath}} || exit 1 chmod a+x {{.installPath}} || exit 1;
`) `)
var installTemplate = template.Must(template.New("wsh-install-template").Parse(installTemplateRawDefault)) var installTemplate = template.Must(template.New("wsh-install-template").Parse(installTemplateRawDefault))
func CpWshToRemote(ctx context.Context, client *ssh.Client, clientOs string, clientArch string) error { func CpWshToRemote(ctx context.Context, client *ssh.Client, clientOs string, clientArch string) error {
wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
if err != nil { if err != nil {
return err return err
} }
@ -132,13 +106,14 @@ func CpWshToRemote(ctx context.Context, client *ssh.Client, clientOs string, cli
defer input.Close() defer input.Close()
installWords := map[string]string{ installWords := map[string]string{
"installDir": filepath.ToSlash(filepath.Dir(wavebase.RemoteFullWshBinPath)), "installDir": filepath.ToSlash(filepath.Dir(wavebase.RemoteFullWshBinPath)),
"tempPath": filepath.ToSlash(wavebase.RemoteFullWshBinPath + ".temp"), "tempPath": wavebase.RemoteFullWshBinPath + ".temp",
"installPath": filepath.ToSlash(wavebase.RemoteFullWshBinPath), "installPath": wavebase.RemoteFullWshBinPath,
} }
var installCmd bytes.Buffer var installCmd bytes.Buffer
if err := installTemplate.Execute(&installCmd, installWords); err != nil { if err := installTemplate.Execute(&installCmd, installWords); err != nil {
return fmt.Errorf("failed to prepare install command: %w", err) return fmt.Errorf("failed to prepare install command: %w", err)
} }
blocklogger.Infof(ctx, "[conndebug] copying %q to remote server %q\n", wshLocalPath, wavebase.RemoteFullWshBinPath)
genCmd, err := genconn.MakeSSHCmdClient(client, genconn.CommandSpec{ genCmd, err := genconn.MakeSSHCmdClient(client, genconn.CommandSpec{
Cmd: installCmd.String(), Cmd: installCmd.String(),
}) })

View File

@ -25,7 +25,6 @@ import (
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/util/pamparse" "github.com/wavetermdev/waveterm/pkg/util/pamparse"
"github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc"
@ -36,6 +35,14 @@ import (
const DefaultGracefulKillWait = 400 * time.Millisecond const DefaultGracefulKillWait = 400 * time.Millisecond
const (
ShellType_bash = "bash"
ShellType_zsh = "zsh"
ShellType_fish = "fish"
ShellType_pwsh = "pwsh"
ShellType_unknown = "unknown"
)
type CommandOptsType struct { type CommandOptsType struct {
Interactive bool `json:"interactive,omitempty"` Interactive bool `json:"interactive,omitempty"`
Login bool `json:"login,omitempty"` Login bool `json:"login,omitempty"`
@ -151,6 +158,23 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s)) return pp.Write([]byte(s))
} }
func getShellTypeFromShellPath(shellPath string) string {
shellBase := filepath.Base(shellPath)
if strings.Contains(shellBase, "bash") {
return ShellType_bash
}
if strings.Contains(shellBase, "zsh") {
return ShellType_zsh
}
if strings.Contains(shellBase, "fish") {
return ShellType_fish
}
if strings.Contains(shellBase, "pwsh") || strings.Contains(shellBase, "powershell") {
return ShellType_pwsh
}
return ShellType_unknown
}
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) { func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second) utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second)
defer cancelFn() defer cancelFn()
@ -245,8 +269,9 @@ func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr st
return &ShellProc{Cmd: cmdWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil return &ShellProc{Cmd: cmdWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
} }
func StartRemoteShellProcNoWsh(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) { func StartRemoteShellProcNoWsh(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) {
client := conn.GetClient() client := conn.GetClient()
conn.Infof(ctx, "SSH-NEWSESSION (StartRemoteShellProcNoWsh)")
session, err := client.NewSession() session, err := client.NewSession()
if err != nil { if err != nil {
return nil, err return nil, err
@ -287,7 +312,7 @@ func StartRemoteShellProcNoWsh(termSize waveobj.TermSize, cmdStr string, cmdOpts
return &ShellProc{Cmd: sessionWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil return &ShellProc{Cmd: sessionWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
} }
func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) { func StartRemoteShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) {
client := conn.GetClient() client := conn.GetClient()
connRoute := wshutil.MakeConnectionRouteId(conn.GetName()) connRoute := wshutil.MakeConnectionRouteId(conn.GetName())
rpcClient := wshclient.GetBareRpcClient() rpcClient := wshclient.GetBareRpcClient()
@ -296,14 +321,27 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm
return nil, fmt.Errorf("unable to obtain client info: %w", err) return nil, fmt.Errorf("unable to obtain client info: %w", err)
} }
log.Printf("client info collected: %+#v", remoteInfo) log.Printf("client info collected: %+#v", remoteInfo)
var shellPath string
shellPath := cmdOpts.ShellPath if cmdOpts.ShellPath != "" {
if shellPath == "" { conn.Infof(ctx, "using shell path from command opts: %s\n", cmdOpts.ShellPath)
shellPath = cmdOpts.ShellPath
}
configShellPath := conn.GetConfigShellPath()
if shellPath == "" && configShellPath != "" {
conn.Infof(ctx, "using shell path from config (conn:shellpath): %s\n", configShellPath)
shellPath = configShellPath
}
if shellPath == "" && remoteInfo.Shell != "" {
conn.Infof(ctx, "using shell path detected on remote machine: %s\n", remoteInfo.Shell)
shellPath = remoteInfo.Shell shellPath = remoteInfo.Shell
} }
if shellPath == "" {
conn.Infof(ctx, "no shell path detected, using default (/bin/bash)\n")
shellPath = "/bin/bash"
}
var shellOpts []string var shellOpts []string
var cmdCombined string var cmdCombined string
log.Printf("using shell: %s", shellPath) log.Printf("detected shell %q for conn %q\n", shellPath, conn.GetName())
err = wshclient.RemoteInstallRcFilesCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000}) err = wshclient.RemoteInstallRcFilesCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000})
if err != nil { if err != nil {
@ -311,46 +349,51 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm
return nil, err return nil, err
} }
shellOpts = append(shellOpts, cmdOpts.ShellOpts...) shellOpts = append(shellOpts, cmdOpts.ShellOpts...)
shellType := getShellTypeFromShellPath(shellPath)
conn.Infof(ctx, "detected shell type: %s\n", shellType)
if cmdStr == "" { if cmdStr == "" {
/* transform command in order to inject environment vars */ /* transform command in order to inject environment vars */
if isBashShell(shellPath) { if shellType == ShellType_bash {
log.Printf("recognized as bash shell")
// add --rcfile // add --rcfile
// cant set -l or -i with --rcfile // cant set -l or -i with --rcfile
bashPath := genconn.SoftQuote(fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir)) bashPath := fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir)
shellOpts = append(shellOpts, "--rcfile", bashPath) shellOpts = append(shellOpts, "--rcfile", bashPath)
} else if isFishShell(shellPath) { } else if shellType == ShellType_fish {
fishDir := genconn.SoftQuote(fmt.Sprintf("~/.waveterm/%s", shellutil.WaveHomeBinDir)) if cmdOpts.Login {
carg := fmt.Sprintf(`"set -x PATH %s $PATH"`, fishDir) shellOpts = append(shellOpts, "-l")
}
// source the wave.fish file
waveFishPath := fmt.Sprintf("~/.waveterm/%s/wave.fish", shellutil.FishIntegrationDir)
carg := fmt.Sprintf(`"source %s"`, waveFishPath)
shellOpts = append(shellOpts, "-C", carg) shellOpts = append(shellOpts, "-C", carg)
} else if remote.IsPowershell(shellPath) { } else if shellType == ShellType_pwsh {
pwshPath := genconn.SoftQuote(fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir)) pwshPath := fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir)
// powershell is weird about quoted path executables and requires an ampersand first // powershell is weird about quoted path executables and requires an ampersand first
shellPath = "& " + shellPath shellPath = "& " + shellPath
shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", pwshPath) shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", pwshPath)
} else { } else {
if cmdOpts.Login { if cmdOpts.Login {
shellOpts = append(shellOpts, "-l") shellOpts = append(shellOpts, "-l")
} else if cmdOpts.Interactive { }
if cmdOpts.Interactive {
shellOpts = append(shellOpts, "-i") shellOpts = append(shellOpts, "-i")
} }
// zdotdir setting moved to after session is created // zdotdir setting moved to after session is created
} }
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " ")) cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
log.Printf("combined command is: %s", cmdCombined)
} else { } else {
// TODO check quoting of cmdStr
shellPath = cmdStr shellPath = cmdStr
shellOpts = append(shellOpts, "-c", cmdStr) shellOpts = append(shellOpts, "-c", cmdStr)
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " ")) cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
log.Printf("combined command is: %s", cmdCombined)
} }
conn.Infof(ctx, "starting shell, using command: %s\n", cmdCombined)
conn.Infof(ctx, "SSH-NEWSESSION (StartRemoteShellProc)\n")
session, err := client.NewSession() session, err := client.NewSession()
if err != nil { if err != nil {
return nil, err return nil, err
} }
remoteStdinRead, remoteStdinWriteOurs, err := os.Pipe() remoteStdinRead, remoteStdinWriteOurs, err := os.Pipe()
if err != nil { if err != nil {
return nil, err return nil, err
@ -381,8 +424,9 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm
session.Setenv(envKey, envVal) session.Setenv(envKey, envVal)
} }
if isZshShell(shellPath) { if shellType == ShellType_zsh {
zshDir := genconn.SoftQuote(fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir)) zshDir := fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir)
conn.Infof(ctx, "setting ZDOTDIR to %s\n", zshDir)
cmdCombined = fmt.Sprintf(`ZDOTDIR=%s %s`, zshDir, cmdCombined) cmdCombined = fmt.Sprintf(`ZDOTDIR=%s %s`, zshDir, cmdCombined)
} }
@ -390,13 +434,7 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm
if !ok { if !ok {
return nil, fmt.Errorf("no jwt token provided to connection") return nil, fmt.Errorf("no jwt token provided to connection")
} }
if remote.IsPowershell(shellPath) {
cmdCombined = fmt.Sprintf(`$env:%s="%s"; %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
} else {
cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined) cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
}
session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil) session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil)
sessionWrap := MakeSessionWrap(session, cmdCombined, pipePty) sessionWrap := MakeSessionWrap(session, cmdCombined, pipePty)
err = sessionWrap.Start() err = sessionWrap.Start()
@ -425,7 +463,7 @@ func isFishShell(shellPath string) bool {
return strings.Contains(shellBase, "fish") return strings.Contains(shellBase, "fish")
} }
func StartShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType) (*ShellProc, error) { func StartLocalShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType) (*ShellProc, error) {
shellutil.InitCustomShellStartupFiles() shellutil.InitCustomShellStartupFiles()
var ecmd *exec.Cmd var ecmd *exec.Cmd
var shellOpts []string var shellOpts []string
@ -433,29 +471,34 @@ func StartShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOpt
if shellPath == "" { if shellPath == "" {
shellPath = shellutil.DetectLocalShellPath() shellPath = shellutil.DetectLocalShellPath()
} }
shellType := getShellTypeFromShellPath(shellPath)
shellOpts = append(shellOpts, cmdOpts.ShellOpts...) shellOpts = append(shellOpts, cmdOpts.ShellOpts...)
if cmdStr == "" { if cmdStr == "" {
if isBashShell(shellPath) { if shellType == ShellType_bash {
// add --rcfile // add --rcfile
// cant set -l or -i with --rcfile // cant set -l or -i with --rcfile
shellOpts = append(shellOpts, "--rcfile", shellutil.GetBashRcFileOverride()) shellOpts = append(shellOpts, "--rcfile", shellutil.GetLocalBashRcFileOverride())
} else if isFishShell(shellPath) { } else if shellType == ShellType_fish {
wshBinDir := filepath.Join(wavebase.GetWaveDataDir(), shellutil.WaveHomeBinDir) if cmdOpts.Login {
quotedWshBinDir := utilfn.ShellQuote(wshBinDir, false, 300) shellOpts = append(shellOpts, "-l")
shellOpts = append(shellOpts, "-C", fmt.Sprintf("set -x PATH %s $PATH", quotedWshBinDir)) }
} else if remote.IsPowershell(shellPath) { waveFishPath := shellutil.GetLocalWaveFishFilePath()
shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", shellutil.GetWavePowershellEnv()) carg := fmt.Sprintf("source %s", genconn.HardQuote(waveFishPath))
shellOpts = append(shellOpts, "-C", carg)
} else if shellType == ShellType_pwsh {
shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", shellutil.GetLocalWavePowershellEnv())
} else { } else {
if cmdOpts.Login { if cmdOpts.Login {
shellOpts = append(shellOpts, "-l") shellOpts = append(shellOpts, "-l")
} else if cmdOpts.Interactive { }
if cmdOpts.Interactive {
shellOpts = append(shellOpts, "-i") shellOpts = append(shellOpts, "-i")
} }
} }
ecmd = exec.Command(shellPath, shellOpts...) ecmd = exec.Command(shellPath, shellOpts...)
ecmd.Env = os.Environ() ecmd.Env = os.Environ()
if isZshShell(shellPath) { if shellType == ShellType_zsh {
shellutil.UpdateCmdEnv(ecmd, map[string]string{"ZDOTDIR": shellutil.GetZshZDotDir()}) shellutil.UpdateCmdEnv(ecmd, map[string]string{"ZDOTDIR": shellutil.GetLocalZshZDotDir()})
} }
} else { } else {
shellOpts = append(shellOpts, "-c", cmdStr) shellOpts = append(shellOpts, "-c", cmdStr)

View File

@ -17,6 +17,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/wavetermdev/waveterm/pkg/genconn"
"github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/waveobj"
@ -33,9 +34,11 @@ var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
const DefaultShellPath = "/bin/bash" const DefaultShellPath = "/bin/bash"
const ( const (
// there must be no spaces in these integration dir paths
ZshIntegrationDir = "shell/zsh" ZshIntegrationDir = "shell/zsh"
BashIntegrationDir = "shell/bash" BashIntegrationDir = "shell/bash"
PwshIntegrationDir = "shell/pwsh" PwshIntegrationDir = "shell/pwsh"
FishIntegrationDir = "shell/fish"
WaveHomeBinDir = "bin" WaveHomeBinDir = "bin"
ZshStartup_Zprofile = ` ZshStartup_Zprofile = `
@ -44,9 +47,12 @@ const (
` `
ZshStartup_Zshrc = ` ZshStartup_Zshrc = `
# Source the original zshrc # Source the original zshrc only if ZDOTDIR has not been changed
if [ "$ZDOTDIR" = "$WAVETERM_ZDOTDIR" ]; then
[ -f ~/.zshrc ] && source ~/.zshrc [ -f ~/.zshrc ] && source ~/.zshrc
fi
# Custom additions
export PATH={{.WSHBINDIR}}:$PATH export PATH={{.WSHBINDIR}}:$PATH
if [[ -n ${_comps+x} ]]; then if [[ -n ${_comps+x} ]]; then
source <(wsh completion zsh) source <(wsh completion zsh)
@ -56,10 +62,26 @@ fi
ZshStartup_Zlogin = ` ZshStartup_Zlogin = `
# Source the original zlogin # Source the original zlogin
[ -f ~/.zlogin ] && source ~/.zlogin [ -f ~/.zlogin ] && source ~/.zlogin
# Unset ZDOTDIR only if it hasn't been modified
if [ "$ZDOTDIR" = "$WAVETERM_ZDOTDIR" ]; then
unset ZDOTDIR
fi
` `
ZshStartup_Zshenv = ` ZshStartup_Zshenv = `
# Store the initial ZDOTDIR value
WAVETERM_ZDOTDIR="$ZDOTDIR"
# Source the original zshenv
[ -f ~/.zshenv ] && source ~/.zshenv [ -f ~/.zshenv ] && source ~/.zshenv
# Detect if ZDOTDIR has changed
if [ "$ZDOTDIR" != "$WAVETERM_ZDOTDIR" ]; then
# If changed, manually source your custom zshrc from the original WAVETERM_ZDOTDIR
[ -f "$WAVETERM_ZDOTDIR/.zshrc" ] && source "$WAVETERM_ZDOTDIR/.zshrc"
fi
` `
BashStartup_Bashrc = ` BashStartup_Bashrc = `
@ -83,11 +105,22 @@ if type _init_completion &>/dev/null; then
fi fi
` `
FishStartup_Wavefish = `
# this file is sourced with -C
# Add Wave binary directory to PATH
set -x PATH {{.WSHBINDIR}} $PATH
# Load Wave completions
wsh completion fish | source
`
PwshStartup_wavepwsh = ` PwshStartup_wavepwsh = `
# no need to source regular profiles since we cannot # We source this file with -NoExit -File
# overwrite those with powershell. Instead we will source $env:PATH = {{.WSHBINDIR_PWSH}} + "{{.PATHSEP}}" + $env:PATH
# this file with -NoExit
$env:PATH = "{{.WSHBINDIR}}" + "{{.PATHSEP}}" + $env:PATH # Load Wave completions
wsh completion powershell | Out-String | Invoke-Expression
` `
) )
@ -207,19 +240,23 @@ func InitCustomShellStartupFiles() error {
return err return err
} }
func GetBashRcFileOverride() string { func GetLocalBashRcFileOverride() string {
return filepath.Join(wavebase.GetWaveDataDir(), BashIntegrationDir, ".bashrc") return filepath.Join(wavebase.GetWaveDataDir(), BashIntegrationDir, ".bashrc")
} }
func GetWavePowershellEnv() string { func GetLocalWaveFishFilePath() string {
return filepath.Join(wavebase.GetWaveDataDir(), FishIntegrationDir, "wave.fish")
}
func GetLocalWavePowershellEnv() string {
return filepath.Join(wavebase.GetWaveDataDir(), PwshIntegrationDir, "wavepwsh.ps1") return filepath.Join(wavebase.GetWaveDataDir(), PwshIntegrationDir, "wavepwsh.ps1")
} }
func GetZshZDotDir() string { func GetLocalZshZDotDir() string {
return filepath.Join(wavebase.GetWaveDataDir(), ZshIntegrationDir) return filepath.Join(wavebase.GetWaveDataDir(), ZshIntegrationDir)
} }
func GetWshBinaryPath(version string, goos string, goarch string) (string, error) { func GetLocalWshBinaryPath(version string, goos string, goarch string) (string, error) {
ext := "" ext := ""
if goarch == "amd64" { if goarch == "amd64" {
goarch = "x64" goarch = "x64"
@ -237,8 +274,10 @@ func GetWshBinaryPath(version string, goos string, goarch string) (string, error
return filepath.Join(wavebase.GetWaveAppBinPath(), baseName), nil return filepath.Join(wavebase.GetWaveAppBinPath(), baseName), nil
} }
func InitRcFiles(waveHome string, wshBinDir string) error { // absWshBinDir must be an absolute, expanded path (no ~ or $HOME, etc.)
// ensure directiries exist // it will be hard-quoted appropriately for the shell
func InitRcFiles(waveHome string, absWshBinDir string) error {
// ensure directories exist
zshDir := filepath.Join(waveHome, ZshIntegrationDir) zshDir := filepath.Join(waveHome, ZshIntegrationDir)
err := wavebase.CacheEnsureDir(zshDir, ZshIntegrationDir, 0755, ZshIntegrationDir) err := wavebase.CacheEnsureDir(zshDir, ZshIntegrationDir, 0755, ZshIntegrationDir)
if err != nil { if err != nil {
@ -249,43 +288,55 @@ func InitRcFiles(waveHome string, wshBinDir string) error {
if err != nil { if err != nil {
return err return err
} }
fishDir := filepath.Join(waveHome, FishIntegrationDir)
err = wavebase.CacheEnsureDir(fishDir, FishIntegrationDir, 0755, FishIntegrationDir)
if err != nil {
return err
}
pwshDir := filepath.Join(waveHome, PwshIntegrationDir) pwshDir := filepath.Join(waveHome, PwshIntegrationDir)
err = wavebase.CacheEnsureDir(pwshDir, PwshIntegrationDir, 0755, PwshIntegrationDir) err = wavebase.CacheEnsureDir(pwshDir, PwshIntegrationDir, 0755, PwshIntegrationDir)
if err != nil { if err != nil {
return err return err
} }
// write files to directory
zprofilePath := filepath.Join(zshDir, ".zprofile")
err = os.WriteFile(zprofilePath, []byte(ZshStartup_Zprofile), 0644)
if err != nil {
return fmt.Errorf("error writing zsh-integration .zprofile: %v", err)
}
err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zshrc"), ZshStartup_Zshrc, map[string]string{"WSHBINDIR": fmt.Sprintf(`"%s"`, wshBinDir)})
if err != nil {
return fmt.Errorf("error writing zsh-integration .zshrc: %v", err)
}
zloginPath := filepath.Join(zshDir, ".zlogin")
err = os.WriteFile(zloginPath, []byte(ZshStartup_Zlogin), 0644)
if err != nil {
return fmt.Errorf("error writing zsh-integration .zlogin: %v", err)
}
zshenvPath := filepath.Join(zshDir, ".zshenv")
err = os.WriteFile(zshenvPath, []byte(ZshStartup_Zshenv), 0644)
if err != nil {
return fmt.Errorf("error writing zsh-integration .zshenv: %v", err)
}
err = utilfn.WriteTemplateToFile(filepath.Join(bashDir, ".bashrc"), BashStartup_Bashrc, map[string]string{"WSHBINDIR": fmt.Sprintf(`"%s"`, wshBinDir)})
if err != nil {
return fmt.Errorf("error writing bash-integration .bashrc: %v", err)
}
var pathSep string var pathSep string
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
pathSep = ";" pathSep = ";"
} else { } else {
pathSep = ":" pathSep = ":"
} }
err = utilfn.WriteTemplateToFile(filepath.Join(pwshDir, "wavepwsh.ps1"), PwshStartup_wavepwsh, map[string]string{"WSHBINDIR": toPwshEnvVarRef(wshBinDir), "PATHSEP": pathSep}) params := map[string]string{
"WSHBINDIR": genconn.HardQuote(absWshBinDir),
"WSHBINDIR_PWSH": genconn.HardQuotePowerShell(absWshBinDir),
"PATHSEP": pathSep,
}
// write files to directory
err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zprofile"), ZshStartup_Zprofile, params)
if err != nil {
return fmt.Errorf("error writing zsh-integration .zprofile: %v", err)
}
err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zshrc"), ZshStartup_Zshrc, params)
if err != nil {
return fmt.Errorf("error writing zsh-integration .zshrc: %v", err)
}
err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zlogin"), ZshStartup_Zlogin, params)
if err != nil {
return fmt.Errorf("error writing zsh-integration .zlogin: %v", err)
}
err = utilfn.WriteTemplateToFile(filepath.Join(zshDir, ".zshenv"), ZshStartup_Zshenv, params)
if err != nil {
return fmt.Errorf("error writing zsh-integration .zshenv: %v", err)
}
err = utilfn.WriteTemplateToFile(filepath.Join(bashDir, ".bashrc"), BashStartup_Bashrc, params)
if err != nil {
return fmt.Errorf("error writing bash-integration .bashrc: %v", err)
}
err = utilfn.WriteTemplateToFile(filepath.Join(fishDir, "wave.fish"), FishStartup_Wavefish, params)
if err != nil {
return fmt.Errorf("error writing fish-integration wave.fish: %v", err)
}
err = utilfn.WriteTemplateToFile(filepath.Join(pwshDir, "wavepwsh.ps1"), PwshStartup_wavepwsh, params)
if err != nil { if err != nil {
return fmt.Errorf("error writing pwsh-integration wavepwsh.ps1: %v", err) return fmt.Errorf("error writing pwsh-integration wavepwsh.ps1: %v", err)
} }
@ -297,7 +348,7 @@ func initCustomShellStartupFilesInternal() error {
log.Printf("initializing wsh and shell startup files\n") log.Printf("initializing wsh and shell startup files\n")
waveDataHome := wavebase.GetWaveDataDir() waveDataHome := wavebase.GetWaveDataDir()
binDir := filepath.Join(waveDataHome, WaveHomeBinDir) binDir := filepath.Join(waveDataHome, WaveHomeBinDir)
err := InitRcFiles(waveDataHome, `$WAVETERM_WSHBINDIR`) err := InitRcFiles(waveDataHome, binDir)
if err != nil { if err != nil {
return err return err
} }
@ -308,7 +359,7 @@ func initCustomShellStartupFilesInternal() error {
} }
// copy the correct binary to bin // copy the correct binary to bin
wshFullPath, err := GetWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) wshFullPath, err := GetLocalWshBinaryPath(wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH)
if err != nil { if err != nil {
log.Printf("error (non-fatal), could not resolve wsh binary path: %v\n", err) log.Printf("error (non-fatal), could not resolve wsh binary path: %v\n", err)
} }
@ -328,7 +379,3 @@ func initCustomShellStartupFilesInternal() error {
log.Printf("wsh binary successfully copied from %q to %q\n", wshBaseName, wshDstPath) log.Printf("wsh binary successfully copied from %q to %q\n", wshBaseName, wshDstPath)
return nil return nil
} }
func toPwshEnvVarRef(input string) string {
return strings.Replace(input, "$", "$env:", -1)
}

View File

@ -475,6 +475,7 @@ type ConnKeywords struct {
ConnAskBeforeWshInstall *bool `json:"conn:askbeforewshinstall,omitempty"` ConnAskBeforeWshInstall *bool `json:"conn:askbeforewshinstall,omitempty"`
ConnOverrideConfig bool `json:"conn:overrideconfig,omitempty"` ConnOverrideConfig bool `json:"conn:overrideconfig,omitempty"`
ConnWshPath string `json:"conn:wshpath,omitempty"` ConnWshPath string `json:"conn:wshpath,omitempty"`
ConnShellPath string `json:"conn:shellpath,omitempty"`
DisplayHidden *bool `json:"display:hidden,omitempty"` DisplayHidden *bool `json:"display:hidden,omitempty"`
DisplayOrder float32 `json:"display:order,omitempty"` DisplayOrder float32 `json:"display:order,omitempty"`

View File

@ -696,7 +696,7 @@ func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, data wshrpc.Co
if conn == nil { if conn == nil {
return fmt.Errorf("connection not found: %s", connName) return fmt.Errorf("connection not found: %s", connName)
} }
return conn.InstallWsh(ctx) return conn.InstallWsh(ctx, "")
} }
func (ws *WshServer) ConnUpdateWshCommand(ctx context.Context, remoteInfo wshrpc.RemoteInfo) (bool, error) { func (ws *WshServer) ConnUpdateWshCommand(ctx context.Context, remoteInfo wshrpc.RemoteInfo) (bool, error) {
@ -710,7 +710,7 @@ func (ws *WshServer) ConnUpdateWshCommand(ctx context.Context, remoteInfo wshrpc
} }
log.Printf("checking wsh version for connection %s (current: %s)", connName, remoteInfo.ClientVersion) log.Printf("checking wsh version for connection %s (current: %s)", connName, remoteInfo.ClientVersion)
upToDate, _, err := conncontroller.IsWshVersionUpToDate(remoteInfo.ClientVersion) upToDate, _, _, err := conncontroller.IsWshVersionUpToDate(remoteInfo.ClientVersion)
if err != nil { if err != nil {
return false, fmt.Errorf("unable to compare wsh version: %w", err) return false, fmt.Errorf("unable to compare wsh version: %w", err)
} }

View File

@ -566,6 +566,6 @@ func GetInfo() wshrpc.RemoteInfo {
func InstallRcFiles() error { func InstallRcFiles() error {
home := wavebase.GetHomeDir() home := wavebase.GetHomeDir()
waveDir := filepath.Join(home, wavebase.RemoteWaveHomeDirName) waveDir := filepath.Join(home, wavebase.RemoteWaveHomeDirName)
winBinDir := filepath.Join(waveDir, wavebase.RemoteWshBinDirName) wshBinDir := filepath.Join(waveDir, wavebase.RemoteWshBinDirName)
return shellutil.InitRcFiles(waveDir, winBinDir) return shellutil.InitRcFiles(waveDir, wshBinDir)
} }

View File

@ -337,7 +337,7 @@ func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName s
return err return err
} }
// attempt to install extension // attempt to install extension
wshLocalPath, err := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) wshLocalPath, err := shellutil.GetLocalWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
if err != nil { if err != nil {
return err return err
} }