zsh support (#227)

adds zsh support to waveterm.  big change, lots going on here.  lots of other improvements and bug fixes added while debugging and building out the feature.

Commits:

* refactor shexec parser.go into new package shellenv.  separate out bash specific parsing from generic functions

* checkpoint

* work on refactoring shexec.  created two new packages shellapi (for bash/zsh specific stuff), and shellutil (shared between shellapi and shexec)

* more refactoring

* create shellapi interface to abstract bash specific functionality

* more refactoring, move bash shell state parsing to shellapi

* move makeRcFile to shellapi.  remove all of the 'client' options CLI options from waveshell

* get shellType passed through to server/single paths for waveshell

* add a local shelltype detector

* mock out a zshapi

* move shelltype through more of the code

* get a command to run via zsh

* zsh can now switch directories.  poc, needs cleanup

* working on ShellState encoding differences between zsh/bash.  Working on parsing zsh decls.  move utilfn package into waveshell (shouldn't have been in wavesrv)

* switch to use []byte for vardecl serialization + diffs

* progress on zsh environment.  still have issues reconciling init environment with trap environment

* fix typeset argument parsing

* parse promptvars, more zsh specific ignores

* fix bug with promptvar not getting set (wrong check in FeState func)

* add sdk (issue #188) to list of rtnstate commands

* more zsh compatibility -- working with a larger ohmyzsh environment.  ignore more variables, handle exit trap better.  unique path/fpath.  add a processtype variable to base.

* must return a value

* zsh alias parsing/restoring.  diff changes (and rtnstate changes).  introduces linediff v1.

* force zmodload of zsh/parameter

* starting work on zsh functions

* need a v1 of mapdiff as well (to handle null chars)

* pack/unpack of ints was wrong (one used int and one use uint).  turned out we only ever encoded '0' so it worked.  that also means it is safe to change unpack to unpackUInt

* reworking for binary encoding of aliases and functions (because of zsh allows any character, including nulls, in names and values)

* fixes, working on functions, issue with line endings

* zsh functions.  lots of ugliness here around dealing with line dicipline and cooked stty.  new runcommand function to grab output from a non-tty fd.  note that we still to run the actual command in a stty to get the proper output.

* write uuid tempdir, cleanup with tmprcfilename code

* hack in some simple zsh function declaration finding code for rtnstate.  create function diff for rtnstate that supports zsh

* make sure key order is constant so shell hashes are consistent

* fix problems with state diffs to support new zsh formats.  add diff/apply code to shellapi (moved from shellenv), that is now specific to zsh or bash

* add log packet and new shellstate packets

* switch to shellstate map that's also keyed by shelltype

* add shelltype to remoteinstance

* remove shell argument from waveshell

* added new shelltype statemap to remote.go (msh), deal with fallout

* move shellstate out of init packet, and move to an explicit reinit call.  try to initialize all of the active shell states

* change dont always store init state (only store on demand).  initialize shell states on demand (if not already initialized).  allow reset to change shells

* add shellpref field to remote table.  use to drive the default shell choice for new tabs

* show shelltag on cmdinput, pass through ri and remote (defaultshellstate)

* bump mshell version to v0.4

* better version validation for shellstate.  also relax compatibility requirements for diffing states (shelltype + major version need to match)

* better error handling, check shellstate compatibility during run (on waveshell server)

* add extra separator for bash shellstate processing to deal with spurious output from rc files

* special migration for v30 -- flag invalid bash shell states and show special button in UI to fix

* format

* remove zsh-decls (unused)

* remove test code

* remove debug print

* fix typo
This commit is contained in:
Mike Sawka 2024-01-16 16:11:04 -08:00 committed by GitHub
parent 76988a5277
commit 422338c04b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 4417 additions and 2374 deletions

View File

@ -7,10 +7,10 @@ signAsync({
app: "temp/Wave.app",
binaries: [
waveAppPath + "/Contents/Resources/app/bin/wavesrv",
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.3-linux.amd64",
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.3-linux.arm64",
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.3-darwin.amd64",
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.3-darwin.arm64",
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.4-linux.amd64",
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.4-linux.arm64",
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.4-darwin.amd64",
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.4-darwin.arm64",
],
}).then(() => {
console.log("signing success");

View File

@ -44,10 +44,10 @@ rm -rf bin/
rm -rf build/
node_modules/.bin/webpack --env prod
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.amd64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.arm64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.amd64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.arm64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.amd64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.arm64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.amd64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.arm64 main-waveshell.go)
(cd wavesrv; CGO_ENABLED=1 go build -tags "osusergo,netgo,sqlite_omit_load_extension" -ldflags "-X main.BuildTime=$(date +'%Y%m%d%H%M')" -o ../bin/wavesrv ./cmd)
node_modules/.bin/electron-forge make
```
@ -60,10 +60,10 @@ rm -rf bin/
rm -rf build/
node_modules/.bin/webpack --env prod
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.amd64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.arm64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.amd64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.arm64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.amd64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.arm64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.amd64 main-waveshell.go)
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.arm64 main-waveshell.go)
# adds -extldflags=-static, *only* on linux (macos does not support fully static binaries) to avoid a glibc dependency
(cd wavesrv; CGO_ENABLED=1 go build -tags "osusergo,netgo,sqlite_omit_load_extension" -ldflags "-linkmode 'external' -extldflags=-static $GO_LDFLAGS" -o ../bin/wavesrv ./cmd)
node_modules/.bin/electron-forge make
@ -86,10 +86,10 @@ CGO_ENABLED=1 go build -tags "osusergo,netgo,sqlite_omit_load_extension" -ldflag
set -e
cd waveshell
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.amd64 main-waveshell.go
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.arm64 main-waveshell.go
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.amd64 main-waveshell.go
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.arm64 main-waveshell.go
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.amd64 main-waveshell.go
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.arm64 main-waveshell.go
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.amd64 main-waveshell.go
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.arm64 main-waveshell.go
```
```bash

View File

@ -25,6 +25,7 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
tempConnectMode: OV<string>;
tempPassword: OV<string>;
tempKeyFile: OV<string>;
tempShellPref: OV<string>;
errorStr: OV<string>;
remoteEdit: T.RemoteEditType;
model: RemotesModel;
@ -40,6 +41,7 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
this.tempConnectMode = mobx.observable.box("auto", { name: "CreateRemote-connectMode" });
this.tempKeyFile = mobx.observable.box("", { name: "CreateRemote-keystr" });
this.tempPassword = mobx.observable.box("", { name: "CreateRemote-password" });
this.tempShellPref = mobx.observable.box("detect", { name: "CreateRemote-shellPref" });
this.errorStr = mobx.observable.box(this.remoteEdit?.errorstr ?? null, { name: "CreateRemote-errorStr" });
}
@ -121,6 +123,7 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
kwargs["password"] = "";
}
kwargs["connectmode"] = this.tempConnectMode.get();
kwargs["shellpref"] = this.tempShellPref.get();
kwargs["visual"] = "1";
kwargs["submit"] = "1";
let prtn = GlobalCommandRunner.createRemote(cname, kwargs, false);
@ -174,6 +177,13 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
})();
}
@boundMethod
handleChangeShellPref(value: string): void {
mobx.action(() => {
this.tempShellPref.set(value);
})();
}
@boundMethod
handleChangePort(value: string): void {
mobx.action(() => {
@ -357,6 +367,20 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
}}
/>
</div>
<div className="shellpref-section">
<Dropdown
label="Shell Preference"
options={[
{ value: "detect", label: "detect" },
{ value: "bash", label: "bash" },
{ value: "zsh", label: "zsh" },
]}
value={this.tempShellPref.get()}
onChange={(val: string) => {
this.tempShellPref.set(val);
}}
/>
</div>
<If condition={!util.isBlank(this.getErrorStr() as string)}>
<div className="settings-field settings-error">Error: {this.getErrorStr()}</div>
</If>

View File

@ -24,6 +24,7 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
tempPassword: OV<string>;
tempConnectMode: OV<string>;
tempAuthMode: OV<string>;
tempShellPref: OV<string>;
model: RemotesModel;
constructor(props: { remotesModel?: RemotesModel }) {
@ -34,6 +35,7 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
this.tempKeyFile = mobx.observable.box(null, { name: "EditRemoteSettings-tempKeyFile" });
this.tempPassword = mobx.observable.box(null, { name: "EditRemoteSettings-tempPassword" });
this.tempConnectMode = mobx.observable.box(null, { name: "EditRemoteSettings-tempConnectMode" });
this.tempShellPref = mobx.observable.box(null, { name: "EditRemoteSettings-tempShellPref" });
}
get selectedRemoteId() {
@ -52,6 +54,10 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
return this.model.isAuthEditMode();
}
isLocalRemote(): boolean {
return this.selectedRemote?.local;
}
componentDidMount(): void {
mobx.action(() => {
this.tempAlias.set(this.selectedRemote?.remotealias);
@ -59,6 +65,7 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
this.tempPassword.set(this.remoteEdit?.haspassword ? PasswordUnchangedSentinel : "");
this.tempConnectMode.set(this.selectedRemote?.connectmode);
this.tempAuthMode.set(this.selectedRemote?.authtype);
this.tempShellPref.set(this.selectedRemote?.shellpref);
})();
}
@ -103,6 +110,13 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
})();
}
@boundMethod
handleChangeShellPref(value: string): void {
mobx.action(() => {
this.tempShellPref.set(value);
})();
}
@boundMethod
canResetPw(): boolean {
if (this.remoteEdit == null) {
@ -154,6 +168,9 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
if (!util.isStrEq(this.tempConnectMode.get(), this.selectedRemote?.connectmode)) {
kwargs["connectmode"] = this.tempConnectMode.get();
}
if (!util.isStrEq(this.tempShellPref.get(), this.selectedRemote?.shellpref)) {
kwargs["shellpref"] = this.tempShellPref.get();
}
kwargs["visual"] = "1";
kwargs["submit"] = "1";
GlobalCommandRunner.editRemote(this.selectedRemote?.remoteid, kwargs);
@ -183,18 +200,8 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
return null;
}
render() {
let authMode = this.tempAuthMode.get();
if (this.remoteEdit === null || !this.isAuthEditMode) {
return null;
}
renderAlias() {
return (
<Modal className="erconn-modal">
<Modal.Header title="Edit Connection" onClose={this.model.closeModal} />
<div className="wave-modal-body">
<div className="name-actions-section">
<div className="name text-primary">{util.getRemoteName(this.selectedRemote)}</div>
</div>
<div className="alias-section">
<TextField
label="Alias"
@ -215,6 +222,47 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
}}
/>
</div>
);
}
renderConnectMode() {
return (
<div className="connectmode-section">
<Dropdown
label="Connect Mode"
options={[
{ value: "startup", label: "startup" },
{ value: "auto", label: "auto" },
{ value: "manual", label: "manual" },
]}
value={this.tempConnectMode.get()}
onChange={this.handleChangeConnectMode}
/>
</div>
);
}
renderShellPref() {
return (
<div className="shellpref-section">
<Dropdown
label="Shell Preference"
options={[
{ value: "detect", label: "detect" },
{ value: "bash", label: "bash" },
{ value: "zsh", label: "zsh" },
]}
value={this.tempShellPref.get()}
onChange={this.handleChangeShellPref}
/>
</div>
);
}
renderAuthMode() {
let authMode = this.tempAuthMode.get();
return (
<>
<div className="authmode-section">
<Dropdown
label="Auth Mode"
@ -287,18 +335,26 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
maxLength={400}
/>
</If>
<div className="connectmode-section">
<Dropdown
label="Connect Mode"
options={[
{ value: "startup", label: "startup" },
{ value: "auto", label: "auto" },
{ value: "manual", label: "manual" },
]}
value={this.tempConnectMode.get()}
onChange={this.handleChangeConnectMode}
/>
</>
);
}
render() {
if (this.remoteEdit === null || !this.isAuthEditMode) {
return null;
}
let isLocal = this.isLocalRemote();
return (
<Modal className="erconn-modal">
<Modal.Header title="Edit Connection" onClose={this.model.closeModal} />
<div className="wave-modal-body">
<div className="name-actions-section">
<div className="name text-primary">{util.getRemoteName(this.selectedRemote)}</div>
</div>
<If condition={!isLocal}>{this.renderAlias()}</If>
<If condition={!isLocal}>{this.renderAuthMode()}</If>
<If condition={!isLocal}>{this.renderConnectMode()}</If>
{this.renderShellPref()}
<If condition={!util.isBlank(this.remoteEdit?.errorstr)}>
<div className="settings-field settings-error">Error: {this.remoteEdit?.errorstr}</div>
</If>

View File

@ -206,7 +206,6 @@ class ViewRemoteConnDetailModal extends React.Component<{}, {}> {
);
if (remote.local) {
installNowButton = <></>;
updateAuthButton = <></>;
cancelInstallButton = <></>;
}
if (remote.sshconfigsrc == "sshconfig-import") {
@ -352,6 +351,10 @@ class ViewRemoteConnDetailModal extends React.Component<{}, {}> {
<div className="settings-label">Connect Mode</div>
<div className="settings-input">{remote.connectmode}</div>
</div>
<div className="settings-field">
<div className="settings-label">Shell Pref</div>
<div className="settings-input">{remote.shellpref}</div>
</div>
{this.renderInstallStatus(remote)}
<div className="flex-spacer" style={{ minHeight: 20 }} />
<div className="status">

View File

@ -128,6 +128,22 @@
padding: 1em 2px;
}
.textareainput-div {
position: relative;
.shelltag {
position: absolute;
bottom: 4px;
right: 3px;
font-size: 10px;
color: @text-secondary;
line-height: 1;
padding: 0px 8px 3px 8px;
background-color: @textarea-background;
border-radius: 0 0 5px 5px;
}
}
textarea {
color: @term-bright-white;
background-color: @textarea-background;

View File

@ -99,6 +99,11 @@ class CmdInput extends React.Component<{}, {}> {
})();
}
@boundMethod
clickResetState(): void {
GlobalCommandRunner.resetShellState();
}
render() {
let model = GlobalModel;
let inputModel = model.inputModel;
@ -115,6 +120,7 @@ class CmdInput extends React.Component<{}, {}> {
remote = GlobalModel.getRemote(ri.remoteid);
feState = ri.festate;
}
feState = feState || {};
let infoShow = inputModel.infoShow.get();
let historyShow = !infoShow && inputModel.historyShow.get();
let aiChatShow = inputModel.aIChatShow.get();
@ -162,6 +168,18 @@ class CmdInput extends React.Component<{}, {}> {
</If>
</div>
</If>
<If condition={feState["invalidshellstate"]}>
<div className="remote-status-warning">
WARNING:&nbsp; The shell state for this tab is invalid (
<a target="_blank" href="https://docs.waveterm.dev/reference/faq">
see FAQ
</a>
). Must reset to continue.
<div className="button is-wave-green is-outlined is-small" onClick={this.clickResetState}>
reset shell state
</div>
</div>
</If>
<div key="prompt" className="cmd-input-context">
<div className="has-text-white">
<span ref={this.promptRef}>

View File

@ -5,6 +5,8 @@ import * as React from "react";
import * as mobxReact from "mobx-react";
import * as mobx from "mobx";
import type * as T from "../../../types/types";
import * as util from "../../../util/util";
import { If } from "tsx-control-statements/components";
import { boundMethod } from "autobind-decorator";
import cn from "classnames";
import { GlobalModel, GlobalCommandRunner, Screen } from "../../../model/model";
@ -585,8 +587,24 @@ class TextAreaInput extends React.Component<{ screen: Screen; onHeightChange: ()
let computedInnerHeight = displayLines * (termFontSize * 1.5) + 2 * 0.5 * termFontSize;
// inner height + 2*1em padding
let computedOuterHeight = computedInnerHeight + 2 * 1.0 * termFontSize;
let shellType: string = "";
let screen = GlobalModel.getActiveScreen();
if (screen != null) {
let ri = screen.getCurRemoteInstance();
console.log("got ri", ri);
if (ri != null && ri.shelltype != null) {
shellType = ri.shelltype;
}
}
return (
<div className="control is-expanded" ref={this.controlRef} style={{ height: computedOuterHeight }}>
<div
className="textareainput-div control is-expanded"
ref={this.controlRef}
style={{ height: computedOuterHeight }}
>
<If condition={!disabled && !util.isBlank(shellType)}>
<div className="shelltag">{shellType}</div>
</If>
<textarea
key="main"
ref={this.mainInputRef}

View File

@ -1213,6 +1213,7 @@ class Session {
remoteid: rptr.remoteid,
name: rptr.name,
festate: remote.defaultfestate,
shelltype: remote.defaultshelltype,
};
}
return null;
@ -4567,6 +4568,10 @@ class CommandRunner {
GlobalModel.submitCommand("history", null, null, kwargs, true);
}
resetShellState() {
GlobalModel.submitCommand("reset", null, null, null, true);
}
historyPurgeLines(lines: string[]): Promise<CommandRtnType> {
let prtn = GlobalModel.submitCommand("history", "purge", lines, { nohist: "1" }, false);
return prtn;

View File

@ -121,6 +121,8 @@ type RemoteType = {
remoteopts?: RemoteOptsType;
local: boolean;
remove?: boolean;
shellpref: string;
defaultshelltype: string;
};
type RemoteStateType = {
@ -136,6 +138,7 @@ type RemoteInstanceType = {
remoteownerid: string;
remoteid: string;
festate: Record<string, string>;
shelltype: string;
remove?: boolean;
};

View File

@ -4,10 +4,8 @@
package main
import (
"bytes"
"fmt"
"os"
"strconv"
"strings"
"syscall"
"time"
@ -16,131 +14,10 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/server"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"golang.org/x/sys/unix"
)
var BuildTime = "0"
// func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
// err := shexec.ValidateRunPacket(pk)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))
// return
// }
// fileNames, err := base.GetCommandFileNames(pk.CK)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))
// return
// }
// cmd, err := shexec.MakeRunnerExec(pk.CK)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))
// return
// }
// cmdStdin, err := cmd.StdinPipe()
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))
// return
// }
// // touch ptyout file (should exist for tailer to work correctly)
// ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))
// return
// }
// ptyOutFd.Close() // just opened to create the file, can close right after
// runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))
// return
// }
// defer runnerOutFd.Close()
// cmd.Stdout = runnerOutFd
// cmd.Stderr = runnerOutFd
// err = cmd.Start()
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))
// return
// }
// go func() {
// err = packet.SendPacket(cmdStdin, pk)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))
// return
// }
// cmdStdin.Close()
// // clean up zombies
// cmd.Wait()
// }()
// }
// func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packet.PacketSender) error {
// err := tailer.AddWatch(pk)
// if err != nil {
// return err
// }
// return nil
// }
// func doMain() {
// homeDir := base.GetHomeDir()
// err := os.Chdir(homeDir)
// if err != nil {
// packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to $HOME '%s': %v", homeDir, err))
// return
// }
// _, err = base.GetMShellPath()
// if err != nil {
// packet.SendErrorPacket(os.Stdout, err.Error())
// return
// }
// packetParser := packet.MakePacketParser(os.Stdin)
// sender := packet.MakePacketSender(os.Stdout)
// tailer, err := cmdtail.MakeTailer(sender)
// if err != nil {
// packet.SendErrorPacket(os.Stdout, err.Error())
// return
// }
// go tailer.Run()
// initPacket := shexec.MakeInitPacket()
// sender.SendPacket(initPacket)
// for pk := range packetParser.MainCh {
// if pk.GetType() == packet.RunPacketStr {
// doMainRun(pk.(*packet.RunPacketType), sender)
// continue
// }
// if pk.GetType() == packet.GetCmdPacketStr {
// err = doGetCmd(tailer, pk.(*packet.GetCmdPacketType), sender)
// if err != nil {
// errPk := packet.MakeErrorPacket(err.Error())
// sender.SendPacket(errPk)
// continue
// }
// continue
// }
// if pk.GetType() == packet.CdPacketStr {
// cdPacket := pk.(*packet.CdPacketType)
// err := os.Chdir(cdPacket.Dir)
// resp := packet.MakeResponsePacket(cdPacket.ReqId)
// if err != nil {
// resp.Error = err.Error()
// } else {
// resp.Success = true
// }
// sender.SendPacket(resp)
// continue
// }
// if pk.GetType() == packet.ErrorPacketStr {
// errPk := pk.(*packet.ErrorPacketType)
// errPk.Error = "invalid packet sent to mshell: " + errPk.Error
// sender.SendPacket(errPk)
// continue
// }
// sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
// }
// }
func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType, error) {
rpb := packet.MakeRunPacketBuilder()
for pk := range packetParser.MainCh {
@ -155,7 +32,7 @@ func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType
return nil, fmt.Errorf("no run packet received")
}
func handleSingle(fromServer bool) {
func handleSingle() {
packetParser := packet.MakePacketParser(os.Stdin, nil)
sender := packet.MakePacketSender(os.Stdout, nil)
defer func() {
@ -177,12 +54,10 @@ func handleSingle(fromServer bool) {
sender.SendErrorResponse(runPacket.ReqId, err)
return
}
if fromServer {
err = runPacket.CK.Validate("run packet")
if err != nil {
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("run packets from server must have a CK: %v", err))
}
}
if runPacket.Detached {
cmd, startPk, err := shexec.RunCommandDetached(runPacket, sender)
if err != nil {
@ -225,306 +100,20 @@ func handleSingle(fromServer bool) {
}
}
func detectOpenFds() ([]packet.RemoteFd, error) {
var fds []packet.RemoteFd
for fdNum := 3; fdNum <= 64; fdNum++ {
flags, err := unix.FcntlInt(uintptr(fdNum), unix.F_GETFL, 0)
if err != nil {
continue
}
flags = flags & 3
rfd := packet.RemoteFd{FdNum: fdNum}
if flags&2 == 2 {
return nil, fmt.Errorf("invalid fd=%d, mshell does not support fds open for reading and writing", fdNum)
}
if flags&1 == 1 {
rfd.Write = true
} else {
rfd.Read = true
}
fds = append(fds, rfd)
}
return fds, nil
}
func parseInstallOpts() (*shexec.InstallOpts, error) {
opts := &shexec.InstallOpts{}
iter := base.MakeOptsIter(os.Args[2:]) // first arg is --install
for iter.HasNext() {
argStr := iter.Next()
found, err := tryParseSSHOpt(iter, &opts.SSHOpts)
if err != nil {
return nil, err
}
if found {
continue
}
if argStr == "--detect" {
opts.Detect = true
continue
}
if base.IsOption(argStr) {
return nil, fmt.Errorf("invalid option '%s' passed to mshell --install", argStr)
}
opts.ArchStr = argStr
break
}
return opts, nil
}
func tryParseSSHOpt(iter *base.OptsIter, sshOpts *shexec.SSHOpts) (bool, error) {
argStr := iter.Current()
if argStr == "--ssh" {
if !iter.IsNextPlain() {
return false, fmt.Errorf("'--ssh [user@host]' missing host")
}
sshOpts.SSHHost = iter.Next()
return true, nil
}
if argStr == "--ssh-opts" {
if !iter.HasNext() {
return false, fmt.Errorf("'--ssh-opts [options]' missing options")
}
sshOpts.SSHOptsStr = iter.Next()
return true, nil
}
if argStr == "-i" {
if !iter.IsNextPlain() {
return false, fmt.Errorf("-i [identity-file]' missing file")
}
sshOpts.SSHIdentity = iter.Next()
return true, nil
}
if argStr == "-l" {
if !iter.IsNextPlain() {
return false, fmt.Errorf("-l [user]' missing user")
}
sshOpts.SSHUser = iter.Next()
return true, nil
}
if argStr == "-p" {
if !iter.IsNextPlain() {
return false, fmt.Errorf("-p [port]' missing port")
}
nextArgStr := iter.Next()
portVal, err := strconv.Atoi(nextArgStr)
if err != nil {
return false, fmt.Errorf("-p [port]' invalid port: %v", err)
}
if portVal <= 0 {
return false, fmt.Errorf("-p [port]' invalid port: %d", portVal)
}
sshOpts.SSHPort = portVal
return true, nil
}
return false, nil
}
func parseClientOpts() (*shexec.ClientOpts, error) {
opts := &shexec.ClientOpts{}
iter := base.MakeOptsIter(os.Args[1:])
for iter.HasNext() {
argStr := iter.Next()
found, err := tryParseSSHOpt(iter, &opts.SSHOpts)
if err != nil {
return nil, err
}
if found {
continue
}
if argStr == "--cwd" {
if !iter.IsNextPlain() {
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
}
opts.Cwd = iter.Next()
continue
}
if argStr == "--detach" {
opts.Detach = true
continue
}
if argStr == "--pty" {
opts.UsePty = true
continue
}
if argStr == "--debug" {
opts.Debug = true
continue
}
if argStr == "--sudo" {
opts.Sudo = true
continue
}
if argStr == "--sudo-with-password" {
if !iter.HasNext() {
return nil, fmt.Errorf("'--sudo-with-password [pw]', missing password")
}
opts.Sudo = true
opts.SudoWithPass = true
opts.SudoPw = iter.Next()
continue
}
if argStr == "--sudo-with-passfile" {
if !iter.IsNextPlain() {
return nil, fmt.Errorf("'--sudo-with-passfile [file]', missing file")
}
opts.Sudo = true
opts.SudoWithPass = true
fileName := iter.Next()
contents, err := os.ReadFile(fileName)
if err != nil {
return nil, fmt.Errorf("cannot read --sudo-with-passfile file '%s': %w", fileName, err)
}
if newlineIdx := bytes.Index(contents, []byte{'\n'}); newlineIdx != -1 {
contents = contents[0:newlineIdx]
}
opts.SudoPw = string(contents) + "\n"
continue
}
if argStr == "--" {
if !iter.HasNext() {
return nil, fmt.Errorf("'--' should be followed by command")
}
opts.Command = strings.Join(iter.Rest(), " ")
break
}
return nil, fmt.Errorf("invalid option '%s' passed to mshell", argStr)
}
return opts, nil
}
func handleClient() (int, error) {
opts, err := parseClientOpts()
if err != nil {
return 1, fmt.Errorf("parsing opts: %w", err)
}
if opts.Debug {
packet.GlobalDebug = true
}
if opts.Command == "" {
return 1, fmt.Errorf("no [command] specified. [command] follows '--' option (see usage)")
}
fds, err := detectOpenFds()
if err != nil {
return 1, err
}
opts.Fds = fds
err = shexec.ValidateRemoteFds(opts.Fds)
if err != nil {
return 1, err
}
runPacket, err := opts.MakeRunPacket() // modifies opts
if err != nil {
return 1, err
}
if runPacket.Detached {
return 1, fmt.Errorf("cannot run detached command from command line client")
}
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, nil, opts.Debug)
if err != nil {
return 1, err
}
return donePacket.ExitCode, nil
}
func handleInstall() (int, error) {
opts, err := parseInstallOpts()
if err != nil {
return 1, fmt.Errorf("parsing opts: %w", err)
}
if opts.SSHOpts.SSHHost == "" {
return 1, fmt.Errorf("cannot install without '--ssh user@host' option")
}
if opts.Detect && opts.ArchStr != "" {
return 1, fmt.Errorf("cannot supply both --detect and arch '%s'", opts.ArchStr)
}
if opts.ArchStr == "" && !opts.Detect {
return 1, fmt.Errorf("must supply an arch string or '--detect' to auto detect")
}
if opts.ArchStr != "" {
fullArch := opts.ArchStr
fields := strings.SplitN(fullArch, ".", 2)
if len(fields) != 2 {
return 1, fmt.Errorf("invalid arch format '%s' passed to mshell --install", fullArch)
}
goos, goarch := fields[0], fields[1]
if !base.ValidGoArch(goos, goarch) {
return 1, fmt.Errorf("invalid arch '%s' passed to mshell --install", fullArch)
}
optName := base.GoArchOptFile(base.MShellVersion, goos, goarch)
_, err = os.Stat(optName)
if err != nil {
return 1, fmt.Errorf("cannot install mshell to remote host, cannot read '%s': %w", optName, err)
}
opts.OptName = optName
}
err = shexec.RunInstallFromOpts(opts)
if err != nil {
return 1, err
}
return 0, nil
}
func handleEnv() (int, error) {
cwd, err := os.Getwd()
if err != nil {
return 1, err
}
fmt.Printf("%s\x00\x00", cwd)
fullEnv := os.Environ()
var linePrinted bool
for _, envLine := range fullEnv {
if envLine != "" {
fmt.Printf("%s\x00", envLine)
linePrinted = true
}
}
if linePrinted {
fmt.Printf("\x00")
} else {
fmt.Printf("\x00\x00")
}
return 0, nil
}
func handleUsage() {
usage := `
Client Usage: mshell [opts] --ssh user@host -- [command]
mshell multiplexes input and output streams to a remote command over ssh.
mshell is a helper program for wave terminal. it is used to execute commands
Options:
-i [identity-file] - used to set '-i' option for ssh command
-l [user] - used to set '-l' option for ssh command
--cwd [dir] - execute remote command in [dir]
--ssh-opts [opts] - addition options to pass to ssh command
[command] - the remote command to execute
--help - prints this message
--version - print version
--server - multiplexer to run multiple commands
--single - run a single command (connected to multiplexer)
--single --version - return an init packet with version info
Sudo Options:
--sudo - use only if sudo never requires a password
--sudo-with-password [pw] - not recommended, use --sudo-with-passfile if possible
--sudo-with-passfile [file]
Sudo options allow you to run the given command using "sudo". The first
option only works when you can sudo without a password. Your password will be passed
securely through a high numbered fd to "sudo -S". Note that to use high numbered
file descriptors with sudo, you will need to add this line to your /etc/sudoers file:
Defaults closefrom_override
See full documentation for more details.
Examples:
# execute a python script remotely, with stdin still hooked up correctly
mshell --cwd "~/work" -i key.pem --ssh ubuntu@somehost -- "python3 /dev/fd/4" 4< myscript.py
# capture multiple outputs
mshell --ssh ubuntu@test -- "cat file1.txt > /dev/fd/3; cat file2.txt > /dev/fd/4" 3> file1.txt 4> file2.txt
# execute a script, catpure stdout/stderr in fd-3 and fd-4
# useful if you need to see stdout for interacting with ssh (password or host auth)
mshell --ssh user@host -- "test.sh > /dev/fd/3 2> /dev/fd/4" 3> test.stdout 4> test.stderr
# run a script as root (via sudo), capture output
mshell --sudo-with-passfile pw.txt --ssh ubuntu@somehost -- "python3 /dev/fd/3 > /dev/fd/4" 3< myscript.py 4> script-output.txt < script-input.txt
mshell does not open any external ports and does not require any additional permissions.
it communicates exclusively through stdin/stdout with an attached process
via a JSON packet format.
`
fmt.Printf("%s\n\n", strings.TrimSpace(usage))
}
@ -542,15 +131,13 @@ func main() {
} else if firstArg == "--version" {
fmt.Printf("mshell %s+%s\n", base.MShellVersion, base.BuildTime)
return
} else if firstArg == "--single" {
} else if firstArg == "--single" || firstArg == "--single-from-server" {
base.ProcessType = base.ProcessType_WaveShellSingle
base.InitDebugLog("single")
handleSingle(false)
return
} else if firstArg == "--single-from-server" {
base.InitDebugLog("single")
handleSingle(true)
handleSingle()
return
} else if firstArg == "--server" {
base.ProcessType = base.ProcessType_WaveShellServer
base.InitDebugLog("server")
rtnCode, err := server.RunServer()
if err != nil {
@ -560,21 +147,8 @@ func main() {
os.Exit(rtnCode)
}
return
} else if firstArg == "--install" {
rtnCode, err := handleInstall()
if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
}
os.Exit(rtnCode)
return
} else {
rtnCode, err := handleClient()
if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
}
if rtnCode != 0 {
os.Exit(rtnCode)
}
handleUsage()
return
}
}

View File

@ -30,7 +30,7 @@ const SSHCommandVarName = "SSH_COMMAND"
const MShellDebugVarName = "MSHELL_DEBUG"
const SessionsDirBaseName = "sessions"
const RcFilesDirBaseName = "rcfiles"
const MShellVersion = "v0.3.0"
const MShellVersion = "v0.4.0"
const RemoteIdFile = "remoteid"
const DefaultMShellInstallBinDir = "/opt/mshell/bin"
const LogFileName = "mshell.log"
@ -39,6 +39,13 @@ const ForceDebugLog = false
const DebugFlag_LogRcFile = "logrc"
const LogRcFileName = "debug.rcfile"
const (
ProcessType_Unknown = "unknown"
ProcessType_WaveSrv = "wavesrv"
ProcessType_WaveShellSingle = "waveshell-single"
ProcessType_WaveShellServer = "waveshell-server"
)
// keys are sessionids (also the key RcFilesDirBaseName)
var ensureDirCache = make(map[string]bool)
var baseLock = &sync.Mutex{}
@ -46,6 +53,8 @@ var DebugLogEnabled = false
var DebugLogger *log.Logger
var BuildTime string = "0"
var ProcessType string = ProcessType_Unknown
type CommandFileNames struct {
PtyOutFile string
StdinFifo string
@ -58,6 +67,10 @@ func SetBuildTime(build string) {
BuildTime = build
}
func IsWaveSrv() bool {
return ProcessType == ProcessType_WaveSrv
}
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
if sessionId == "" && cmdId == "" {
return CommandKey("")

View File

@ -44,9 +44,9 @@ func PackStrArr(w io.Writer, strs []string) error {
return PackValue(w, barr)
}
func PackInt(w io.Writer, ival int) error {
func PackUInt(w io.Writer, ival uint64) error {
viBuf := make([]byte, binary.MaxVarintLen64)
l := binary.PutUvarint(viBuf, uint64(ival))
l := binary.PutUvarint(viBuf, ival)
_, err := w.Write(viBuf[0:l])
return err
}
@ -80,8 +80,16 @@ func UnpackStrArr(r FullByteReader) ([]string, error) {
return strs, nil
}
func UnpackInt(r io.ByteReader) (int, error) {
ival64, err := binary.ReadVarint(r)
func UnpackUInt(r io.ByteReader) (uint64, error) {
ival64, err := binary.ReadUvarint(r)
if err != nil {
return 0, err
}
return ival64, nil
}
func UnpackUIntAsInt(r io.ByteReader) (int, error) {
ival64, err := UnpackUInt(r)
if err != nil {
return 0, err
}
@ -99,15 +107,15 @@ func (u *Unpacker) UnpackValue(name string) []byte {
return rtn
}
func (u *Unpacker) UnpackInt(name string) int {
func (u *Unpacker) UnpackUInt(name string) int {
if u.Err != nil {
return 0
}
rtn, err := UnpackInt(u.R)
rtn, err := UnpackUInt(u.R)
if err != nil {
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
}
return rtn
return int(rtn)
}
func (u *Unpacker) UnpackStrArr(name string) []string {

View File

@ -14,6 +14,7 @@ import (
"os"
"reflect"
"sync"
"time"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
)
@ -59,11 +60,18 @@ const (
WriteFileReadyPacketStr = "writefileready" // rpc-response
WriteFileDonePacketStr = "writefiledone" // rpc-response
FileDataPacketStr = "filedata"
LogPacketStr = "log" // logging packet (sent from waveshell back to server)
ShellStatePacketStr = "shellstate"
OpenAIPacketStr = "openai" // other
OpenAICloudReqStr = "openai-cloudreq"
)
const (
ShellType_bash = "bash"
ShellType_zsh = "zsh"
)
const PacketSenderQueueSize = 20
const PacketEOFStr = "EOF"
@ -102,6 +110,8 @@ func init() {
TypeStrToFactory[WriteFilePacketStr] = reflect.TypeOf(WriteFilePacketType{})
TypeStrToFactory[WriteFileReadyPacketStr] = reflect.TypeOf(WriteFileReadyPacketType{})
TypeStrToFactory[WriteFileDonePacketStr] = reflect.TypeOf(WriteFileDonePacketType{})
TypeStrToFactory[LogPacketStr] = reflect.TypeOf(LogPacketType{})
TypeStrToFactory[ShellStatePacketStr] = reflect.TypeOf(ShellStatePacketType{})
var _ RpcPacketType = (*RunPacketType)(nil)
var _ RpcPacketType = (*GetCmdPacketType)(nil)
@ -119,6 +129,7 @@ func init() {
var _ RpcResponsePacketType = (*FileDataPacketType)(nil)
var _ RpcResponsePacketType = (*WriteFileReadyPacketType)(nil)
var _ RpcResponsePacketType = (*WriteFileDonePacketType)(nil)
var _ RpcResponsePacketType = (*ShellStatePacketType)(nil)
var _ CommandPacketType = (*DataPacketType)(nil)
var _ CommandPacketType = (*DataAckPacketType)(nil)
@ -383,6 +394,7 @@ func MakeCdPacket() *CdPacketType {
type ReInitPacketType struct {
Type string `json:"type"`
ShellType string `json:"shelltype"`
ReqId string `json:"reqid"`
}
@ -543,6 +555,54 @@ func MakeRawPacket(val string) *RawPacketType {
return &RawPacketType{Type: RawPacketStr, Data: val}
}
type LogPacketType struct {
Type string `json:"type"`
Ts int64 `json:"ts"` // log timestamp
ReqId string `json:"reqid,omitempty"` // if this log line is related to an rpc request
ProcInfo string `json:"procinfo,omitempty"` // server/single
LogLine string `json:"logline"` // the logline data
}
func (*LogPacketType) GetType() string {
return LogPacketStr
}
func (p *LogPacketType) String() string {
return "log"
}
func MakeLogPacket() *LogPacketType {
return &LogPacketType{Type: LogPacketStr, Ts: time.Now().UnixMilli()}
}
type ShellStatePacketType struct {
Type string `json:"type"`
ShellType string `json:"shelltype"`
RespId string `json:"respid,omitempty"`
State *ShellState `json:"state"`
Error string `json:"error,omitempty"`
}
func (*ShellStatePacketType) GetType() string {
return ShellStatePacketStr
}
func (p *ShellStatePacketType) String() string {
return fmt.Sprintf("shellstate[%s]", p.ShellType)
}
func (p *ShellStatePacketType) GetResponseId() string {
return p.RespId
}
func (p *ShellStatePacketType) GetResponseDone() bool {
return true
}
func MakeShellStatePacket() *ShellStatePacketType {
return &ShellStatePacketType{Type: ShellStatePacketStr}
}
type MessagePacketType struct {
Type string `json:"type"`
CK base.CommandKey `json:"ck,omitempty"`
@ -573,7 +633,6 @@ type InitPacketType struct {
BuildTime string `json:"buildtime,omitempty"`
MShellHomeDir string `json:"mshellhomedir,omitempty"`
HomeDir string `json:"homedir,omitempty"`
State *ShellState `json:"state,omitempty"`
User string `json:"user,omitempty"`
HostName string `json:"hostname,omitempty"`
NotFound bool `json:"notfound,omitempty"`
@ -701,6 +760,7 @@ type RunPacketType struct {
Type string `json:"type"`
ReqId string `json:"reqid"`
CK base.CommandKey `json:"ck"`
ShellType string `json:"shelltype"` // new in v0.6.0 (either "bash" or "zsh") (set by remote.go)
Command string `json:"command"`
State *ShellState `json:"state,omitempty"`
StateDiff *ShellStateDiff `json:"statediff,omitempty"`

View File

@ -144,14 +144,20 @@ func (p *PacketParser) getRpcEntry(reqId string) *RpcEntry {
return entry
}
// returns true if sent to an RPC channel. false if not (which then allows the packet to be sent to MainCh)
// if GetResponseId() returns "", then this will return false
func (p *PacketParser) trySendRpcResponse(pk PacketType) bool {
respPk, ok := pk.(RpcResponsePacketType)
if !ok {
return false
}
respId := respPk.GetResponseId()
if respId == "" {
return false
}
p.Lock.Lock()
defer p.Lock.Unlock()
entry := p.RpcMap[respPk.GetResponseId()]
entry := p.RpcMap[respId]
if entry == nil {
return false
}

View File

@ -30,7 +30,7 @@ type ShellState struct {
}
type ShellStateDiff struct {
Version string `json:"version"` // [type] [semver]
Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base)
BaseHash string `json:"basehash"`
DiffHashArr []string `json:"diffhasharr,omitempty"`
Cwd string `json:"cwd,omitempty"`
@ -41,6 +41,66 @@ type ShellStateDiff struct {
HashVal string `json:"-"`
}
func (state ShellState) GetShellType() string {
shell, _, _ := ParseShellStateVersion(state.Version)
return shell
}
// returns (shell, version, error)
func ParseShellStateVersion(fullVersionStr string) (string, string, error) {
if fullVersionStr == "" {
return "", "", fmt.Errorf("empty shellstate version")
}
fields := strings.Split(fullVersionStr, " ")
if len(fields) != 2 {
return "", "", fmt.Errorf("invalid shellstate version format: %q", fullVersionStr)
}
shell := fields[0]
version := fields[1]
if shell != ShellType_zsh && shell != ShellType_bash {
return "", "", fmt.Errorf("invalid shellstate shell type: %q", fullVersionStr)
}
if !semver.IsValid(version) {
return "", "", fmt.Errorf("invalid shellstate semver: %q", fullVersionStr)
}
return shell, version, nil
}
// we're going to allow different versions (as long as shelltype is the same)
// before we required version numbers to match exactly which was too restrictive
func StateVersionsCompatible(v1 string, v2 string) bool {
if v1 == v2 {
return true
}
shell1, version1, err := ParseShellStateVersion(v1)
if err != nil {
return false
}
shell2, version2, err := ParseShellStateVersion(v2)
if err != nil {
return false
}
if shell1 != shell2 {
return false
}
if semver.Major(version1) != semver.Major(version2) {
return false
}
return true
}
func (diff ShellStateDiff) GetShellType() string {
shell, _, _ := ParseShellStateVersion(diff.Version)
return shell
}
func (state ShellState) GetLineDiffSplitString() string {
if state.GetShellType() == ShellType_zsh {
return "\x00"
}
return "\n"
}
func (state ShellState) IsEmpty() bool {
return state.Version == "" && state.Cwd == "" && len(state.ShellVars) == 0 && state.Aliases == "" && state.Funcs == "" && state.Error == ""
}
@ -55,7 +115,7 @@ func sha1Hash(data []byte) string {
// returns (SHA1, encoded-state)
func (state ShellState) EncodeAndHash() (string, []byte) {
var buf bytes.Buffer
binpack.PackInt(&buf, ShellStatePackVersion)
binpack.PackUInt(&buf, ShellStatePackVersion)
binpack.PackValue(&buf, []byte(state.Version))
binpack.PackValue(&buf, []byte(state.Cwd))
binpack.PackValue(&buf, state.ShellVars)
@ -66,7 +126,7 @@ func (state ShellState) EncodeAndHash() (string, []byte) {
}
// returns a string like "v4" ("" is an unparseable version)
func GetBashMajorVersion(versionStr string) string {
func GetMajorVersion(versionStr string) string {
if versionStr == "" {
return ""
}
@ -94,7 +154,7 @@ func (state *ShellState) DecodeShellState(barr []byte) error {
state.HashVal = sha1Hash(barr)
buf := bytes.NewBuffer(barr)
u := binpack.MakeUnpacker(buf)
version := u.UnpackInt("ShellState pack version")
version := u.UnpackUInt("ShellState pack version")
if version != ShellStatePackVersion {
return fmt.Errorf("invalid ShellState pack version: %d", version)
}
@ -118,7 +178,7 @@ func (state *ShellState) UnmarshalJSON(jsonBytes []byte) error {
func (sdiff ShellStateDiff) EncodeAndHash() (string, []byte) {
var buf bytes.Buffer
binpack.PackInt(&buf, ShellStateDiffPackVersion)
binpack.PackUInt(&buf, ShellStateDiffPackVersion)
binpack.PackValue(&buf, []byte(sdiff.Version))
binpack.PackValue(&buf, []byte(sdiff.BaseHash))
binpack.PackStrArr(&buf, sdiff.DiffHashArr)
@ -139,7 +199,7 @@ func (sdiff *ShellStateDiff) DecodeShellStateDiff(barr []byte) error {
sdiff.HashVal = sha1Hash(barr)
buf := bytes.NewBuffer(barr)
u := binpack.MakeUnpacker(buf)
version := u.UnpackInt("ShellState pack version")
version := u.UnpackUInt("ShellState pack version")
if version != ShellStateDiffPackVersion {
return fmt.Errorf("invalid ShellStateDiff pack version: %d", version)
}

View File

@ -0,0 +1,58 @@
package packet
import "testing"
func TestShellVersions(t *testing.T) {
if !StateVersionsCompatible("bash v5.0.17", "bash v5.0.17") {
t.Errorf("versions should be compatible")
}
if !StateVersionsCompatible("bash v5.0.17", "bash v5.0.18") {
t.Errorf("versions should be compatible")
}
if !StateVersionsCompatible("bash v5.0.17", "bash v5.1.0") {
t.Errorf("versions should be compatible")
}
if StateVersionsCompatible("bash v5.0.17", "bash v6.0.0") {
t.Errorf("versions should not be compatible")
}
if StateVersionsCompatible("bash v5.0.17", "zsh v5.0.17") {
t.Errorf("versions should not be compatible")
}
shell, version, err := ParseShellStateVersion("bash v5.0.17")
if err != nil {
t.Errorf("version should be valid, got error %v", err)
}
if shell != ShellType_bash {
t.Errorf("shell should be bash")
}
if version != "v5.0.17" {
t.Errorf("version should be v5.0.17")
}
shell, version, err = ParseShellStateVersion("zsh v5.0.17")
if err != nil {
t.Errorf("version should be valid, got error %v", err)
}
if shell != ShellType_zsh {
t.Errorf("shell should be zsh")
}
if version != "v5.0.17" {
t.Errorf("version should be v5.0.17")
}
_, _, err = ParseShellStateVersion("fish v5.0.17")
if err == nil {
t.Errorf("version should be invalid")
}
_, _, err = ParseShellStateVersion("bash v5.0.17.1")
if err == nil {
t.Errorf("version should be invalid")
}
_, _, err = ParseShellStateVersion("bash")
if err == nil {
t.Errorf("version should be invalid")
}
_, _, err = ParseShellStateVersion("bash v5.0.17 extrastuff")
if err == nil {
t.Errorf("version should be invalid")
}
}

View File

@ -20,7 +20,9 @@ import (
"github.com/alessio/shellescape"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
const MaxFileDataPacketSize = 16 * 1024
@ -28,6 +30,17 @@ const WriteFileContextTimeout = 30 * time.Second
const cleanLoopTime = 5 * time.Second
const MaxWriteFileContextData = 100
type shellStateMapKey struct {
ShellType string
Hash string
}
type ShellStateMap struct {
Lock *sync.Mutex
StateMap map[shellStateMapKey]*packet.ShellState // shelltype+hash -> state
CurrentStateMap map[string]string // shelltype -> hash
}
// TODO create unblockable packet-sender (backed by an array) for clientproc
type MServer struct {
Lock *sync.Mutex
@ -35,8 +48,7 @@ type MServer struct {
Sender *packet.PacketSender
ClientMap map[base.CommandKey]*shexec.ClientProc
Debug bool
StateMap map[string]*packet.ShellState // sha1->state
CurrentState string // sha1
StateMap *ShellStateMap
WriteErrorCh chan bool // closed if there is a I/O write error
WriteErrorChOnce *sync.Once
WriteFileContextMap map[string]*WriteFileContext
@ -146,11 +158,15 @@ func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
}
func runSingleCompGen(cwd string, compType string, prefix string) ([]string, bool, error) {
sapi, err := shellapi.MakeShellApi(packet.ShellType_bash)
if err != nil {
return nil, false, err
}
if !packet.IsValidCompGenType(compType) {
return nil, false, fmt.Errorf("invalid compgen type '%s'", compType)
}
compGenCmdStr := fmt.Sprintf("cd %s; compgen -A %s -- %s | sort | uniq | head -n %d", shellescape.Quote(cwd), shellescape.Quote(compType), shellescape.Quote(prefix), packet.MaxCompGenValues+1)
ecmd := exec.Command(shexec.GetLocalBashPath(), "-c", compGenCmdStr)
ecmd := exec.Command(sapi.GetLocalShellPath(), "-c", compGenCmdStr)
outputBytes, err := ecmd.Output()
if err != nil {
return nil, false, fmt.Errorf("compgen error: %w", err)
@ -230,26 +246,19 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
return
}
func (m *MServer) setCurrentState(state *packet.ShellState) {
if state == nil {
return
}
hval, _ := state.EncodeAndHash()
m.Lock.Lock()
defer m.Lock.Unlock()
m.StateMap[hval] = state
m.CurrentState = hval
}
func (m *MServer) reinit(reqId string) {
initPk, err := shexec.MakeServerInitPacket()
func (m *MServer) reinit(reqId string, shellType string) {
ssPk, err := shexec.MakeShellStatePacket(shellType)
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
return
}
m.setCurrentState(initPk.State)
initPk.RespId = reqId
m.Sender.SendPacket(initPk)
err = m.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error setting current state: %w", err))
return
}
ssPk.RespId = reqId
m.Sender.SendPacket(ssPk)
}
func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
@ -564,8 +573,8 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
go m.runCompGen(compPk)
return
}
if _, ok := pk.(*packet.ReInitPacketType); ok {
go m.reinit(reqId)
if reinitPk, ok := pk.(*packet.ReInitPacketType); ok {
go m.reinit(reqId, reinitPk.ShellType)
return
}
if streamPk, ok := pk.(*packet.StreamFilePacketType); ok {
@ -581,13 +590,7 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
return
}
func (m *MServer) getCurrentState() (string, *packet.ShellState) {
m.Lock.Lock()
defer m.Lock.Unlock()
return m.CurrentState, m.StateMap[m.CurrentState]
}
func (m *MServer) clientPacketCallback(pk packet.PacketType) {
func (m *MServer) clientPacketCallback(shellType string, pk packet.PacketType) {
if pk.GetType() != packet.CmdDonePacketStr {
return
}
@ -595,16 +598,25 @@ func (m *MServer) clientPacketCallback(pk packet.PacketType) {
if donePk.FinalState == nil {
return
}
stateHash, curState := m.getCurrentState()
stateHash, curState := m.StateMap.GetCurrentState(shellType)
if curState == nil {
return
}
diff, err := shexec.MakeShellStateDiff(*curState, stateHash, *donePk.FinalState)
sapi, err := shellapi.MakeShellApi(curState.GetShellType())
if err != nil {
return
}
diff, err := sapi.MakeShellStateDiff(curState, stateHash, donePk.FinalState)
if err != nil {
return
}
donePk.FinalState = nil
donePk.FinalStateDiff = &diff
donePk.FinalStateDiff = diff
}
func (m *MServer) isShellInitialized(shellType string) bool {
_, curState := m.StateMap.GetCurrentState(shellType)
return curState != nil
}
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
@ -612,7 +624,29 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
return
}
ecmd, err := shexec.SSHOpts{}.MakeMShellSingleCmd(true)
if runPacket.ShellType == "" {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require shell type"))
return
}
_, curInitState := m.StateMap.GetCurrentState(runPacket.ShellType)
if curInitState == nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("shell type %q is not initialized", runPacket.ShellType))
return
}
if runPacket.State == nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require state"))
return
}
_, _, err := packet.ParseShellStateVersion(runPacket.State.Version)
if err != nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("invalid shellstate version: %w", err))
return
}
if !packet.StateVersionsCompatible(runPacket.State.Version, curInitState.Version) {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("shellstate version %q is not compatible with current shell version %q", runPacket.State.Version, curInitState.Version))
return
}
ecmd, err := shexec.MakeMShellSingleCmd()
if err != nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
return
@ -640,7 +674,9 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
cproc.Close()
}()
shexec.SendRunPacketAndRunData(context.Background(), cproc.Input, runPacket)
cproc.ProxySingleOutput(runPacket.CK, m.Sender, m.clientPacketCallback)
cproc.ProxySingleOutput(runPacket.CK, m.Sender, func(pk packet.PacketType) {
m.clientPacketCallback(runPacket.ShellType, pk)
})
}()
}
@ -699,7 +735,7 @@ func RunServer() (int, error) {
server := &MServer{
Lock: &sync.Mutex{},
ClientMap: make(map[base.CommandKey]*shexec.ClientProc),
StateMap: make(map[string]*packet.ShellState),
StateMap: MakeShellStateMap(),
Debug: debug,
WriteErrorCh: make(chan bool),
WriteErrorChOnce: &sync.Once{},
@ -725,7 +761,6 @@ func RunServer() (int, error) {
if err != nil {
return 1, err
}
server.setCurrentState(initPacket.State)
server.Sender.SendPacket(initPacket)
ticker := time.NewTicker(1 * time.Minute)
go func() {
@ -748,3 +783,60 @@ func RunServer() (int, error) {
}
return 0, nil
}
func MakeShellStateMap() *ShellStateMap {
return &ShellStateMap{
Lock: &sync.Mutex{},
StateMap: make(map[shellStateMapKey]*packet.ShellState),
CurrentStateMap: make(map[string]string),
}
}
func (sm *ShellStateMap) GetCurrentState(shellType string) (string, *packet.ShellState) {
sm.Lock.Lock()
defer sm.Lock.Unlock()
hval := sm.CurrentStateMap[shellType]
return hval, sm.StateMap[shellStateMapKey{ShellType: shellType, Hash: hval}]
}
func (sm *ShellStateMap) SetCurrentState(shellType string, state *packet.ShellState) error {
if state == nil {
return fmt.Errorf("cannot set nil state")
}
if shellType != state.GetShellType() {
return fmt.Errorf("shell type mismatch: %s != %s", shellType, state.GetShellType())
}
sm.Lock.Lock()
defer sm.Lock.Unlock()
hval, _ := state.EncodeAndHash()
key := shellStateMapKey{ShellType: shellType, Hash: hval}
sm.StateMap[key] = state
sm.CurrentStateMap[shellType] = hval
return nil
}
func (sm *ShellStateMap) GetStateByHash(shellType string, hash string) *packet.ShellState {
sm.Lock.Lock()
defer sm.Lock.Unlock()
return sm.StateMap[shellStateMapKey{ShellType: shellType, Hash: hash}]
}
func (sm *ShellStateMap) Clear() {
sm.Lock.Lock()
defer sm.Lock.Unlock()
sm.StateMap = make(map[shellStateMapKey]*packet.ShellState)
sm.CurrentStateMap = make(map[string]string)
}
func (sm *ShellStateMap) GetShells() []string {
sm.Lock.Lock()
defer sm.Lock.Unlock()
return utilfn.GetMapKeys(sm.CurrentStateMap)
}
func (sm *ShellStateMap) HasShell(shellType string) bool {
sm.Lock.Lock()
defer sm.Lock.Unlock()
_, found := sm.CurrentStateMap[shellType]
return found
}

View File

@ -0,0 +1,260 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shellapi
import (
"bytes"
"context"
"fmt"
"os/exec"
"runtime"
"strings"
"sync"
"github.com/alessio/shellescape"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
)
const BaseBashOpts = `set +m; set +H; shopt -s extglob`
const BashShellVersionCmdStr = `echo bash v${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}`
const RemoteBashPath = "bash"
// TODO fix bash path in these constants
const RunBashSudoCommandFmt = `sudo -n -C %d bash /dev/fd/%d`
const RunBashSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d bash -c "echo '[from-mshell]'; exec %d>&-; bash /dev/fd/%d < /dev/fd/%d"`
// do not use these directly, call GetLocalMajorVersion()
var localBashMajorVersionOnce = &sync.Once{}
var localBashMajorVersion = ""
// the "exec 2>" line also adds an extra printf at the *beginning* to strip out spurious rc file output
var GetBashShellStateCmds = []string{
"exec 2> /dev/null;",
BashShellVersionCmdStr + ";",
`pwd;`,
`declare -p $(compgen -A variable);`,
`alias -p;`,
`declare -f;`,
GetGitBranchCmdStr + ";",
}
type bashShellApi struct{}
func (b bashShellApi) GetShellType() string {
return packet.ShellType_bash
}
func (b bashShellApi) MakeExitTrap(fdNum int) string {
return MakeBashExitTrap(fdNum)
}
func (b bashShellApi) GetLocalMajorVersion() string {
return GetLocalBashMajorVersion()
}
func (b bashShellApi) GetLocalShellPath() string {
return GetLocalBashPath()
}
func (b bashShellApi) GetRemoteShellPath() string {
return RemoteBashPath
}
func (b bashShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
if !opts.Sudo {
return fmt.Sprintf(RunCommandFmt, cmdStr)
}
if opts.SudoWithPass {
return fmt.Sprintf(RunBashSudoPasswordCommandFmt, opts.PwFdNum, opts.MaxFdNum+1, opts.PwFdNum, opts.CommandFdNum, opts.CommandStdinFdNum)
} else {
return fmt.Sprintf(RunBashSudoCommandFmt, opts.MaxFdNum+1, opts.CommandFdNum)
}
}
func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
}
func (b bashShellApi) GetShellState() (*packet.ShellState, error) {
return GetBashShellState()
}
func (b bashShellApi) GetBaseShellOpts() string {
return BaseBashOpts
}
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, error) {
return parseBashShellStateOutput(output)
}
func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
var rcBuf bytes.Buffer
rcBuf.WriteString(b.GetBaseShellOpts() + "\n")
varDecls := shellenv.VarDeclsFromState(pk.State)
for _, varDecl := range varDecls {
if varDecl.IsExport() || varDecl.IsReadOnly() {
continue
}
rcBuf.WriteString(BashDeclareStmt(varDecl))
rcBuf.WriteString("\n")
}
if pk.State != nil && pk.State.Funcs != "" {
rcBuf.WriteString(pk.State.Funcs)
rcBuf.WriteString("\n")
}
if pk.State != nil && pk.State.Aliases != "" {
rcBuf.WriteString(pk.State.Aliases)
rcBuf.WriteString("\n")
}
return rcBuf.String()
}
func GetBashShellStateCmd() string {
return strings.Join(GetBashShellStateCmds, ` printf "\x00\x00";`)
}
func execGetLocalBashShellVersion() string {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
ecmd := exec.CommandContext(ctx, "bash", "-c", BashShellVersionCmdStr)
out, err := ecmd.Output()
if err != nil {
return ""
}
versionStr := strings.TrimSpace(string(out))
if strings.Index(versionStr, "bash ") == -1 {
// invalid shell version (only bash is supported)
return ""
}
return versionStr
}
func GetLocalBashMajorVersion() string {
localBashMajorVersionOnce.Do(func() {
fullVersion := execGetLocalBashShellVersion()
localBashMajorVersion = packet.GetMajorVersion(fullVersion)
})
return localBashMajorVersion
}
func GetBashShellState() (*packet.ShellState, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
cmdStr := BaseBashOpts + "; " + GetBashShellStateCmd()
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
outputBytes, err := RunSimpleCmdInPty(ecmd)
if err != nil {
return nil, err
}
return parseBashShellStateOutput(outputBytes)
}
func GetLocalBashPath() string {
if runtime.GOOS == "darwin" {
macShell := GetMacUserShell()
if strings.Index(macShell, "bash") != -1 {
return shellescape.Quote(macShell)
}
}
return "bash"
}
func GetLocalZshPath() string {
if runtime.GOOS == "darwin" {
macShell := GetMacUserShell()
if strings.Index(macShell, "zsh") != -1 {
return shellescape.Quote(macShell)
}
}
return "zsh"
}
func GetBashShellStateRedirectCommandStr(outputFdNum int) string {
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetBashShellStateCmd(), outputFdNum)
}
func MakeBashExitTrap(fdNum int) string {
stateCmd := GetBashShellStateRedirectCommandStr(fdNum)
fmtStr := `
_waveshell_exittrap () {
%s
}
trap _waveshell_exittrap EXIT
`
return fmt.Sprintf(fmtStr, stateCmd)
}
func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
if usePty {
return exec.Command(GetLocalBashPath(), "--rcfile", rcFileName, "-i", "-c", cmdStr)
} else {
return exec.Command(GetLocalBashPath(), "--rcfile", rcFileName, "-c", cmdStr)
}
}
func (bashShellApi) MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error) {
if oldState == nil {
return nil, fmt.Errorf("cannot diff, oldState is nil")
}
if newState == nil {
return nil, fmt.Errorf("cannot diff, newState is nil")
}
if !packet.StateVersionsCompatible(oldState.Version, newState.Version) {
return nil, fmt.Errorf("cannot diff, incompatible shell versions: %q %q", oldState.Version, newState.Version)
}
rtn := &packet.ShellStateDiff{}
rtn.BaseHash = oldStateHash
rtn.Version = newState.Version // always set version in the diff
if oldState.Cwd != newState.Cwd {
rtn.Cwd = newState.Cwd
}
rtn.Error = newState.Error
oldVars := shellenv.ShellStateVarsToMap(oldState.ShellVars)
newVars := shellenv.ShellStateVarsToMap(newState.ShellVars)
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
rtn.AliasesDiff = statediff.MakeLineDiff(oldState.Aliases, newState.Aliases, oldState.GetLineDiffSplitString())
rtn.FuncsDiff = statediff.MakeLineDiff(oldState.Funcs, newState.Funcs, oldState.GetLineDiffSplitString())
return rtn, nil
}
func (bashShellApi) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) {
if oldState == nil {
return nil, fmt.Errorf("cannot apply diff, oldState is nil")
}
if diff == nil {
return oldState, nil
}
rtnState := &packet.ShellState{}
var err error
rtnState.Version = oldState.Version
// work around a bug (before v0.6.0) where version could be invalid.
// so only overwrite the oldversion if diff version is valid
_, _, diffVersionErr := packet.ParseShellStateVersion(diff.Version)
if diffVersionErr == nil {
rtnState.Version = diff.Version
}
rtnState.Cwd = oldState.Cwd
if diff.Cwd != "" {
rtnState.Cwd = diff.Cwd
}
rtnState.Error = diff.Error
oldVars := shellenv.ShellStateVarsToMap(oldState.ShellVars)
newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff)
if err != nil {
return nil, fmt.Errorf("applying mapdiff 'vars': %v", err)
}
rtnState.ShellVars = shellenv.StrMapToShellStateVars(newVars)
rtnState.Aliases, err = statediff.ApplyLineDiff(oldState.Aliases, diff.AliasesDiff)
if err != nil {
return nil, fmt.Errorf("applying diff 'aliases': %v", err)
}
rtnState.Funcs, err = statediff.ApplyLineDiff(oldState.Funcs, diff.FuncsDiff)
if err != nil {
return nil, fmt.Errorf("applying diff 'funcs': %v", err)
}
return rtnState, nil
}

View File

@ -0,0 +1,328 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shellapi
import (
"bytes"
"fmt"
"io"
"regexp"
"sort"
"strings"
"github.com/alessio/shellescape"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/syntax"
)
type DeclareDeclType = shellenv.DeclareDeclType
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
return nil
}
func doProcSubst(w *syntax.ProcSubst) (string, error) {
return "", nil
}
type bashParseEnviron struct {
Env map[string]string
}
func (e *bashParseEnviron) Get(name string) expand.Variable {
val, ok := e.Env[name]
if !ok {
return expand.Variable{}
}
return expand.Variable{
Exported: true,
Kind: expand.String,
Str: val,
}
}
func (e *bashParseEnviron) Each(fn func(name string, vr expand.Variable) bool) {
for key := range e.Env {
rtn := fn(key, e.Get(key))
if !rtn {
break
}
}
}
func GetParserConfig(envMap map[string]string) *expand.Config {
cfg := &expand.Config{
Env: &bashParseEnviron{Env: envMap},
GlobStar: false,
NullGlob: false,
NoUnset: false,
CmdSubst: func(w io.Writer, word *syntax.CmdSubst) error { return doCmdSubst("", w, word) },
ProcSubst: doProcSubst,
ReadDir: nil,
}
return cfg
}
// https://wiki.bash-hackers.org/syntax/shellvars
var BashNoStoreVarNames = map[string]bool{
"BASH": true,
"BASHOPTS": true,
"BASHPID": true,
"BASH_ALIASES": true,
"BASH_ARGC": true,
"BASH_ARGV": true,
"BASH_ARGV0": true,
"BASH_CMDS": true,
"BASH_COMMAND": true,
"BASH_EXECUTION_STRING": true,
"LINENO": true,
"BASH_LINENO": true,
"BASH_REMATCH": true,
"BASH_SOURCE": true,
"BASH_SUBSHELL": true,
"COPROC": true,
"DIRSTACK": true,
"EPOCHREALTIME": true,
"EPOCHSECONDS": true,
"FUNCNAME": true,
"HISTCMD": true,
"OLDPWD": true,
"PIPESTATUS": true,
"PPID": true,
"PWD": true,
"RANDOM": true,
"SECONDS": true,
"SHLVL": true,
"HISTFILE": true,
"HISTFILESIZE": true,
"HISTCONTROL": true,
"HISTIGNORE": true,
"HISTSIZE": true,
"HISTTIMEFORMAT": true,
"SRANDOM": true,
"COLUMNS": true,
"LINES": true,
// we want these in our remote state object
// "EUID": true,
// "SHELLOPTS": true,
// "UID": true,
// "BASH_VERSINFO": true,
// "BASH_VERSION": true,
}
var declareDeclArgsRe = regexp.MustCompile("^[aAxrifx]*$")
var bashValidIdentifierRe = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
func bashValidate(d *DeclareDeclType) error {
if len(d.Name) == 0 || !isValidBashIdentifier(d.Name) {
return fmt.Errorf("invalid shell variable name (invalid bash identifier)")
}
if strings.Index(d.Value, "\x00") >= 0 {
return fmt.Errorf("invalid shell variable value (cannot contain 0 byte)")
}
if !declareDeclArgsRe.MatchString(d.Args) {
return fmt.Errorf("invalid shell variable type %s", shellescape.Quote(d.Args))
}
return nil
}
func isValidBashIdentifier(s string) bool {
return bashValidIdentifierRe.MatchString(s)
}
func bashParseDeclareStmt(stmt *syntax.Stmt, src string) (*DeclareDeclType, error) {
cmd := stmt.Cmd
decl, ok := cmd.(*syntax.DeclClause)
if !ok || decl.Variant.Value != "declare" || len(decl.Args) != 2 {
return nil, fmt.Errorf("invalid declare variant")
}
rtn := &DeclareDeclType{}
declArgs := decl.Args[0]
if !declArgs.Naked || len(declArgs.Value.Parts) != 1 {
return nil, fmt.Errorf("wrong number of declare args parts")
}
declArgsLit, ok := declArgs.Value.Parts[0].(*syntax.Lit)
if !ok {
return nil, fmt.Errorf("declare args is not a literal")
}
if !strings.HasPrefix(declArgsLit.Value, "-") {
return nil, fmt.Errorf("declare args not an argument (does not start with '-')")
}
if declArgsLit.Value == "--" {
rtn.Args = ""
} else {
rtn.Args = declArgsLit.Value[1:]
}
declAssign := decl.Args[1]
if declAssign.Name == nil {
return nil, fmt.Errorf("declare does not have a valid name")
}
rtn.Name = declAssign.Name.Value
if declAssign.Naked || declAssign.Index != nil || declAssign.Append {
return nil, fmt.Errorf("invalid decl format")
}
if declAssign.Value != nil {
rtn.Value = string(src[declAssign.Value.Pos().Offset():declAssign.Value.End().Offset()])
} else if declAssign.Array != nil {
rtn.Value = string(src[declAssign.Array.Pos().Offset():declAssign.Array.End().Offset()])
} else {
return nil, fmt.Errorf("invalid decl, not plain value or array")
}
err := bashNormalize(rtn)
if err != nil {
return nil, err
}
if err = bashValidate(rtn); err != nil {
return nil, err
}
return rtn, nil
}
func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarBytes []byte) error {
declareStr := string(declareBytes)
r := bytes.NewReader(declareBytes)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(r, "aliases")
if err != nil {
return err
}
var firstParseErr error
declMap := make(map[string]*DeclareDeclType)
for _, stmt := range file.Stmts {
decl, err := bashParseDeclareStmt(stmt, declareStr)
if err != nil {
if firstParseErr == nil {
firstParseErr = err
}
}
if decl != nil && !BashNoStoreVarNames[decl.Name] {
declMap[decl.Name] = decl
}
}
pvarMap := parsePVarOutput(pvarBytes, false)
utilfn.CombineMaps(declMap, pvarMap)
state.ShellVars = shellenv.SerializeDeclMap(declMap) // this writes out the decls in a canonical order
if firstParseErr != nil {
state.Error = firstParseErr.Error()
}
return nil
}
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
if scbase.IsDevMode() && DebugState {
writeStateToFile(packet.ShellType_bash, outputBytes)
}
// 7 fields: ignored [0], version [1], cwd [2], env/vars [3], aliases [4], funcs [5], pvars [6]
fields := bytes.Split(outputBytes, []byte{0, 0})
if len(fields) != 7 {
return nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(fields))
}
rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(string(fields[1]))
if rtn.GetShellType() != packet.ShellType_bash {
return nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
}
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
return nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
}
cwdStr := string(fields[2])
if strings.HasSuffix(cwdStr, "\r\n") {
cwdStr = cwdStr[0 : len(cwdStr)-2]
} else if strings.HasSuffix(cwdStr, "\n") {
cwdStr = cwdStr[0 : len(cwdStr)-1]
}
rtn.Cwd = string(cwdStr)
err := bashParseDeclareOutput(rtn, fields[3], fields[6])
if err != nil {
return nil, err
}
rtn.Aliases = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(fields[5]), "\r\n", "\n")
rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap")
return rtn, nil
}
func bashNormalize(d *DeclareDeclType) error {
if d.DataType() == shellenv.DeclTypeAssocArray {
return bashNormalizeAssocArrayDecl(d)
}
return nil
}
// normalizes order of assoc array keys so value is stable
func bashNormalizeAssocArrayDecl(d *DeclareDeclType) error {
if d.DataType() != shellenv.DeclTypeAssocArray {
return fmt.Errorf("invalid decltype passed to assocArrayDeclToStr: %s", d.DataType())
}
varMap, err := bashAssocArrayVarToMap(d)
if err != nil {
return err
}
keys := make([]string, 0, len(varMap))
for key := range varMap {
keys = append(keys, key)
}
sort.Strings(keys)
var buf bytes.Buffer
buf.WriteByte('(')
for _, key := range keys {
buf.WriteByte('[')
buf.WriteString(key)
buf.WriteByte(']')
buf.WriteByte('=')
buf.WriteString(varMap[key])
buf.WriteByte(' ')
}
buf.WriteByte(')')
d.Value = buf.String()
return nil
}
func bashAssocArrayVarToMap(d *DeclareDeclType) (map[string]string, error) {
if d.DataType() != shellenv.DeclTypeAssocArray {
return nil, fmt.Errorf("decl is not an assoc-array")
}
refStr := "X=" + d.Value
r := strings.NewReader(refStr)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(r, "assocdecl")
if err != nil {
return nil, err
}
if len(file.Stmts) != 1 {
return nil, fmt.Errorf("invalid assoc-array parse (multiple stmts)")
}
stmt := file.Stmts[0]
callExpr, ok := stmt.Cmd.(*syntax.CallExpr)
if !ok || len(callExpr.Args) != 0 || len(callExpr.Assigns) != 1 {
return nil, fmt.Errorf("invalid assoc-array parse (bad expr)")
}
assign := callExpr.Assigns[0]
arrayExpr := assign.Array
if arrayExpr == nil {
return nil, fmt.Errorf("invalid assoc-array parse (no array expr)")
}
rtn := make(map[string]string)
for _, elem := range arrayExpr.Elems {
indexStr := refStr[elem.Index.Pos().Offset():elem.Index.End().Offset()]
valStr := refStr[elem.Value.Pos().Offset():elem.Value.End().Offset()]
rtn[indexStr] = valStr
}
return rtn, nil
}
func BashDeclareStmt(d *DeclareDeclType) string {
var argsStr string
if d.Args == "" {
argsStr = "--"
} else {
argsStr = "-" + d.Args
}
return fmt.Sprintf("declare %s %s=%s", argsStr, d.Name, d.Value)
}

View File

@ -0,0 +1,254 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shellapi
import (
"bytes"
"context"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"path"
"path/filepath"
"regexp"
"runtime"
"strings"
"sync"
"syscall"
"time"
"github.com/alessio/shellescape"
"github.com/creack/pty"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
)
const GetStateTimeout = 5 * time.Second
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
const RunCommandFmt = `%s`
const DebugState = false
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
var cachedMacUserShell string
var macUserShellOnce = &sync.Once{}
const DefaultMacOSShell = "/bin/bash"
type RunCommandOpts struct {
Sudo bool
SudoWithPass bool
MaxFdNum int // needed for Sudo
CommandFdNum int // needed for Sudo
PwFdNum int // needed for SudoWithPass
CommandStdinFdNum int // needed for SudoWithPass
}
type ShellApi interface {
GetShellType() string
MakeExitTrap(fdNum int) string
GetLocalMajorVersion() string
GetLocalShellPath() string
GetRemoteShellPath() string
MakeRunCommand(cmdStr string, opts RunCommandOpts) string
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
GetShellState() (*packet.ShellState, error)
GetBaseShellOpts() string
ParseShellStateOutput(output []byte) (*packet.ShellState, error)
MakeRcFileStr(pk *packet.RunPacketType) string
MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error)
ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error)
}
func DetectLocalShellType() string {
shellPath := GetMacUserShell()
if shellPath == "" {
shellPath = os.Getenv("SHELL")
}
if shellPath == "" {
return packet.ShellType_bash
}
_, file := filepath.Split(shellPath)
if strings.HasPrefix(file, "zsh") {
return packet.ShellType_zsh
}
return packet.ShellType_bash
}
func HasShell(shellType string) bool {
if shellType == packet.ShellType_bash {
_, err := exec.LookPath("bash")
return err != nil
}
if shellType == packet.ShellType_zsh {
_, err := exec.LookPath("zsh")
return err != nil
}
return false
}
func MakeShellApi(shellType string) (ShellApi, error) {
if shellType == "" || shellType == packet.ShellType_bash {
return &bashShellApi{}, nil
}
if shellType == packet.ShellType_zsh {
return &zshShellApi{}, nil
}
return nil, fmt.Errorf("shell type not supported: %s", shellType)
}
func GetMacUserShell() string {
if runtime.GOOS != "darwin" {
return ""
}
macUserShellOnce.Do(func() {
cachedMacUserShell = internalMacUserShell()
})
return cachedMacUserShell
}
// dscl . -read /User/[username] UserShell
// defaults to /bin/bash
func internalMacUserShell() string {
osUser, err := user.Current()
if err != nil {
return DefaultMacOSShell
}
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
userStr := "/Users/" + osUser.Username
out, err := exec.CommandContext(ctx, "dscl", ".", "-read", userStr, "UserShell").CombinedOutput()
if err != nil {
return DefaultMacOSShell
}
outStr := strings.TrimSpace(string(out))
m := userShellRegexp.FindStringSubmatch(outStr)
if m == nil {
return DefaultMacOSShell
}
return m[1]
}
const FirstExtraFilesFdNum = 3
// returns output(stdout+stderr), extraFdOutput, error
func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, error) {
ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
cmdPty, cmdTty, err := pty.Open()
if err != nil {
return nil, nil, fmt.Errorf("opening new pty: %w", err)
}
defer cmdTty.Close()
defer cmdPty.Close()
pty.Setsize(cmdPty, &pty.Winsize{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols})
ecmd.Stdin = cmdTty
ecmd.Stdout = cmdTty
ecmd.Stderr = cmdTty
ecmd.SysProcAttr = &syscall.SysProcAttr{}
ecmd.SysProcAttr.Setsid = true
ecmd.SysProcAttr.Setctty = true
pipeReader, pipeWriter, err := os.Pipe()
if err != nil {
return nil, nil, fmt.Errorf("could not create pipe: %w", err)
}
defer pipeWriter.Close()
defer pipeReader.Close()
extraFiles := make([]*os.File, extraFdNum+1)
extraFiles[extraFdNum] = pipeWriter
ecmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:]
defer pipeReader.Close()
ecmd.Start()
cmdTty.Close()
pipeWriter.Close()
if err != nil {
return nil, nil, err
}
var outputWg sync.WaitGroup
var outputBuf bytes.Buffer
var extraFdOutputBuf bytes.Buffer
outputWg.Add(2)
go func() {
// ignore error (/dev/ptmx has read error when process is done)
defer outputWg.Done()
io.Copy(&outputBuf, cmdPty)
}()
go func() {
defer outputWg.Done()
io.Copy(&extraFdOutputBuf, pipeReader)
}()
exitErr := ecmd.Wait()
if exitErr != nil {
return nil, nil, exitErr
}
outputWg.Wait()
return outputBuf.Bytes(), extraFdOutputBuf.Bytes(), nil
}
func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
cmdPty, cmdTty, err := pty.Open()
if err != nil {
return nil, fmt.Errorf("opening new pty: %w", err)
}
pty.Setsize(cmdPty, &pty.Winsize{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols})
ecmd.Stdin = cmdTty
ecmd.Stdout = cmdTty
ecmd.Stderr = cmdTty
ecmd.SysProcAttr = &syscall.SysProcAttr{}
ecmd.SysProcAttr.Setsid = true
ecmd.SysProcAttr.Setctty = true
err = ecmd.Start()
cmdTty.Close()
if err != nil {
cmdPty.Close()
return nil, err
}
defer cmdPty.Close()
ioDone := make(chan bool)
var outputBuf bytes.Buffer
go func() {
// ignore error (/dev/ptmx has read error when process is done)
io.Copy(&outputBuf, cmdPty)
close(ioDone)
}()
exitErr := ecmd.Wait()
if exitErr != nil {
return nil, exitErr
}
<-ioDone
return outputBuf.Bytes(), nil
}
func parsePVarOutput(pvarBytes []byte, isZsh bool) map[string]*DeclareDeclType {
declMap := make(map[string]*DeclareDeclType)
pvars := bytes.Split(pvarBytes, []byte{0})
for _, pvarBA := range pvars {
pvarStr := string(pvarBA)
pvarFields := strings.SplitN(pvarStr, " ", 2)
if len(pvarFields) != 2 {
continue
}
if pvarFields[0] == "" {
continue
}
decl := &DeclareDeclType{IsZshDecl: isZsh, Args: "x"}
decl.Name = "PROMPTVAR_" + pvarFields[0]
decl.Value = shellescape.Quote(pvarFields[1])
declMap[decl.Name] = decl
}
return declMap
}
// for debugging (not for production use)
func writeStateToFile(shellType string, outputBytes []byte) error {
msHome := base.GetMShellHomeDir()
stateFileName := path.Join(msHome, shellType+"-state.txt")
os.WriteFile(stateFileName, outputBytes, 0644)
return nil
}

View File

@ -0,0 +1,826 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shellapi
import (
"bytes"
"context"
"fmt"
"math/rand"
"os/exec"
"strings"
"sync"
"github.com/alessio/shellescape"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/binpack"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
)
const BaseZshOpts = ``
const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION`
const StateOutputFdNum = 20
// TODO these need updating
const RunZshSudoCommandFmt = `sudo -n -C %d zsh /dev/fd/%d`
const RunZshSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d zsh -c "echo '[from-mshell]'; exec %d>&-; zsh /dev/fd/%d < /dev/fd/%d"`
var ZshIgnoreVars = map[string]bool{
"_": true,
"0": true,
"terminfo": true,
"RANDOM": true,
"COLUMNS": true,
"LINES": true,
"argv": true,
"SECONDS": true,
"PWD": true,
"HISTCHARS": true,
"HISTFILE": true,
"HISTSIZE": true,
"SAVEHIST": true,
"ZSH_EXECUTION_STRING": true,
"EPOCHSECONDS": true,
"EPOCHREALTIME": true,
"SHLVL": true,
"TTY": true,
"epochtime": true,
"langinfo": true,
"aliases": true,
"dis_aliases": true,
"saliases": true,
"dis_saliases": true,
"galiases": true,
"dis_galiases": true,
"builtins": true,
"dis_builtins": true,
"modules": true,
"history": true,
"historywords": true,
"jobdirs": true,
"jobstates": true,
"jobtexts": true,
"funcfiletrace": true,
"funcsourcetrace": true,
"funcstack": true,
"functrace": true,
"parameters": true,
"commands": true,
"functions": true,
"dis_functions": true,
"functions_source": true,
"dis_functions_source": true,
"_comps": true,
"_patcomps": true,
"_postpatcomps": true,
}
var ZshUniqueArrayVars = map[string]bool{
"path": true,
"fpath": true,
}
var ZshSpecialDecls = map[string]bool{
"precmd_functions": true,
"preexec_functions": true,
}
var ZshUnsetVars = []string{
"HISTFILE",
"ZSH_EXECUTION_STRING",
}
// do not use these directly, call GetLocalMajorVersion()
var localZshMajorVersionOnce = &sync.Once{}
var localZshMajorVersion = ""
// sentinel value for functions that should be autoloaded
const ZshFnAutoLoad = "autoload"
type ZshParamKey struct {
// paramtype cannot contain spaces
// "aliases", "dis_aliases", "saliases", "dis_saliases", "galiases", "dis_galiases"
// "functions", "dis_functions", "functions_source", "dis_functions_source"
ParamType string
ParamName string
}
func (k ZshParamKey) String() string {
return k.ParamType + " " + k.ParamName
}
func ZshParamKeyFromString(s string) (ZshParamKey, error) {
parts := strings.SplitN(s, " ", 2)
if len(parts) != 2 {
return ZshParamKey{}, fmt.Errorf("invalid zsh param key")
}
return ZshParamKey{ParamType: parts[0], ParamName: parts[1]}, nil
}
type ZshMap = map[ZshParamKey]string
type zshShellApi struct{}
func (z zshShellApi) GetShellType() string {
return packet.ShellType_zsh
}
func (z zshShellApi) MakeExitTrap(fdNum int) string {
return MakeZshExitTrap(fdNum)
}
func (z zshShellApi) GetLocalMajorVersion() string {
return GetLocalZshMajorVersion()
}
func (z zshShellApi) GetLocalShellPath() string {
return "/bin/zsh"
}
func (z zshShellApi) GetRemoteShellPath() string {
return "zsh"
}
func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
if !opts.Sudo {
return cmdStr
}
if opts.SudoWithPass {
return fmt.Sprintf(RunZshSudoPasswordCommandFmt, opts.PwFdNum, opts.MaxFdNum+1, opts.PwFdNum, opts.CommandFdNum, opts.CommandStdinFdNum)
} else {
return fmt.Sprintf(RunZshSudoCommandFmt, opts.MaxFdNum+1, opts.CommandFdNum)
}
}
func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
}
func (z zshShellApi) GetShellState() (*packet.ShellState, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
cmdStr := BaseZshOpts + "; " + GetZshShellStateCmd(StateOutputFdNum)
ecmd := exec.CommandContext(ctx, GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
_, outputBytes, err := RunCommandWithExtraFd(ecmd, StateOutputFdNum)
if err != nil {
return nil, err
}
rtn, err := z.ParseShellStateOutput(outputBytes)
if err != nil {
return nil, err
}
return rtn, nil
}
func (z zshShellApi) GetBaseShellOpts() string {
return BaseZshOpts
}
func makeZshTypesetStmt(varDecl *shellenv.DeclareDeclType) string {
if !varDecl.IsZshDecl {
// not sure what to do here?
return ""
}
var argsStr string
if varDecl.Args == "" {
argsStr = "--"
} else {
argsStr = "-" + varDecl.Args
}
if varDecl.IsZshScalarBound() {
// varDecl.Value contains the extra "separator" field (if present in the original typeset def)
return fmt.Sprintf("typeset %s %s %s=%s", argsStr, varDecl.ZshBoundScalar, varDecl.Name, varDecl.Value)
} else {
return fmt.Sprintf("typeset %s %s=%s", argsStr, varDecl.Name, varDecl.Value)
}
}
func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
var rcBuf bytes.Buffer
rcBuf.WriteString(z.GetBaseShellOpts() + "\n")
rcBuf.WriteString("unsetopt GLOBAL_RCS\n")
rcBuf.WriteString("unset KSH_ARRAYS\n")
rcBuf.WriteString("zmodload zsh/parameter\n")
varDecls := shellenv.VarDeclsFromState(pk.State)
var postDecls []*shellenv.DeclareDeclType
for _, varDecl := range varDecls {
if ZshIgnoreVars[varDecl.Name] {
continue
}
if ZshUniqueArrayVars[varDecl.Name] && !varDecl.IsUniqueArray() {
varDecl.AddFlag("U")
}
if ZshSpecialDecls[varDecl.Name] {
postDecls = append(postDecls, varDecl)
continue
}
stmt := makeZshTypesetStmt(varDecl)
if stmt == "" {
continue
}
rcBuf.WriteString(makeZshTypesetStmt(varDecl))
rcBuf.WriteString("\n")
}
if shellenv.FindVarDecl(varDecls, "ZDOTDIR") == nil {
rcBuf.WriteString("unset ZDOTDIR\n")
rcBuf.WriteString("\n")
}
for _, varName := range ZshUnsetVars {
rcBuf.WriteString("unset " + shellescape.Quote(varName) + "\n")
}
// aliases
aliasMap, err := DecodeZshMap([]byte(pk.State.Aliases))
if err != nil {
base.Logf("error decoding zsh aliases: %v\n", err)
rcBuf.WriteString("# error decoding zsh aliases\n")
} else {
for aliasKey, aliasValue := range aliasMap {
// tricky here, don't quote AliasName (it gets implicit quotes, and quoting doesn't work as expected)
aliasStr := fmt.Sprintf("%s[%s]=%s\n", aliasKey.ParamType, aliasKey.ParamName, shellescape.Quote(aliasValue))
rcBuf.WriteString(aliasStr)
}
}
// functions
fnMap, err := DecodeZshMap([]byte(pk.State.Funcs))
if err != nil {
base.Logf("error decoding zsh functions: %v\n", err)
rcBuf.WriteString("# error decoding zsh functions\n")
} else {
for fnKey, fnValue := range fnMap {
if fnValue == ZshFnAutoLoad {
rcBuf.WriteString(fmt.Sprintf("autoload %s\n", shellescape.Quote(fnKey.ParamName)))
} else {
// careful, no whitespace (except newlines)
rcBuf.WriteString(fmt.Sprintf("function %s () {\n%s\n}\n", shellescape.Quote(fnKey.ParamName), fnValue))
if fnKey.ParamType == "dis_functions" {
rcBuf.WriteString(fmt.Sprintf("disable -f %s\n", shellescape.Quote(fnKey.ParamName)))
}
}
}
}
// write postdecls
for _, varDecl := range postDecls {
rcBuf.WriteString(makeZshTypesetStmt(varDecl))
rcBuf.WriteString("\n")
}
return rcBuf.String()
}
func writeZshId(buf *bytes.Buffer, idStr string) {
buf.WriteString(shellescape.Quote(idStr))
}
const numRandomBytes = 4
// returns (cmd-string)
func GetZshShellStateCmd(fdNum int) string {
var sectionSeparator []byte
// adding this extra "\n" helps with debuging and readability of output
sectionSeparator = append(sectionSeparator, byte('\n'))
for len(sectionSeparator) < numRandomBytes {
// any character *except* null (0)
rn := rand.Intn(256)
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
sectionSeparator = append(sectionSeparator, byte(rn))
}
}
sectionSeparator = append(sectionSeparator, 0, 0)
// we have to use these crazy separators because zsh allows basically anything in
// variable names and values (including nulls).
// note that we don't need crazy separators for "env" or "typeset".
// environment variables *cannot* contain nulls by definition, and "typeset" already escapes nulls.
// the raw aliases and functions though need to be handled more carefully
// output redirection is necessary to prevent cooked tty options from screwing up the output (funcs especially)
// note we do not need the "extra" separator that bashapi uses because we are reading from OUTPUTFD (which already excludes any spurious stdout/stderr data)
cmd := `
exec > [%OUTPUTFD%]
unsetopt SH_WORD_SPLIT;
zmodload zsh/parameter;
[%ZSHVERSION%];
printf "\x00[%SECTIONSEP%]";
pwd;
printf "[%SECTIONSEP%]";
env -0;
printf "[%SECTIONSEP%]";
typeset -p +H -m '*';
printf "[%SECTIONSEP%]";
for var in "${(@k)aliases}"; do
printf "aliases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${aliases[$var]}
done
for var in "${(@k)dis_aliases}"; do
printf "dis_aliases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_aliases[$var]}
done
for var in "${(@k)saliases}"; do
printf "saliases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${saliases[$var]}
done
for var in "${(@k)dis_saliases}"; do
printf "dis_saliases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_saliases[$var]}
done
for var in "${(@k)galiases}"; do
printf "galiases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${galiases[$var]}
done
for var in "${(@k)dis_galiases}"; do
printf "dis_galiases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_galiases[$var]}
done
printf "[%SECTIONSEP%]";
echo $FPATH;
printf "[%SECTIONSEP%]";
for var in "${(@k)functions}"; do
printf "functions %s[%PARTSEP%]%s[%PARTSEP%]" $var ${functions[$var]}
done
for var in "${(@k)dis_functions}"; do
printf "dis_functions %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_functions[$var]}
done
for var in "${(@k)functions_source}"; do
printf "functions_source %s[%PARTSEP%]%s[%PARTSEP%]" $var ${functions_source[$var]}
done
for var in "${(@k)dis_functions_source}"; do
printf "dis_functions_source %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_functions_source[$var]}
done
printf "[%SECTIONSEP%]";
[%GITBRANCH%]
`
cmd = strings.TrimSpace(cmd)
cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr)
cmd = strings.ReplaceAll(cmd, "[%GITBRANCH%]", GetGitBranchCmdStr)
cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1])))
cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator)))
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
return cmd
}
func MakeZshExitTrap(fdNum int) string {
stateCmd := GetZshShellStateCmd(fdNum)
fmtStr := `
zshexit () {
%s
}
`
return fmt.Sprintf(fmtStr, stateCmd)
}
func execGetLocalZshShellVersion() string {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
ecmd := exec.CommandContext(ctx, "zsh", "-c", ZshShellVersionCmdStr)
out, err := ecmd.Output()
if err != nil {
return ""
}
versionStr := strings.TrimSpace(string(out))
if strings.Index(versionStr, "zsh ") == -1 {
return ""
}
return versionStr
}
func GetLocalZshMajorVersion() string {
localZshMajorVersionOnce.Do(func() {
fullVersion := execGetLocalZshShellVersion()
localZshMajorVersion = packet.GetMajorVersion(fullVersion)
})
return localZshMajorVersion
}
func EncodeZshMap(m ZshMap) []byte {
var buf bytes.Buffer
binpack.PackUInt(&buf, uint64(len(m)))
orderedKeys := utilfn.GetOrderedStringerMapKeys(m)
for _, key := range orderedKeys {
value := m[key]
binpack.PackValue(&buf, []byte(key.String()))
binpack.PackValue(&buf, []byte(value))
}
return buf.Bytes()
}
func EncodeZshMapForApply(m map[string][]byte) string {
var buf bytes.Buffer
binpack.PackUInt(&buf, uint64(len(m)))
orderedKeys := utilfn.GetOrderedMapKeys(m)
for _, key := range orderedKeys {
value := m[key]
binpack.PackValue(&buf, []byte(key))
binpack.PackValue(&buf, value)
}
return buf.String()
}
func DecodeZshMapForDiff(barr []byte) (map[string][]byte, error) {
rtn := make(map[string][]byte)
buf := bytes.NewBuffer(barr)
u := binpack.MakeUnpacker(buf)
numEntries := u.UnpackUInt("numEntries")
for idx := 0; idx < numEntries; idx++ {
key := string(u.UnpackValue("key"))
value := u.UnpackValue("value")
rtn[key] = value
}
if u.Error() != nil {
return nil, u.Error()
}
return rtn, nil
}
func DecodeZshMap(barr []byte) (ZshMap, error) {
rtn := make(ZshMap)
buf := bytes.NewBuffer(barr)
u := binpack.MakeUnpacker(buf)
numEntries := u.UnpackUInt("numEntries")
for idx := 0; idx < numEntries; idx++ {
key := string(u.UnpackValue("key"))
value := string(u.UnpackValue("value"))
zshKey, err := ZshParamKeyFromString(key)
if err != nil {
return nil, err
}
rtn[zshKey] = value
}
if u.Error() != nil {
return nil, u.Error()
}
return rtn, nil
}
func parseZshAliasStateOutput(aliasBytes []byte, partSeparator []byte) map[ZshParamKey]string {
aliasParts := bytes.Split(aliasBytes, partSeparator)
rtn := make(map[ZshParamKey]string)
for aliasPartIdx := 0; aliasPartIdx < len(aliasParts)-1; aliasPartIdx += 2 {
aliasNameAndType := string(aliasParts[aliasPartIdx])
aliasNameAndTypeParts := strings.SplitN(aliasNameAndType, " ", 2)
if len(aliasNameAndTypeParts) != 2 {
continue
}
aliasKey := ZshParamKey{ParamType: aliasNameAndTypeParts[0], ParamName: aliasNameAndTypeParts[1]}
aliasValue := string(aliasParts[aliasPartIdx+1])
rtn[aliasKey] = aliasValue
}
return rtn
}
func isSourceFileInFpath(fpathArr []string, sourceFile string) bool {
for _, fpath := range fpathArr {
if fpath == "" || fpath == "." {
continue
}
firstChar := fpath[0]
if firstChar != '/' && firstChar != '~' {
continue
}
if strings.HasPrefix(sourceFile, fpath) {
return true
}
}
return false
}
func ParseZshFunctions(fpathArr []string, fnBytes []byte, partSeparator []byte) map[ZshParamKey]string {
fnBody := make(map[ZshParamKey]string)
fnSource := make(map[string]string)
fnParts := bytes.Split(fnBytes, partSeparator)
for fnPartIdx := 0; fnPartIdx < len(fnParts)-1; fnPartIdx += 2 {
fnTypeAndName := string(fnParts[fnPartIdx])
fnValue := string(fnParts[fnPartIdx+1])
fnTypeAndNameParts := strings.SplitN(fnTypeAndName, " ", 2)
if len(fnTypeAndNameParts) != 2 {
continue
}
fnType := fnTypeAndNameParts[0]
fnName := fnTypeAndNameParts[1]
if fnName == "zshexit" {
continue
}
if fnType == "functions" || fnType == "dis_functions" {
fnBody[ZshParamKey{ParamType: fnType, ParamName: fnName}] = fnValue
}
if fnType == "functions_source" || fnType == "dis_functions_source" {
fnSource[fnName] = fnValue
}
}
// ok, so the trick here is that we want to only include functions that are *not* autoloaded
// the ones that are pending autoloading or come from a source file in fpath, can just be set to autoload
for fnKey := range fnBody {
source := fnSource[fnKey.ParamName]
if isSourceFileInFpath(fpathArr, source) {
fnBody[fnKey] = ZshFnAutoLoad
}
}
return fnBody
}
func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string {
var buf bytes.Buffer
for fnKey, fnValue := range fnMap {
buf.WriteString(fmt.Sprintf("%s %s %s\x00", fnKey.ParamType, fnKey.ParamName, fnValue))
}
return buf.String()
}
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
if scbase.IsDevMode() && DebugState {
writeStateToFile(packet.ShellType_zsh, outputBytes)
}
firstZeroIdx := bytes.Index(outputBytes, []byte{0})
firstDZeroIdx := bytes.Index(outputBytes, []byte{0, 0})
if firstZeroIdx == -1 || firstDZeroIdx == -1 {
return nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes")
}
versionStr := string(outputBytes[0:firstZeroIdx])
sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2]
partSeparator := sectionSeparator[0 : len(sectionSeparator)-1]
// 8 fields: version [0], cwd [1], env [2], vars [3], aliases [4], fpath [5], functions [6], pvars [7]
fields := bytes.Split(outputBytes, sectionSeparator)
if len(fields) != 8 {
base.Logf("invalid -- numfields\n")
return nil, fmt.Errorf("invalid zsh shell state output, wrong number of fields, fields=%d", len(fields))
}
rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(versionStr)
if rtn.GetShellType() != packet.ShellType_zsh {
return nil, fmt.Errorf("invalid zsh shell state output, wrong shell type")
}
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
}
cwdStr := stripNewLineChars(string(fields[1]))
rtn.Cwd = cwdStr
zshEnv := parseZshEnv(fields[2])
zshDecls, err := parseZshDecls(fields[3])
if err != nil {
base.Logf("invalid - parsedecls %v\n", err)
return nil, err
}
for _, decl := range zshDecls {
if decl.IsZshScalarBound() {
decl.ZshEnvValue = zshEnv[decl.ZshBoundScalar]
}
}
aliasMap := parseZshAliasStateOutput(fields[4], partSeparator)
rtn.Aliases = string(EncodeZshMap(aliasMap))
fpathStr := stripNewLineChars(string(string(fields[5])))
fpathArr := strings.Split(fpathStr, ":")
zshFuncs := ParseZshFunctions(fpathArr, fields[6], partSeparator)
rtn.Funcs = string(EncodeZshMap(zshFuncs))
pvarMap := parsePVarOutput(fields[7], true)
utilfn.CombineMaps(zshDecls, pvarMap)
rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls)
base.Logf("parse shellstate done\n")
return rtn, nil
}
func parseZshEnv(output []byte) map[string]string {
outputStr := string(output)
lines := strings.Split(outputStr, "\x00")
rtn := make(map[string]string)
for _, line := range lines {
if line == "" {
continue
}
eqIdx := strings.Index(line, "=")
if eqIdx == -1 {
continue
}
name := line[0:eqIdx]
if ZshIgnoreVars[name] {
continue
}
val := line[eqIdx+1:]
rtn[name] = val
}
return rtn
}
func parseZshScalarBoundAssignment(declStr string, decl *DeclareDeclType) error {
declStr = strings.TrimLeft(declStr, " ")
spaceIdx := strings.Index(declStr, " ")
if spaceIdx == -1 {
return fmt.Errorf("invalid zsh decl (scalar bound): %q", declStr)
}
decl.ZshBoundScalar = declStr[0:spaceIdx]
standardDecl := declStr[spaceIdx+1:]
return parseStandardZshAssignment(standardDecl, decl)
}
func parseStandardZshAssignment(declStr string, decl *DeclareDeclType) error {
declStr = strings.TrimLeft(declStr, " ")
eqIdx := strings.Index(declStr, "=")
if eqIdx == -1 {
return fmt.Errorf("invalid zsh decl: %q", declStr)
}
decl.Name = declStr[0:eqIdx]
decl.Value = declStr[eqIdx+1:]
return nil
}
func parseZshDeclAssignment(declStr string, decl *DeclareDeclType) error {
if decl.IsZshScalarBound() {
return parseZshScalarBoundAssignment(declStr, decl)
}
return parseStandardZshAssignment(declStr, decl)
}
// returns (newDeclStr, argsStr, err)
func parseZshDeclArgs(declStr string, isExport bool) (string, string, error) {
origDeclStr := declStr
var argsStr string
if isExport {
argsStr = "x"
}
declStr = strings.TrimLeft(declStr, " ")
for strings.HasPrefix(declStr, "-") {
spaceIdx := strings.Index(declStr, " ")
if spaceIdx == -1 {
return "", "", fmt.Errorf("invalid zsh export line: %q", origDeclStr)
}
newArgsStr := strings.TrimSpace(declStr[1:spaceIdx])
argsStr = argsStr + newArgsStr
declStr = declStr[spaceIdx+1:]
declStr = strings.TrimLeft(declStr, " ")
}
return declStr, argsStr, nil
}
func stripNewLineChars(s string) string {
for {
if len(s) == 0 {
return s
}
lastChar := s[len(s)-1]
if lastChar == '\n' || lastChar == '\r' {
s = s[0 : len(s)-1]
} else {
return s
}
}
}
func parseZshDeclLine(line string) (*DeclareDeclType, error) {
line = stripNewLineChars(line)
if strings.HasPrefix(line, "export ") {
exportLine := line[7:]
assignLine, exportArgs, err := parseZshDeclArgs(exportLine, true)
rtn := &DeclareDeclType{IsZshDecl: true, Args: exportArgs}
err = parseZshDeclAssignment(assignLine, rtn)
if err != nil {
return nil, err
}
return rtn, nil
} else if strings.HasPrefix(line, "typeset ") {
typesetLine := line[8:]
assignLine, typesetArgs, err := parseZshDeclArgs(typesetLine, false)
rtn := &DeclareDeclType{IsZshDecl: true, Args: typesetArgs}
err = parseZshDeclAssignment(assignLine, rtn)
if err != nil {
return nil, err
}
return rtn, nil
} else {
return nil, fmt.Errorf("invalid zsh decl line: %q", line)
}
}
// combine decl2 INTO decl1
func combineTiedZshDecls(decl1 *DeclareDeclType, decl2 *DeclareDeclType) {
if decl2.IsExport() {
decl1.AddFlag("x")
}
if decl2.IsArray() {
decl1.AddFlag("a")
}
}
func parseZshDecls(output []byte) (map[string]*DeclareDeclType, error) {
// NOTES:
// - we get extra \r characters in the output (trimmed in parseZshDeclLine) (we get \r\n)
// - tied variables (-T) are printed twice! this is especially confusing for exported vars:
// (1) `export -T PATH path=( ... )`
// (2) `typeset -aT PATH path=( ... )`
// we have to "combine" these two lines into one decl.
outputStr := string(output)
lines := strings.Split(outputStr, "\n")
rtn := make(map[string]*DeclareDeclType)
for _, line := range lines {
if line == "" {
continue
}
decl, err := parseZshDeclLine(line)
if err != nil {
base.Logf("error parsing zsh decl line: %v", err)
continue
}
if decl == nil {
continue
}
if ZshIgnoreVars[decl.Name] {
continue
}
if rtn[decl.Name] != nil && decl.IsZshScalarBound() {
combineTiedZshDecls(rtn[decl.Name], decl)
continue
}
rtn[decl.Name] = decl
}
return rtn, nil
}
func makeZshMapDiff(oldMap string, newMap string) ([]byte, error) {
oldMapMap, err := DecodeZshMapForDiff([]byte(oldMap))
if err != nil {
return nil, fmt.Errorf("error zshMapDiff decoding old-zsh map: %v", err)
}
newMapMap, err := DecodeZshMapForDiff([]byte(newMap))
if err != nil {
return nil, fmt.Errorf("error zshMapDiff decoding new-zsh map: %v", err)
}
return statediff.MakeMapDiff(oldMapMap, newMapMap), nil
}
func applyZshMapDiff(oldMap string, diff []byte) (string, error) {
oldMapMap, err := DecodeZshMapForDiff([]byte(oldMap))
if err != nil {
return "", fmt.Errorf("error zshMapDiff decoding old-zsh map: %v", err)
}
newMapMap, err := statediff.ApplyMapDiff(oldMapMap, diff)
if err != nil {
return "", fmt.Errorf("error zshMapDiff applying diff: %v", err)
}
return EncodeZshMapForApply(newMapMap), nil
}
func (zshShellApi) MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error) {
if oldState == nil {
return nil, fmt.Errorf("cannot diff, oldState is nil")
}
if newState == nil {
return nil, fmt.Errorf("cannot diff, newState is nil")
}
if oldState.Version != newState.Version {
return nil, fmt.Errorf("cannot diff, states have different versions")
}
rtn := &packet.ShellStateDiff{}
rtn.BaseHash = oldStateHash
rtn.Version = newState.Version // always set version
if oldState.Cwd != newState.Cwd {
rtn.Cwd = newState.Cwd
}
rtn.Error = newState.Error
oldVars := shellenv.ShellStateVarsToMap(oldState.ShellVars)
newVars := shellenv.ShellStateVarsToMap(newState.ShellVars)
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
var err error
rtn.AliasesDiff, err = makeZshMapDiff(oldState.Aliases, newState.Aliases)
if err != nil {
return nil, err
}
rtn.FuncsDiff, err = makeZshMapDiff(oldState.Funcs, newState.Funcs)
if err != nil {
return nil, err
}
return rtn, nil
}
func (zshShellApi) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) {
if oldState == nil {
return nil, fmt.Errorf("cannot apply diff, oldState is nil")
}
if diff == nil {
return oldState, nil
}
rtnState := &packet.ShellState{}
var err error
rtnState.Version = oldState.Version
if diff.Version != rtnState.Version {
rtnState.Version = diff.Version
}
rtnState.Cwd = oldState.Cwd
if diff.Cwd != "" {
rtnState.Cwd = diff.Cwd
}
rtnState.Error = diff.Error
oldVars := shellenv.ShellStateVarsToMap(oldState.ShellVars)
newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff)
if err != nil {
return nil, fmt.Errorf("applying mapdiff 'vars': %v", err)
}
rtnState.ShellVars = shellenv.StrMapToShellStateVars(newVars)
rtnState.Aliases, err = applyZshMapDiff(oldState.Aliases, diff.AliasesDiff)
if err != nil {
return nil, fmt.Errorf("applying diff 'aliases': %v", err)
}
rtnState.Funcs, err = applyZshMapDiff(oldState.Funcs, diff.FuncsDiff)
if err != nil {
return nil, fmt.Errorf("applying diff 'funcs': %v", err)
}
return rtnState, nil
}

View File

@ -0,0 +1,29 @@
package shellapi
import (
"fmt"
"testing"
)
func testSingleDecl(declStr string) {
decl, err := parseZshDeclLine(declStr)
if err != nil {
fmt.Printf("error: %v\n", err)
}
fmt.Printf("decl %#v\n", decl)
}
func TestParseZshDecl(t *testing.T) {
declStr := `export -T PATH path=( /usr/local/bin /usr/bin /bin /usr/sbin /sbin )`
testSingleDecl(declStr)
declStr = `typeset -i10 SAVEHIST=1000`
testSingleDecl(declStr)
declStr = `typeset -a signals=( EXIT HUP INT QUIT ILL TRAP ABRT EMT FPE KILL BUS SEGV )`
testSingleDecl(declStr)
declStr = `typeset -aT RC rc=(80 25) 'x'`
testSingleDecl(declStr)
declStr = `typeset -g -A foo=( [bar]=baz [quux]=quuux )`
testSingleDecl(declStr)
declStr = `typeset -x -g -aT FOO foo=( 1 2 3 )`
testSingleDecl(declStr)
}

View File

@ -0,0 +1,362 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shellenv
import (
"bytes"
"fmt"
"strings"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
const (
DeclTypeArray = "array"
DeclTypeAssocArray = "assoc"
DeclTypeInt = "int"
DeclTypeNormal = "normal"
)
type DeclareDeclType struct {
IsZshDecl bool
Args string
Name string
// this holds the raw quoted value suitable for bash. this is *not* the real expanded variable value
Value string
// special fields for zsh "-T" output.
// for bound scalars, "Value" hold everything after the "=" (including the separator character)
ZshBoundScalar string // the name of the "scalar" env variable
ZshEnvValue string // unlike Value this *is* the expanded value of scalar env variable
}
func (d *DeclareDeclType) IsExport() bool {
return strings.Index(d.Args, "x") >= 0
}
func (d *DeclareDeclType) IsReadOnly() bool {
return strings.Index(d.Args, "r") >= 0
}
func (d *DeclareDeclType) IsZshScalarBound() bool {
return strings.Index(d.Args, "T") >= 0
}
func (d *DeclareDeclType) IsArray() bool {
return strings.Index(d.Args, "a") >= 0
}
func (d *DeclareDeclType) IsAssocArray() bool {
return strings.Index(d.Args, "A") >= 0
}
func (d *DeclareDeclType) IsUniqueArray() bool {
return d.IsArray() && strings.Index(d.Args, "U") >= 0
}
func (d *DeclareDeclType) AddFlag(flag string) {
if strings.Index(d.Args, flag) >= 0 {
return
}
d.Args += flag
}
func (d *DeclareDeclType) SortZshFlags() {
// x is always first (or g)
// T is always last
// the 'i' flags are tricky (they shouldn't be sorted, because the order matters, e.g. i10)
var hasX, hasG, hasT bool
var newArgs []rune
for _, r := range d.Args {
if r == 'x' {
hasX = true
continue
}
if r == 'g' {
hasG = true
continue
}
if r == 'T' {
hasT = true
continue
}
newArgs = append(newArgs, r)
}
newArgsStr := string(newArgs)
if hasG {
newArgsStr = "g" + newArgsStr
}
if hasX {
newArgsStr = "x" + newArgsStr
}
if hasT {
newArgsStr += "T"
}
d.Args = newArgsStr
}
func (d *DeclareDeclType) DataType() string {
if strings.Index(d.Args, "a") >= 0 {
return DeclTypeArray
}
if strings.Index(d.Args, "A") >= 0 {
return DeclTypeAssocArray
}
if strings.Index(d.Args, "i") >= 0 {
return DeclTypeInt
}
return DeclTypeNormal
}
func FindVarDecl(decls []*DeclareDeclType, name string) *DeclareDeclType {
for _, decl := range decls {
if decl.Name == name {
return decl
}
}
return nil
}
// NOTE Serialize no longer writes the final null byte
func (d *DeclareDeclType) Serialize() []byte {
if d.IsZshDecl {
d.SortZshFlags()
parts := []string{
"z1",
d.Args,
d.Name,
d.Value,
d.ZshBoundScalar,
d.ZshEnvValue,
}
return utilfn.EncodeStringArray(parts)
} else {
parts := []string{
"b1",
d.Args,
d.Name,
d.Value,
}
return utilfn.EncodeStringArray(parts)
}
// this is the v0 encoding (keeping here for reference since we still need to decode this)
// rtn := fmt.Sprintf("%s|%s=%s\x00", d.Args, d.Name, d.Value)
// return []byte(rtn)
}
func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool {
if d1.IsExport() != d2.IsExport() {
return false
}
if d1.DataType() != d2.DataType() {
return false
}
if compareName && d1.Name != d2.Name {
return false
}
return d1.Value == d2.Value // this works even for assoc arrays because we normalize them when parsing
}
// envline should be valid
func parseDeclLine(envLineBytes []byte) *DeclareDeclType {
if utilfn.EncodedStringArrayHasFirstKey(envLineBytes, "z1") {
parts, err := utilfn.DecodeStringArray(envLineBytes)
if err != nil {
return nil
}
if len(parts) != 6 {
return nil
}
return &DeclareDeclType{
IsZshDecl: true,
Args: parts[1],
Name: parts[2],
Value: parts[3],
ZshBoundScalar: parts[4],
ZshEnvValue: parts[5],
}
} else if utilfn.EncodedStringArrayHasFirstKey(envLineBytes, "b1") {
parts, err := utilfn.DecodeStringArray(envLineBytes)
if err != nil {
return nil
}
if len(parts) != 4 {
return nil
}
return &DeclareDeclType{
Args: parts[1],
Name: parts[2],
Value: parts[3],
}
}
// legacy decoding (v0)
envLine := string(envLineBytes)
eqIdx := strings.Index(envLine, "=")
if eqIdx == -1 {
return nil
}
namePart := envLine[0:eqIdx]
valPart := envLine[eqIdx+1:]
pipeIdx := strings.Index(namePart, "|")
if pipeIdx == -1 {
return nil
}
return &DeclareDeclType{
Args: namePart[0:pipeIdx],
Name: namePart[pipeIdx+1:],
Value: valPart,
}
}
// returns name => full-line
func parseDeclLineToKV(envLine []byte) (string, []byte) {
decl := parseDeclLine(envLine)
if decl == nil {
return "", nil
}
return decl.Name, envLine
}
func ShellStateVarsToMap(shellVars []byte) map[string][]byte {
if len(shellVars) == 0 {
return nil
}
rtn := make(map[string][]byte)
vars := bytes.Split(shellVars, []byte{0})
for _, varLine := range vars {
name, val := parseDeclLineToKV(varLine)
if name == "" {
continue
}
rtn[name] = val
}
return rtn
}
func StrMapToShellStateVars(varMap map[string][]byte) []byte {
var buf bytes.Buffer
orderedKeys := utilfn.GetOrderedMapKeys(varMap)
for _, key := range orderedKeys {
val := varMap[key]
buf.Write(val)
buf.WriteByte(0)
}
return buf.Bytes()
}
func DeclMapFromState(state *packet.ShellState) map[string]*DeclareDeclType {
if state == nil {
return nil
}
rtn := make(map[string]*DeclareDeclType)
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := parseDeclLine(varLine)
if decl != nil {
rtn[decl.Name] = decl
}
}
return rtn
}
func SerializeDeclMap(declMap map[string]*DeclareDeclType) []byte {
var rtn bytes.Buffer
orderedKeys := utilfn.GetOrderedMapKeys(declMap)
for _, key := range orderedKeys {
decl := declMap[key]
rtn.Write(decl.Serialize())
rtn.WriteByte(0)
}
return rtn.Bytes()
}
func EnvMapFromState(state *packet.ShellState) map[string]string {
if state == nil {
return nil
}
rtn := make(map[string]string)
ectx := simpleexpand.SimpleExpandContext{}
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := parseDeclLine(varLine)
if decl != nil && decl.IsExport() {
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
}
}
return rtn
}
func ShellVarMapFromState(state *packet.ShellState) map[string]string {
if state == nil {
return nil
}
rtn := make(map[string]string)
ectx := simpleexpand.SimpleExpandContext{}
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := parseDeclLine(varLine)
if decl != nil {
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
}
}
return rtn
}
func DumpVarMapFromState(state *packet.ShellState) {
fmt.Printf("DUMP-STATE-VARS:\n")
if state == nil {
fmt.Printf(" nil\n")
return
}
decls := VarDeclsFromState(state)
for _, decl := range decls {
fmt.Printf(" %s %#v\n", decl.Name, decl)
}
envMap := EnvMapFromState(state)
fmt.Printf("DUMP-STATE-ENV:\n")
for k, v := range envMap {
fmt.Printf(" %s=%s\n", k, v)
}
fmt.Printf("\n\n")
}
func VarDeclsFromState(state *packet.ShellState) []*DeclareDeclType {
if state == nil {
return nil
}
var rtn []*DeclareDeclType
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := parseDeclLine(varLine)
if decl != nil {
rtn = append(rtn, decl)
}
}
return rtn
}
func RemoveFunc(funcs string, toRemove string) string {
lines := strings.Split(funcs, "\n")
var newLines []string
removeLine := fmt.Sprintf("%s ()", toRemove)
doingRemove := false
for _, line := range lines {
if line == removeLine {
doingRemove = true
continue
}
if doingRemove {
if line == "}" {
doingRemove = false
}
continue
}
newLines = append(newLines, line)
}
return strings.Join(newLines, "\n")
}

View File

@ -0,0 +1,62 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shellutil
import (
"os"
"os/exec"
"strings"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
)
const DefaultTermType = "xterm-256color"
const DefaultTermRows = 24
const DefaultTermCols = 80
func MShellEnvVars(termType string) map[string]string {
rtn := make(map[string]string)
if termType != "" {
rtn["TERM"] = termType
}
rtn["WAVESHELL"], _ = os.Executable()
rtn["WAVESHELL_VERSION"] = base.MShellVersion
return rtn
}
func UpdateCmdEnv(cmd *exec.Cmd, envVars map[string]string) {
if len(envVars) == 0 {
return
}
found := make(map[string]bool)
var newEnv []string
for _, envStr := range cmd.Env {
envKey := GetEnvStrKey(envStr)
newEnvVal, ok := envVars[envKey]
if ok {
if newEnvVal == "" {
continue
}
newEnv = append(newEnv, envKey+"="+newEnvVal)
found[envKey] = true
} else {
newEnv = append(newEnv, envStr)
}
}
for envKey, envVal := range envVars {
if found[envKey] {
continue
}
newEnv = append(newEnv, envKey+"="+envVal)
}
cmd.Env = newEnv
}
func GetEnvStrKey(envStr string) string {
eqIdx := strings.Index(envStr, "=")
if eqIdx == -1 {
return envStr
}
return envStr[0:eqIdx]
}

View File

@ -1,641 +0,0 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shexec
import (
"bytes"
"fmt"
"io"
"regexp"
"sort"
"strings"
"github.com/alessio/shellescape"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/syntax"
)
const (
DeclTypeArray = "array"
DeclTypeAssocArray = "assoc"
DeclTypeInt = "int"
DeclTypeNormal = "normal"
)
type ParseEnviron struct {
Env map[string]string
}
func (e *ParseEnviron) Get(name string) expand.Variable {
val, ok := e.Env[name]
if !ok {
return expand.Variable{}
}
return expand.Variable{
Exported: true,
Kind: expand.String,
Str: val,
}
}
func (e *ParseEnviron) Each(fn func(name string, vr expand.Variable) bool) {
for key := range e.Env {
rtn := fn(key, e.Get(key))
if !rtn {
break
}
}
}
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
return nil
}
func doProcSubst(w *syntax.ProcSubst) (string, error) {
return "", nil
}
func GetParserConfig(envMap map[string]string) *expand.Config {
cfg := &expand.Config{
Env: &ParseEnviron{Env: envMap},
GlobStar: false,
NullGlob: false,
NoUnset: false,
CmdSubst: func(w io.Writer, word *syntax.CmdSubst) error { return doCmdSubst("", w, word) },
ProcSubst: doProcSubst,
ReadDir: nil,
}
return cfg
}
func writeIndent(buf *bytes.Buffer, num int) {
for i := 0; i < num; i++ {
buf.WriteByte(' ')
}
}
func makeSpaceStr(num int) string {
barr := make([]byte, num)
for i := 0; i < num; i++ {
barr[i] = ' '
}
return string(barr)
}
// https://wiki.bash-hackers.org/syntax/shellvars
var NoStoreVarNames = map[string]bool{
"BASH": true,
"BASHOPTS": true,
"BASHPID": true,
"BASH_ALIASES": true,
"BASH_ARGC": true,
"BASH_ARGV": true,
"BASH_ARGV0": true,
"BASH_CMDS": true,
"BASH_COMMAND": true,
"BASH_EXECUTION_STRING": true,
"LINENO": true,
"BASH_LINENO": true,
"BASH_REMATCH": true,
"BASH_SOURCE": true,
"BASH_SUBSHELL": true,
"COPROC": true,
"DIRSTACK": true,
"EPOCHREALTIME": true,
"EPOCHSECONDS": true,
"FUNCNAME": true,
"HISTCMD": true,
"OLDPWD": true,
"PIPESTATUS": true,
"PPID": true,
"PWD": true,
"RANDOM": true,
"SECONDS": true,
"SHLVL": true,
"HISTFILE": true,
"HISTFILESIZE": true,
"HISTCONTROL": true,
"HISTIGNORE": true,
"HISTSIZE": true,
"HISTTIMEFORMAT": true,
"SRANDOM": true,
"COLUMNS": true,
"LINES": true,
// we want these in our remote state object
// "EUID": true,
// "SHELLOPTS": true,
// "UID": true,
// "BASH_VERSINFO": true,
// "BASH_VERSION": true,
}
type DeclareDeclType struct {
Args string
Name string
// this holds the raw quoted value suitable for bash. this is *not* the real expanded variable value
Value string
}
var declareDeclArgsRe = regexp.MustCompile("^[aAxrifx]*$")
var bashValidIdentifierRe = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
func (d *DeclareDeclType) Validate() error {
if len(d.Name) == 0 || !IsValidBashIdentifier(d.Name) {
return fmt.Errorf("invalid shell variable name (invalid bash identifier)")
}
if strings.Index(d.Value, "\x00") >= 0 {
return fmt.Errorf("invalid shell variable value (cannot contain 0 byte)")
}
if !declareDeclArgsRe.MatchString(d.Args) {
return fmt.Errorf("invalid shell variable type %s", shellescape.Quote(d.Args))
}
return nil
}
func (d *DeclareDeclType) Serialize() string {
return fmt.Sprintf("%s|%s=%s\x00", d.Args, d.Name, d.Value)
}
func (d *DeclareDeclType) DeclareStmt() string {
var argsStr string
if d.Args == "" {
argsStr = "--"
} else {
argsStr = "-" + d.Args
}
return fmt.Sprintf("declare %s %s=%s", argsStr, d.Name, d.Value)
}
// envline should be valid
func ParseDeclLine(envLine string) *DeclareDeclType {
eqIdx := strings.Index(envLine, "=")
if eqIdx == -1 {
return nil
}
namePart := envLine[0:eqIdx]
valPart := envLine[eqIdx+1:]
pipeIdx := strings.Index(namePart, "|")
if pipeIdx == -1 {
return nil
}
return &DeclareDeclType{
Args: namePart[0:pipeIdx],
Name: namePart[pipeIdx+1:],
Value: valPart,
}
}
// returns name => full-line
func parseDeclLineToKV(envLine string) (string, string) {
decl := ParseDeclLine(envLine)
if decl == nil {
return "", ""
}
return decl.Name, envLine
}
func shellStateVarsToMap(shellVars []byte) map[string]string {
if len(shellVars) == 0 {
return nil
}
rtn := make(map[string]string)
vars := bytes.Split(shellVars, []byte{0})
for _, varLine := range vars {
name, val := parseDeclLineToKV(string(varLine))
if name == "" {
continue
}
rtn[name] = val
}
return rtn
}
func strMapToShellStateVars(varMap map[string]string) []byte {
var buf bytes.Buffer
orderedKeys := getOrderedKeysStrMap(varMap)
for _, key := range orderedKeys {
val := varMap[key]
buf.WriteString(val)
buf.WriteByte(0)
}
return buf.Bytes()
}
func getOrderedKeysStrMap(m map[string]string) []string {
keys := make([]string, 0, len(m))
for key := range m {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}
func getOrderedKeysDeclMap(m map[string]*DeclareDeclType) []string {
keys := make([]string, 0, len(m))
for key := range m {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}
func DeclMapFromState(state *packet.ShellState) map[string]*DeclareDeclType {
if state == nil {
return nil
}
rtn := make(map[string]*DeclareDeclType)
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := ParseDeclLine(string(varLine))
if decl != nil {
rtn[decl.Name] = decl
}
}
return rtn
}
func SerializeDeclMap(declMap map[string]*DeclareDeclType) []byte {
var rtn bytes.Buffer
orderedKeys := getOrderedKeysDeclMap(declMap)
for _, key := range orderedKeys {
decl := declMap[key]
rtn.WriteString(decl.Serialize())
}
return rtn.Bytes()
}
func EnvMapFromState(state *packet.ShellState) map[string]string {
if state == nil {
return nil
}
rtn := make(map[string]string)
ectx := simpleexpand.SimpleExpandContext{}
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := ParseDeclLine(string(varLine))
if decl != nil && decl.IsExport() {
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
}
}
return rtn
}
func ShellVarMapFromState(state *packet.ShellState) map[string]string {
if state == nil {
return nil
}
rtn := make(map[string]string)
ectx := simpleexpand.SimpleExpandContext{}
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := ParseDeclLine(string(varLine))
if decl != nil {
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
}
}
return rtn
}
func DumpVarMapFromState(state *packet.ShellState) {
fmt.Printf("DUMP-STATE-VARS:\n")
if state == nil {
fmt.Printf(" nil\n")
return
}
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
fmt.Printf(" %s\n", varLine)
}
}
func VarDeclsFromState(state *packet.ShellState) []*DeclareDeclType {
if state == nil {
return nil
}
var rtn []*DeclareDeclType
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := ParseDeclLine(string(varLine))
if decl != nil {
rtn = append(rtn, decl)
}
}
return rtn
}
func IsValidBashIdentifier(s string) bool {
return bashValidIdentifierRe.MatchString(s)
}
func (d *DeclareDeclType) IsExport() bool {
return strings.Index(d.Args, "x") >= 0
}
func (d *DeclareDeclType) IsReadOnly() bool {
return strings.Index(d.Args, "r") >= 0
}
func (d *DeclareDeclType) DataType() string {
if strings.Index(d.Args, "a") >= 0 {
return DeclTypeArray
}
if strings.Index(d.Args, "A") >= 0 {
return DeclTypeAssocArray
}
if strings.Index(d.Args, "i") >= 0 {
return DeclTypeInt
}
return DeclTypeNormal
}
func parseDeclareStmt(stmt *syntax.Stmt, src string) (*DeclareDeclType, error) {
cmd := stmt.Cmd
decl, ok := cmd.(*syntax.DeclClause)
if !ok || decl.Variant.Value != "declare" || len(decl.Args) != 2 {
return nil, fmt.Errorf("invalid declare variant")
}
rtn := &DeclareDeclType{}
declArgs := decl.Args[0]
if !declArgs.Naked || len(declArgs.Value.Parts) != 1 {
return nil, fmt.Errorf("wrong number of declare args parts")
}
declArgsLit, ok := declArgs.Value.Parts[0].(*syntax.Lit)
if !ok {
return nil, fmt.Errorf("declare args is not a literal")
}
if !strings.HasPrefix(declArgsLit.Value, "-") {
return nil, fmt.Errorf("declare args not an argument (does not start with '-')")
}
if declArgsLit.Value == "--" {
rtn.Args = ""
} else {
rtn.Args = declArgsLit.Value[1:]
}
declAssign := decl.Args[1]
if declAssign.Name == nil {
return nil, fmt.Errorf("declare does not have a valid name")
}
rtn.Name = declAssign.Name.Value
if declAssign.Naked || declAssign.Index != nil || declAssign.Append {
return nil, fmt.Errorf("invalid decl format")
}
if declAssign.Value != nil {
rtn.Value = string(src[declAssign.Value.Pos().Offset():declAssign.Value.End().Offset()])
} else if declAssign.Array != nil {
rtn.Value = string(src[declAssign.Array.Pos().Offset():declAssign.Array.End().Offset()])
} else {
return nil, fmt.Errorf("invalid decl, not plain value or array")
}
err := rtn.normalize()
if err != nil {
return nil, err
}
if err = rtn.Validate(); err != nil {
return nil, err
}
return rtn, nil
}
func parseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarBytes []byte) error {
declareStr := string(declareBytes)
r := bytes.NewReader(declareBytes)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(r, "aliases")
if err != nil {
return err
}
var firstParseErr error
declMap := make(map[string]*DeclareDeclType)
for _, stmt := range file.Stmts {
decl, err := parseDeclareStmt(stmt, declareStr)
if err != nil {
if firstParseErr == nil {
firstParseErr = err
}
}
if decl != nil && !NoStoreVarNames[decl.Name] {
declMap[decl.Name] = decl
}
}
pvars := bytes.Split(pvarBytes, []byte{0})
for _, pvarBA := range pvars {
pvarStr := string(pvarBA)
pvarFields := strings.SplitN(pvarStr, " ", 2)
if len(pvarFields) != 2 {
continue
}
if pvarFields[0] == "" {
continue
}
decl := &DeclareDeclType{Args: "x"}
decl.Name = "PROMPTVAR_" + pvarFields[0]
decl.Value = shellescape.Quote(pvarFields[1])
declMap[decl.Name] = decl
}
state.ShellVars = SerializeDeclMap(declMap) // this writes out the decls in a canonical order
if firstParseErr != nil {
state.Error = firstParseErr.Error()
}
return nil
}
func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
// 5 fields: version, cwd, env/vars, aliases, funcs
fields := bytes.Split(outputBytes, []byte{0, 0})
if len(fields) != 6 {
return nil, fmt.Errorf("invalid shell state output, wrong number of fields, fields=%d", len(fields))
}
rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(string(fields[0]))
if strings.Index(rtn.Version, "bash") == -1 {
return nil, fmt.Errorf("invalid shell state output, only bash is supported")
}
cwdStr := string(fields[1])
if strings.HasSuffix(cwdStr, "\r\n") {
cwdStr = cwdStr[0 : len(cwdStr)-2]
} else if strings.HasSuffix(cwdStr, "\n") {
cwdStr = cwdStr[0 : len(cwdStr)-1]
}
rtn.Cwd = string(cwdStr)
err := parseDeclareOutput(rtn, fields[2], fields[5])
if err != nil {
return nil, err
}
rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
rtn.Funcs = removeFunc(rtn.Funcs, "_mshell_exittrap")
return rtn, nil
}
func removeFunc(funcs string, toRemove string) string {
lines := strings.Split(funcs, "\n")
var newLines []string
removeLine := fmt.Sprintf("%s ()", toRemove)
doingRemove := false
for _, line := range lines {
if line == removeLine {
doingRemove = true
continue
}
if doingRemove {
if line == "}" {
doingRemove = false
}
continue
}
newLines = append(newLines, line)
}
return strings.Join(newLines, "\n")
}
func (d *DeclareDeclType) normalize() error {
if d.DataType() == DeclTypeAssocArray {
return d.normalizeAssocArrayDecl()
}
return nil
}
// normalizes order of assoc array keys so value is stable
func (d *DeclareDeclType) normalizeAssocArrayDecl() error {
if d.DataType() != DeclTypeAssocArray {
return fmt.Errorf("invalid decltype passed to assocArrayDeclToStr: %s", d.DataType())
}
varMap, err := assocArrayVarToMap(d)
if err != nil {
return err
}
keys := make([]string, 0, len(varMap))
for key := range varMap {
keys = append(keys, key)
}
sort.Strings(keys)
var buf bytes.Buffer
buf.WriteByte('(')
for _, key := range keys {
buf.WriteByte('[')
buf.WriteString(key)
buf.WriteByte(']')
buf.WriteByte('=')
buf.WriteString(varMap[key])
buf.WriteByte(' ')
}
buf.WriteByte(')')
d.Value = buf.String()
return nil
}
func assocArrayVarToMap(d *DeclareDeclType) (map[string]string, error) {
if d.DataType() != DeclTypeAssocArray {
return nil, fmt.Errorf("decl is not an assoc-array")
}
refStr := "X=" + d.Value
r := strings.NewReader(refStr)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(r, "assocdecl")
if err != nil {
return nil, err
}
if len(file.Stmts) != 1 {
return nil, fmt.Errorf("invalid assoc-array parse (multiple stmts)")
}
stmt := file.Stmts[0]
callExpr, ok := stmt.Cmd.(*syntax.CallExpr)
if !ok || len(callExpr.Args) != 0 || len(callExpr.Assigns) != 1 {
return nil, fmt.Errorf("invalid assoc-array parse (bad expr)")
}
assign := callExpr.Assigns[0]
arrayExpr := assign.Array
if arrayExpr == nil {
return nil, fmt.Errorf("invalid assoc-array parse (no array expr)")
}
rtn := make(map[string]string)
for _, elem := range arrayExpr.Elems {
indexStr := refStr[elem.Index.Pos().Offset():elem.Index.End().Offset()]
valStr := refStr[elem.Value.Pos().Offset():elem.Value.End().Offset()]
rtn[indexStr] = valStr
}
return rtn, nil
}
func strMapsEqual(m1 map[string]string, m2 map[string]string) bool {
if len(m1) != len(m2) {
return false
}
for key, val1 := range m1 {
val2, found := m2[key]
if !found || val1 != val2 {
return false
}
}
for key := range m2 {
_, found := m1[key]
if !found {
return false
}
}
return true
}
func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool {
if d1.IsExport() != d2.IsExport() {
return false
}
if d1.DataType() != d2.DataType() {
return false
}
if compareName && d1.Name != d2.Name {
return false
}
return d1.Value == d2.Value // this works even for assoc arrays because we normalize them when parsing
}
func MakeShellStateDiff(oldState packet.ShellState, oldStateHash string, newState packet.ShellState) (packet.ShellStateDiff, error) {
var rtn packet.ShellStateDiff
rtn.BaseHash = oldStateHash
if oldState.Version != newState.Version {
return rtn, fmt.Errorf("cannot diff, states have different versions")
}
rtn.Version = newState.Version
if oldState.Cwd != newState.Cwd {
rtn.Cwd = newState.Cwd
}
rtn.Error = newState.Error
oldVars := shellStateVarsToMap(oldState.ShellVars)
newVars := shellStateVarsToMap(newState.ShellVars)
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
rtn.AliasesDiff = statediff.MakeLineDiff(oldState.Aliases, newState.Aliases)
rtn.FuncsDiff = statediff.MakeLineDiff(oldState.Funcs, newState.Funcs)
return rtn, nil
}
func ApplyShellStateDiff(oldState packet.ShellState, diff packet.ShellStateDiff) (packet.ShellState, error) {
var rtnState packet.ShellState
var err error
rtnState.Version = oldState.Version
rtnState.Cwd = oldState.Cwd
if diff.Cwd != "" {
rtnState.Cwd = diff.Cwd
}
rtnState.Error = diff.Error
oldVars := shellStateVarsToMap(oldState.ShellVars)
newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff)
if err != nil {
return rtnState, fmt.Errorf("applying mapdiff 'vars': %v", err)
}
rtnState.ShellVars = strMapToShellStateVars(newVars)
rtnState.Aliases, err = statediff.ApplyLineDiff(oldState.Aliases, diff.AliasesDiff)
if err != nil {
return rtnState, fmt.Errorf("applying diff 'aliases': %v", err)
}
rtnState.Funcs, err = statediff.ApplyLineDiff(oldState.Funcs, diff.FuncsDiff)
if err != nil {
return rtnState, fmt.Errorf("applying diff 'funcs': %v", err)
}
return rtnState, nil
}

View File

@ -14,7 +14,6 @@ import (
"os/signal"
"os/user"
"path"
"regexp"
"runtime"
"strconv"
"strings"
@ -29,19 +28,19 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/cirfile"
"github.com/wavetermdev/waveterm/waveshell/pkg/mpio"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"golang.org/x/mod/semver"
"golang.org/x/sys/unix"
)
const DefaultTermRows = 24
const DefaultTermCols = 80
const MinTermRows = 2
const MinTermCols = 10
const MaxTermRows = 1024
const MaxTermCols = 1024
const MaxFdNum = 1023
const FirstExtraFilesFdNum = 3
const DefaultTermType = "xterm-256color"
const DefaultMaxPtySize = 1024 * 1024
const MinMaxPtySize = 16 * 1024
const MaxMaxPtySize = 100 * 1024 * 1024
@ -52,27 +51,6 @@ const SigKillWaitTime = 2 * time.Second
const RtnStateFdNum = 20
const ReturnStateReadWaitTime = 2 * time.Second
const GetStateTimeout = 5 * time.Second
const RemoteBashPath = "bash"
const DefaultMacOSShell = "/bin/bash"
const BaseBashOpts = `set +m; set +H; shopt -s extglob`
const ShellVersionCmdStr = `echo bash v${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}`
// do not use these directly, call GetLocalBashMajorVersion()
var LocalBashMajorVersionOnce = &sync.Once{}
var LocalBashMajorVersion = ""
var GetShellStateCmds = []string{
ShellVersionCmdStr + ";",
`pwd;`,
`declare -p $(compgen -A variable);`,
`alias -p;`,
`declare -f;`,
`printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`,
}
const ClientCommandFmt = `
PATH=$PATH:~/.mshell;
which mshell > /dev/null;
@ -104,11 +82,6 @@ func MakeInstallCommandStr() string {
return strings.ReplaceAll(InstallCommandFmt, "[%VERSION%]", semver.MajorMinor(base.MShellVersion))
}
// TODO fix bash path in these constants
const RunCommandFmt = `%s`
const RunSudoCommandFmt = `sudo -n -C %d bash /dev/fd/%d`
const RunSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d bash -c "echo '[from-mshell]'; exec %d>&-; bash /dev/fd/%d < /dev/fd/%d"`
type MShellBinaryReaderFn func(version string, goos string, goarch string) (io.ReadCloser, error)
type ReturnStateBuf struct {
@ -140,7 +113,8 @@ type ShExecType struct {
MsgSender *packet.PacketSender // where to send out-of-band messages back to calling proceess
ReturnState *ReturnStateBuf
Exited bool // locked via Lock
TmpRcFileName string
TmpRcFileName string // file *or* directory holding temporary rc file(s)
SAPI shellapi.ShellApi
}
type StdContext struct{}
@ -183,20 +157,6 @@ type ShExecUPR struct {
UPR packet.UnknownPacketReporter
}
func GetLocalBashPath() string {
if runtime.GOOS == "darwin" {
macShell := GetMacUserShell()
if strings.Index(macShell, "bash") != -1 {
return shellescape.Quote(macShell)
}
}
return "bash"
}
func GetShellStateCmd() string {
return strings.Join(GetShellStateCmds, ` printf "\x00\x00";`)
}
func (s *ShExecType) processSpecialInputPacket(pk *packet.SpecialInputPacketType) error {
base.Logf("processSpecialInputPacket: %#v\n", pk)
if pk.WinSize != nil {
@ -242,12 +202,13 @@ func (s ShExecUPR) UnknownPacket(pk packet.PacketType) {
}
}
func MakeShExec(ck base.CommandKey, upr packet.UnknownPacketReporter) *ShExecType {
func MakeShExec(ck base.CommandKey, upr packet.UnknownPacketReporter, sapi shellapi.ShellApi) *ShExecType {
return &ShExecType{
Lock: &sync.Mutex{},
StartTs: time.Now(),
CK: ck,
Multiplexer: mpio.MakeMultiplexer(ck, upr),
SAPI: sapi,
}
}
@ -267,7 +228,8 @@ func (c *ShExecType) Close() {
c.ReturnState.Reader.Close()
}
if c.TmpRcFileName != "" {
os.Remove(c.TmpRcFileName)
// TmpRcFileName can be a file or a directory
os.RemoveAll(c.TmpRcFileName)
}
}
@ -280,42 +242,6 @@ func (c *ShExecType) MakeCmdStartPacket(reqId string) *packet.CmdStartPacketType
return startPacket
}
func getEnvStrKey(envStr string) string {
eqIdx := strings.Index(envStr, "=")
if eqIdx == -1 {
return envStr
}
return envStr[0:eqIdx]
}
func UpdateCmdEnv(cmd *exec.Cmd, envVars map[string]string) {
if len(envVars) == 0 {
return
}
found := make(map[string]bool)
var newEnv []string
for _, envStr := range cmd.Env {
envKey := getEnvStrKey(envStr)
newEnvVal, ok := envVars[envKey]
if ok {
if newEnvVal == "" {
continue
}
newEnv = append(newEnv, envKey+"="+newEnvVal)
found[envKey] = true
} else {
newEnv = append(newEnv, envStr)
}
}
for envKey, envVal := range envVars {
if found[envKey] {
continue
}
newEnv = append(newEnv, envKey+"="+envVal)
}
cmd.Env = newEnv
}
// returns (pr, err)
func MakeSimpleStaticWriterPipe(data []byte) (*os.File, error) {
pr, pw, err := os.Pipe()
@ -329,17 +255,30 @@ func MakeSimpleStaticWriterPipe(data []byte) (*os.File, error) {
return pr, err
}
func MakeRunnerExec(ck base.CommandKey) (*exec.Cmd, error) {
msPath, err := base.GetMShellPath()
if err != nil {
return nil, err
}
ecmd := exec.Command(msPath, string(ck))
return ecmd, nil
}
func MakeDetachedExecCmd(pk *packet.RunPacketType, cmdTty *os.File) (*exec.Cmd, error) {
sapi, err := shellapi.MakeShellApi(pk.ShellType)
if err != nil {
return nil, err
}
state := pk.State
if state == nil {
state = &packet.ShellState{}
}
ecmd := exec.Command(GetLocalBashPath(), "-c", pk.Command)
ecmd := exec.Command(sapi.GetLocalShellPath(), "-c", pk.Command)
if !pk.StateComplete {
ecmd.Env = os.Environ()
}
UpdateCmdEnv(ecmd, EnvMapFromState(state))
UpdateCmdEnv(ecmd, MShellEnvVars(getTermType(pk)))
shellutil.UpdateCmdEnv(ecmd, shellenv.EnvMapFromState(state))
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(getTermType(pk)))
if state.Cwd != "" {
ecmd.Dir = base.ExpandHomeDir(state.Cwd)
}
@ -373,15 +312,6 @@ func MakeDetachedExecCmd(pk *packet.RunPacketType, cmdTty *os.File) (*exec.Cmd,
return ecmd, nil
}
func MakeRunnerExec(ck base.CommandKey) (*exec.Cmd, error) {
msPath, err := base.GetMShellPath()
if err != nil {
return nil, err
}
ecmd := exec.Command(msPath, string(ck))
return ecmd, nil
}
// this will never return (unless there is an error creating/opening the file), as fifoFile will never EOF
func MakeAndCopyStdinFifo(dst *os.File, fifoName string) error {
os.Remove(fifoName)
@ -457,8 +387,8 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
}
func GetWinsize(p *packet.RunPacketType) *pty.Winsize {
rows := DefaultTermRows
cols := DefaultTermCols
rows := shellutil.DefaultTermRows
cols := shellutil.DefaultTermCols
if p.TermOpts != nil {
rows = base.BoundInt(p.TermOpts.Rows, MinTermRows, MaxTermRows)
cols = base.BoundInt(p.TermOpts.Cols, MinTermCols, MaxTermCols)
@ -497,49 +427,23 @@ type ClientOpts struct {
UsePty bool
}
func (opts SSHOpts) MakeSSHInstallCmd() (*exec.Cmd, error) {
if opts.SSHHost == "" {
return nil, fmt.Errorf("no ssh host provided, can only install to a remote host")
}
cmdStr := MakeInstallCommandStr()
return opts.MakeSSHExecCmd(cmdStr), nil
}
func (opts SSHOpts) MakeMShellServerCmd() (*exec.Cmd, error) {
msPath, err := base.GetMShellPath()
if err != nil {
return nil, err
}
ecmd := exec.Command(msPath, "--server")
return ecmd, nil
}
func (opts SSHOpts) MakeMShellSingleCmd(fromServer bool) (*exec.Cmd, error) {
if opts.SSHHost == "" {
func MakeMShellSingleCmd() (*exec.Cmd, error) {
execFile, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("cannot find local mshell executable: %w", err)
}
var ecmd *exec.Cmd
if fromServer {
ecmd = exec.Command(execFile, "--single-from-server")
} else {
ecmd = exec.Command(execFile, "--single")
}
ecmd := exec.Command(execFile, "--single-from-server")
return ecmd, nil
}
cmdStr := MakeClientCommandStr()
return opts.MakeSSHExecCmd(cmdStr), nil
}
func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string) *exec.Cmd {
func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string, sapi shellapi.ShellApi) *exec.Cmd {
remoteCommand = strings.TrimSpace(remoteCommand)
if opts.SSHHost == "" {
homeDir, _ := os.UserHomeDir() // ignore error
if homeDir == "" {
homeDir = "/"
}
ecmd := exec.Command(GetLocalBashPath(), "-c", remoteCommand)
ecmd := exec.Command(sapi.GetLocalShellPath(), "-c", remoteCommand)
ecmd.Dir = homeDir
return ecmd
} else {
@ -566,7 +470,7 @@ func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string) *exec.Cmd {
}
// note that SSHOptsStr is *not* escaped
sshCmd := fmt.Sprintf("ssh %s %s %s %s", strings.Join(moreSSHOpts, " "), opts.SSHOptsStr, shellescape.Quote(opts.SSHHost), shellescape.Quote(remoteCommand))
ecmd := exec.Command(RemoteBashPath, "-c", sshCmd)
ecmd := exec.Command(sapi.GetRemoteShellPath(), "-c", sshCmd)
return ecmd
}
}
@ -605,59 +509,6 @@ func GetTerminalSize() (int, int, error) {
return pty.Getsize(fd)
}
func (opts *ClientOpts) MakeRunPacket() (*packet.RunPacketType, error) {
runPacket := packet.MakeRunPacket()
runPacket.Detached = opts.Detach
runPacket.State = &packet.ShellState{}
runPacket.State.Cwd = opts.Cwd
runPacket.Fds = opts.Fds
if opts.UsePty {
runPacket.UsePty = true
runPacket.TermOpts = &packet.TermOpts{}
rows, cols, err := GetTerminalSize()
if err == nil {
runPacket.TermOpts.Rows = rows
runPacket.TermOpts.Cols = cols
}
term := os.Getenv("TERM")
if term != "" {
runPacket.TermOpts.Term = term
}
}
if !opts.Sudo {
// normal, non-sudo command
runPacket.Command = fmt.Sprintf(RunCommandFmt, opts.Command)
return runPacket, nil
}
if opts.SudoWithPass {
pwFdNum, err := AddRunData(runPacket, opts.SudoPw, "sudo pw")
if err != nil {
return nil, err
}
commandFdNum, err := AddRunData(runPacket, opts.Command, "command")
if err != nil {
return nil, err
}
commandStdinFdNum, err := NextFreeFdNum(runPacket)
if err != nil {
return nil, err
}
commandStdinRfd := packet.RemoteFd{FdNum: commandStdinFdNum, Read: true, DupStdin: true}
runPacket.Fds = append(runPacket.Fds, commandStdinRfd)
maxFdNum := MaxFdNumInPacket(runPacket)
runPacket.Command = fmt.Sprintf(RunSudoPasswordCommandFmt, pwFdNum, maxFdNum+1, pwFdNum, commandFdNum, commandStdinFdNum)
return runPacket, nil
} else {
commandFdNum, err := AddRunData(runPacket, opts.Command, "command")
if err != nil {
return nil, err
}
maxFdNum := MaxFdNumInPacket(runPacket)
runPacket.Command = fmt.Sprintf(RunSudoCommandFmt, maxFdNum+1, commandFdNum)
return runPacket, nil
}
}
func AddRunData(pk *packet.RunPacketType, data string, dataType string) (int, error) {
if len(data) > MaxRunDataSize {
return 0, fmt.Errorf("%s too large, exceeds read buffer size size:%d", dataType, len(data))
@ -810,31 +661,6 @@ func RunInstallFromCmd(ctx context.Context, ecmd *exec.Cmd, tryDetect bool, mshe
}
}
func RunInstallFromOpts(opts *InstallOpts) error {
ecmd, err := opts.SSHOpts.MakeSSHInstallCmd()
if err != nil {
return err
}
msgFn := func(str string) {
fmt.Printf("%s", str)
}
var mshellStream *os.File
if opts.OptName != "" {
mshellStream, err = os.Open(opts.OptName)
if err != nil {
return fmt.Errorf("cannot open mshell binary %q: %v", opts.OptName, err)
}
defer mshellStream.Close()
}
err = RunInstallFromCmd(context.Background(), ecmd, opts.Detect, mshellStream, base.MShellBinaryFromOptDir, msgFn)
if err != nil {
return err
}
mmVersion := semver.MajorMinor(base.MShellVersion)
fmt.Printf("mshell installed successfully at %s:~/.mshell/mshell%s\n", opts.SSHOpts.SSHHost, mmVersion)
return nil
}
func HasDupStdin(fds []packet.RemoteFd) bool {
for _, rfd := range fds {
if rfd.Read && rfd.DupStdin {
@ -844,101 +670,6 @@ func HasDupStdin(fds []packet.RemoteFd) bool {
return false
}
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, upr packet.UnknownPacketReporter, debug bool) (*packet.CmdDonePacketType, error) {
cmd := MakeShExec(runPacket.CK, upr)
ecmd, err := sshOpts.MakeMShellSingleCmd(false)
if err != nil {
return nil, err
}
cmd.Cmd = ecmd
inputWriter, err := ecmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("creating stdin pipe: %v", err)
}
stdoutReader, err := ecmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := ecmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("creating stderr pipe: %v", err)
}
if !HasDupStdin(runPacket.Fds) {
cmd.Multiplexer.MakeRawFdReader(0, fdContext.GetReader(0), false, false)
}
cmd.Multiplexer.MakeRawFdWriter(1, fdContext.GetWriter(1), false, "client")
cmd.Multiplexer.MakeRawFdWriter(2, fdContext.GetWriter(2), false, "client")
for _, rfd := range runPacket.Fds {
if rfd.Read && rfd.DupStdin {
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fdContext.GetReader(0), false, false)
continue
}
if rfd.Read {
fd := fdContext.GetReader(rfd.FdNum)
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, false, false)
} else if rfd.Write {
fd := fdContext.GetWriter(rfd.FdNum)
cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true, "client")
}
}
err = ecmd.Start()
if err != nil {
return nil, fmt.Errorf("running ssh command: %w", err)
}
defer cmd.Close()
stdoutPacketParser := packet.MakePacketParser(stdoutReader, nil)
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, false)
sender := packet.MakePacketSender(inputWriter, nil)
versionOk := false
for pk := range packetParser.MainCh {
if pk.GetType() == packet.RawPacketStr {
rawPk := pk.(*packet.RawPacketType)
fmt.Printf("%s\n", rawPk.Data)
continue
}
if pk.GetType() == packet.InitPacketStr {
initPk := pk.(*packet.InitPacketType)
mmVersion := semver.MajorMinor(base.MShellVersion)
if initPk.NotFound {
if sshOpts.SSHHost == "" {
return nil, fmt.Errorf("mshell-%s command not found on local server", mmVersion)
}
if initPk.UName == "" {
return nil, fmt.Errorf("mshell-%s command not found on remote server, no uname detected", mmVersion)
}
goos, goarch, err := DetectGoArch(initPk.UName)
if err != nil {
return nil, fmt.Errorf("mshell-%s command not found on remote server, architecture cannot be detected (might be incompatible with mshell): %w", mmVersion, err)
}
sshOptsStr := sshOpts.MakeMShellSSHOpts()
return nil, fmt.Errorf("mshell-%s command not found on remote server, can install with 'mshell --install %s %s.%s'", mmVersion, sshOptsStr, goos, goarch)
}
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
return nil, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion))
}
versionOk = true
if debug {
fmt.Printf("VERSION> %s\n", initPk.Version)
}
break
}
}
if !versionOk {
return nil, fmt.Errorf("did not receive version from remote mshell")
}
SendRunPacketAndRunData(context.Background(), sender, runPacket)
if debug {
cmd.Multiplexer.Debug = true
}
remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetParser, sender, false, true, true)
donePacket := cmd.WaitForCommand()
if remoteDonePacket != nil {
donePacket = remoteDonePacket
}
return donePacket, nil
}
func min(v1 int, v2 int) int {
if v1 <= v2 {
return v1
@ -1015,46 +746,13 @@ func (cmd *ShExecType) RunRemoteIOAndWait(packetParser *packet.PacketParser, sen
}
func getTermType(pk *packet.RunPacketType) string {
termType := DefaultTermType
termType := shellutil.DefaultTermType
if pk.TermOpts != nil && pk.TermOpts.Term != "" {
termType = pk.TermOpts.Term
}
return termType
}
func makeRcFileStr(pk *packet.RunPacketType) string {
var rcBuf bytes.Buffer
rcBuf.WriteString(BaseBashOpts + "\n")
varDecls := VarDeclsFromState(pk.State)
for _, varDecl := range varDecls {
if varDecl.IsExport() || varDecl.IsReadOnly() {
continue
}
rcBuf.WriteString(varDecl.DeclareStmt())
rcBuf.WriteString("\n")
}
if pk.State != nil && pk.State.Funcs != "" {
rcBuf.WriteString(pk.State.Funcs)
rcBuf.WriteString("\n")
}
if pk.State != nil && pk.State.Aliases != "" {
rcBuf.WriteString(pk.State.Aliases)
rcBuf.WriteString("\n")
}
return rcBuf.String()
}
func makeExitTrap(fdNum int) string {
stateCmd := GetShellStateRedirectCommandStr(fdNum)
fmtStr := `
_mshell_exittrap () {
%s
}
trap _mshell_exittrap EXIT
`
return fmt.Sprintf(fmtStr, stateCmd)
}
func (s *ShExecType) SendSignal(sig syscall.Signal) {
base.Logf("signal start %v\n", sig)
if sig == syscall.SIGKILL {
@ -1086,11 +784,15 @@ func (s *ShExecType) SendSignal(sig syscall.Signal) {
}
func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fromServer bool) (rtnShExec *ShExecType, rtnErr error) {
sapi, err := shellapi.MakeShellApi(pk.ShellType)
if err != nil {
return nil, err
}
state := pk.State
if state == nil {
state = &packet.ShellState{}
return nil, fmt.Errorf("invalid run packet, no state")
}
cmd := MakeShExec(pk.CK, nil)
cmd := MakeShExec(pk.CK, nil, sapi)
defer func() {
// on error, call cmd.Close()
if rtnErr != nil {
@ -1104,7 +806,7 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
cmd.MsgSender = sender
}
var rtnStateWriter *os.File
rcFileStr := makeRcFileStr(pk)
rcFileStr := sapi.MakeRcFileStr(pk)
if pk.ReturnState {
pr, pw, err := os.Pipe()
if err != nil {
@ -1115,10 +817,10 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
cmd.ReturnState.FdNum = RtnStateFdNum
rtnStateWriter = pw
defer pw.Close()
trapCmdStr := makeExitTrap(cmd.ReturnState.FdNum)
trapCmdStr := sapi.MakeExitTrap(cmd.ReturnState.FdNum)
rcFileStr += trapCmdStr
}
shellVarMap := ShellVarMapFromState(state)
shellVarMap := shellenv.ShellVarMapFromState(state)
if base.HasDebugFlag(shellVarMap, base.DebugFlag_LogRcFile) {
debugRcFileName := base.GetDebugRcFileName()
err := os.WriteFile(debugRcFileName, []byte(rcFileStr), 0600)
@ -1126,9 +828,13 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
base.Logf("error writing %s: %v\n", debugRcFileName, err)
}
}
bashVersion := GetLocalBashMajorVersion()
isOldBashVersion := (semver.Compare(bashVersion, "v4") < 0)
var isOldBashVersion bool
if sapi.GetShellType() == packet.ShellType_bash {
bashVersion := sapi.GetLocalMajorVersion()
isOldBashVersion = (semver.Compare(bashVersion, "v4") < 0)
}
var rcFileName string
var zdotdir string
if isOldBashVersion {
rcFileDir, err := base.EnsureRcFilesDir()
if err != nil {
@ -1140,12 +846,19 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
return nil, fmt.Errorf("could not write temp rcfile: %w", err)
}
cmd.TmpRcFileName = rcFileName
go func() {
// cmd.Close() will also remove rcFileName
// adding this to also try to proactively clean up after 1-second.
time.Sleep(1 * time.Second)
os.Remove(rcFileName)
}()
} else if sapi.GetShellType() == packet.ShellType_zsh {
rcFileDir, err := base.EnsureRcFilesDir()
if err != nil {
return nil, err
}
zdotdir = path.Join(rcFileDir, uuid.New().String())
os.Mkdir(zdotdir, 0700)
rcFileName = path.Join(zdotdir, ".zshenv")
err = os.WriteFile(rcFileName, []byte(rcFileStr), 0600)
if err != nil {
return nil, fmt.Errorf("could not write temp rcfile: %w", err)
}
cmd.TmpRcFileName = zdotdir
} else {
rcFileFdNum, err := AddRunData(pk, rcFileStr, "rcfile")
if err != nil {
@ -1153,19 +866,26 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
}
rcFileName = fmt.Sprintf("/dev/fd/%d", rcFileFdNum)
}
if pk.UsePty {
cmd.Cmd = exec.Command(GetLocalBashPath(), "--rcfile", rcFileName, "-i", "-c", pk.Command)
} else {
cmd.Cmd = exec.Command(GetLocalBashPath(), "--rcfile", rcFileName, "-c", pk.Command)
if cmd.TmpRcFileName != "" {
go func() {
// cmd.Close() will also remove rcFileName
// adding this to also try to proactively clean up after 2-seconds.
time.Sleep(2 * time.Second)
os.Remove(cmd.TmpRcFileName)
}()
}
cmd.Cmd = sapi.MakeShExecCommand(pk.Command, rcFileName, pk.UsePty)
if !pk.StateComplete {
cmd.Cmd.Env = os.Environ()
}
UpdateCmdEnv(cmd.Cmd, EnvMapFromState(state))
shellutil.UpdateCmdEnv(cmd.Cmd, shellenv.EnvMapFromState(state))
if sapi.GetShellType() == packet.ShellType_zsh {
shellutil.UpdateCmdEnv(cmd.Cmd, map[string]string{"ZDOTDIR": zdotdir})
}
if state.Cwd != "" {
cmd.Cmd.Dir = base.ExpandHomeDir(state.Cwd)
}
err := ValidateRemoteFds(pk.Fds)
err = ValidateRemoteFds(pk.Fds)
if err != nil {
return nil, err
}
@ -1181,7 +901,7 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
cmdTty.Close()
}()
cmd.CmdPty = cmdPty
UpdateCmdEnv(cmd.Cmd, MShellEnvVars(getTermType(pk)))
shellutil.UpdateCmdEnv(cmd.Cmd, shellutil.MShellEnvVars(getTermType(pk)))
}
if cmdTty != nil {
cmd.Cmd.Stdin = cmdTty
@ -1374,6 +1094,10 @@ func (cmd *ShExecType) DetachedWait(startPacket *packet.CmdStartPacketType) {
}
func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, *packet.CmdStartPacketType, error) {
sapi, err := shellapi.MakeShellApi(pk.ShellType)
if err != nil {
return nil, nil, err
}
fileNames, err := base.GetCommandFileNames(pk.CK)
if err != nil {
return nil, nil, err
@ -1393,7 +1117,7 @@ func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
defer func() {
cmdTty.Close()
}()
cmd := MakeShExec(pk.CK, nil)
cmd := MakeShExec(pk.CK, nil, sapi)
cmd.FileNames = fileNames
cmd.CmdPty = cmdPty
cmd.Detached = true
@ -1462,7 +1186,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
c.ReturnState.Reader.Close()
}()
<-c.ReturnState.DoneCh
state, _ := ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
state, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
donePacket.FinalState = state
}
endTs := time.Now()
@ -1487,18 +1211,27 @@ func MakeInitPacket() *packet.InitPacketType {
}
initPacket.HostName, _ = os.Hostname()
initPacket.UName = fmt.Sprintf("%s|%s", runtime.GOOS, runtime.GOARCH)
initPacket.Shell = shellapi.DetectLocalShellType()
return initPacket
}
func MakeShellStatePacket(shellType string) (*packet.ShellStatePacketType, error) {
sapi, err := shellapi.MakeShellApi(shellType)
if err != nil {
return nil, err
}
shellState, err := sapi.GetShellState()
if err != nil {
return nil, err
}
rtn := packet.MakeShellStatePacket()
rtn.State = shellState
return rtn, nil
}
func MakeServerInitPacket() (*packet.InitPacketType, error) {
var err error
initPacket := MakeInitPacket()
shellState, err := GetShellState()
if err != nil {
return nil, err
}
initPacket.State = shellState
initPacket.Shell = os.Getenv(ShellVarName)
initPacket.RemoteId, err = base.GetRemoteId()
if err != nil {
return nil, err
@ -1549,129 +1282,3 @@ func getStderr(err error) string {
}
return lines[0]
}
func runSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
ecmd.Env = os.Environ()
UpdateCmdEnv(ecmd, MShellEnvVars(DefaultTermType))
cmdPty, cmdTty, err := pty.Open()
if err != nil {
return nil, fmt.Errorf("opening new pty: %w", err)
}
pty.Setsize(cmdPty, &pty.Winsize{Rows: DefaultTermRows, Cols: DefaultTermCols})
ecmd.Stdin = cmdTty
ecmd.Stdout = cmdTty
ecmd.Stderr = cmdTty
ecmd.SysProcAttr = &syscall.SysProcAttr{}
ecmd.SysProcAttr.Setsid = true
ecmd.SysProcAttr.Setctty = true
err = ecmd.Start()
if err != nil {
cmdTty.Close()
cmdPty.Close()
return nil, err
}
cmdTty.Close()
defer cmdPty.Close()
ioDone := make(chan bool)
var outputBuf bytes.Buffer
go func() {
// ignore error (/dev/ptmx has read error when process is done)
io.Copy(&outputBuf, cmdPty)
close(ioDone)
}()
exitErr := ecmd.Wait()
if exitErr != nil {
return nil, exitErr
}
<-ioDone
return outputBuf.Bytes(), nil
}
func GetShellStateRedirectCommandStr(outputFdNum int) string {
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetShellStateCmd(), outputFdNum)
}
func GetShellState() (*packet.ShellState, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
cmdStr := BaseBashOpts + "; " + GetShellStateCmd()
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
outputBytes, err := runSimpleCmdInPty(ecmd)
if err != nil {
return nil, err
}
return ParseShellStateOutput(outputBytes)
}
func MShellEnvVars(termType string) map[string]string {
rtn := make(map[string]string)
if termType != "" {
rtn["TERM"] = termType
}
rtn["MSHELL"], _ = os.Executable()
rtn["MSHELL_VERSION"] = base.MShellVersion
rtn["WAVESHELL"], _ = os.Executable()
rtn["WAVESHELL_VERSION"] = base.MShellVersion
return rtn
}
func ExecGetLocalShellVersion() string {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
ecmd := exec.CommandContext(ctx, "bash", "-c", ShellVersionCmdStr)
out, err := ecmd.Output()
if err != nil {
return ""
}
versionStr := strings.TrimSpace(string(out))
if strings.Index(versionStr, "bash ") == -1 {
// invalid shell version (only bash is supported)
return ""
}
return versionStr
}
func GetLocalBashMajorVersion() string {
LocalBashMajorVersionOnce.Do(func() {
fullVersion := ExecGetLocalShellVersion()
LocalBashMajorVersion = packet.GetBashMajorVersion(fullVersion)
})
return LocalBashMajorVersion
}
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
var cachedMacUserShell string
var macUserShellOnce = &sync.Once{}
func GetMacUserShell() string {
if runtime.GOOS != "darwin" {
return ""
}
macUserShellOnce.Do(func() {
cachedMacUserShell = internalMacUserShell()
})
return cachedMacUserShell
}
// dscl . -read /User/[username] UserShell
// defaults to /bin/bash
func internalMacUserShell() string {
osUser, err := user.Current()
if err != nil {
return DefaultMacOSShell
}
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
userStr := "/Users/" + osUser.Username
out, err := exec.CommandContext(ctx, "dscl", ".", "-read", userStr, "UserShell").CombinedOutput()
if err != nil {
return DefaultMacOSShell
}
outStr := strings.TrimSpace(string(out))
m := userShellRegexp.FindStringSubmatch(outStr)
if m == nil {
return DefaultMacOSShell
}
return m[1]
}

View File

@ -10,7 +10,8 @@ import (
"strings"
)
const LineDiffVersion = 0
const LineDiffVersion_0 = 0
const LineDiffVersion = 1
type SingleLineEntry struct {
LineVal int
@ -18,12 +19,21 @@ type SingleLineEntry struct {
}
type LineDiffType struct {
Version int
SplitString string // added in version 1
Lines []SingleLineEntry
NewData []string
}
func (diff *LineDiffType) Clear() {
diff.Version = LineDiffVersion
diff.SplitString = ""
diff.Lines = nil
diff.NewData = nil
}
func (diff LineDiffType) Dump() {
fmt.Printf("DIFF:\n")
fmt.Printf("DIFF: v%d\n", diff.Version)
pos := 1
for _, entry := range diff.Lines {
fmt.Printf(" %d-%d: %d\n", pos, pos+entry.Run, entry.LineVal)
@ -68,13 +78,37 @@ func putUVarint(buf *bytes.Buffer, viBuf []byte, ival int) {
buf.Write(viBuf[0:l])
}
// run length encoding, writes a uvarint for length, and then that many bytes of data
func putEncodedString(buf *bytes.Buffer, viBuf []byte, str string) {
l := binary.PutUvarint(viBuf, uint64(len(str)))
buf.Write(viBuf[0:l])
buf.WriteString(str)
}
func readEncodedString(buf *bytes.Buffer) (string, error) {
strLen64, err := binary.ReadUvarint(buf)
if err != nil {
return "", fmt.Errorf("invalid diff, cannot read string length: %v", err)
}
strLen := int(strLen64)
if strLen == 0 {
return "", nil
}
strBytes := buf.Next(strLen)
if len(strBytes) != strLen {
return "", fmt.Errorf("invalid diff, partial read, expected %d, got %d", strLen, len(strBytes))
}
return string(strBytes), nil
}
// version 0 is no longer used, but kept here as a reference for decoding
// simple encoding
// write varints. first version, then len, then len-number-of-varints, then fill the rest with newdata
// write varints. first version, then then len, then len-number-of-varints, then fill the rest with newdata
// [version] [len-varint] [varint]xlen... newdata (bytes)
func (diff LineDiffType) Encode() []byte {
func (diff LineDiffType) Encode_v0() []byte {
var buf bytes.Buffer
viBuf := make([]byte, binary.MaxVarintLen64)
putUVarint(&buf, viBuf, LineDiffVersion)
putUVarint(&buf, viBuf, LineDiffVersion_0)
putUVarint(&buf, viBuf, len(diff.Lines))
for _, entry := range diff.Lines {
putUVarint(&buf, viBuf, entry.LineVal)
@ -83,37 +117,117 @@ func (diff LineDiffType) Encode() []byte {
for idx, str := range diff.NewData {
buf.WriteString(str)
if idx != len(diff.NewData)-1 {
buf.WriteByte('\n')
buf.WriteString(diff.SplitString)
}
}
return buf.Bytes()
}
func (rtn *LineDiffType) Decode(diffBytes []byte) error {
r := bytes.NewBuffer(diffBytes)
version, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read version: %v", err)
// version 1 updates the diff to include the split-string
// it also encodes all the strings with run-length encoding
func (diff LineDiffType) Encode() []byte {
var buf bytes.Buffer
viBuf := make([]byte, binary.MaxVarintLen64)
putUVarint(&buf, viBuf, LineDiffVersion)
putEncodedString(&buf, viBuf, diff.SplitString)
putUVarint(&buf, viBuf, len(diff.Lines))
for _, entry := range diff.Lines {
putUVarint(&buf, viBuf, entry.LineVal)
putUVarint(&buf, viBuf, entry.Run)
}
if version != LineDiffVersion {
return fmt.Errorf("invalid diff, bad version: %d", version)
}
linesLen64, err := binary.ReadUvarint(r)
writeEncodedStringArray(&buf, viBuf, diff.NewData)
return buf.Bytes()
}
func (rtn *LineDiffType) readEncodedLines(buf *bytes.Buffer) error {
linesLen64, err := binary.ReadUvarint(buf)
if err != nil {
return fmt.Errorf("invalid diff, cannot read lines length: %v", err)
}
linesLen := int(linesLen64)
rtn.Lines = make([]SingleLineEntry, linesLen)
for idx := 0; idx < linesLen; idx++ {
lineVal, err := binary.ReadUvarint(r)
lineVal64, err := binary.ReadUvarint(buf)
if err != nil {
return fmt.Errorf("invalid diff, cannot read line %d: %v", idx, err)
}
lineRun, err := binary.ReadUvarint(r)
lineRun64, err := binary.ReadUvarint(buf)
if err != nil {
return fmt.Errorf("invalid diff, cannot read line-run %d: %v", idx, err)
}
rtn.Lines[idx] = SingleLineEntry{LineVal: int(lineVal), Run: int(lineRun)}
rtn.Lines[idx] = SingleLineEntry{LineVal: int(lineVal64), Run: int(lineRun64)}
}
return nil
}
func writeEncodedStringArray(buf *bytes.Buffer, viBuf []byte, strArr []string) {
putUVarint(buf, viBuf, len(strArr))
for _, str := range strArr {
putEncodedString(buf, viBuf, str)
}
}
func readEncodedStringArray(buf *bytes.Buffer) ([]string, error) {
strArrLen64, err := binary.ReadUvarint(buf)
if err != nil {
return nil, fmt.Errorf("invalid diff, cannot read string-array length: %v", err)
}
strArrLen := int(strArrLen64)
rtn := make([]string, strArrLen)
for idx := 0; idx < strArrLen; idx++ {
str, err := readEncodedString(buf)
if err != nil {
return nil, err
}
rtn[idx] = str
}
return rtn, nil
}
func (rtn *LineDiffType) Decode(diffBytes []byte) error {
rtn.Clear()
r := bytes.NewBuffer(diffBytes)
version, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read version: %v", err)
}
if version == LineDiffVersion_0 {
return rtn.Decode_v0(diffBytes)
}
if version != LineDiffVersion {
return fmt.Errorf("invalid diff, bad version: %d", version)
}
rtn.Version = int(version)
splitString, err := readEncodedString(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read split-string: %v", err)
}
rtn.SplitString = splitString
err = rtn.readEncodedLines(r)
if err != nil {
return err
}
rtn.NewData, err = readEncodedStringArray(r)
if err != nil {
return err
}
return nil
}
func (rtn *LineDiffType) Decode_v0(diffBytes []byte) error {
r := bytes.NewBuffer(diffBytes)
version, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read version: %v", err)
}
if version != LineDiffVersion_0 {
return fmt.Errorf("invalid diff, bad version: %d", version)
}
rtn.Version = int(version)
rtn.SplitString = "\n" // added when we added version 1
err = rtn.readEncodedLines(r)
if err != nil {
return err
}
restOfInput := string(r.Bytes())
if len(restOfInput) > 0 {
@ -122,8 +236,10 @@ func (rtn *LineDiffType) Decode(diffBytes []byte) error {
return nil
}
func makeLineDiff(oldData []string, newData []string) LineDiffType {
func makeLineDiff(oldData []string, newData []string, splitString string) LineDiffType {
var rtn LineDiffType
rtn.Version = LineDiffVersion
rtn.SplitString = splitString
oldDataMap := make(map[string]int) // 1-indexed
for idx, str := range oldData {
if _, found := oldDataMap[str]; found {
@ -163,13 +279,13 @@ func makeLineDiff(oldData []string, newData []string) LineDiffType {
return rtn
}
func MakeLineDiff(str1 string, str2 string) []byte {
func MakeLineDiff(str1 string, str2 string, splitString string) []byte {
if str1 == str2 {
return nil
}
str1Arr := strings.Split(str1, "\n")
str2Arr := strings.Split(str2, "\n")
diff := makeLineDiff(str1Arr, str2Arr)
str1Arr := strings.Split(str1, splitString)
str2Arr := strings.Split(str2, splitString)
diff := makeLineDiff(str1Arr, str2Arr, splitString)
return diff.Encode()
}
@ -182,10 +298,10 @@ func ApplyLineDiff(str1 string, diffBytes []byte) (string, error) {
if err != nil {
return "", err
}
str1Arr := strings.Split(str1, "\n")
str1Arr := strings.Split(str1, diff.SplitString)
str2Arr, err := diff.applyDiff(str1Arr)
if err != nil {
return "", err
}
return strings.Join(str2Arr, "\n"), nil
return strings.Join(str2Arr, diff.SplitString), nil
}

View File

@ -7,17 +7,27 @@ import (
"bytes"
"encoding/binary"
"fmt"
"slices"
"github.com/wavetermdev/waveterm/waveshell/pkg/binpack"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
const MapDiffVersion = 0
const MapDiffVersion_0 = 0
const MapDiffVersion = 1
// 0-bytes are not allowed in entries or keys (same as bash)
type MapDiffType struct {
ToAdd map[string]string
ToAdd map[string][]byte
ToRemove []string
}
func (diff *MapDiffType) Clear() {
diff.ToAdd = nil
diff.ToRemove = nil
}
func (diff MapDiffType) Dump() {
fmt.Printf("VAR-DIFF\n")
for name, val := range diff.ToAdd {
@ -28,12 +38,12 @@ func (diff MapDiffType) Dump() {
}
}
func makeMapDiff(oldMap map[string]string, newMap map[string]string) MapDiffType {
func makeMapDiff(oldMap map[string][]byte, newMap map[string][]byte) MapDiffType {
var rtn MapDiffType
rtn.ToAdd = make(map[string]string)
rtn.ToAdd = make(map[string][]byte)
for name, newVal := range newMap {
oldVal, found := oldMap[name]
if !found || oldVal != newVal {
if !found || !bytes.Equal(oldVal, newVal) {
rtn.ToAdd[name] = newVal
continue
}
@ -47,8 +57,8 @@ func makeMapDiff(oldMap map[string]string, newMap map[string]string) MapDiffType
return rtn
}
func (diff MapDiffType) apply(oldMap map[string]string) map[string]string {
rtn := make(map[string]string)
func (diff MapDiffType) apply(oldMap map[string][]byte) map[string][]byte {
rtn := make(map[string][]byte)
for name, val := range oldMap {
rtn[name] = val
}
@ -61,15 +71,16 @@ func (diff MapDiffType) apply(oldMap map[string]string) map[string]string {
return rtn
}
func (diff MapDiffType) Encode() []byte {
// this is kept for reference
func (diff MapDiffType) Encode_v0() []byte {
var buf bytes.Buffer
viBuf := make([]byte, binary.MaxVarintLen64)
putUVarint(&buf, viBuf, MapDiffVersion)
putUVarint(&buf, viBuf, MapDiffVersion_0)
putUVarint(&buf, viBuf, len(diff.ToAdd))
for key, val := range diff.ToAdd {
buf.WriteString(key)
buf.WriteByte(0)
buf.WriteString(val)
buf.Write(val)
buf.WriteByte(0)
}
for _, val := range diff.ToRemove {
@ -79,13 +90,75 @@ func (diff MapDiffType) Encode() []byte {
return buf.Bytes()
}
// we sort map keys and remove values to make the diff deterministic
func (diff MapDiffType) Encode() []byte {
var buf bytes.Buffer
binpack.PackUInt(&buf, MapDiffVersion)
binpack.PackUInt(&buf, uint64(len(diff.ToAdd)))
addKeys := utilfn.GetOrderedMapKeys(diff.ToAdd)
for _, key := range addKeys {
val := diff.ToAdd[key]
binpack.PackValue(&buf, []byte(key))
binpack.PackValue(&buf, val)
}
slices.Sort(diff.ToRemove)
binpack.PackUInt(&buf, uint64(len(diff.ToRemove)))
for _, val := range diff.ToRemove {
binpack.PackValue(&buf, []byte(val))
}
return buf.Bytes()
}
func (diff *MapDiffType) Decode(diffBytes []byte) error {
diff.Clear()
r := bytes.NewBuffer(diffBytes)
version, err := binpack.UnpackUInt(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read version: %v", err)
}
if version == MapDiffVersion_0 {
return diff.Decode_v0(diffBytes)
}
if version != MapDiffVersion {
return fmt.Errorf("invalid diff, bad version: %d", version)
}
addLen, err := binpack.UnpackUIntAsInt(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read add length: %v", err)
}
diff.ToAdd = make(map[string][]byte)
for i := 0; i < addLen; i++ {
key, err := binpack.UnpackValue(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read add key %d: %v", i, err)
}
val, err := binpack.UnpackValue(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read add val %d: %v", i, err)
}
diff.ToAdd[string(key)] = val
}
removeLen, err := binpack.UnpackUIntAsInt(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read remove length: %v", err)
}
for i := 0; i < removeLen; i++ {
val, err := binpack.UnpackValue(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read remove val %d: %v", i, err)
}
diff.ToRemove = append(diff.ToRemove, string(val))
}
return nil
}
func (diff *MapDiffType) Decode_v0(diffBytes []byte) error {
r := bytes.NewBuffer(diffBytes)
version, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read version: %v", err)
}
if version != MapDiffVersion {
if version != MapDiffVersion_0 {
return fmt.Errorf("invalid diff, bad version: %d", version)
}
mapLen64, err := binary.ReadUvarint(r)
@ -99,9 +172,9 @@ func (diff *MapDiffType) Decode(diffBytes []byte) error {
}
mapFields := fields[0 : 2*mapLen]
removeFields := fields[2*mapLen:]
diff.ToAdd = make(map[string]string)
diff.ToAdd = make(map[string][]byte)
for i := 0; i < len(mapFields); i += 2 {
diff.ToAdd[string(mapFields[i])] = string(mapFields[i+1])
diff.ToAdd[string(mapFields[i])] = mapFields[i+1]
}
for _, removeVal := range removeFields {
if len(removeVal) == 0 {
@ -112,7 +185,7 @@ func (diff *MapDiffType) Decode(diffBytes []byte) error {
return nil
}
func MakeMapDiff(m1 map[string]string, m2 map[string]string) []byte {
func MakeMapDiff(m1 map[string][]byte, m2 map[string][]byte) []byte {
diff := makeMapDiff(m1, m2)
if len(diff.ToAdd) == 0 && len(diff.ToRemove) == 0 {
return nil
@ -120,7 +193,7 @@ func MakeMapDiff(m1 map[string]string, m2 map[string]string) []byte {
return diff.Encode()
}
func ApplyMapDiff(oldMap map[string]string, diffBytes []byte) (map[string]string, error) {
func ApplyMapDiff(oldMap map[string][]byte, diffBytes []byte) (map[string][]byte, error) {
if len(diffBytes) == 0 {
return oldMap, nil
}

View File

@ -4,8 +4,12 @@
package statediff
import (
"encoding/binary"
"fmt"
"strings"
"testing"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
const Str1 = `
@ -38,8 +42,8 @@ banana2
coconut
`
func testLineDiff(t *testing.T, str1 string, str2 string) {
diffBytes := MakeLineDiff(str1, str2)
func testLineDiff(t *testing.T, str1 string, str2 string, splitString string) {
diffBytes := MakeLineDiff(str1, str2, splitString)
fmt.Printf("diff-len: %d\n", len(diffBytes))
out, err := ApplyLineDiff(str1, diffBytes)
if err != nil {
@ -57,35 +61,86 @@ func testLineDiff(t *testing.T, str1 string, str2 string) {
}
func TestLineDiff(t *testing.T) {
testLineDiff(t, Str1, Str2)
testLineDiff(t, Str2, Str3)
testLineDiff(t, Str1, Str3)
testLineDiff(t, Str3, Str1)
testLineDiff(t, Str3, Str4)
testLineDiff(t, Str1, Str2, "\n")
testLineDiff(t, Str2, Str3, "\n")
testLineDiff(t, Str1, Str3, "\n")
testLineDiff(t, Str3, Str1, "\n")
testLineDiff(t, Str3, Str4, "\n")
}
func strMapsEqual(m1 map[string]string, m2 map[string]string) bool {
if len(m1) != len(m2) {
return false
func TestLineDiff0(t *testing.T) {
var str1Arr []string = []string{"a", "b", "c", "d", "e"}
var str2Arr []string = []string{"a", "e"}
str1 := strings.Join(str1Arr, "\x00")
str2 := strings.Join(str2Arr, "\x00")
diffBytes := MakeLineDiff(str1, str2, "\x00")
fmt.Printf("diff-len: %d\n", len(diffBytes))
out, err := ApplyLineDiff(str1, diffBytes)
if err != nil {
t.Errorf("error in diff: %v", err)
return
}
for key, val := range m1 {
val2, ok := m2[key]
if !ok || val != val2 {
return false
if out != str2 {
t.Errorf("bad diff output")
}
diffBytes = MakeLineDiff(str2, str1, "\x00")
fmt.Printf("diff-len: %d\n", len(diffBytes))
out, err = ApplyLineDiff(str2, diffBytes)
if err != nil {
t.Errorf("error in diff: %v", err)
return
}
for key, val := range m2 {
val2, ok := m1[key]
if !ok || val != val2 {
return false
if out != str1 {
t.Errorf("bad diff output")
}
diffBytes = MakeLineDiff(str1, str1, "\x00")
if len(diffBytes) != 0 {
t.Errorf("bad diff output (len should be 0)")
}
var diffVar LineDiffType
diffVar.Decode(diffBytes)
if len(diffVar.Lines) != 0 || len(diffVar.NewData) != 0 || diffVar.Version != LineDiffVersion {
t.Errorf("bad diff output (for decoding nil)")
}
}
func TestLineDiffVersion0(t *testing.T) {
var str1Arr []string = []string{"a", "b", "c", "d", "e"}
var str2Arr []string = []string{"a", "e"}
str1 := strings.Join(str1Arr, "\n")
str2 := strings.Join(str2Arr, "\n")
var diff LineDiffType
diff.Version = 0
diff.SplitString = "\n"
diff.Lines = []SingleLineEntry{{LineVal: 1, Run: 1}, {LineVal: 5, Run: 1}}
encDiff0 := diff.Encode_v0()
var decDiff LineDiffType
err := decDiff.Decode(encDiff0)
if err != nil {
t.Errorf("error decoding diff: %v\n", err)
}
if decDiff.Version != 0 {
t.Errorf("bad version")
}
if decDiff.SplitString != "\n" {
t.Errorf("bad split string")
}
out, err := ApplyLineDiff(str1, encDiff0)
if err != nil {
t.Errorf("error in diff: %v", err)
return
}
if out != str2 {
t.Errorf("bad diff output")
}
return true
}
func TestMapDiff(t *testing.T) {
m1 := map[string]string{"a": "5", "b": "hello", "c": "mike"}
m2 := map[string]string{"a": "5", "b": "goodbye", "d": "more"}
m1 := map[string][]byte{"a": []byte("5"), "b": []byte("hello"), "c": []byte("mike")}
m2 := map[string][]byte{"a": []byte("5"), "b": []byte("goodbye"), "d": []byte("more")}
diffBytes := MakeMapDiff(m1, m2)
fmt.Printf("mapdifflen: %d\n", len(diffBytes))
var diff MapDiffType
@ -95,8 +150,38 @@ func TestMapDiff(t *testing.T) {
if err != nil {
t.Fatalf("error applying map diff: %v", err)
}
if !strMapsEqual(m2, mcheck) {
if !utilfn.ByteMapsEqual(m2, mcheck) {
t.Errorf("maps not equal")
}
// try v0
mdiff := makeMapDiff(m1, m2)
diffBytes = mdiff.Encode_v0()
mcheck, err = ApplyMapDiff(m1, diffBytes)
if err != nil {
t.Fatalf("error applying map diff: %v", err)
}
if !utilfn.ByteMapsEqual(m2, mcheck) {
t.Errorf("maps not equal")
}
diffBytes = MakeMapDiff(m1, m1)
if len(diffBytes) != 0 {
t.Errorf("bad diff output (len should be 0)")
}
mcheck, err = ApplyMapDiff(m1, diffBytes)
if err != nil {
t.Fatalf("error applying map diff: %v", err)
}
if !utilfn.ByteMapsEqual(m1, mcheck) {
t.Errorf("maps not equal")
}
fmt.Printf("%v\n", mcheck)
}
func TestVarint(t *testing.T) {
viBuf := make([]byte, 10)
viLen := binary.PutVarint(viBuf, 1)
fmt.Printf("%#v\n", viBuf[0:viLen])
viLen = binary.PutUvarint(viBuf, 1)
fmt.Printf("%#v\n", viBuf[0:viLen])
}

View File

@ -0,0 +1,522 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package utilfn
import (
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"math"
"regexp"
"sort"
"strings"
"unicode/utf8"
)
var HexDigits = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}
func GetStrArr(v interface{}, field string) []string {
if v == nil {
return nil
}
m, ok := v.(map[string]interface{})
if !ok {
return nil
}
fieldVal := m[field]
if fieldVal == nil {
return nil
}
iarr, ok := fieldVal.([]interface{})
if !ok {
return nil
}
var sarr []string
for _, iv := range iarr {
if sv, ok := iv.(string); ok {
sarr = append(sarr, sv)
}
}
return sarr
}
func GetBool(v interface{}, field string) bool {
if v == nil {
return false
}
m, ok := v.(map[string]interface{})
if !ok {
return false
}
fieldVal := m[field]
if fieldVal == nil {
return false
}
bval, ok := fieldVal.(bool)
if !ok {
return false
}
return bval
}
var needsQuoteRe = regexp.MustCompile(`[^\w@%:,./=+-]`)
// minimum maxlen=6
func ShellQuote(val string, forceQuote bool, maxLen int) string {
if maxLen < 6 {
maxLen = 6
}
rtn := val
if needsQuoteRe.MatchString(val) {
rtn = "'" + strings.ReplaceAll(val, "'", `'"'"'`) + "'"
}
if strings.HasPrefix(rtn, "\"") || strings.HasPrefix(rtn, "'") {
if len(rtn) > maxLen {
return rtn[0:maxLen-4] + "..." + rtn[0:1]
}
return rtn
}
if forceQuote {
if len(rtn) > maxLen-2 {
return "\"" + rtn[0:maxLen-5] + "...\""
}
return "\"" + rtn + "\""
} else {
if len(rtn) > maxLen {
return rtn[0:maxLen-3] + "..."
}
return rtn
}
}
func EllipsisStr(s string, maxLen int) string {
if maxLen < 4 {
maxLen = 4
}
if len(s) > maxLen {
return s[0:maxLen-3] + "..."
}
return s
}
func LongestPrefix(root string, strs []string) string {
if len(strs) == 0 {
return root
}
if len(strs) == 1 {
comp := strs[0]
if len(comp) >= len(root) && strings.HasPrefix(comp, root) {
if strings.HasSuffix(comp, "/") {
return strs[0]
}
return strs[0]
}
}
lcp := strs[0]
for i := 1; i < len(strs); i++ {
s := strs[i]
for j := 0; j < len(lcp); j++ {
if j >= len(s) || lcp[j] != s[j] {
lcp = lcp[0:j]
break
}
}
}
if len(lcp) < len(root) || !strings.HasPrefix(lcp, root) {
return root
}
return lcp
}
func ContainsStr(strs []string, test string) bool {
for _, s := range strs {
if s == test {
return true
}
}
return false
}
func IsPrefix(strs []string, test string) bool {
for _, s := range strs {
if len(s) > len(test) && strings.HasPrefix(s, test) {
return true
}
}
return false
}
// sentinel value for StrWithPos.Pos to indicate no position
const NoStrPos = -1
type StrWithPos struct {
Str string `json:"str"`
Pos int `json:"pos"` // this is a 'rune' position (not a byte position)
}
func (sp StrWithPos) String() string {
return strWithCursor(sp.Str, sp.Pos)
}
func ParseToSP(s string) StrWithPos {
idx := strings.Index(s, "[*]")
if idx == -1 {
return StrWithPos{Str: s, Pos: NoStrPos}
}
return StrWithPos{Str: s[0:idx] + s[idx+3:], Pos: utf8.RuneCountInString(s[0:idx])}
}
func strWithCursor(str string, pos int) string {
if pos == NoStrPos {
return str
}
if pos < 0 {
// invalid position
return "[*]_" + str
}
if pos > len(str) {
// invalid position
return str + "_[*]"
}
if pos == len(str) {
return str + "[*]"
}
var rtn []rune
for _, ch := range str {
if len(rtn) == pos {
rtn = append(rtn, '[', '*', ']')
}
rtn = append(rtn, ch)
}
return string(rtn)
}
func (sp StrWithPos) Prepend(str string) StrWithPos {
return StrWithPos{Str: str + sp.Str, Pos: utf8.RuneCountInString(str) + sp.Pos}
}
func (sp StrWithPos) Append(str string) StrWithPos {
return StrWithPos{Str: sp.Str + str, Pos: sp.Pos}
}
// returns base64 hash of data
func Sha1Hash(data []byte) string {
hvalRaw := sha1.Sum(data)
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
return hval
}
func ChunkSlice[T any](s []T, chunkSize int) [][]T {
var rtn [][]T
for len(rtn) > 0 {
if len(s) <= chunkSize {
rtn = append(rtn, s)
break
}
rtn = append(rtn, s[:chunkSize])
s = s[chunkSize:]
}
return rtn
}
var ErrOverflow = errors.New("integer overflow")
// Add two int values, returning an error if the result overflows.
func AddInt(left, right int) (int, error) {
if right > 0 {
if left > math.MaxInt-right {
return 0, ErrOverflow
}
} else {
if left < math.MinInt-right {
return 0, ErrOverflow
}
}
return left + right, nil
}
// Add a slice of ints, returning an error if the result overflows.
func AddIntSlice(vals ...int) (int, error) {
var rtn int
for _, v := range vals {
var err error
rtn, err = AddInt(rtn, v)
if err != nil {
return 0, err
}
}
return rtn, nil
}
func StrsEqual(s1arr []string, s2arr []string) bool {
if len(s1arr) != len(s2arr) {
return false
}
for i, s1 := range s1arr {
s2 := s2arr[i]
if s1 != s2 {
return false
}
}
return true
}
func StrMapsEqual(m1 map[string]string, m2 map[string]string) bool {
if len(m1) != len(m2) {
return false
}
for key, val1 := range m1 {
val2, found := m2[key]
if !found || val1 != val2 {
return false
}
}
for key := range m2 {
_, found := m1[key]
if !found {
return false
}
}
return true
}
func ByteMapsEqual(m1 map[string][]byte, m2 map[string][]byte) bool {
if len(m1) != len(m2) {
return false
}
for key, val1 := range m1 {
val2, found := m2[key]
if !found || !bytes.Equal(val1, val2) {
return false
}
}
for key := range m2 {
_, found := m1[key]
if !found {
return false
}
}
return true
}
func GetOrderedStringerMapKeys[K interface {
comparable
fmt.Stringer
}, V any](m map[K]V) []K {
keyStrMap := make(map[K]string)
keys := make([]K, 0, len(m))
for key := range m {
keys = append(keys, key)
keyStrMap[key] = key.String()
}
sort.Slice(keys, func(i, j int) bool {
return keyStrMap[keys[i]] < keyStrMap[keys[j]]
})
return keys
}
func GetOrderedMapKeys[V any](m map[string]V) []string {
keys := make([]string, 0, len(m))
for key := range m {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}
const (
nullEncodeEscByte = '\\'
nullEncodeSepByte = '|'
nullEncodeEqByte = '='
nullEncodeZeroByteEsc = '0'
nullEncodeEscByteEsc = '\\'
nullEncodeSepByteEsc = 's'
nullEncodeEqByteEsc = 'e'
)
func EncodeStringMap(m map[string]string) []byte {
var buf bytes.Buffer
for idx, key := range GetOrderedMapKeys(m) {
val := m[key]
buf.Write(NullEncodeStr(key))
buf.WriteByte(nullEncodeEqByte)
buf.Write(NullEncodeStr(val))
if idx < len(m)-1 {
buf.WriteByte(nullEncodeSepByte)
}
}
return buf.Bytes()
}
func DecodeStringMap(barr []byte) (map[string]string, error) {
if len(barr) == 0 {
return nil, nil
}
var rtn = make(map[string]string)
for _, b := range bytes.Split(barr, []byte{nullEncodeSepByte}) {
keyVal := bytes.SplitN(b, []byte{nullEncodeEqByte}, 2)
if len(keyVal) != 2 {
return nil, fmt.Errorf("invalid null encoding: %s", string(b))
}
key, err := NullDecodeStr(keyVal[0])
if err != nil {
return nil, err
}
val, err := NullDecodeStr(keyVal[1])
if err != nil {
return nil, err
}
rtn[key] = val
}
return rtn, nil
}
func EncodeStringArray(arr []string) []byte {
var buf bytes.Buffer
for idx, s := range arr {
buf.Write(NullEncodeStr(s))
if idx < len(arr)-1 {
buf.WriteByte(nullEncodeSepByte)
}
}
return buf.Bytes()
}
func DecodeStringArray(barr []byte) ([]string, error) {
if len(barr) == 0 {
return nil, nil
}
var rtn []string
for _, b := range bytes.Split(barr, []byte{nullEncodeSepByte}) {
s, err := NullDecodeStr(b)
if err != nil {
return nil, err
}
rtn = append(rtn, s)
}
return rtn, nil
}
func EncodedStringArrayHasFirstKey(encoded []byte, firstKey string) bool {
firstKeyBytes := NullEncodeStr(firstKey)
if !bytes.HasPrefix(encoded, firstKeyBytes) {
return false
}
if len(encoded) == len(firstKeyBytes) || encoded[len(firstKeyBytes)] == nullEncodeSepByte {
return true
}
return false
}
// encodes a string, removing null/zero bytes (and separators '|')
// a zero byte is encoded as "\0", a '\' is encoded as "\\", sep is encoded as "\s"
// allows for easy double splitting (first on \x00, and next on "|")
func NullEncodeStr(s string) []byte {
strBytes := []byte(s)
if bytes.IndexByte(strBytes, 0) == -1 &&
bytes.IndexByte(strBytes, nullEncodeEscByte) == -1 &&
bytes.IndexByte(strBytes, nullEncodeSepByte) == -1 &&
bytes.IndexByte(strBytes, nullEncodeEqByte) == -1 {
return strBytes
}
var rtn []byte
for _, b := range strBytes {
if b == 0 {
rtn = append(rtn, nullEncodeEscByte, nullEncodeZeroByteEsc)
} else if b == nullEncodeEscByte {
rtn = append(rtn, nullEncodeEscByte, nullEncodeEscByteEsc)
} else if b == nullEncodeSepByte {
rtn = append(rtn, nullEncodeEscByte, nullEncodeSepByteEsc)
} else if b == nullEncodeEqByte {
rtn = append(rtn, nullEncodeEscByte, nullEncodeEqByteEsc)
} else {
rtn = append(rtn, b)
}
}
return rtn
}
func NullDecodeStr(barr []byte) (string, error) {
if bytes.IndexByte(barr, nullEncodeEscByte) == -1 {
return string(barr), nil
}
var rtn []byte
for i := 0; i < len(barr); i++ {
curByte := barr[i]
if curByte == nullEncodeEscByte {
i++
nextByte := barr[i]
if nextByte == nullEncodeZeroByteEsc {
rtn = append(rtn, 0)
} else if nextByte == nullEncodeEscByteEsc {
rtn = append(rtn, nullEncodeEscByte)
} else if nextByte == nullEncodeSepByteEsc {
rtn = append(rtn, nullEncodeSepByte)
} else if nextByte == nullEncodeEqByteEsc {
rtn = append(rtn, nullEncodeEqByte)
} else {
// invalid encoding
return "", fmt.Errorf("invalid null encoding: %d", nextByte)
}
} else {
rtn = append(rtn, curByte)
}
}
return string(rtn), nil
}
func SortStringRunes(s string) string {
runes := []rune(s)
sort.Slice(runes, func(i, j int) bool {
return runes[i] < runes[j]
})
return string(runes)
}
// will overwrite m1 with m2's values
func CombineMaps[V any](m1 map[string]V, m2 map[string]V) {
for key, val := range m2 {
m1[key] = val
}
}
// returns hex escaped string (\xNN for each byte)
func ShellHexEscape(s string) string {
var rtn []byte
for _, ch := range []byte(s) {
rtn = append(rtn, []byte(fmt.Sprintf("\\x%02x", ch))...)
}
return string(rtn)
}
func GetMapKeys[K comparable, V any](m map[K]V) []K {
var rtn []K
for key := range m {
rtn = append(rtn, key)
}
return rtn
}
// combines string arrays and removes duplicates (returns a new array)
func CombineStrArrays(sarr1 []string, sarr2 []string) []string {
var rtn []string
m := make(map[string]struct{})
for _, s := range sarr1 {
if _, found := m[s]; found {
continue
}
m[s] = struct{}{}
rtn = append(rtn, s)
}
for _, s := range sarr2 {
if _, found := m[s]; found {
continue
}
m[s] = struct{}{}
rtn = append(rtn, s)
}
return rtn
}

View File

@ -0,0 +1,175 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package utilfn
import (
"fmt"
"math"
"testing"
)
const Str1 = `
hello
line #2
more
stuff
apple
`
const Str2 = `
line #2
apple
grapes
banana
`
const Str3 = `
more
stuff
banana
coconut
`
func testDiff(t *testing.T, str1 string, str2 string) {
diffBytes := MakeDiff(str1, str2)
fmt.Printf("diff-len: %d\n", len(diffBytes))
out, err := ApplyDiff(str1, diffBytes)
if err != nil {
t.Errorf("error in diff: %v", err)
return
}
if out != str2 {
t.Errorf("bad diff output")
}
}
func TestDiff(t *testing.T) {
testDiff(t, Str1, Str2)
testDiff(t, Str2, Str3)
testDiff(t, Str1, Str3)
testDiff(t, Str3, Str1)
}
func testArithmetic(t *testing.T, fn func() (int, error), shouldError bool, expected int) {
retVal, err := fn()
if err != nil {
if !shouldError {
t.Errorf("unexpected error")
}
return
}
if shouldError {
t.Errorf("expected error")
return
}
if retVal != expected {
t.Errorf("wrong return value")
}
}
func testAddInt(t *testing.T, shouldError bool, expected int, a int, b int) {
testArithmetic(t, func() (int, error) { return AddInt(a, b) }, shouldError, expected)
}
func TestAddInt(t *testing.T) {
testAddInt(t, false, 3, 1, 2)
testAddInt(t, true, 0, 1, math.MaxInt)
testAddInt(t, true, 0, math.MinInt, -1)
testAddInt(t, false, math.MaxInt-1, math.MaxInt, -1)
testAddInt(t, false, math.MinInt+1, math.MinInt, 1)
testAddInt(t, false, math.MaxInt, math.MaxInt, 0)
testAddInt(t, true, 0, math.MinInt, -1)
}
func testAddIntSlice(t *testing.T, shouldError bool, expected int, vals ...int) {
testArithmetic(t, func() (int, error) { return AddIntSlice(vals...) }, shouldError, expected)
}
func TestAddIntSlice(t *testing.T) {
testAddIntSlice(t, false, 0)
testAddIntSlice(t, false, 1, 1)
testAddIntSlice(t, false, 3, 1, 2)
testAddIntSlice(t, false, 6, 1, 2, 3)
testAddIntSlice(t, true, 0, 1, math.MaxInt)
testAddIntSlice(t, true, 0, 1, 2, math.MaxInt)
testAddIntSlice(t, true, 0, math.MaxInt, 2, 1)
testAddIntSlice(t, false, math.MaxInt, 0, 0, math.MaxInt)
testAddIntSlice(t, true, 0, math.MinInt, -1)
testAddIntSlice(t, false, math.MaxInt, math.MaxInt-3, 1, 2)
testAddIntSlice(t, true, 0, math.MaxInt-2, 1, 2)
testAddIntSlice(t, false, math.MinInt, math.MinInt+3, -1, -2)
testAddIntSlice(t, true, 0, math.MinInt+2, -1, -2)
}
func testNullEncodeStr(t *testing.T, str string, expected string) {
encoded := NullEncodeStr(str)
decoded, err := NullDecodeStr(encoded)
if err != nil {
t.Errorf("error in null encoding: %v", err)
} else if decoded != str {
t.Errorf("bad null encoding")
}
if string(encoded) != expected {
t.Errorf("bad null encoding, %q != %q", str, expected)
}
}
func TestNullEncodeStr(t *testing.T) {
testNullEncodeStr(t, "", "")
testNullEncodeStr(t, "hello", "hello")
testNullEncodeStr(t, "hello\x00", "hello\\0")
testNullEncodeStr(t, "abc|def", "abc\\sdef")
testNullEncodeStr(t, "a|b\x00c\\d", "a\\sb\\0c\\\\d")
testNullEncodeStr(t, "v==v", "v\\e\\ev")
}
func testEncodeStringArray(t *testing.T, strs []string) {
encoded := EncodeStringArray(strs)
decoded, err := DecodeStringArray(encoded)
if err != nil {
t.Errorf("error in string array encoding: %v", err)
} else if !StrsEqual(strs, decoded) {
t.Errorf("bad string array encoding: %#v != %#v", strs, decoded)
}
}
func TestEncodeStringArray(t *testing.T) {
testEncodeStringArray(t, nil)
testEncodeStringArray(t, []string{})
testEncodeStringArray(t, []string{"hello"})
testEncodeStringArray(t, []string{"hello", "world=bar"})
testEncodeStringArray(t, []string{"hello", "wor\x00ld", "fo|\\o", "N\\\x00|||ul==l"})
}
func testEncodeStringMap(t *testing.T, m map[string]string) {
encoded := EncodeStringMap(m)
decoded, err := DecodeStringMap(encoded)
if err != nil {
t.Errorf("error in string map encoding: %v", err)
} else if !StrMapsEqual(m, decoded) {
t.Errorf("bad string map encoding: %#v != %#v", m, decoded)
}
}
func TestEncodeStringMap(t *testing.T) {
testEncodeStringMap(t, nil)
testEncodeStringMap(t, map[string]string{})
testEncodeStringMap(t, map[string]string{"hello": "world"})
testEncodeStringMap(t, map[string]string{"hello": "world", "foo": "bar"})
testEncodeStringMap(t, map[string]string{"hello": "world", "fo=o": "b=ar", "a|b": "c\\d"})
testEncodeStringMap(t, map[string]string{"hello\x00|": "w\x00orld", "foo": "bar", "a|b": "c\\d", "v==v": "v\\e\\ev"})
}
func testShellHexEscape(t *testing.T, s string, expected string) {
encoded := ShellHexEscape(s)
if encoded != expected {
t.Errorf("bad shell hex encoding, %q != %q", encoded, expected)
}
}
func TestShellHexEscape(t *testing.T) {
testShellHexEscape(t, "", "")
testShellHexEscape(t, "a", `\x61`)
testShellHexEscape(t, "\x00\x01abc\x00", `\x00\x01\x61\x62\x63\x00`)
}

View File

@ -30,6 +30,7 @@ import (
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/server"
"github.com/wavetermdev/waveterm/wavesrv/pkg/cmdrunner"
@ -802,6 +803,7 @@ func doShutdown(reason string) {
func main() {
scbase.BuildTime = BuildTime
base.ProcessType = base.ProcessType_WaveSrv
if len(os.Args) >= 2 && os.Args[1] == "--test" {
log.Printf("running test fn\n")

View File

@ -0,0 +1,2 @@
ALTER TABLE remote_instance DROP COLUMN shelltype;
ALTER TABLE remote DROP COLUMN shellpref;

View File

@ -0,0 +1,2 @@
ALTER TABLE remote_instance ADD COLUMN shelltype varchar(20) NOT NULL DEFAULT 'bash';
ALTER TABLE remote ADD COLUMN shellpref varchar(20) NOT NULL DEFAULT 'detect';

View File

@ -27,7 +27,10 @@ import (
"github.com/kevinburke/ssh_config"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/comp"
"github.com/wavetermdev/waveterm/wavesrv/pkg/dbutil"
"github.com/wavetermdev/waveterm/wavesrv/pkg/pcloud"
@ -37,7 +40,6 @@ import (
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"golang.org/x/mod/semver"
)
@ -243,6 +245,7 @@ func init() {
registerCmdFn("chat", OpenAICommand)
registerCmdFn("_killserver", KillServerCommand)
registerCmdFn("_dumpstate", DumpStateCommand)
registerCmdFn("set", SetCommand)
@ -1200,9 +1203,8 @@ type RemoteEditArgs struct {
ConnectMode string
Alias string
AutoInstall bool
SSHPassword string
SSHKeyFile string
Color string
ShellPref string
EditMap map[string]interface{}
}
@ -1286,6 +1288,16 @@ func parseRemoteEditArgs(isNew bool, pk *scpacket.FeCommandPacketType, isLocal b
return nil, fmt.Errorf("invalid alias format")
}
}
var shellPref string
if isNew {
shellPref = sstore.ShellTypePref_Detect
}
if pk.Kwargs["shellpref"] != "" {
shellPref = pk.Kwargs["shellpref"]
}
if shellPref != "" && shellPref != packet.ShellType_bash && shellPref != packet.ShellType_zsh && shellPref != sstore.ShellTypePref_Detect {
return nil, fmt.Errorf("invalid shellpref %q, must be %s", shellPref, formatStrs([]string{packet.ShellType_bash, packet.ShellType_zsh, sstore.ShellTypePref_Detect}, "or", false))
}
var connectMode string
if isNew {
connectMode = sstore.ConnectModeAuto
@ -1340,6 +1352,9 @@ func parseRemoteEditArgs(isNew bool, pk *scpacket.FeCommandPacketType, isLocal b
}
editMap[sstore.RemoteField_SSHPassword] = sshPassword
}
if _, found := pk.Kwargs["shellpref"]; found {
editMap[sstore.RemoteField_ShellPref] = shellPref
}
return &RemoteEditArgs{
SSHOpts: sshOpts,
@ -1347,10 +1362,9 @@ func parseRemoteEditArgs(isNew bool, pk *scpacket.FeCommandPacketType, isLocal b
Alias: alias,
AutoInstall: true,
CanonicalName: canonicalName,
SSHKeyFile: keyFile,
SSHPassword: sshPassword,
Color: color,
EditMap: editMap,
ShellPref: shellPref,
}, nil
}
@ -1375,6 +1389,7 @@ func RemoteNewCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ss
AutoInstall: editArgs.AutoInstall,
SSHOpts: editArgs.SSHOpts,
SSHConfigSrc: sstore.SSHConfigSrcTypeManual,
ShellPref: editArgs.ShellPref,
}
if editArgs.Color != "" {
r.RemoteOpts = &sstore.RemoteOptsType{Color: editArgs.Color}
@ -1881,7 +1896,7 @@ func crShowCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, ids re
if riBaseMap[remoteId] {
continue
}
feState := msh.GetDefaultFeState()
feState := msh.GetDefaultFeState(msh.GetShellPref())
if feState == nil {
continue
}
@ -2380,7 +2395,7 @@ func makeStaticCmd(ctx context.Context, metaCmd string, ids resolvedIds, cmdStr
CmdStr: cmdStr,
RawCmdStr: cmdStr,
Remote: ids.Remote.RemotePtr,
TermOpts: sstore.TermOpts{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize},
TermOpts: sstore.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize},
Status: sstore.CmdStatusDone,
RunOut: nil,
}
@ -2977,23 +2992,31 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ssto
}
func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen)
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
if err != nil {
return nil, err
}
initPk, err := ids.Remote.MShell.ReInit(ctx)
shellType := ids.Remote.ShellType
if pk.Kwargs["shell"] != "" {
shellArg := pk.Kwargs["shell"]
if shellArg != packet.ShellType_bash && shellArg != packet.ShellType_zsh {
return nil, fmt.Errorf("/reset invalid shell type %q", shellArg)
}
shellType = shellArg
}
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType)
if err != nil {
return nil, err
}
if initPk == nil || initPk.State == nil {
if ssPk == nil || ssPk.State == nil {
return nil, fmt.Errorf("invalid initpk received from remote (no remote state)")
}
feState := sstore.FeStateFromShellState(initPk.State)
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, initPk.State, nil)
feState := sstore.FeStateFromShellState(ssPk.State)
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
if err != nil {
return nil, err
}
outputStr := "reset remote state"
outputStr := fmt.Sprintf("reset remote state (shell:%s)", ssPk.State.GetShellType())
cmd, err := makeStaticCmd(ctx, "reset", ids, pk.GetRawStr(), []byte(outputStr))
if err != nil {
// TODO tricky error since the command was a success, but we can't show the output
@ -4199,6 +4222,20 @@ func KillServerCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (s
return nil, nil
}
func DumpStateCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
if err != nil {
return nil, err
}
currentState, err := sstore.GetFullState(ctx, *ids.Remote.StatePtr)
if err != nil {
return nil, fmt.Errorf("error getting state: %v", err)
}
feState := sstore.FeStateFromShellState(currentState)
shellenv.DumpVarMapFromState(currentState)
return sstore.InfoMsgUpdate("current connection state sent to log. festate: %s", dbutil.QuickJson(feState)), nil
}
func ClientCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
return nil, fmt.Errorf("/client requires a subcommand: %s", formatStrs([]string{"show", "set"}, "or", false))
}

View File

@ -36,6 +36,7 @@ type ResolvedRemote struct {
MShell *remote.MShellProc
RState remote.RemoteRuntimeState
RemoteCopy *sstore.RemoteType
ShellType string
StatePtr *sstore.ShellStatePtr
FeState map[string]string
}
@ -472,6 +473,7 @@ func ResolveRemoteFromPtr(ctx context.Context, rptr *sstore.RemotePtrType, sessi
RemoteCopy: &rcopy,
StatePtr: nil,
FeState: nil,
ShellType: "",
}
if sessionId != "" && screenId != "" {
ri, err := sstore.GetRemoteInstance(ctx, sessionId, screenId, *rptr)
@ -480,11 +482,13 @@ func ResolveRemoteFromPtr(ctx context.Context, rptr *sstore.RemotePtrType, sessi
// continue with state set to nil
} else {
if ri == nil {
rtn.StatePtr = msh.GetDefaultStatePtr()
rtn.FeState = msh.GetDefaultFeState()
rtn.ShellType = msh.GetShellPref()
rtn.StatePtr = msh.GetDefaultStatePtr(rtn.ShellType)
rtn.FeState = msh.GetDefaultFeState(rtn.ShellType)
} else {
rtn.StatePtr = &sstore.ShellStatePtr{BaseHash: ri.StateBaseHash, DiffHashArr: ri.StateDiffHashArr}
rtn.FeState = ri.FeState
rtn.ShellType = ri.ShellType
}
}
}

View File

@ -9,10 +9,10 @@ import (
"regexp"
"strings"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/syntax"
)
@ -141,6 +141,12 @@ func onlyRawArgs(metaCmd string, metaSubCmd string) bool {
return CmdParseOverrides[metaCmd] == CmdParseTypeRaw
}
var waveValidIdentifierRe = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
func isValidWaveParamName(name string) bool {
return waveValidIdentifierRe.MatchString(name)
}
func setBracketArgs(argMap map[string]string, bracketStr string) error {
bracketStr = strings.TrimSpace(bracketStr)
if bracketStr == "" {
@ -160,7 +166,7 @@ func setBracketArgs(argMap map[string]string, bracketStr string) error {
varName = litStr[0:eqIdx]
varVal = litStr[eqIdx+1:]
}
if !shexec.IsValidBashIdentifier(varName) {
if !isValidWaveParamName(varName) {
wordErr = fmt.Errorf("invalid identifier %s in bracket args", utilfn.ShellQuote(varName, true, 20))
return false
}
@ -183,14 +189,29 @@ var literalRtnStateCommands = []string{
".",
"source",
"unset",
"unsetopt",
"cd",
"alias",
"unalias",
"deactivate",
"eval",
"asdf",
"sdk",
"nvm",
"virtualenv",
"builtin",
"typeset",
"declare",
"float",
"functions",
"integer",
"local",
"readonly",
"unfunction",
"shopt",
"enable",
"disable",
"function",
}
func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string {
@ -235,14 +256,13 @@ func isRtnStateCmd(cmd syntax.Command) bool {
if arg0 != "" && utilfn.ContainsStr(literalRtnStateCommands, arg0) {
return true
}
if arg0 == "git" {
arg1 := getCallExprLitArg(callExpr, 1)
if arg0 == "git" {
if arg1 == "checkout" || arg1 == "switch" {
return true
}
}
if arg0 == "conda" {
arg1 := getCallExprLitArg(callExpr, 1)
if arg1 == "activate" || arg1 == "deactivate" {
return true
}
@ -253,12 +273,30 @@ func isRtnStateCmd(cmd syntax.Command) bool {
return false
}
func checkSimpleRtnStateCmd(cmdStr string) bool {
cmdStr = strings.TrimSpace(cmdStr)
if strings.HasPrefix(cmdStr, "function ") {
return true
}
firstSpace := strings.Index(cmdStr, " ")
if firstSpace != -1 {
firstWord := strings.TrimSpace(cmdStr[:firstSpace])
if strings.HasSuffix(firstWord, "()") {
return true
}
}
return false
}
// detects: export, declare, ., source, X=1, unset
func IsReturnStateCommand(cmdStr string) bool {
cmdReader := strings.NewReader(cmdStr)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(cmdReader, "cmd")
if err != nil {
if checkSimpleRtnStateCmd(cmdStr) {
return true
}
return false
}
for _, stmt := range file.Stmts {
@ -353,7 +391,7 @@ func EvalMetaCommand(ctx context.Context, origPk *scpacket.FeCommandPacketType)
return nil, fmt.Errorf("parsing metacmd, position %v", err)
}
envMap := make(map[string]string) // later we can add vars like session, screen, remote, and user
cfg := shexec.GetParserConfig(envMap)
cfg := shellapi.GetParserConfig(envMap)
// process arguments
for idx, w := range words {
literalVal, err := expand.Literal(cfg, w)

View File

@ -10,6 +10,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
@ -87,15 +88,15 @@ func GetUITermOpts(winSize *packet.WinSize, ptermStr string) (*packet.TermOpts,
if err != nil {
return nil, err
}
termOpts := &packet.TermOpts{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols, Term: remote.DefaultTerm, MaxPtySize: shexec.DefaultMaxPtySize}
termOpts := &packet.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, Term: remote.DefaultTerm, MaxPtySize: shexec.DefaultMaxPtySize}
if winSize == nil {
winSize = &packet.WinSize{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols}
winSize = &packet.WinSize{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols}
}
if winSize.Rows == 0 {
winSize.Rows = shexec.DefaultTermRows
winSize.Rows = shellutil.DefaultTermRows
}
if winSize.Cols == 0 {
winSize.Cols = shexec.DefaultTermCols
winSize.Cols = shellutil.DefaultTermCols
}
if opts.Rows == PTermMax {
termOpts.Rows = winSize.Rows

View File

@ -15,9 +15,9 @@ import (
"unicode/utf8"
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/shparse"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"mvdan.cc/sh/v3/syntax"
)
@ -482,7 +482,7 @@ func splitCompWord(p *CompPoint) {
w1 := ParsedWord{Offset: w.Offset, Prefix: w.Prefix[:prefixPos]}
w2 := ParsedWord{Offset: w.Offset + prefixPos, Prefix: w.Prefix[prefixPos:], Word: w.Word, PartialWord: w.PartialWord}
p.CompWord = p.CompWord // the same (w1)
// p.CompWord = p.CompWord // the same (w1)
p.CompWordPos = 0 // will be at 0 since w1 has a word length of 0
var newWords []ParsedWord
if p.CompWord > 0 {

View File

@ -8,7 +8,7 @@ import (
"strings"
"testing"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
func parseToSP(s string) utilfn.StrWithPos {

View File

@ -10,8 +10,8 @@ import (
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
)
var globalLock = &sync.Mutex{}

View File

@ -12,7 +12,7 @@ import (
"io"
"reflect"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
ccp "golang.org/x/crypto/chacha20poly1305"
)

View File

@ -27,8 +27,12 @@ import (
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/server"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
@ -44,6 +48,18 @@ const RemoteTermCols = 80
const PtyReadBufSize = 100
const RemoteConnectTimeout = 15 * time.Second
var envVarsToStrip map[string]bool = map[string]bool{
"PROMPT": true,
"PROMPT_VERSION": true,
"MSHELL": true,
"MSHELL_VERSION": true,
"WAVETERM": true,
"WAVETERM_VERSION": true,
"TERM_PROGRAM": true,
"TERM_PROGRAM_VERSION": true,
"TERM_SESSION_ID": true,
}
// we add this ping packet to the MShellServer Commands in order to deal with spurious SSH output
// basically we guarantee the parser will see a valid packet (either an init error or a ping)
// so we can pass ignoreUntilValid to PacketParser
@ -120,9 +136,9 @@ type MShellProc struct {
PtyBuffer *circbuf.Buffer
MakeClientCancelFn context.CancelFunc
MakeClientDeadline *time.Time
StateMap map[string]*packet.ShellState // sha1->state
CurrentState string // sha1
StateMap *server.ShellStateMap
NumTryConnect int
InitPkShellType string
// install
InstallStatus string
@ -159,32 +175,26 @@ func (msh *MShellProc) GetStatus() string {
return msh.Status
}
func (msh *MShellProc) GetDefaultState() *packet.ShellState {
msh.Lock.Lock()
defer msh.Lock.Unlock()
return msh.StateMap[msh.CurrentState]
func (msh *MShellProc) GetDefaultState(shellType string) *packet.ShellState {
_, state := msh.StateMap.GetCurrentState(shellType)
return state
}
func (msh *MShellProc) GetDefaultStatePtr() *sstore.ShellStatePtr {
func (msh *MShellProc) GetDefaultStatePtr(shellType string) *sstore.ShellStatePtr {
msh.Lock.Lock()
defer msh.Lock.Unlock()
if msh.CurrentState == "" {
hash, _ := msh.StateMap.GetCurrentState(shellType)
if hash == "" {
return nil
}
return &sstore.ShellStatePtr{BaseHash: msh.CurrentState}
return &sstore.ShellStatePtr{BaseHash: hash}
}
func (msh *MShellProc) GetDefaultFeState() map[string]string {
state := msh.GetDefaultState()
func (msh *MShellProc) GetDefaultFeState(shellType string) map[string]string {
state := msh.GetDefaultState(shellType)
return sstore.FeStateFromShellState(state)
}
func (msh *MShellProc) GetStateByHash(hval string) *packet.ShellState {
msh.Lock.Lock()
defer msh.Lock.Unlock()
return msh.StateMap[hval]
}
func (msh *MShellProc) GetRemoteId() string {
msh.Lock.Lock()
defer msh.Lock.Unlock()
@ -494,7 +504,19 @@ func (msh *MShellProc) tryAutoInstall() {
go msh.RunInstall()
}
// if msh.IsConnected() then GetShellPref() should return a valid shell
// if msh is not connected, then InitPkShellType might be empty
func (msh *MShellProc) GetShellPref() string {
msh.Lock.Lock()
defer msh.Lock.Unlock()
if msh.Remote.ShellPref == sstore.ShellTypePref_Detect {
return msh.InitPkShellType
}
return msh.Remote.ShellPref
}
func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
shellPref := msh.GetShellPref()
msh.Lock.Lock()
defer msh.Lock.Unlock()
state := RemoteRuntimeState{
@ -514,6 +536,8 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
Local: msh.Remote.Local,
NoInitPk: msh.ErrNoInitPk,
AuthType: sstore.RemoteAuthTypeNone,
ShellPref: msh.Remote.ShellPref,
DefaultShellType: shellPref,
}
if msh.Remote.SSHOpts != nil {
state.AuthType = msh.Remote.SSHOpts.GetAuthType()
@ -580,7 +604,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
vars["besthost"] = vars["remotehost"]
vars["bestshorthost"] = vars["remoteshorthost"]
}
curState := msh.StateMap[msh.CurrentState]
_, curState := msh.StateMap.GetCurrentState(shellPref)
if curState != nil {
state.DefaultFeState = sstore.FeStateFromShellState(curState)
vars["cwd"] = curState.Cwd
@ -601,6 +625,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
vars["isroot"] = "1"
}
state.RemoteVars = vars
state.ActiveShells = msh.StateMap.GetShells()
return state
}
@ -622,21 +647,6 @@ func GetAllRemoteRuntimeState() []RemoteRuntimeState {
return rtn
}
func GetDefaultRemoteStateById(remoteId string) (*packet.ShellState, error) {
remote := GetRemoteById(remoteId)
if remote == nil {
return nil, fmt.Errorf("remote not found")
}
if !remote.IsConnected() {
return nil, fmt.Errorf("remote not connected")
}
state := remote.GetDefaultState()
if state == nil {
return nil, fmt.Errorf("could not get default remote state")
}
return state, nil
}
func MakeMShell(r *sstore.RemoteType) *MShellProc {
buf, err := circbuf.NewBuffer(CircBufSize)
if err != nil {
@ -651,7 +661,7 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
InstallStatus: StatusDisconnected,
RunningCmds: make(map[base.CommandKey]RunCmdType),
PendingStateCmds: make(map[pendingStateKey]base.CommandKey),
StateMap: make(map[string]*packet.ShellState),
StateMap: server.MakeShellStateMap(),
}
rtn.WriteToPtyBuffer("console for connection [%s]\n", r.GetName())
return rtn
@ -999,11 +1009,16 @@ func (msh *MShellProc) RunInstall() {
msh.WriteToPtyBuffer("*error: cannot install on remote that is already trying to install, cancel current install to try again\n")
return
}
sapi, err := shellapi.MakeShellApi(packet.ShellType_bash)
if err != nil {
msh.WriteToPtyBuffer("*error: %v\n", err)
return
}
msh.WriteToPtyBuffer("installing mshell %s to %s...\n", scbase.MShellVersion, remoteCopy.RemoteCanonicalName)
sshOpts := convertSSHOpts(remoteCopy.SSHOpts)
sshOpts.SSHErrorsToTty = true
cmdStr := shexec.MakeInstallCommandStr()
ecmd := sshOpts.MakeSSHExecCmd(cmdStr)
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi)
cmdPty, err := msh.addControllingTty(ecmd)
if err != nil {
statusErr := fmt.Errorf("cannot attach controlling tty to mshell install command: %w", err)
@ -1084,12 +1099,20 @@ func getStateVarsFromInitPk(initPk *packet.InitPacketType) map[string]string {
rtn["remoteuser"] = initPk.User
rtn["remotehost"] = initPk.HostName
rtn["remoteuname"] = initPk.UName
rtn["shelltype"] = initPk.Shell
return rtn
}
func (msh *MShellProc) ReInit(ctx context.Context) (*packet.InitPacketType, error) {
func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.ShellStatePacketType, error) {
if !msh.IsConnected() {
return nil, fmt.Errorf("cannot reinit, remote is not connected")
}
if shellType != packet.ShellType_bash && shellType != packet.ShellType_zsh {
return nil, fmt.Errorf("invalid shell type %q", shellType)
}
reinitPk := packet.MakeReInitPacket()
reinitPk.ReqId = uuid.New().String()
reinitPk.ShellType = shellType
resp, err := msh.PacketRpcRaw(ctx, reinitPk)
if err != nil {
return nil, err
@ -1097,21 +1120,26 @@ func (msh *MShellProc) ReInit(ctx context.Context) (*packet.InitPacketType, erro
if resp == nil {
return nil, fmt.Errorf("no response")
}
initPk, ok := resp.(*packet.InitPacketType)
ssPk, ok := resp.(*packet.ShellStatePacketType)
if !ok {
return nil, fmt.Errorf("invalid reinit response (not an initpacket): %T", resp)
if respPk, ok := resp.(*packet.ResponsePacketType); ok && respPk.Error != "" {
return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error)
}
if initPk.State == nil {
return nil, fmt.Errorf("invalid reinit response initpk does not contain remote state")
return nil, fmt.Errorf("invalid reinit response (not an shellstate packet): %T", resp)
}
hval := initPk.State.GetHashVal(false)
sstore.StoreStateBase(ctx, initPk.State)
msh.WithLock(func() {
msh.CurrentState = hval
msh.StateMap[hval] = initPk.State
})
msh.updateRemoteStateVars(ctx, msh.RemoteId, initPk)
return initPk, nil
if ssPk.State == nil {
return nil, fmt.Errorf("invalid reinit response shellstate packet does not contain remote state")
}
// TODO: maybe we don't need to save statebase here. should be possible to save it on demand
// when it is actually used. complication from other functions that try to get the statebase
// from the DB. probably need to route those through MShellProc.
err = sstore.StoreStateBase(ctx, ssPk.State)
if err != nil {
return nil, fmt.Errorf("error storing remote state: %w", err)
}
msh.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
msh.WriteToPtyBuffer("initialized shell:%s state:%s\n", shellType, ssPk.State.GetHashVal(false))
return ssPk, nil
}
func (msh *MShellProc) StreamFile(ctx context.Context, streamPk *packet.StreamFilePacketType) (*packet.RpcResponseIter, error) {
@ -1123,13 +1151,15 @@ func addScVarsToState(state *packet.ShellState) *packet.ShellState {
return nil
}
rtn := *state
envMap := shexec.DeclMapFromState(&rtn)
envMap["PROMPT"] = &shexec.DeclareDeclType{Name: "PROMPT", Value: "1", Args: "x"}
envMap["PROMPT_VERSION"] = &shexec.DeclareDeclType{Name: "PROMPT_VERSION", Value: scbase.WaveVersion, Args: "x"}
envMap := shellenv.DeclMapFromState(&rtn)
envMap["WAVETERM"] = &shellenv.DeclareDeclType{Name: "WAVETERM", Value: "1", Args: "x"}
envMap["WAVETERM_VERSION"] = &shellenv.DeclareDeclType{Name: "WAVETERM_VERSION", Value: scbase.WaveVersion, Args: "x"}
envMap["TERM_PROGRAM"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM", Value: "waveterm", Args: "x"}
envMap["TERM_PROGRAM_VERSION"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM_VERSION", Value: scbase.WaveVersion, Args: "x"}
if _, exists := envMap["LANG"]; !exists {
envMap["LANG"] = &shexec.DeclareDeclType{Name: "LANG", Value: scbase.DetermineLang(), Args: "x"}
envMap["LANG"] = &shellenv.DeclareDeclType{Name: "LANG", Value: scbase.DetermineLang(), Args: "x"}
}
rtn.ShellVars = shexec.SerializeDeclMap(envMap)
rtn.ShellVars = shellenv.SerializeDeclMap(envMap)
return &rtn
}
@ -1139,10 +1169,11 @@ func stripScVarsFromState(state *packet.ShellState) *packet.ShellState {
}
rtn := *state
rtn.HashVal = ""
envMap := shexec.DeclMapFromState(&rtn)
delete(envMap, "PROMPT")
delete(envMap, "PROMPT_VERSION")
rtn.ShellVars = shexec.SerializeDeclMap(envMap)
envMap := shellenv.DeclMapFromState(&rtn)
for key := range envVarsToStrip {
delete(envMap, key)
}
rtn.ShellVars = shellenv.SerializeDeclMap(envMap)
return &rtn
}
@ -1158,12 +1189,23 @@ func stripScVarsFromStateDiff(stateDiff *packet.ShellStateDiff) *packet.ShellSta
log.Printf("error decoding statediff in stripScVarsFromStateDiff: %v\n", err)
return stateDiff
}
delete(mapDiff.ToAdd, "PROMPT")
delete(mapDiff.ToAdd, "PROMPT_VERSION")
for key := range envVarsToStrip {
delete(mapDiff.ToAdd, key)
}
rtn.VarsDiff = mapDiff.Encode()
return &rtn
}
func (msh *MShellProc) getActiveShellTypes(ctx context.Context) ([]string, error) {
shellPref := msh.GetShellPref()
rtn := []string{shellPref}
activeShells, err := sstore.GetRemoteActiveShells(ctx, msh.RemoteId)
if err != nil {
return nil, err
}
return utilfn.CombineStrArrays(rtn, activeShells), nil
}
func (msh *MShellProc) Launch(interactive bool) {
remoteCopy := msh.GetRemoteCopy()
if remoteCopy.Archived {
@ -1179,6 +1221,11 @@ func (msh *MShellProc) Launch(interactive bool) {
msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n")
return
}
sapi, err := shellapi.MakeShellApi(msh.GetShellType())
if err != nil {
msh.WriteToPtyBuffer("*error, %v\n", err)
return
}
istatus := msh.GetInstallStatus()
if istatus == StatusConnecting {
msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n")
@ -1205,7 +1252,7 @@ func (msh *MShellProc) Launch(interactive bool) {
} else {
cmdStr = MakeServerCommandStr()
}
ecmd := sshOpts.MakeSSHExecCmd(cmdStr)
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi)
cmdPty, err := msh.addControllingTty(ecmd)
if err != nil {
statusErr := fmt.Errorf("cannot attach controlling tty to mshell command: %w", err)
@ -1237,7 +1284,6 @@ func (msh *MShellProc) Launch(interactive bool) {
cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, ecmd)
// TODO check if initPk.State is not nil
var mshellVersion string
var stateBaseHash string
var hitDeadline bool
msh.WithLock(func() {
msh.MakeClientCancelFn = nil
@ -1255,16 +1301,9 @@ func (msh *MShellProc) Launch(interactive bool) {
// only set NeedsMShellUpgrade if we got an InitPk
msh.NeedsMShellUpgrade = true
}
msh.InitPkShellType = initPk.Shell
}
if initPk != nil && initPk.State != nil {
hval := initPk.State.GetHashVal(false)
msh.CurrentState = hval
msh.StateMap[hval] = initPk.State
sstore.StoreStateBase(context.Background(), initPk.State)
stateBaseHash = hval
} else {
msh.CurrentState = ""
}
msh.StateMap.Clear()
// no notify here, because we'll call notify in either case below
})
if err == context.Canceled {
@ -1290,11 +1329,9 @@ func (msh *MShellProc) Launch(interactive bool) {
return
}
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk)
msh.WriteToPtyBuffer("connected state:%s\n", stateBaseHash)
msh.WithLock(func() {
msh.ServerProc = cproc
msh.Status = StatusConnected
go msh.NotifyRemoteUpdate()
})
go func() {
exitErr := cproc.Cmd.Wait()
@ -1308,15 +1345,40 @@ func (msh *MShellProc) Launch(interactive bool) {
msh.WriteToPtyBuffer("*disconnected exitcode=%d\n", exitCode)
}()
go msh.ProcessPackets()
msh.initActiveShells()
go msh.NotifyRemoteUpdate()
return
}
func (msh *MShellProc) initActiveShells() {
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
activeShells, err := msh.getActiveShellTypes(ctx)
if err != nil {
// we're not going to fail the connect for this error (it will be unusable, but technically connected)
msh.WriteToPtyBuffer("*error getting active shells: %v\n", err)
return
}
for _, shellType := range activeShells {
_, err = msh.ReInit(ctx, shellType)
if err != nil {
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
}
}
}
func (msh *MShellProc) IsConnected() bool {
msh.Lock.Lock()
defer msh.Lock.Unlock()
return msh.Status == StatusConnected
}
func (msh *MShellProc) GetShellType() string {
msh.Lock.Lock()
defer msh.Lock.Unlock()
return msh.InitPkShellType
}
func replaceHomePath(pathStr string, homeDir string) string {
if homeDir == "" {
return pathStr
@ -1451,13 +1513,13 @@ func RunCommand(ctx context.Context, sessionId string, screenId string, remotePt
// get current remote-instance state
statePtr, err := sstore.GetRemoteStatePtr(ctx, sessionId, screenId, remotePtr)
if err != nil {
return nil, nil, fmt.Errorf("cannot get current remote stateptr: %w", err)
return nil, nil, fmt.Errorf("cannot get current connection stateptr: %w", err)
}
if statePtr == nil {
statePtr = msh.GetDefaultStatePtr()
statePtr = msh.GetDefaultStatePtr(msh.GetShellPref())
}
if statePtr == nil {
return nil, nil, fmt.Errorf("cannot run command, no valid remote stateptr")
return nil, nil, fmt.Errorf("cannot run command, no valid connection stateptr")
}
currentState, err := sstore.GetFullState(ctx, *statePtr)
if err != nil || currentState == nil {
@ -1465,6 +1527,15 @@ func RunCommand(ctx context.Context, sessionId string, screenId string, remotePt
}
runPacket.State = addScVarsToState(currentState)
runPacket.StateComplete = true
runPacket.ShellType = currentState.GetShellType()
// check to see if shellType is initialized
if !msh.StateMap.HasShell(runPacket.ShellType) {
// try to reinit the shell
_, err := msh.ReInit(ctx, runPacket.ShellType)
if err != nil {
return nil, nil, fmt.Errorf("error trying to initialize shell %q: %v", runPacket.ShellType, err)
}
}
msh.ServerProc.Output.RegisterRpc(runPacket.ReqId)
err = shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket)
if err != nil {
@ -1653,6 +1724,8 @@ func (msh *MShellProc) notifyHangups_nolock() {
}
func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
// this will remove from RunningCmds and from PendingStateCmds
defer msh.RemoveRunningCmd(donePk.CK)
if donePk.FinalState != nil {
@ -1661,12 +1734,12 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
if donePk.FinalStateDiff != nil {
donePk.FinalStateDiff = stripScVarsFromStateDiff(donePk.FinalStateDiff)
}
update, err := sstore.UpdateCmdDoneInfo(context.Background(), donePk.CK, donePk, sstore.CmdStatusDone)
update, err := sstore.UpdateCmdDoneInfo(ctx, donePk.CK, donePk, sstore.CmdStatusDone)
if err != nil {
msh.WriteToPtyBuffer("*error updating cmddone: %v\n", err)
return
}
screen, err := sstore.UpdateScreenFocusForDoneCmd(context.Background(), donePk.CK.GetGroupId(), donePk.CK.GetCmdId())
screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, donePk.CK.GetGroupId(), donePk.CK.GetCmdId())
if err != nil {
msh.WriteToPtyBuffer("*error trying to update screen focus type: %v\n", err)
// fall-through (nothing to do)
@ -1678,7 +1751,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
var statePtr *sstore.ShellStatePtr
if donePk.FinalState != nil && rct != nil {
feState := sstore.FeStateFromShellState(donePk.FinalState)
remoteInst, err := sstore.UpdateRemoteState(context.Background(), rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, donePk.FinalState, nil)
remoteInst, err := sstore.UpdateRemoteState(ctx, rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, donePk.FinalState, nil)
if err != nil {
msh.WriteToPtyBuffer("*error trying to update remotestate: %v\n", err)
// fall-through (nothing to do)
@ -1693,7 +1766,12 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
msh.WriteToPtyBuffer("*error trying to update remotestate: %v\n", err)
// fall-through (nothing to do)
} else {
remoteInst, err := sstore.UpdateRemoteState(context.Background(), rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, nil, donePk.FinalStateDiff)
stateDiff := donePk.FinalStateDiff
fullState := msh.StateMap.GetStateByHash(stateDiff.GetShellType(), stateDiff.BaseHash)
if fullState != nil {
sstore.StoreStateBase(ctx, fullState)
}
remoteInst, err := sstore.UpdateRemoteState(ctx, rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, nil, stateDiff)
if err != nil {
msh.WriteToPtyBuffer("*error trying to update remotestate: %v\n", err)
// fall-through (nothing to do)
@ -1707,7 +1785,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
}
}
if statePtr != nil {
err = sstore.UpdateCmdRtnState(context.Background(), donePk.CK, *statePtr)
err = sstore.UpdateCmdRtnState(ctx, donePk.CK, *statePtr)
if err != nil {
msh.WriteToPtyBuffer("*error trying to update cmd rtnstate: %v\n", err)
// fall-through (nothing to do)
@ -1951,7 +2029,7 @@ func evalPromptEsc(escCode string, vars map[string]string, state *packet.ShellSt
return ""
}
varName := escCode[2 : len(escCode)-1]
varMap := shexec.ShellVarMapFromState(state)
varMap := shellenv.ShellVarMapFromState(state)
return varMap[varName]
}
if escCode == "h" {
@ -2024,43 +2102,56 @@ func evalPromptEsc(escCode string, vars map[string]string, state *packet.ShellSt
return "(" + escCode + ")"
}
func (msh *MShellProc) getFullState(stateDiff *packet.ShellStateDiff) (*packet.ShellState, error) {
baseState := msh.GetStateByHash(stateDiff.BaseHash)
func (msh *MShellProc) getFullState(shellType string, stateDiff *packet.ShellStateDiff) (*packet.ShellState, error) {
baseState := msh.StateMap.GetStateByHash(shellType, stateDiff.BaseHash)
if baseState != nil && len(stateDiff.DiffHashArr) == 0 {
newState, err := shexec.ApplyShellStateDiff(*baseState, *stateDiff)
sapi, err := shellapi.MakeShellApi(baseState.GetShellType())
newState, err := sapi.ApplyShellStateDiff(baseState, stateDiff)
if err != nil {
return nil, err
}
return &newState, nil
return newState, nil
} else {
fullState, err := sstore.GetFullState(context.Background(), sstore.ShellStatePtr{BaseHash: stateDiff.BaseHash, DiffHashArr: stateDiff.DiffHashArr})
if err != nil {
return nil, err
}
newState, err := shexec.ApplyShellStateDiff(*fullState, *stateDiff)
return &newState, nil
sapi, err := shellapi.MakeShellApi(fullState.GetShellType())
if err != nil {
return nil, err
}
newState, err := sapi.ApplyShellStateDiff(fullState, stateDiff)
return newState, nil
}
}
// internal func, first tries the StateMap, otherwise will fallback on sstore.GetFullState
func (msh *MShellProc) getFeStateFromDiff(stateDiff *packet.ShellStateDiff) (map[string]string, error) {
baseState := msh.GetStateByHash(stateDiff.BaseHash)
baseState := msh.StateMap.GetStateByHash(stateDiff.GetShellType(), stateDiff.BaseHash)
if baseState != nil && len(stateDiff.DiffHashArr) == 0 {
newState, err := shexec.ApplyShellStateDiff(*baseState, *stateDiff)
sapi, err := shellapi.MakeShellApi(baseState.GetShellType())
if err != nil {
return nil, err
}
return sstore.FeStateFromShellState(&newState), nil
newState, err := sapi.ApplyShellStateDiff(baseState, stateDiff)
if err != nil {
return nil, err
}
return sstore.FeStateFromShellState(newState), nil
} else {
fullState, err := sstore.GetFullState(context.Background(), sstore.ShellStatePtr{BaseHash: stateDiff.BaseHash, DiffHashArr: stateDiff.DiffHashArr})
if err != nil {
return nil, err
}
newState, err := shexec.ApplyShellStateDiff(*fullState, *stateDiff)
sapi, err := shellapi.MakeShellApi(fullState.GetShellType())
if err != nil {
return nil, err
}
return sstore.FeStateFromShellState(&newState), nil
newState, err := sapi.ApplyShellStateDiff(fullState, stateDiff)
if err != nil {
return nil, err
}
return sstore.FeStateFromShellState(newState), nil
}
}

View File

@ -11,10 +11,11 @@ import (
"github.com/alessio/shellescape"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"mvdan.cc/sh/v3/syntax"
)
@ -112,21 +113,89 @@ var IgnoreVars = map[string]bool{
"MSHELL_VERSION": true,
"WAVESHELL": true,
"WAVESHELL_VERSION": true,
"WAVETERM": true,
"WAVETERM_VERSION": true,
"TERM_PROGRAM": true,
"TERM_PROGRAM_VERSION": true,
"TERM_SESSION_ID": true,
}
func displayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newState packet.ShellState) {
func makeBashAliasesDiff(buf *bytes.Buffer, oldAliases string, newAliases string) {
newAliasMap, _ := ParseAliases(newAliases)
oldAliasMap, _ := ParseAliases(oldAliases)
for aliasName, newAliasVal := range newAliasMap {
oldAliasVal, found := oldAliasMap[aliasName]
if !found || newAliasVal != oldAliasVal {
buf.WriteString(fmt.Sprintf("alias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
}
}
for aliasName := range oldAliasMap {
_, found := newAliasMap[aliasName]
if !found {
buf.WriteString(fmt.Sprintf("unalias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
}
}
}
func makeZshAlisesDiff(buf *bytes.Buffer, oldAliases string, newAliases string) {
newAliasMap, err := shellapi.DecodeZshMap([]byte(newAliases))
if err != nil {
return
}
oldAliasMap, err := shellapi.DecodeZshMap([]byte(oldAliases))
if err != nil {
return
}
for aliasKey, newAliasVal := range newAliasMap {
oldAliasVal, found := oldAliasMap[aliasKey]
if !found || newAliasVal != oldAliasVal {
buf.WriteString(fmt.Sprintf("%s %s=%s\n", aliasKey.ParamType, aliasKey.ParamName, utilfn.EllipsisStr(shellescape.Quote(newAliasVal), MaxDiffKeyLen)))
}
}
for aliasKey := range oldAliasMap {
_, found := newAliasMap[aliasKey]
if !found {
buf.WriteString(fmt.Sprintf("remove %s %s\n", aliasKey.ParamType, aliasKey.ParamName))
}
}
}
func makeZshFuncsDiff(buf *bytes.Buffer, oldFuncs string, newFuncs string) {
newFuncMap, err := shellapi.DecodeZshMap([]byte(newFuncs))
if err != nil {
return
}
oldFuncMap, err := shellapi.DecodeZshMap([]byte(oldFuncs))
if err != nil {
return
}
for funcKey, newFuncVal := range newFuncMap {
oldFuncVal, found := oldFuncMap[funcKey]
if !found || newFuncVal != oldFuncVal {
buf.WriteString(fmt.Sprintf("%s %s\n", funcKey.ParamType, funcKey.ParamName))
}
}
for funcKey := range oldFuncMap {
_, found := newFuncMap[funcKey]
if !found {
buf.WriteString(fmt.Sprintf("remove %s %s\n", funcKey.ParamType, funcKey.ParamName))
}
}
}
func DisplayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newState packet.ShellState) {
if newState.Cwd != oldState.Cwd {
buf.WriteString(fmt.Sprintf("cwd %s\n", newState.Cwd))
}
if !bytes.Equal(newState.ShellVars, oldState.ShellVars) {
newEnvMap := shexec.DeclMapFromState(&newState)
oldEnvMap := shexec.DeclMapFromState(&oldState)
newEnvMap := shellenv.DeclMapFromState(&newState)
oldEnvMap := shellenv.DeclMapFromState(&oldState)
for key, newVal := range newEnvMap {
if IgnoreVars[key] {
continue
}
oldVal, found := oldEnvMap[key]
if !found || !shexec.DeclsEqual(false, oldVal, newVal) {
if !found || !shellenv.DeclsEqual(false, oldVal, newVal) {
var exportStr string
if newVal.IsExport() {
exportStr = "export "
@ -134,7 +203,7 @@ func displayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newSt
buf.WriteString(fmt.Sprintf("%s%s=%s\n", exportStr, utilfn.EllipsisStr(key, MaxDiffKeyLen), utilfn.EllipsisStr(newVal.Value, MaxDiffValLen)))
}
}
for key, _ := range oldEnvMap {
for key := range oldEnvMap {
if IgnoreVars[key] {
continue
}
@ -144,23 +213,19 @@ func displayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newSt
}
}
}
if newState.Aliases != oldState.Aliases {
newAliasMap, _ := ParseAliases(newState.Aliases)
oldAliasMap, _ := ParseAliases(oldState.Aliases)
for aliasName, newAliasVal := range newAliasMap {
oldAliasVal, found := oldAliasMap[aliasName]
if !found || newAliasVal != oldAliasVal {
buf.WriteString(fmt.Sprintf("alias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
if newState.GetShellType() == packet.ShellType_zsh {
makeZshAlisesDiff(buf, oldState.Aliases, newState.Aliases)
makeZshFuncsDiff(buf, oldState.Funcs, newState.Funcs)
} else {
makeBashAliasesDiff(buf, oldState.Aliases, newState.Aliases)
makeBashFuncsDiff(newState, oldState, buf)
}
}
func makeBashFuncsDiff(newState packet.ShellState, oldState packet.ShellState, buf *bytes.Buffer) {
if newState.Funcs == oldState.Funcs {
return
}
for aliasName, _ := range oldAliasMap {
_, found := newAliasMap[aliasName]
if !found {
buf.WriteString(fmt.Sprintf("unalias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
}
}
}
if newState.Funcs != oldState.Funcs {
newFuncMap, _ := ParseFuncs(newState.Funcs)
oldFuncMap, _ := ParseFuncs(oldState.Funcs)
for funcName, newFuncVal := range newFuncMap {
@ -169,13 +234,12 @@ func displayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newSt
buf.WriteString(fmt.Sprintf("function %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
}
}
for funcName, _ := range oldFuncMap {
for funcName := range oldFuncMap {
_, found := newFuncMap[funcName]
if !found {
buf.WriteString(fmt.Sprintf("unset -f %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
}
}
}
}
func GetRtnStateDiff(ctx context.Context, screenId string, lineId string) ([]byte, error) {
@ -201,6 +265,6 @@ func GetRtnStateDiff(ctx context.Context, screenId string, lineId string) ([]byt
if err != nil {
return nil, fmt.Errorf("getting rtn full state: %v", err)
}
displayStateUpdateDiff(&outputBytes, *initialState, *rtnState)
DisplayStateUpdateDiff(&outputBytes, *initialState, *rtnState)
return outputBytes.Bytes(), nil
}

View File

@ -37,7 +37,7 @@ const WaveDevDirName = ".waveterm-dev" // must match emain.ts
const WaveAppPathVarName = "WAVETERM_APP_PATH"
const WaveVersion = "v0.5.3"
const WaveAuthKeyFileName = "waveterm.authkey"
const MShellVersion = "v0.3.0"
const MShellVersion = "v0.4.0"
var SessionDirCache = make(map[string]string)
var ScreenDirCache = make(map[string]string)

View File

@ -11,8 +11,8 @@ import (
"github.com/alessio/shellescape"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
)
const FeCommandPacketStr = "fecmd"

View File

@ -6,7 +6,7 @@ package shparse
import (
"strings"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
const (

View File

@ -8,7 +8,7 @@ import (
"unicode"
"unicode/utf8"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
var noEscChars []bool

View File

@ -7,7 +7,7 @@ import (
"bytes"
"fmt"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
//

View File

@ -7,7 +7,7 @@ import (
"fmt"
"testing"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
// $(ls f[*]); ./x

View File

@ -18,7 +18,8 @@ import (
"github.com/sawka/txwrap"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/wavesrv/pkg/dbutil"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
)
@ -217,8 +218,8 @@ func UpsertRemote(ctx context.Context, r *RemoteType) error {
maxRemoteIdx := tx.GetInt(query)
r.RemoteIdx = int64(maxRemoteIdx + 1)
query = `INSERT INTO remote
( remoteid, remotetype, remotealias, remotecanonicalname, remoteuser, remotehost, connectmode, autoinstall, sshopts, remoteopts, lastconnectts, archived, remoteidx, local, statevars, sshconfigsrc, openaiopts) VALUES
(:remoteid,:remotetype,:remotealias,:remotecanonicalname,:remoteuser,:remotehost,:connectmode,:autoinstall,:sshopts,:remoteopts,:lastconnectts,:archived,:remoteidx,:local,:statevars,:sshconfigsrc,:openaiopts)`
( remoteid, remotetype, remotealias, remotecanonicalname, remoteuser, remotehost, connectmode, autoinstall, sshopts, remoteopts, lastconnectts, archived, remoteidx, local, statevars, sshconfigsrc, openaiopts, shellpref) VALUES
(:remoteid,:remotetype,:remotealias,:remotecanonicalname,:remoteuser,:remotehost,:connectmode,:autoinstall,:sshopts,:remoteopts,:lastconnectts,:archived,:remoteidx,:local,:statevars,:sshconfigsrc,:openaiopts,:shellpref)`
tx.NamedExec(query, r.ToMap())
return nil
})
@ -1288,11 +1289,12 @@ func GetRemoteInstance(ctx context.Context, sessionId string, screenId string, r
return ri, nil
}
// internal function for UpdateRemoteState
// internal function for UpdateRemoteState (sets StateBaseHash, StateDiffHashArr, and ShellType)
func updateRIWithState(ctx context.Context, ri *RemoteInstance, stateBase *packet.ShellState, stateDiff *packet.ShellStateDiff) error {
if stateBase != nil {
ri.StateBaseHash = stateBase.GetHashVal(false)
ri.StateDiffHashArr = nil
ri.ShellType = stateBase.GetShellType()
err := StoreStateBase(ctx, stateBase)
if err != nil {
return err
@ -1300,6 +1302,7 @@ func updateRIWithState(ctx context.Context, ri *RemoteInstance, stateBase *packe
} else if stateDiff != nil {
ri.StateBaseHash = stateDiff.BaseHash
ri.StateDiffHashArr = append(stateDiff.DiffHashArr, stateDiff.GetHashVal(false))
ri.ShellType = stateDiff.GetShellType()
err := StoreStateDiff(ctx, stateDiff)
if err != nil {
return err
@ -1340,18 +1343,18 @@ func UpdateRemoteState(ctx context.Context, sessionId string, screenId string, r
if err != nil {
return err
}
query = `INSERT INTO remote_instance ( riid, name, sessionid, screenid, remoteownerid, remoteid, festate, statebasehash, statediffhasharr)
VALUES (:riid,:name,:sessionid,:screenid,:remoteownerid,:remoteid,:festate,:statebasehash,:statediffhasharr)`
query = `INSERT INTO remote_instance ( riid, name, sessionid, screenid, remoteownerid, remoteid, festate, statebasehash, statediffhasharr, shelltype)
VALUES (:riid,:name,:sessionid,:screenid,:remoteownerid,:remoteid,:festate,:statebasehash,:statediffhasharr,:shelltype)`
tx.NamedExec(query, ri.ToMap())
return nil
} else {
query = `UPDATE remote_instance SET festate = ?, statebasehash = ?, statediffhasharr = ? WHERE riid = ?`
query = `UPDATE remote_instance SET festate = ?, statebasehash = ?, statediffhasharr = ?, shelltype = ? WHERE riid = ?`
ri.FeState = feState
err = updateRIWithState(tx.Context(), ri, stateBase, stateDiff)
if err != nil {
return err
}
tx.Exec(query, quickJson(ri.FeState), ri.StateBaseHash, quickJsonArr(ri.StateDiffHashArr), ri.RIId)
tx.Exec(query, quickJson(ri.FeState), ri.StateBaseHash, quickJsonArr(ri.StateDiffHashArr), ri.ShellType, ri.RIId)
return nil
}
})
@ -1730,9 +1733,11 @@ const (
RemoteField_SSHKey = "sshkey" // string
RemoteField_SSHPassword = "sshpassword" // string
RemoteField_Color = "color" // string
RemoteField_ShellPref = "shellpref" // string
)
// editMap: alias, connectmode, autoinstall, sshkey, color, sshpassword (from constants)
// note that all validation should have already happened outside of this function
func UpdateRemote(ctx context.Context, remoteId string, editMap map[string]interface{}) (*RemoteType, error) {
var rtn *RemoteType
txErr := WithTx(ctx, func(tx *TxWrap) error {
@ -1760,6 +1765,10 @@ func UpdateRemote(ctx context.Context, remoteId string, editMap map[string]inter
query = `UPDATE remote SET sshopts = json_set(sshopts, '$.sshpassword', ?) WHERE remoteid = ?`
tx.Exec(query, sshPassword, remoteId)
}
if shellPref, found := editMap[RemoteField_ShellPref]; found {
query = `UPDATE remote SET shellpref = ? WHERE remoteid = ?`
tx.Exec(query, shellPref, remoteId)
}
if color, found := editMap[RemoteField_Color]; found {
query = `UPDATE remote SET remoteopts = json_set(remoteopts, '$.color', ?) WHERE remoteid = ?`
tx.Exec(query, color, remoteId)
@ -1958,22 +1967,26 @@ func GetFullState(ctx context.Context, ssPtr ShellStatePtr) (*packet.ShellState,
if err != nil {
return err
}
sapi, err := shellapi.MakeShellApi(state.GetShellType())
if err != nil {
return err
}
for idx, diffHash := range ssPtr.DiffHashArr {
query = `SELECT * FROM state_diff WHERE diffhash = ?`
stateDiff := dbutil.GetMapGen[*StateDiff](tx, query, diffHash)
if stateDiff == nil {
return fmt.Errorf("ShellStateDiff %s not found", diffHash)
}
var ssDiff packet.ShellStateDiff
ssDiff := &packet.ShellStateDiff{}
err = ssDiff.DecodeShellStateDiff(stateDiff.Data)
if err != nil {
return err
}
newState, err := shexec.ApplyShellStateDiff(*state, ssDiff)
newState, err := sapi.ApplyShellStateDiff(state, ssDiff)
if err != nil {
return fmt.Errorf("GetFullState, diff[%d]:%s: %v", idx, diffHash, err)
}
state = &newState
state = newState
}
return nil
})
@ -2750,3 +2763,15 @@ func SetWebPtyPos(ctx context.Context, screenId string, lineId string, ptyPos in
return nil
})
}
func GetRemoteActiveShells(ctx context.Context, remoteId string) ([]string, error) {
return WithTxRtn(ctx, func(tx *TxWrap) ([]string, error) {
query := `SELECT * FROM remote_instance WHERE remoteid = ?`
riArr := dbutil.SelectMapsGen[*RemoteInstance](tx, query, remoteId)
shellTypeMap := make(map[string]bool)
for _, ri := range riArr {
shellTypeMap[ri.ShellType] = true
}
return utilfn.GetMapKeys(shellTypeMap), nil
})
}

View File

@ -10,7 +10,7 @@ import (
"sync"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
// global lock for all memory operations

View File

@ -22,10 +22,11 @@ import (
"github.com/golang-migrate/migrate/v4"
)
const MaxMigration = 29
const MaxMigration = 30
const MigratePrimaryScreenVersion = 9
const CmdScreenSpecialMigration = 13
const CmdLineSpecialMigration = 20
const RISpecialMigration = 30
func MakeMigrate() (*migrate.Migrate, error) {
fsVar, err := iofs.New(sh2db.MigrationFS, "migrations")
@ -84,6 +85,12 @@ func MigrateUpStep(m *migrate.Migrate, newVersion uint) error {
return fmt.Errorf("migrating to v%d: %w", newVersion, mErr)
}
}
if newVersion == RISpecialMigration {
mErr := RunMigration30()
if mErr != nil {
return fmt.Errorf("migrating to v%d: %w", newVersion, mErr)
}
}
log.Printf("[db] migration v%d, elapsed %v\n", newVersion, time.Since(startTime))
return nil
}

View File

@ -25,7 +25,7 @@ import (
"github.com/sawka/txwrap"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/wavesrv/pkg/dbutil"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
@ -47,6 +47,10 @@ const LocalRemoteAlias = "local"
const DefaultCwd = "~"
const APITokenSentinel = "--apitoken--"
// defined here and not in packet.go since this value should never
// be passed to waveshell (it should always get resolved prior to sending a run packet)
const ShellTypePref_Detect = "detect"
const (
LineTypeCmd = "cmd"
LineTypeText = "text"
@ -694,6 +698,7 @@ type RemoteInstance struct {
RemoteOwnerId string `json:"remoteownerid"`
RemoteId string `json:"remoteid"`
FeState map[string]string `json:"festate"`
ShellType string `json:"shelltype"`
StateBaseHash string `json:"-"`
StateDiffHashArr []string `json:"-"`
@ -741,15 +746,19 @@ func FeStateFromShellState(state *packet.ShellState) map[string]string {
}
rtn := make(map[string]string)
rtn["cwd"] = state.Cwd
envMap := shexec.EnvMapFromState(state)
envMap := shellenv.EnvMapFromState(state)
if envMap["VIRTUAL_ENV"] != "" {
rtn["VIRTUAL_ENV"] = envMap["VIRTUAL_ENV"]
}
for key, val := range envMap {
if strings.HasPrefix(key, "PROMPTVAR_") && rtn[key] != "" {
if strings.HasPrefix(key, "PROMPTVAR_") && envMap[key] != "" {
rtn[key] = val
}
}
_, _, err := packet.ParseShellStateVersion(state.Version)
if err != nil {
rtn["invalidstate"] = "1"
}
return rtn
}
@ -763,6 +772,7 @@ func (ri *RemoteInstance) FromMap(m map[string]interface{}) bool {
quickSetJson(&ri.FeState, m, "festate")
quickSetStr(&ri.StateBaseHash, m, "statebasehash")
quickSetJsonArr(&ri.StateDiffHashArr, m, "statediffhasharr")
quickSetStr(&ri.ShellType, m, "shelltype")
return true
}
@ -777,6 +787,7 @@ func (ri *RemoteInstance) ToMap() map[string]interface{} {
rtn["festate"] = quickJson(ri.FeState)
rtn["statebasehash"] = ri.StateBaseHash
rtn["statediffhasharr"] = quickJsonArr(ri.StateDiffHashArr)
rtn["shelltype"] = ri.ShellType
return rtn
}
@ -1014,6 +1025,9 @@ type RemoteRuntimeState struct {
Local bool `json:"local,omitempty"`
RemoteOpts *RemoteOptsType `json:"remoteopts,omitempty"`
CanComplete bool `json:"cancomplete,omitempty"`
ActiveShells []string `json:"activeshells,omitempty"`
ShellPref string `json:"shellpref,omitempty"`
DefaultShellType string `json:"defaultshelltype,omitempty"`
}
func (state RemoteRuntimeState) IsConnected() bool {
@ -1068,8 +1082,9 @@ type RemoteType struct {
SSHOpts *SSHOpts `json:"sshopts"`
StateVars map[string]string `json:"statevars"`
SSHConfigSrc string `json:"sshconfigsrc"`
ShellPref string `json:"shellpref"` // bash, zsh, or detect
// OpenAI fields
// OpenAI fields (unused)
OpenAIOpts *OpenAIOptsType `json:"openaiopts,omitempty"`
}
@ -1125,6 +1140,7 @@ func (r *RemoteType) ToMap() map[string]interface{} {
rtn["statevars"] = quickJson(r.StateVars)
rtn["sshconfigsrc"] = r.SSHConfigSrc
rtn["openaiopts"] = quickJson(r.OpenAIOpts)
rtn["shellpref"] = r.ShellPref
return rtn
}
@ -1146,6 +1162,7 @@ func (r *RemoteType) FromMap(m map[string]interface{}) bool {
quickSetJson(&r.StateVars, m, "statevars")
quickSetStr(&r.SSHConfigSrc, m, "sshconfigsrc")
quickSetJson(&r.OpenAIOpts, m, "openaiopts")
quickSetStr(&r.ShellPref, m, "shellpref")
return true
}
@ -1308,6 +1325,7 @@ func EnsureLocalRemote(ctx context.Context) error {
SSHOpts: &SSHOpts{Local: true},
Local: true,
SSHConfigSrc: SSHConfigSrcTypeManual,
ShellPref: ShellTypePref_Detect,
}
err = UpsertRemote(ctx, localRemote)
if err != nil {
@ -1327,6 +1345,7 @@ func EnsureLocalRemote(ctx context.Context) error {
RemoteOpts: &RemoteOptsType{Color: "red"},
Local: true,
SSHConfigSrc: SSHConfigSrcTypeManual,
ShellPref: ShellTypePref_Detect,
}
err = UpsertRemote(ctx, sudoRemote)
if err != nil {

View File

@ -10,6 +10,7 @@ import (
"os"
"time"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
)
@ -34,6 +35,36 @@ func getSliceChunk[T any](slice []T, chunkSize int) ([]T, []T) {
return slice[0:chunkSize], slice[chunkSize:]
}
// we're going to mark any invalid basestate versions as "invalid"
// so we can give a better error message for the FE and prompt a reset
func RunMigration30() error {
ctx := context.Background()
startTime := time.Now()
updateCount := 0
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `SELECT riid FROM remote_instance`
riidArr := tx.SelectStrings(query)
for _, riid := range riidArr {
query = `SELECT version FROM state_base WHERE basehash = (SELECT statebasehash FROM remote_instance WHERE riid = ?)`
version := tx.GetString(query, riid)
_, _, err := packet.ParseShellStateVersion(version)
if err == nil {
continue
}
// deal with bad versions by marking festate with an invalidshellstate flag
query = `UPDATE remote_instance SET festate = json_set(festate, '$.invalidshellstate', '1') WHERE riid = ?`
tx.Exec(query, riid)
updateCount++
}
return nil
})
if txErr != nil {
return fmt.Errorf("error running remote-instance v30 migration: %w", txErr)
}
log.Printf("[db] remote-instance v30 migration done: %v (%d bad versions)\n", time.Since(startTime), updateCount)
return nil
}
func RunMigration20() error {
ctx := context.Background()
startTime := time.Now()

View File

@ -9,7 +9,7 @@ import (
"sync"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
var MainBus *UpdateBus = MakeUpdateBus()

View File

@ -1,249 +0,0 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package utilfn
import (
"crypto/sha1"
"encoding/base64"
"errors"
"math"
"regexp"
"strings"
"unicode/utf8"
)
var HexDigits = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}
func GetStrArr(v interface{}, field string) []string {
if v == nil {
return nil
}
m, ok := v.(map[string]interface{})
if !ok {
return nil
}
fieldVal := m[field]
if fieldVal == nil {
return nil
}
iarr, ok := fieldVal.([]interface{})
if !ok {
return nil
}
var sarr []string
for _, iv := range iarr {
if sv, ok := iv.(string); ok {
sarr = append(sarr, sv)
}
}
return sarr
}
func GetBool(v interface{}, field string) bool {
if v == nil {
return false
}
m, ok := v.(map[string]interface{})
if !ok {
return false
}
fieldVal := m[field]
if fieldVal == nil {
return false
}
bval, ok := fieldVal.(bool)
if !ok {
return false
}
return bval
}
var needsQuoteRe = regexp.MustCompile(`[^\w@%:,./=+-]`)
// minimum maxlen=6
func ShellQuote(val string, forceQuote bool, maxLen int) string {
if maxLen < 6 {
maxLen = 6
}
rtn := val
if needsQuoteRe.MatchString(val) {
rtn = "'" + strings.ReplaceAll(val, "'", `'"'"'`) + "'"
}
if strings.HasPrefix(rtn, "\"") || strings.HasPrefix(rtn, "'") {
if len(rtn) > maxLen {
return rtn[0:maxLen-4] + "..." + rtn[0:1]
}
return rtn
}
if forceQuote {
if len(rtn) > maxLen-2 {
return "\"" + rtn[0:maxLen-5] + "...\""
}
return "\"" + rtn + "\""
} else {
if len(rtn) > maxLen {
return rtn[0:maxLen-3] + "..."
}
return rtn
}
}
func EllipsisStr(s string, maxLen int) string {
if maxLen < 4 {
maxLen = 4
}
if len(s) > maxLen {
return s[0:maxLen-3] + "..."
}
return s
}
func LongestPrefix(root string, strs []string) string {
if len(strs) == 0 {
return root
}
if len(strs) == 1 {
comp := strs[0]
if len(comp) >= len(root) && strings.HasPrefix(comp, root) {
if strings.HasSuffix(comp, "/") {
return strs[0]
}
return strs[0]
}
}
lcp := strs[0]
for i := 1; i < len(strs); i++ {
s := strs[i]
for j := 0; j < len(lcp); j++ {
if j >= len(s) || lcp[j] != s[j] {
lcp = lcp[0:j]
break
}
}
}
if len(lcp) < len(root) || !strings.HasPrefix(lcp, root) {
return root
}
return lcp
}
func ContainsStr(strs []string, test string) bool {
for _, s := range strs {
if s == test {
return true
}
}
return false
}
func IsPrefix(strs []string, test string) bool {
for _, s := range strs {
if len(s) > len(test) && strings.HasPrefix(s, test) {
return true
}
}
return false
}
// sentinel value for StrWithPos.Pos to indicate no position
const NoStrPos = -1
type StrWithPos struct {
Str string `json:"str"`
Pos int `json:"pos"` // this is a 'rune' position (not a byte position)
}
func (sp StrWithPos) String() string {
return strWithCursor(sp.Str, sp.Pos)
}
func ParseToSP(s string) StrWithPos {
idx := strings.Index(s, "[*]")
if idx == -1 {
return StrWithPos{Str: s, Pos: NoStrPos}
}
return StrWithPos{Str: s[0:idx] + s[idx+3:], Pos: utf8.RuneCountInString(s[0:idx])}
}
func strWithCursor(str string, pos int) string {
if pos == NoStrPos {
return str
}
if pos < 0 {
// invalid position
return "[*]_" + str
}
if pos > len(str) {
// invalid position
return str + "_[*]"
}
if pos == len(str) {
return str + "[*]"
}
var rtn []rune
for _, ch := range str {
if len(rtn) == pos {
rtn = append(rtn, '[', '*', ']')
}
rtn = append(rtn, ch)
}
return string(rtn)
}
func (sp StrWithPos) Prepend(str string) StrWithPos {
return StrWithPos{Str: str + sp.Str, Pos: utf8.RuneCountInString(str) + sp.Pos}
}
func (sp StrWithPos) Append(str string) StrWithPos {
return StrWithPos{Str: sp.Str + str, Pos: sp.Pos}
}
// returns base64 hash of data
func Sha1Hash(data []byte) string {
hvalRaw := sha1.Sum(data)
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
return hval
}
func ChunkSlice[T any](s []T, chunkSize int) [][]T {
var rtn [][]T
for len(rtn) > 0 {
if len(s) <= chunkSize {
rtn = append(rtn, s)
break
}
rtn = append(rtn, s[:chunkSize])
s = s[chunkSize:]
}
return rtn
}
var ErrOverflow = errors.New("integer overflow")
// Add two int values, returning an error if the result overflows.
func AddInt(left, right int) (int, error) {
if right > 0 {
if left > math.MaxInt-right {
return 0, ErrOverflow
}
} else {
if left < math.MinInt-right {
return 0, ErrOverflow
}
}
return left + right, nil
}
// Add a slice of ints, returning an error if the result overflows.
func AddIntSlice(vals ...int) (int, error) {
var rtn int
for _, v := range vals {
var err error
rtn, err = AddInt(rtn, v)
if err != nil {
return 0, err
}
}
return rtn, nil
}

View File

@ -1,103 +0,0 @@
// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package utilfn
import (
"fmt"
"math"
"testing"
)
const Str1 = `
hello
line #2
more
stuff
apple
`
const Str2 = `
line #2
apple
grapes
banana
`
const Str3 = `
more
stuff
banana
coconut
`
func testDiff(t *testing.T, str1 string, str2 string) {
diffBytes := MakeDiff(str1, str2)
fmt.Printf("diff-len: %d\n", len(diffBytes))
out, err := ApplyDiff(str1, diffBytes)
if err != nil {
t.Errorf("error in diff: %v", err)
return
}
if out != str2 {
t.Errorf("bad diff output")
}
}
func TestDiff(t *testing.T) {
testDiff(t, Str1, Str2)
testDiff(t, Str2, Str3)
testDiff(t, Str1, Str3)
testDiff(t, Str3, Str1)
}
func testArithmetic(t *testing.T, fn func() (int, error), shouldError bool, expected int) {
retVal, err := fn()
if err != nil {
if !shouldError {
t.Errorf("unexpected error")
}
return
}
if shouldError {
t.Errorf("expected error")
return
}
if retVal != expected {
t.Errorf("wrong return value")
}
}
func testAddInt(t *testing.T, shouldError bool, expected int, a int, b int) {
testArithmetic(t, func() (int, error) { return AddInt(a, b) }, shouldError, expected)
}
func TestAddInt(t *testing.T) {
testAddInt(t, false, 3, 1, 2)
testAddInt(t, true, 0, 1, math.MaxInt)
testAddInt(t, true, 0, math.MinInt, -1)
testAddInt(t, false, math.MaxInt-1, math.MaxInt, -1)
testAddInt(t, false, math.MinInt+1, math.MinInt, 1)
testAddInt(t, false, math.MaxInt, math.MaxInt, 0)
testAddInt(t, true, 0, math.MinInt, -1)
}
func testAddIntSlice(t *testing.T, shouldError bool, expected int, vals ...int) {
testArithmetic(t, func() (int, error) { return AddIntSlice(vals...) }, shouldError, expected)
}
func TestAddIntSlice(t *testing.T) {
testAddIntSlice(t, false, 0)
testAddIntSlice(t, false, 1, 1)
testAddIntSlice(t, false, 3, 1, 2)
testAddIntSlice(t, false, 6, 1, 2, 3)
testAddIntSlice(t, true, 0, 1, math.MaxInt)
testAddIntSlice(t, true, 0, 1, 2, math.MaxInt)
testAddIntSlice(t, true, 0, math.MaxInt, 2, 1)
testAddIntSlice(t, false, math.MaxInt, 0, 0, math.MaxInt)
testAddIntSlice(t, true, 0, math.MinInt, -1)
testAddIntSlice(t, false, math.MaxInt, math.MaxInt-3, 1, 2)
testAddIntSlice(t, true, 0, math.MaxInt-2, 1, 2)
testAddIntSlice(t, false, math.MinInt, math.MinInt+3, -1, -2)
testAddIntSlice(t, true, 0, math.MinInt+2, -1, -2)
}