mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
zsh support (#227)
adds zsh support to waveterm. big change, lots going on here. lots of other improvements and bug fixes added while debugging and building out the feature. Commits: * refactor shexec parser.go into new package shellenv. separate out bash specific parsing from generic functions * checkpoint * work on refactoring shexec. created two new packages shellapi (for bash/zsh specific stuff), and shellutil (shared between shellapi and shexec) * more refactoring * create shellapi interface to abstract bash specific functionality * more refactoring, move bash shell state parsing to shellapi * move makeRcFile to shellapi. remove all of the 'client' options CLI options from waveshell * get shellType passed through to server/single paths for waveshell * add a local shelltype detector * mock out a zshapi * move shelltype through more of the code * get a command to run via zsh * zsh can now switch directories. poc, needs cleanup * working on ShellState encoding differences between zsh/bash. Working on parsing zsh decls. move utilfn package into waveshell (shouldn't have been in wavesrv) * switch to use []byte for vardecl serialization + diffs * progress on zsh environment. still have issues reconciling init environment with trap environment * fix typeset argument parsing * parse promptvars, more zsh specific ignores * fix bug with promptvar not getting set (wrong check in FeState func) * add sdk (issue #188) to list of rtnstate commands * more zsh compatibility -- working with a larger ohmyzsh environment. ignore more variables, handle exit trap better. unique path/fpath. add a processtype variable to base. * must return a value * zsh alias parsing/restoring. diff changes (and rtnstate changes). introduces linediff v1. * force zmodload of zsh/parameter * starting work on zsh functions * need a v1 of mapdiff as well (to handle null chars) * pack/unpack of ints was wrong (one used int and one use uint). turned out we only ever encoded '0' so it worked. that also means it is safe to change unpack to unpackUInt * reworking for binary encoding of aliases and functions (because of zsh allows any character, including nulls, in names and values) * fixes, working on functions, issue with line endings * zsh functions. lots of ugliness here around dealing with line dicipline and cooked stty. new runcommand function to grab output from a non-tty fd. note that we still to run the actual command in a stty to get the proper output. * write uuid tempdir, cleanup with tmprcfilename code * hack in some simple zsh function declaration finding code for rtnstate. create function diff for rtnstate that supports zsh * make sure key order is constant so shell hashes are consistent * fix problems with state diffs to support new zsh formats. add diff/apply code to shellapi (moved from shellenv), that is now specific to zsh or bash * add log packet and new shellstate packets * switch to shellstate map that's also keyed by shelltype * add shelltype to remoteinstance * remove shell argument from waveshell * added new shelltype statemap to remote.go (msh), deal with fallout * move shellstate out of init packet, and move to an explicit reinit call. try to initialize all of the active shell states * change dont always store init state (only store on demand). initialize shell states on demand (if not already initialized). allow reset to change shells * add shellpref field to remote table. use to drive the default shell choice for new tabs * show shelltag on cmdinput, pass through ri and remote (defaultshellstate) * bump mshell version to v0.4 * better version validation for shellstate. also relax compatibility requirements for diffing states (shelltype + major version need to match) * better error handling, check shellstate compatibility during run (on waveshell server) * add extra separator for bash shellstate processing to deal with spurious output from rc files * special migration for v30 -- flag invalid bash shell states and show special button in UI to fix * format * remove zsh-decls (unused) * remove test code * remove debug print * fix typo
This commit is contained in:
parent
76988a5277
commit
422338c04b
@ -7,10 +7,10 @@ signAsync({
|
||||
app: "temp/Wave.app",
|
||||
binaries: [
|
||||
waveAppPath + "/Contents/Resources/app/bin/wavesrv",
|
||||
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.3-linux.amd64",
|
||||
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.3-linux.arm64",
|
||||
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.3-darwin.amd64",
|
||||
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.3-darwin.arm64",
|
||||
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.4-linux.amd64",
|
||||
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.4-linux.arm64",
|
||||
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.4-darwin.amd64",
|
||||
waveAppPath + "/Contents/Resources/app/bin/mshell/mshell-v0.4-darwin.arm64",
|
||||
],
|
||||
}).then(() => {
|
||||
console.log("signing success");
|
||||
|
@ -44,10 +44,10 @@ rm -rf bin/
|
||||
rm -rf build/
|
||||
node_modules/.bin/webpack --env prod
|
||||
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.amd64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.arm64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.amd64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.arm64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.amd64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.arm64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.amd64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.arm64 main-waveshell.go)
|
||||
(cd wavesrv; CGO_ENABLED=1 go build -tags "osusergo,netgo,sqlite_omit_load_extension" -ldflags "-X main.BuildTime=$(date +'%Y%m%d%H%M')" -o ../bin/wavesrv ./cmd)
|
||||
node_modules/.bin/electron-forge make
|
||||
```
|
||||
@ -60,10 +60,10 @@ rm -rf bin/
|
||||
rm -rf build/
|
||||
node_modules/.bin/webpack --env prod
|
||||
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.amd64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.arm64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.amd64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.arm64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.amd64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.arm64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.amd64 main-waveshell.go)
|
||||
(cd waveshell; CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.arm64 main-waveshell.go)
|
||||
# adds -extldflags=-static, *only* on linux (macos does not support fully static binaries) to avoid a glibc dependency
|
||||
(cd wavesrv; CGO_ENABLED=1 go build -tags "osusergo,netgo,sqlite_omit_load_extension" -ldflags "-linkmode 'external' -extldflags=-static $GO_LDFLAGS" -o ../bin/wavesrv ./cmd)
|
||||
node_modules/.bin/electron-forge make
|
||||
@ -86,10 +86,10 @@ CGO_ENABLED=1 go build -tags "osusergo,netgo,sqlite_omit_load_extension" -ldflag
|
||||
set -e
|
||||
cd waveshell
|
||||
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.amd64 main-waveshell.go
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-linux.arm64 main-waveshell.go
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.amd64 main-waveshell.go
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.3-darwin.arm64 main-waveshell.go
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.amd64 main-waveshell.go
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-linux.arm64 main-waveshell.go
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.amd64 main-waveshell.go
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o ../bin/mshell/mshell-v0.4-darwin.arm64 main-waveshell.go
|
||||
```
|
||||
|
||||
```bash
|
||||
|
@ -25,6 +25,7 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
|
||||
tempConnectMode: OV<string>;
|
||||
tempPassword: OV<string>;
|
||||
tempKeyFile: OV<string>;
|
||||
tempShellPref: OV<string>;
|
||||
errorStr: OV<string>;
|
||||
remoteEdit: T.RemoteEditType;
|
||||
model: RemotesModel;
|
||||
@ -40,6 +41,7 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
|
||||
this.tempConnectMode = mobx.observable.box("auto", { name: "CreateRemote-connectMode" });
|
||||
this.tempKeyFile = mobx.observable.box("", { name: "CreateRemote-keystr" });
|
||||
this.tempPassword = mobx.observable.box("", { name: "CreateRemote-password" });
|
||||
this.tempShellPref = mobx.observable.box("detect", { name: "CreateRemote-shellPref" });
|
||||
this.errorStr = mobx.observable.box(this.remoteEdit?.errorstr ?? null, { name: "CreateRemote-errorStr" });
|
||||
}
|
||||
|
||||
@ -121,6 +123,7 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
|
||||
kwargs["password"] = "";
|
||||
}
|
||||
kwargs["connectmode"] = this.tempConnectMode.get();
|
||||
kwargs["shellpref"] = this.tempShellPref.get();
|
||||
kwargs["visual"] = "1";
|
||||
kwargs["submit"] = "1";
|
||||
let prtn = GlobalCommandRunner.createRemote(cname, kwargs, false);
|
||||
@ -174,6 +177,13 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
|
||||
})();
|
||||
}
|
||||
|
||||
@boundMethod
|
||||
handleChangeShellPref(value: string): void {
|
||||
mobx.action(() => {
|
||||
this.tempShellPref.set(value);
|
||||
})();
|
||||
}
|
||||
|
||||
@boundMethod
|
||||
handleChangePort(value: string): void {
|
||||
mobx.action(() => {
|
||||
@ -357,6 +367,20 @@ class CreateRemoteConnModal extends React.Component<{}, {}> {
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="shellpref-section">
|
||||
<Dropdown
|
||||
label="Shell Preference"
|
||||
options={[
|
||||
{ value: "detect", label: "detect" },
|
||||
{ value: "bash", label: "bash" },
|
||||
{ value: "zsh", label: "zsh" },
|
||||
]}
|
||||
value={this.tempShellPref.get()}
|
||||
onChange={(val: string) => {
|
||||
this.tempShellPref.set(val);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<If condition={!util.isBlank(this.getErrorStr() as string)}>
|
||||
<div className="settings-field settings-error">Error: {this.getErrorStr()}</div>
|
||||
</If>
|
||||
|
@ -24,6 +24,7 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
|
||||
tempPassword: OV<string>;
|
||||
tempConnectMode: OV<string>;
|
||||
tempAuthMode: OV<string>;
|
||||
tempShellPref: OV<string>;
|
||||
model: RemotesModel;
|
||||
|
||||
constructor(props: { remotesModel?: RemotesModel }) {
|
||||
@ -34,6 +35,7 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
|
||||
this.tempKeyFile = mobx.observable.box(null, { name: "EditRemoteSettings-tempKeyFile" });
|
||||
this.tempPassword = mobx.observable.box(null, { name: "EditRemoteSettings-tempPassword" });
|
||||
this.tempConnectMode = mobx.observable.box(null, { name: "EditRemoteSettings-tempConnectMode" });
|
||||
this.tempShellPref = mobx.observable.box(null, { name: "EditRemoteSettings-tempShellPref" });
|
||||
}
|
||||
|
||||
get selectedRemoteId() {
|
||||
@ -52,6 +54,10 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
|
||||
return this.model.isAuthEditMode();
|
||||
}
|
||||
|
||||
isLocalRemote(): boolean {
|
||||
return this.selectedRemote?.local;
|
||||
}
|
||||
|
||||
componentDidMount(): void {
|
||||
mobx.action(() => {
|
||||
this.tempAlias.set(this.selectedRemote?.remotealias);
|
||||
@ -59,6 +65,7 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
|
||||
this.tempPassword.set(this.remoteEdit?.haspassword ? PasswordUnchangedSentinel : "");
|
||||
this.tempConnectMode.set(this.selectedRemote?.connectmode);
|
||||
this.tempAuthMode.set(this.selectedRemote?.authtype);
|
||||
this.tempShellPref.set(this.selectedRemote?.shellpref);
|
||||
})();
|
||||
}
|
||||
|
||||
@ -103,6 +110,13 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
|
||||
})();
|
||||
}
|
||||
|
||||
@boundMethod
|
||||
handleChangeShellPref(value: string): void {
|
||||
mobx.action(() => {
|
||||
this.tempShellPref.set(value);
|
||||
})();
|
||||
}
|
||||
|
||||
@boundMethod
|
||||
canResetPw(): boolean {
|
||||
if (this.remoteEdit == null) {
|
||||
@ -154,6 +168,9 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
|
||||
if (!util.isStrEq(this.tempConnectMode.get(), this.selectedRemote?.connectmode)) {
|
||||
kwargs["connectmode"] = this.tempConnectMode.get();
|
||||
}
|
||||
if (!util.isStrEq(this.tempShellPref.get(), this.selectedRemote?.shellpref)) {
|
||||
kwargs["shellpref"] = this.tempShellPref.get();
|
||||
}
|
||||
kwargs["visual"] = "1";
|
||||
kwargs["submit"] = "1";
|
||||
GlobalCommandRunner.editRemote(this.selectedRemote?.remoteid, kwargs);
|
||||
@ -183,11 +200,150 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
|
||||
return null;
|
||||
}
|
||||
|
||||
render() {
|
||||
renderAlias() {
|
||||
return (
|
||||
<div className="alias-section">
|
||||
<TextField
|
||||
label="Alias"
|
||||
onChange={this.handleChangeAlias}
|
||||
value={this.tempAlias.get()}
|
||||
maxLength={100}
|
||||
decoration={{
|
||||
endDecoration: (
|
||||
<InputDecoration>
|
||||
<Tooltip
|
||||
message={`(Optional) A short alias to use when selecting or displaying this connection.`}
|
||||
icon={<i className="fa-sharp fa-regular fa-circle-question" />}
|
||||
>
|
||||
<i className="fa-sharp fa-regular fa-circle-question" />
|
||||
</Tooltip>
|
||||
</InputDecoration>
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
renderConnectMode() {
|
||||
return (
|
||||
<div className="connectmode-section">
|
||||
<Dropdown
|
||||
label="Connect Mode"
|
||||
options={[
|
||||
{ value: "startup", label: "startup" },
|
||||
{ value: "auto", label: "auto" },
|
||||
{ value: "manual", label: "manual" },
|
||||
]}
|
||||
value={this.tempConnectMode.get()}
|
||||
onChange={this.handleChangeConnectMode}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
renderShellPref() {
|
||||
return (
|
||||
<div className="shellpref-section">
|
||||
<Dropdown
|
||||
label="Shell Preference"
|
||||
options={[
|
||||
{ value: "detect", label: "detect" },
|
||||
{ value: "bash", label: "bash" },
|
||||
{ value: "zsh", label: "zsh" },
|
||||
]}
|
||||
value={this.tempShellPref.get()}
|
||||
onChange={this.handleChangeShellPref}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
renderAuthMode() {
|
||||
let authMode = this.tempAuthMode.get();
|
||||
return (
|
||||
<>
|
||||
<div className="authmode-section">
|
||||
<Dropdown
|
||||
label="Auth Mode"
|
||||
options={[
|
||||
{ value: "none", label: "none" },
|
||||
{ value: "key", label: "key" },
|
||||
{ value: "password", label: "password" },
|
||||
{ value: "key+password", label: "key+password" },
|
||||
]}
|
||||
value={this.tempAuthMode.get()}
|
||||
onChange={this.handleChangeAuthMode}
|
||||
decoration={{
|
||||
endDecoration: (
|
||||
<InputDecoration>
|
||||
<Tooltip
|
||||
message={
|
||||
<ul>
|
||||
<li>
|
||||
<b>none</b> - no authentication, or authentication is already
|
||||
configured in your ssh config.
|
||||
</li>
|
||||
<li>
|
||||
<b>key</b> - use a private key.
|
||||
</li>
|
||||
<li>
|
||||
<b>password</b> - use a password.
|
||||
</li>
|
||||
<li>
|
||||
<b>key+password</b> - use a key with a passphrase.
|
||||
</li>
|
||||
</ul>
|
||||
}
|
||||
icon={<i className="fa-sharp fa-regular fa-circle-question" />}
|
||||
>
|
||||
<i className="fa-sharp fa-regular fa-circle-question" />
|
||||
</Tooltip>
|
||||
</InputDecoration>
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<If condition={authMode == "key" || authMode == "key+password"}>
|
||||
<TextField
|
||||
label="SSH Keyfile"
|
||||
placeholder="keyfile path"
|
||||
onChange={this.handleChangeKeyFile}
|
||||
value={this.tempKeyFile.get()}
|
||||
maxLength={400}
|
||||
required={true}
|
||||
decoration={{
|
||||
endDecoration: (
|
||||
<InputDecoration>
|
||||
<Tooltip
|
||||
message={`(Required) The path to your ssh key file.`}
|
||||
icon={<i className="fa-sharp fa-regular fa-circle-question" />}
|
||||
>
|
||||
<i className="fa-sharp fa-regular fa-circle-question" />
|
||||
</Tooltip>
|
||||
</InputDecoration>
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</If>
|
||||
<If condition={authMode == "password" || authMode == "key+password"}>
|
||||
<PasswordField
|
||||
label={authMode == "password" ? "SSH Password" : "Key Passphrase"}
|
||||
placeholder="password"
|
||||
onChange={this.handleChangePassword}
|
||||
value={this.tempPassword.get()}
|
||||
maxLength={400}
|
||||
/>
|
||||
</If>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
render() {
|
||||
if (this.remoteEdit === null || !this.isAuthEditMode) {
|
||||
return null;
|
||||
}
|
||||
let isLocal = this.isLocalRemote();
|
||||
return (
|
||||
<Modal className="erconn-modal">
|
||||
<Modal.Header title="Edit Connection" onClose={this.model.closeModal} />
|
||||
@ -195,110 +351,10 @@ class EditRemoteConnModal extends React.Component<{}, {}> {
|
||||
<div className="name-actions-section">
|
||||
<div className="name text-primary">{util.getRemoteName(this.selectedRemote)}</div>
|
||||
</div>
|
||||
<div className="alias-section">
|
||||
<TextField
|
||||
label="Alias"
|
||||
onChange={this.handleChangeAlias}
|
||||
value={this.tempAlias.get()}
|
||||
maxLength={100}
|
||||
decoration={{
|
||||
endDecoration: (
|
||||
<InputDecoration>
|
||||
<Tooltip
|
||||
message={`(Optional) A short alias to use when selecting or displaying this connection.`}
|
||||
icon={<i className="fa-sharp fa-regular fa-circle-question" />}
|
||||
>
|
||||
<i className="fa-sharp fa-regular fa-circle-question" />
|
||||
</Tooltip>
|
||||
</InputDecoration>
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="authmode-section">
|
||||
<Dropdown
|
||||
label="Auth Mode"
|
||||
options={[
|
||||
{ value: "none", label: "none" },
|
||||
{ value: "key", label: "key" },
|
||||
{ value: "password", label: "password" },
|
||||
{ value: "key+password", label: "key+password" },
|
||||
]}
|
||||
value={this.tempAuthMode.get()}
|
||||
onChange={this.handleChangeAuthMode}
|
||||
decoration={{
|
||||
endDecoration: (
|
||||
<InputDecoration>
|
||||
<Tooltip
|
||||
message={
|
||||
<ul>
|
||||
<li>
|
||||
<b>none</b> - no authentication, or authentication is already
|
||||
configured in your ssh config.
|
||||
</li>
|
||||
<li>
|
||||
<b>key</b> - use a private key.
|
||||
</li>
|
||||
<li>
|
||||
<b>password</b> - use a password.
|
||||
</li>
|
||||
<li>
|
||||
<b>key+password</b> - use a key with a passphrase.
|
||||
</li>
|
||||
</ul>
|
||||
}
|
||||
icon={<i className="fa-sharp fa-regular fa-circle-question" />}
|
||||
>
|
||||
<i className="fa-sharp fa-regular fa-circle-question" />
|
||||
</Tooltip>
|
||||
</InputDecoration>
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<If condition={authMode == "key" || authMode == "key+password"}>
|
||||
<TextField
|
||||
label="SSH Keyfile"
|
||||
placeholder="keyfile path"
|
||||
onChange={this.handleChangeKeyFile}
|
||||
value={this.tempKeyFile.get()}
|
||||
maxLength={400}
|
||||
required={true}
|
||||
decoration={{
|
||||
endDecoration: (
|
||||
<InputDecoration>
|
||||
<Tooltip
|
||||
message={`(Required) The path to your ssh key file.`}
|
||||
icon={<i className="fa-sharp fa-regular fa-circle-question" />}
|
||||
>
|
||||
<i className="fa-sharp fa-regular fa-circle-question" />
|
||||
</Tooltip>
|
||||
</InputDecoration>
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</If>
|
||||
<If condition={authMode == "password" || authMode == "key+password"}>
|
||||
<PasswordField
|
||||
label={authMode == "password" ? "SSH Password" : "Key Passphrase"}
|
||||
placeholder="password"
|
||||
onChange={this.handleChangePassword}
|
||||
value={this.tempPassword.get()}
|
||||
maxLength={400}
|
||||
/>
|
||||
</If>
|
||||
<div className="connectmode-section">
|
||||
<Dropdown
|
||||
label="Connect Mode"
|
||||
options={[
|
||||
{ value: "startup", label: "startup" },
|
||||
{ value: "auto", label: "auto" },
|
||||
{ value: "manual", label: "manual" },
|
||||
]}
|
||||
value={this.tempConnectMode.get()}
|
||||
onChange={this.handleChangeConnectMode}
|
||||
/>
|
||||
</div>
|
||||
<If condition={!isLocal}>{this.renderAlias()}</If>
|
||||
<If condition={!isLocal}>{this.renderAuthMode()}</If>
|
||||
<If condition={!isLocal}>{this.renderConnectMode()}</If>
|
||||
{this.renderShellPref()}
|
||||
<If condition={!util.isBlank(this.remoteEdit?.errorstr)}>
|
||||
<div className="settings-field settings-error">Error: {this.remoteEdit?.errorstr}</div>
|
||||
</If>
|
||||
|
@ -206,7 +206,6 @@ class ViewRemoteConnDetailModal extends React.Component<{}, {}> {
|
||||
);
|
||||
if (remote.local) {
|
||||
installNowButton = <></>;
|
||||
updateAuthButton = <></>;
|
||||
cancelInstallButton = <></>;
|
||||
}
|
||||
if (remote.sshconfigsrc == "sshconfig-import") {
|
||||
@ -352,6 +351,10 @@ class ViewRemoteConnDetailModal extends React.Component<{}, {}> {
|
||||
<div className="settings-label">Connect Mode</div>
|
||||
<div className="settings-input">{remote.connectmode}</div>
|
||||
</div>
|
||||
<div className="settings-field">
|
||||
<div className="settings-label">Shell Pref</div>
|
||||
<div className="settings-input">{remote.shellpref}</div>
|
||||
</div>
|
||||
{this.renderInstallStatus(remote)}
|
||||
<div className="flex-spacer" style={{ minHeight: 20 }} />
|
||||
<div className="status">
|
||||
|
@ -128,6 +128,22 @@
|
||||
padding: 1em 2px;
|
||||
}
|
||||
|
||||
.textareainput-div {
|
||||
position: relative;
|
||||
|
||||
.shelltag {
|
||||
position: absolute;
|
||||
bottom: 4px;
|
||||
right: 3px;
|
||||
font-size: 10px;
|
||||
color: @text-secondary;
|
||||
line-height: 1;
|
||||
padding: 0px 8px 3px 8px;
|
||||
background-color: @textarea-background;
|
||||
border-radius: 0 0 5px 5px;
|
||||
}
|
||||
}
|
||||
|
||||
textarea {
|
||||
color: @term-bright-white;
|
||||
background-color: @textarea-background;
|
||||
|
@ -99,6 +99,11 @@ class CmdInput extends React.Component<{}, {}> {
|
||||
})();
|
||||
}
|
||||
|
||||
@boundMethod
|
||||
clickResetState(): void {
|
||||
GlobalCommandRunner.resetShellState();
|
||||
}
|
||||
|
||||
render() {
|
||||
let model = GlobalModel;
|
||||
let inputModel = model.inputModel;
|
||||
@ -115,6 +120,7 @@ class CmdInput extends React.Component<{}, {}> {
|
||||
remote = GlobalModel.getRemote(ri.remoteid);
|
||||
feState = ri.festate;
|
||||
}
|
||||
feState = feState || {};
|
||||
let infoShow = inputModel.infoShow.get();
|
||||
let historyShow = !infoShow && inputModel.historyShow.get();
|
||||
let aiChatShow = inputModel.aIChatShow.get();
|
||||
@ -162,6 +168,18 @@ class CmdInput extends React.Component<{}, {}> {
|
||||
</If>
|
||||
</div>
|
||||
</If>
|
||||
<If condition={feState["invalidshellstate"]}>
|
||||
<div className="remote-status-warning">
|
||||
WARNING: The shell state for this tab is invalid (
|
||||
<a target="_blank" href="https://docs.waveterm.dev/reference/faq">
|
||||
see FAQ
|
||||
</a>
|
||||
). Must reset to continue.
|
||||
<div className="button is-wave-green is-outlined is-small" onClick={this.clickResetState}>
|
||||
reset shell state
|
||||
</div>
|
||||
</div>
|
||||
</If>
|
||||
<div key="prompt" className="cmd-input-context">
|
||||
<div className="has-text-white">
|
||||
<span ref={this.promptRef}>
|
||||
|
@ -5,6 +5,8 @@ import * as React from "react";
|
||||
import * as mobxReact from "mobx-react";
|
||||
import * as mobx from "mobx";
|
||||
import type * as T from "../../../types/types";
|
||||
import * as util from "../../../util/util";
|
||||
import { If } from "tsx-control-statements/components";
|
||||
import { boundMethod } from "autobind-decorator";
|
||||
import cn from "classnames";
|
||||
import { GlobalModel, GlobalCommandRunner, Screen } from "../../../model/model";
|
||||
@ -585,8 +587,24 @@ class TextAreaInput extends React.Component<{ screen: Screen; onHeightChange: ()
|
||||
let computedInnerHeight = displayLines * (termFontSize * 1.5) + 2 * 0.5 * termFontSize;
|
||||
// inner height + 2*1em padding
|
||||
let computedOuterHeight = computedInnerHeight + 2 * 1.0 * termFontSize;
|
||||
let shellType: string = "";
|
||||
let screen = GlobalModel.getActiveScreen();
|
||||
if (screen != null) {
|
||||
let ri = screen.getCurRemoteInstance();
|
||||
console.log("got ri", ri);
|
||||
if (ri != null && ri.shelltype != null) {
|
||||
shellType = ri.shelltype;
|
||||
}
|
||||
}
|
||||
return (
|
||||
<div className="control is-expanded" ref={this.controlRef} style={{ height: computedOuterHeight }}>
|
||||
<div
|
||||
className="textareainput-div control is-expanded"
|
||||
ref={this.controlRef}
|
||||
style={{ height: computedOuterHeight }}
|
||||
>
|
||||
<If condition={!disabled && !util.isBlank(shellType)}>
|
||||
<div className="shelltag">{shellType}</div>
|
||||
</If>
|
||||
<textarea
|
||||
key="main"
|
||||
ref={this.mainInputRef}
|
||||
|
@ -1213,6 +1213,7 @@ class Session {
|
||||
remoteid: rptr.remoteid,
|
||||
name: rptr.name,
|
||||
festate: remote.defaultfestate,
|
||||
shelltype: remote.defaultshelltype,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
@ -4567,6 +4568,10 @@ class CommandRunner {
|
||||
GlobalModel.submitCommand("history", null, null, kwargs, true);
|
||||
}
|
||||
|
||||
resetShellState() {
|
||||
GlobalModel.submitCommand("reset", null, null, null, true);
|
||||
}
|
||||
|
||||
historyPurgeLines(lines: string[]): Promise<CommandRtnType> {
|
||||
let prtn = GlobalModel.submitCommand("history", "purge", lines, { nohist: "1" }, false);
|
||||
return prtn;
|
||||
|
@ -121,6 +121,8 @@ type RemoteType = {
|
||||
remoteopts?: RemoteOptsType;
|
||||
local: boolean;
|
||||
remove?: boolean;
|
||||
shellpref: string;
|
||||
defaultshelltype: string;
|
||||
};
|
||||
|
||||
type RemoteStateType = {
|
||||
@ -136,6 +138,7 @@ type RemoteInstanceType = {
|
||||
remoteownerid: string;
|
||||
remoteid: string;
|
||||
festate: Record<string, string>;
|
||||
shelltype: string;
|
||||
|
||||
remove?: boolean;
|
||||
};
|
||||
|
@ -4,10 +4,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -16,131 +14,10 @@ import (
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/server"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var BuildTime = "0"
|
||||
|
||||
// func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
|
||||
// err := shexec.ValidateRunPacket(pk)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))
|
||||
// return
|
||||
// }
|
||||
// fileNames, err := base.GetCommandFileNames(pk.CK)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))
|
||||
// return
|
||||
// }
|
||||
// cmd, err := shexec.MakeRunnerExec(pk.CK)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))
|
||||
// return
|
||||
// }
|
||||
// cmdStdin, err := cmd.StdinPipe()
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))
|
||||
// return
|
||||
// }
|
||||
// // touch ptyout file (should exist for tailer to work correctly)
|
||||
// ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))
|
||||
// return
|
||||
// }
|
||||
// ptyOutFd.Close() // just opened to create the file, can close right after
|
||||
// runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))
|
||||
// return
|
||||
// }
|
||||
// defer runnerOutFd.Close()
|
||||
// cmd.Stdout = runnerOutFd
|
||||
// cmd.Stderr = runnerOutFd
|
||||
// err = cmd.Start()
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))
|
||||
// return
|
||||
// }
|
||||
// go func() {
|
||||
// err = packet.SendPacket(cmdStdin, pk)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))
|
||||
// return
|
||||
// }
|
||||
// cmdStdin.Close()
|
||||
|
||||
// // clean up zombies
|
||||
// cmd.Wait()
|
||||
// }()
|
||||
// }
|
||||
|
||||
// func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packet.PacketSender) error {
|
||||
// err := tailer.AddWatch(pk)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func doMain() {
|
||||
// homeDir := base.GetHomeDir()
|
||||
// err := os.Chdir(homeDir)
|
||||
// if err != nil {
|
||||
// packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to $HOME '%s': %v", homeDir, err))
|
||||
// return
|
||||
// }
|
||||
// _, err = base.GetMShellPath()
|
||||
// if err != nil {
|
||||
// packet.SendErrorPacket(os.Stdout, err.Error())
|
||||
// return
|
||||
// }
|
||||
// packetParser := packet.MakePacketParser(os.Stdin)
|
||||
// sender := packet.MakePacketSender(os.Stdout)
|
||||
// tailer, err := cmdtail.MakeTailer(sender)
|
||||
// if err != nil {
|
||||
// packet.SendErrorPacket(os.Stdout, err.Error())
|
||||
// return
|
||||
// }
|
||||
// go tailer.Run()
|
||||
// initPacket := shexec.MakeInitPacket()
|
||||
// sender.SendPacket(initPacket)
|
||||
// for pk := range packetParser.MainCh {
|
||||
// if pk.GetType() == packet.RunPacketStr {
|
||||
// doMainRun(pk.(*packet.RunPacketType), sender)
|
||||
// continue
|
||||
// }
|
||||
// if pk.GetType() == packet.GetCmdPacketStr {
|
||||
// err = doGetCmd(tailer, pk.(*packet.GetCmdPacketType), sender)
|
||||
// if err != nil {
|
||||
// errPk := packet.MakeErrorPacket(err.Error())
|
||||
// sender.SendPacket(errPk)
|
||||
// continue
|
||||
// }
|
||||
// continue
|
||||
// }
|
||||
// if pk.GetType() == packet.CdPacketStr {
|
||||
// cdPacket := pk.(*packet.CdPacketType)
|
||||
// err := os.Chdir(cdPacket.Dir)
|
||||
// resp := packet.MakeResponsePacket(cdPacket.ReqId)
|
||||
// if err != nil {
|
||||
// resp.Error = err.Error()
|
||||
// } else {
|
||||
// resp.Success = true
|
||||
// }
|
||||
// sender.SendPacket(resp)
|
||||
// continue
|
||||
// }
|
||||
// if pk.GetType() == packet.ErrorPacketStr {
|
||||
// errPk := pk.(*packet.ErrorPacketType)
|
||||
// errPk.Error = "invalid packet sent to mshell: " + errPk.Error
|
||||
// sender.SendPacket(errPk)
|
||||
// continue
|
||||
// }
|
||||
// sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
|
||||
// }
|
||||
// }
|
||||
|
||||
func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType, error) {
|
||||
rpb := packet.MakeRunPacketBuilder()
|
||||
for pk := range packetParser.MainCh {
|
||||
@ -155,7 +32,7 @@ func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType
|
||||
return nil, fmt.Errorf("no run packet received")
|
||||
}
|
||||
|
||||
func handleSingle(fromServer bool) {
|
||||
func handleSingle() {
|
||||
packetParser := packet.MakePacketParser(os.Stdin, nil)
|
||||
sender := packet.MakePacketSender(os.Stdout, nil)
|
||||
defer func() {
|
||||
@ -177,11 +54,9 @@ func handleSingle(fromServer bool) {
|
||||
sender.SendErrorResponse(runPacket.ReqId, err)
|
||||
return
|
||||
}
|
||||
if fromServer {
|
||||
err = runPacket.CK.Validate("run packet")
|
||||
if err != nil {
|
||||
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("run packets from server must have a CK: %v", err))
|
||||
}
|
||||
err = runPacket.CK.Validate("run packet")
|
||||
if err != nil {
|
||||
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("run packets from server must have a CK: %v", err))
|
||||
}
|
||||
if runPacket.Detached {
|
||||
cmd, startPk, err := shexec.RunCommandDetached(runPacket, sender)
|
||||
@ -225,306 +100,20 @@ func handleSingle(fromServer bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func detectOpenFds() ([]packet.RemoteFd, error) {
|
||||
var fds []packet.RemoteFd
|
||||
for fdNum := 3; fdNum <= 64; fdNum++ {
|
||||
flags, err := unix.FcntlInt(uintptr(fdNum), unix.F_GETFL, 0)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
flags = flags & 3
|
||||
rfd := packet.RemoteFd{FdNum: fdNum}
|
||||
if flags&2 == 2 {
|
||||
return nil, fmt.Errorf("invalid fd=%d, mshell does not support fds open for reading and writing", fdNum)
|
||||
}
|
||||
if flags&1 == 1 {
|
||||
rfd.Write = true
|
||||
} else {
|
||||
rfd.Read = true
|
||||
}
|
||||
fds = append(fds, rfd)
|
||||
}
|
||||
return fds, nil
|
||||
}
|
||||
|
||||
func parseInstallOpts() (*shexec.InstallOpts, error) {
|
||||
opts := &shexec.InstallOpts{}
|
||||
iter := base.MakeOptsIter(os.Args[2:]) // first arg is --install
|
||||
for iter.HasNext() {
|
||||
argStr := iter.Next()
|
||||
found, err := tryParseSSHOpt(iter, &opts.SSHOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
if argStr == "--detect" {
|
||||
opts.Detect = true
|
||||
continue
|
||||
}
|
||||
if base.IsOption(argStr) {
|
||||
return nil, fmt.Errorf("invalid option '%s' passed to mshell --install", argStr)
|
||||
}
|
||||
opts.ArchStr = argStr
|
||||
break
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func tryParseSSHOpt(iter *base.OptsIter, sshOpts *shexec.SSHOpts) (bool, error) {
|
||||
argStr := iter.Current()
|
||||
if argStr == "--ssh" {
|
||||
if !iter.IsNextPlain() {
|
||||
return false, fmt.Errorf("'--ssh [user@host]' missing host")
|
||||
}
|
||||
sshOpts.SSHHost = iter.Next()
|
||||
return true, nil
|
||||
}
|
||||
if argStr == "--ssh-opts" {
|
||||
if !iter.HasNext() {
|
||||
return false, fmt.Errorf("'--ssh-opts [options]' missing options")
|
||||
}
|
||||
sshOpts.SSHOptsStr = iter.Next()
|
||||
return true, nil
|
||||
}
|
||||
if argStr == "-i" {
|
||||
if !iter.IsNextPlain() {
|
||||
return false, fmt.Errorf("-i [identity-file]' missing file")
|
||||
}
|
||||
sshOpts.SSHIdentity = iter.Next()
|
||||
return true, nil
|
||||
}
|
||||
if argStr == "-l" {
|
||||
if !iter.IsNextPlain() {
|
||||
return false, fmt.Errorf("-l [user]' missing user")
|
||||
}
|
||||
sshOpts.SSHUser = iter.Next()
|
||||
return true, nil
|
||||
}
|
||||
if argStr == "-p" {
|
||||
if !iter.IsNextPlain() {
|
||||
return false, fmt.Errorf("-p [port]' missing port")
|
||||
}
|
||||
nextArgStr := iter.Next()
|
||||
portVal, err := strconv.Atoi(nextArgStr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("-p [port]' invalid port: %v", err)
|
||||
}
|
||||
if portVal <= 0 {
|
||||
return false, fmt.Errorf("-p [port]' invalid port: %d", portVal)
|
||||
}
|
||||
sshOpts.SSHPort = portVal
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func parseClientOpts() (*shexec.ClientOpts, error) {
|
||||
opts := &shexec.ClientOpts{}
|
||||
iter := base.MakeOptsIter(os.Args[1:])
|
||||
for iter.HasNext() {
|
||||
argStr := iter.Next()
|
||||
found, err := tryParseSSHOpt(iter, &opts.SSHOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
if argStr == "--cwd" {
|
||||
if !iter.IsNextPlain() {
|
||||
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
|
||||
}
|
||||
opts.Cwd = iter.Next()
|
||||
continue
|
||||
}
|
||||
if argStr == "--detach" {
|
||||
opts.Detach = true
|
||||
continue
|
||||
}
|
||||
if argStr == "--pty" {
|
||||
opts.UsePty = true
|
||||
continue
|
||||
}
|
||||
if argStr == "--debug" {
|
||||
opts.Debug = true
|
||||
continue
|
||||
}
|
||||
if argStr == "--sudo" {
|
||||
opts.Sudo = true
|
||||
continue
|
||||
}
|
||||
if argStr == "--sudo-with-password" {
|
||||
if !iter.HasNext() {
|
||||
return nil, fmt.Errorf("'--sudo-with-password [pw]', missing password")
|
||||
}
|
||||
opts.Sudo = true
|
||||
opts.SudoWithPass = true
|
||||
opts.SudoPw = iter.Next()
|
||||
continue
|
||||
}
|
||||
if argStr == "--sudo-with-passfile" {
|
||||
if !iter.IsNextPlain() {
|
||||
return nil, fmt.Errorf("'--sudo-with-passfile [file]', missing file")
|
||||
}
|
||||
opts.Sudo = true
|
||||
opts.SudoWithPass = true
|
||||
fileName := iter.Next()
|
||||
contents, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read --sudo-with-passfile file '%s': %w", fileName, err)
|
||||
}
|
||||
if newlineIdx := bytes.Index(contents, []byte{'\n'}); newlineIdx != -1 {
|
||||
contents = contents[0:newlineIdx]
|
||||
}
|
||||
opts.SudoPw = string(contents) + "\n"
|
||||
continue
|
||||
}
|
||||
if argStr == "--" {
|
||||
if !iter.HasNext() {
|
||||
return nil, fmt.Errorf("'--' should be followed by command")
|
||||
}
|
||||
opts.Command = strings.Join(iter.Rest(), " ")
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("invalid option '%s' passed to mshell", argStr)
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func handleClient() (int, error) {
|
||||
opts, err := parseClientOpts()
|
||||
if err != nil {
|
||||
return 1, fmt.Errorf("parsing opts: %w", err)
|
||||
}
|
||||
if opts.Debug {
|
||||
packet.GlobalDebug = true
|
||||
}
|
||||
if opts.Command == "" {
|
||||
return 1, fmt.Errorf("no [command] specified. [command] follows '--' option (see usage)")
|
||||
}
|
||||
fds, err := detectOpenFds()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
opts.Fds = fds
|
||||
err = shexec.ValidateRemoteFds(opts.Fds)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
runPacket, err := opts.MakeRunPacket() // modifies opts
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
if runPacket.Detached {
|
||||
return 1, fmt.Errorf("cannot run detached command from command line client")
|
||||
}
|
||||
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, nil, opts.Debug)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
return donePacket.ExitCode, nil
|
||||
}
|
||||
|
||||
func handleInstall() (int, error) {
|
||||
opts, err := parseInstallOpts()
|
||||
if err != nil {
|
||||
return 1, fmt.Errorf("parsing opts: %w", err)
|
||||
}
|
||||
if opts.SSHOpts.SSHHost == "" {
|
||||
return 1, fmt.Errorf("cannot install without '--ssh user@host' option")
|
||||
}
|
||||
if opts.Detect && opts.ArchStr != "" {
|
||||
return 1, fmt.Errorf("cannot supply both --detect and arch '%s'", opts.ArchStr)
|
||||
}
|
||||
if opts.ArchStr == "" && !opts.Detect {
|
||||
return 1, fmt.Errorf("must supply an arch string or '--detect' to auto detect")
|
||||
}
|
||||
if opts.ArchStr != "" {
|
||||
fullArch := opts.ArchStr
|
||||
fields := strings.SplitN(fullArch, ".", 2)
|
||||
if len(fields) != 2 {
|
||||
return 1, fmt.Errorf("invalid arch format '%s' passed to mshell --install", fullArch)
|
||||
}
|
||||
goos, goarch := fields[0], fields[1]
|
||||
if !base.ValidGoArch(goos, goarch) {
|
||||
return 1, fmt.Errorf("invalid arch '%s' passed to mshell --install", fullArch)
|
||||
}
|
||||
optName := base.GoArchOptFile(base.MShellVersion, goos, goarch)
|
||||
_, err = os.Stat(optName)
|
||||
if err != nil {
|
||||
return 1, fmt.Errorf("cannot install mshell to remote host, cannot read '%s': %w", optName, err)
|
||||
}
|
||||
opts.OptName = optName
|
||||
}
|
||||
err = shexec.RunInstallFromOpts(opts)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func handleEnv() (int, error) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
fmt.Printf("%s\x00\x00", cwd)
|
||||
fullEnv := os.Environ()
|
||||
var linePrinted bool
|
||||
for _, envLine := range fullEnv {
|
||||
if envLine != "" {
|
||||
fmt.Printf("%s\x00", envLine)
|
||||
linePrinted = true
|
||||
}
|
||||
}
|
||||
if linePrinted {
|
||||
fmt.Printf("\x00")
|
||||
} else {
|
||||
fmt.Printf("\x00\x00")
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func handleUsage() {
|
||||
usage := `
|
||||
Client Usage: mshell [opts] --ssh user@host -- [command]
|
||||
|
||||
mshell multiplexes input and output streams to a remote command over ssh.
|
||||
mshell is a helper program for wave terminal. it is used to execute commands
|
||||
|
||||
Options:
|
||||
-i [identity-file] - used to set '-i' option for ssh command
|
||||
-l [user] - used to set '-l' option for ssh command
|
||||
--cwd [dir] - execute remote command in [dir]
|
||||
--ssh-opts [opts] - addition options to pass to ssh command
|
||||
[command] - the remote command to execute
|
||||
--help - prints this message
|
||||
--version - print version
|
||||
--server - multiplexer to run multiple commands
|
||||
--single - run a single command (connected to multiplexer)
|
||||
--single --version - return an init packet with version info
|
||||
|
||||
Sudo Options:
|
||||
--sudo - use only if sudo never requires a password
|
||||
--sudo-with-password [pw] - not recommended, use --sudo-with-passfile if possible
|
||||
--sudo-with-passfile [file]
|
||||
|
||||
Sudo options allow you to run the given command using "sudo". The first
|
||||
option only works when you can sudo without a password. Your password will be passed
|
||||
securely through a high numbered fd to "sudo -S". Note that to use high numbered
|
||||
file descriptors with sudo, you will need to add this line to your /etc/sudoers file:
|
||||
Defaults closefrom_override
|
||||
See full documentation for more details.
|
||||
|
||||
Examples:
|
||||
# execute a python script remotely, with stdin still hooked up correctly
|
||||
mshell --cwd "~/work" -i key.pem --ssh ubuntu@somehost -- "python3 /dev/fd/4" 4< myscript.py
|
||||
|
||||
# capture multiple outputs
|
||||
mshell --ssh ubuntu@test -- "cat file1.txt > /dev/fd/3; cat file2.txt > /dev/fd/4" 3> file1.txt 4> file2.txt
|
||||
|
||||
# execute a script, catpure stdout/stderr in fd-3 and fd-4
|
||||
# useful if you need to see stdout for interacting with ssh (password or host auth)
|
||||
mshell --ssh user@host -- "test.sh > /dev/fd/3 2> /dev/fd/4" 3> test.stdout 4> test.stderr
|
||||
|
||||
# run a script as root (via sudo), capture output
|
||||
mshell --sudo-with-passfile pw.txt --ssh ubuntu@somehost -- "python3 /dev/fd/3 > /dev/fd/4" 3< myscript.py 4> script-output.txt < script-input.txt
|
||||
mshell does not open any external ports and does not require any additional permissions.
|
||||
it communicates exclusively through stdin/stdout with an attached process
|
||||
via a JSON packet format.
|
||||
`
|
||||
fmt.Printf("%s\n\n", strings.TrimSpace(usage))
|
||||
}
|
||||
@ -542,15 +131,13 @@ func main() {
|
||||
} else if firstArg == "--version" {
|
||||
fmt.Printf("mshell %s+%s\n", base.MShellVersion, base.BuildTime)
|
||||
return
|
||||
} else if firstArg == "--single" {
|
||||
} else if firstArg == "--single" || firstArg == "--single-from-server" {
|
||||
base.ProcessType = base.ProcessType_WaveShellSingle
|
||||
base.InitDebugLog("single")
|
||||
handleSingle(false)
|
||||
return
|
||||
} else if firstArg == "--single-from-server" {
|
||||
base.InitDebugLog("single")
|
||||
handleSingle(true)
|
||||
handleSingle()
|
||||
return
|
||||
} else if firstArg == "--server" {
|
||||
base.ProcessType = base.ProcessType_WaveShellServer
|
||||
base.InitDebugLog("server")
|
||||
rtnCode, err := server.RunServer()
|
||||
if err != nil {
|
||||
@ -560,21 +147,8 @@ func main() {
|
||||
os.Exit(rtnCode)
|
||||
}
|
||||
return
|
||||
} else if firstArg == "--install" {
|
||||
rtnCode, err := handleInstall()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
|
||||
}
|
||||
os.Exit(rtnCode)
|
||||
return
|
||||
} else {
|
||||
rtnCode, err := handleClient()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
|
||||
}
|
||||
if rtnCode != 0 {
|
||||
os.Exit(rtnCode)
|
||||
}
|
||||
handleUsage()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ const SSHCommandVarName = "SSH_COMMAND"
|
||||
const MShellDebugVarName = "MSHELL_DEBUG"
|
||||
const SessionsDirBaseName = "sessions"
|
||||
const RcFilesDirBaseName = "rcfiles"
|
||||
const MShellVersion = "v0.3.0"
|
||||
const MShellVersion = "v0.4.0"
|
||||
const RemoteIdFile = "remoteid"
|
||||
const DefaultMShellInstallBinDir = "/opt/mshell/bin"
|
||||
const LogFileName = "mshell.log"
|
||||
@ -39,6 +39,13 @@ const ForceDebugLog = false
|
||||
const DebugFlag_LogRcFile = "logrc"
|
||||
const LogRcFileName = "debug.rcfile"
|
||||
|
||||
const (
|
||||
ProcessType_Unknown = "unknown"
|
||||
ProcessType_WaveSrv = "wavesrv"
|
||||
ProcessType_WaveShellSingle = "waveshell-single"
|
||||
ProcessType_WaveShellServer = "waveshell-server"
|
||||
)
|
||||
|
||||
// keys are sessionids (also the key RcFilesDirBaseName)
|
||||
var ensureDirCache = make(map[string]bool)
|
||||
var baseLock = &sync.Mutex{}
|
||||
@ -46,6 +53,8 @@ var DebugLogEnabled = false
|
||||
var DebugLogger *log.Logger
|
||||
var BuildTime string = "0"
|
||||
|
||||
var ProcessType string = ProcessType_Unknown
|
||||
|
||||
type CommandFileNames struct {
|
||||
PtyOutFile string
|
||||
StdinFifo string
|
||||
@ -58,6 +67,10 @@ func SetBuildTime(build string) {
|
||||
BuildTime = build
|
||||
}
|
||||
|
||||
func IsWaveSrv() bool {
|
||||
return ProcessType == ProcessType_WaveSrv
|
||||
}
|
||||
|
||||
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
|
||||
if sessionId == "" && cmdId == "" {
|
||||
return CommandKey("")
|
||||
|
@ -44,9 +44,9 @@ func PackStrArr(w io.Writer, strs []string) error {
|
||||
return PackValue(w, barr)
|
||||
}
|
||||
|
||||
func PackInt(w io.Writer, ival int) error {
|
||||
func PackUInt(w io.Writer, ival uint64) error {
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
l := binary.PutUvarint(viBuf, uint64(ival))
|
||||
l := binary.PutUvarint(viBuf, ival)
|
||||
_, err := w.Write(viBuf[0:l])
|
||||
return err
|
||||
}
|
||||
@ -80,8 +80,16 @@ func UnpackStrArr(r FullByteReader) ([]string, error) {
|
||||
return strs, nil
|
||||
}
|
||||
|
||||
func UnpackInt(r io.ByteReader) (int, error) {
|
||||
ival64, err := binary.ReadVarint(r)
|
||||
func UnpackUInt(r io.ByteReader) (uint64, error) {
|
||||
ival64, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return ival64, nil
|
||||
}
|
||||
|
||||
func UnpackUIntAsInt(r io.ByteReader) (int, error) {
|
||||
ival64, err := UnpackUInt(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -99,15 +107,15 @@ func (u *Unpacker) UnpackValue(name string) []byte {
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (u *Unpacker) UnpackInt(name string) int {
|
||||
func (u *Unpacker) UnpackUInt(name string) int {
|
||||
if u.Err != nil {
|
||||
return 0
|
||||
}
|
||||
rtn, err := UnpackInt(u.R)
|
||||
rtn, err := UnpackUInt(u.R)
|
||||
if err != nil {
|
||||
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
|
||||
}
|
||||
return rtn
|
||||
return int(rtn)
|
||||
}
|
||||
|
||||
func (u *Unpacker) UnpackStrArr(name string) []string {
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
)
|
||||
@ -59,11 +60,18 @@ const (
|
||||
WriteFileReadyPacketStr = "writefileready" // rpc-response
|
||||
WriteFileDonePacketStr = "writefiledone" // rpc-response
|
||||
FileDataPacketStr = "filedata"
|
||||
LogPacketStr = "log" // logging packet (sent from waveshell back to server)
|
||||
ShellStatePacketStr = "shellstate"
|
||||
|
||||
OpenAIPacketStr = "openai" // other
|
||||
OpenAICloudReqStr = "openai-cloudreq"
|
||||
)
|
||||
|
||||
const (
|
||||
ShellType_bash = "bash"
|
||||
ShellType_zsh = "zsh"
|
||||
)
|
||||
|
||||
const PacketSenderQueueSize = 20
|
||||
|
||||
const PacketEOFStr = "EOF"
|
||||
@ -102,6 +110,8 @@ func init() {
|
||||
TypeStrToFactory[WriteFilePacketStr] = reflect.TypeOf(WriteFilePacketType{})
|
||||
TypeStrToFactory[WriteFileReadyPacketStr] = reflect.TypeOf(WriteFileReadyPacketType{})
|
||||
TypeStrToFactory[WriteFileDonePacketStr] = reflect.TypeOf(WriteFileDonePacketType{})
|
||||
TypeStrToFactory[LogPacketStr] = reflect.TypeOf(LogPacketType{})
|
||||
TypeStrToFactory[ShellStatePacketStr] = reflect.TypeOf(ShellStatePacketType{})
|
||||
|
||||
var _ RpcPacketType = (*RunPacketType)(nil)
|
||||
var _ RpcPacketType = (*GetCmdPacketType)(nil)
|
||||
@ -119,6 +129,7 @@ func init() {
|
||||
var _ RpcResponsePacketType = (*FileDataPacketType)(nil)
|
||||
var _ RpcResponsePacketType = (*WriteFileReadyPacketType)(nil)
|
||||
var _ RpcResponsePacketType = (*WriteFileDonePacketType)(nil)
|
||||
var _ RpcResponsePacketType = (*ShellStatePacketType)(nil)
|
||||
|
||||
var _ CommandPacketType = (*DataPacketType)(nil)
|
||||
var _ CommandPacketType = (*DataAckPacketType)(nil)
|
||||
@ -382,8 +393,9 @@ func MakeCdPacket() *CdPacketType {
|
||||
}
|
||||
|
||||
type ReInitPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ReqId string `json:"reqid"`
|
||||
Type string `json:"type"`
|
||||
ShellType string `json:"shelltype"`
|
||||
ReqId string `json:"reqid"`
|
||||
}
|
||||
|
||||
func (*ReInitPacketType) GetType() string {
|
||||
@ -543,6 +555,54 @@ func MakeRawPacket(val string) *RawPacketType {
|
||||
return &RawPacketType{Type: RawPacketStr, Data: val}
|
||||
}
|
||||
|
||||
type LogPacketType struct {
|
||||
Type string `json:"type"`
|
||||
Ts int64 `json:"ts"` // log timestamp
|
||||
ReqId string `json:"reqid,omitempty"` // if this log line is related to an rpc request
|
||||
ProcInfo string `json:"procinfo,omitempty"` // server/single
|
||||
LogLine string `json:"logline"` // the logline data
|
||||
}
|
||||
|
||||
func (*LogPacketType) GetType() string {
|
||||
return LogPacketStr
|
||||
}
|
||||
|
||||
func (p *LogPacketType) String() string {
|
||||
return "log"
|
||||
}
|
||||
|
||||
func MakeLogPacket() *LogPacketType {
|
||||
return &LogPacketType{Type: LogPacketStr, Ts: time.Now().UnixMilli()}
|
||||
}
|
||||
|
||||
type ShellStatePacketType struct {
|
||||
Type string `json:"type"`
|
||||
ShellType string `json:"shelltype"`
|
||||
RespId string `json:"respid,omitempty"`
|
||||
State *ShellState `json:"state"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (*ShellStatePacketType) GetType() string {
|
||||
return ShellStatePacketStr
|
||||
}
|
||||
|
||||
func (p *ShellStatePacketType) String() string {
|
||||
return fmt.Sprintf("shellstate[%s]", p.ShellType)
|
||||
}
|
||||
|
||||
func (p *ShellStatePacketType) GetResponseId() string {
|
||||
return p.RespId
|
||||
}
|
||||
|
||||
func (p *ShellStatePacketType) GetResponseDone() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func MakeShellStatePacket() *ShellStatePacketType {
|
||||
return &ShellStatePacketType{Type: ShellStatePacketStr}
|
||||
}
|
||||
|
||||
type MessagePacketType struct {
|
||||
Type string `json:"type"`
|
||||
CK base.CommandKey `json:"ck,omitempty"`
|
||||
@ -567,19 +627,18 @@ func FmtMessagePacket(fmtStr string, args ...interface{}) *MessagePacketType {
|
||||
}
|
||||
|
||||
type InitPacketType struct {
|
||||
Type string `json:"type"`
|
||||
RespId string `json:"respid,omitempty"`
|
||||
Version string `json:"version"`
|
||||
BuildTime string `json:"buildtime,omitempty"`
|
||||
MShellHomeDir string `json:"mshellhomedir,omitempty"`
|
||||
HomeDir string `json:"homedir,omitempty"`
|
||||
State *ShellState `json:"state,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
HostName string `json:"hostname,omitempty"`
|
||||
NotFound bool `json:"notfound,omitempty"`
|
||||
UName string `json:"uname,omitempty"`
|
||||
Shell string `json:"shell,omitempty"`
|
||||
RemoteId string `json:"remoteid,omitempty"`
|
||||
Type string `json:"type"`
|
||||
RespId string `json:"respid,omitempty"`
|
||||
Version string `json:"version"`
|
||||
BuildTime string `json:"buildtime,omitempty"`
|
||||
MShellHomeDir string `json:"mshellhomedir,omitempty"`
|
||||
HomeDir string `json:"homedir,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
HostName string `json:"hostname,omitempty"`
|
||||
NotFound bool `json:"notfound,omitempty"`
|
||||
UName string `json:"uname,omitempty"`
|
||||
Shell string `json:"shell,omitempty"`
|
||||
RemoteId string `json:"remoteid,omitempty"`
|
||||
}
|
||||
|
||||
func (*InitPacketType) GetType() string {
|
||||
@ -701,6 +760,7 @@ type RunPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ReqId string `json:"reqid"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
ShellType string `json:"shelltype"` // new in v0.6.0 (either "bash" or "zsh") (set by remote.go)
|
||||
Command string `json:"command"`
|
||||
State *ShellState `json:"state,omitempty"`
|
||||
StateDiff *ShellStateDiff `json:"statediff,omitempty"`
|
||||
|
@ -144,14 +144,20 @@ func (p *PacketParser) getRpcEntry(reqId string) *RpcEntry {
|
||||
return entry
|
||||
}
|
||||
|
||||
// returns true if sent to an RPC channel. false if not (which then allows the packet to be sent to MainCh)
|
||||
// if GetResponseId() returns "", then this will return false
|
||||
func (p *PacketParser) trySendRpcResponse(pk PacketType) bool {
|
||||
respPk, ok := pk.(RpcResponsePacketType)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
respId := respPk.GetResponseId()
|
||||
if respId == "" {
|
||||
return false
|
||||
}
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
entry := p.RpcMap[respPk.GetResponseId()]
|
||||
entry := p.RpcMap[respId]
|
||||
if entry == nil {
|
||||
return false
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ type ShellState struct {
|
||||
}
|
||||
|
||||
type ShellStateDiff struct {
|
||||
Version string `json:"version"` // [type] [semver]
|
||||
Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base)
|
||||
BaseHash string `json:"basehash"`
|
||||
DiffHashArr []string `json:"diffhasharr,omitempty"`
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
@ -41,6 +41,66 @@ type ShellStateDiff struct {
|
||||
HashVal string `json:"-"`
|
||||
}
|
||||
|
||||
func (state ShellState) GetShellType() string {
|
||||
shell, _, _ := ParseShellStateVersion(state.Version)
|
||||
return shell
|
||||
}
|
||||
|
||||
// returns (shell, version, error)
|
||||
func ParseShellStateVersion(fullVersionStr string) (string, string, error) {
|
||||
if fullVersionStr == "" {
|
||||
return "", "", fmt.Errorf("empty shellstate version")
|
||||
}
|
||||
fields := strings.Split(fullVersionStr, " ")
|
||||
if len(fields) != 2 {
|
||||
return "", "", fmt.Errorf("invalid shellstate version format: %q", fullVersionStr)
|
||||
}
|
||||
shell := fields[0]
|
||||
version := fields[1]
|
||||
if shell != ShellType_zsh && shell != ShellType_bash {
|
||||
return "", "", fmt.Errorf("invalid shellstate shell type: %q", fullVersionStr)
|
||||
}
|
||||
if !semver.IsValid(version) {
|
||||
return "", "", fmt.Errorf("invalid shellstate semver: %q", fullVersionStr)
|
||||
}
|
||||
return shell, version, nil
|
||||
}
|
||||
|
||||
// we're going to allow different versions (as long as shelltype is the same)
|
||||
// before we required version numbers to match exactly which was too restrictive
|
||||
func StateVersionsCompatible(v1 string, v2 string) bool {
|
||||
if v1 == v2 {
|
||||
return true
|
||||
}
|
||||
shell1, version1, err := ParseShellStateVersion(v1)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
shell2, version2, err := ParseShellStateVersion(v2)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if shell1 != shell2 {
|
||||
return false
|
||||
}
|
||||
if semver.Major(version1) != semver.Major(version2) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (diff ShellStateDiff) GetShellType() string {
|
||||
shell, _, _ := ParseShellStateVersion(diff.Version)
|
||||
return shell
|
||||
}
|
||||
|
||||
func (state ShellState) GetLineDiffSplitString() string {
|
||||
if state.GetShellType() == ShellType_zsh {
|
||||
return "\x00"
|
||||
}
|
||||
return "\n"
|
||||
}
|
||||
|
||||
func (state ShellState) IsEmpty() bool {
|
||||
return state.Version == "" && state.Cwd == "" && len(state.ShellVars) == 0 && state.Aliases == "" && state.Funcs == "" && state.Error == ""
|
||||
}
|
||||
@ -55,7 +115,7 @@ func sha1Hash(data []byte) string {
|
||||
// returns (SHA1, encoded-state)
|
||||
func (state ShellState) EncodeAndHash() (string, []byte) {
|
||||
var buf bytes.Buffer
|
||||
binpack.PackInt(&buf, ShellStatePackVersion)
|
||||
binpack.PackUInt(&buf, ShellStatePackVersion)
|
||||
binpack.PackValue(&buf, []byte(state.Version))
|
||||
binpack.PackValue(&buf, []byte(state.Cwd))
|
||||
binpack.PackValue(&buf, state.ShellVars)
|
||||
@ -66,7 +126,7 @@ func (state ShellState) EncodeAndHash() (string, []byte) {
|
||||
}
|
||||
|
||||
// returns a string like "v4" ("" is an unparseable version)
|
||||
func GetBashMajorVersion(versionStr string) string {
|
||||
func GetMajorVersion(versionStr string) string {
|
||||
if versionStr == "" {
|
||||
return ""
|
||||
}
|
||||
@ -94,7 +154,7 @@ func (state *ShellState) DecodeShellState(barr []byte) error {
|
||||
state.HashVal = sha1Hash(barr)
|
||||
buf := bytes.NewBuffer(barr)
|
||||
u := binpack.MakeUnpacker(buf)
|
||||
version := u.UnpackInt("ShellState pack version")
|
||||
version := u.UnpackUInt("ShellState pack version")
|
||||
if version != ShellStatePackVersion {
|
||||
return fmt.Errorf("invalid ShellState pack version: %d", version)
|
||||
}
|
||||
@ -118,7 +178,7 @@ func (state *ShellState) UnmarshalJSON(jsonBytes []byte) error {
|
||||
|
||||
func (sdiff ShellStateDiff) EncodeAndHash() (string, []byte) {
|
||||
var buf bytes.Buffer
|
||||
binpack.PackInt(&buf, ShellStateDiffPackVersion)
|
||||
binpack.PackUInt(&buf, ShellStateDiffPackVersion)
|
||||
binpack.PackValue(&buf, []byte(sdiff.Version))
|
||||
binpack.PackValue(&buf, []byte(sdiff.BaseHash))
|
||||
binpack.PackStrArr(&buf, sdiff.DiffHashArr)
|
||||
@ -139,7 +199,7 @@ func (sdiff *ShellStateDiff) DecodeShellStateDiff(barr []byte) error {
|
||||
sdiff.HashVal = sha1Hash(barr)
|
||||
buf := bytes.NewBuffer(barr)
|
||||
u := binpack.MakeUnpacker(buf)
|
||||
version := u.UnpackInt("ShellState pack version")
|
||||
version := u.UnpackUInt("ShellState pack version")
|
||||
if version != ShellStateDiffPackVersion {
|
||||
return fmt.Errorf("invalid ShellStateDiff pack version: %d", version)
|
||||
}
|
||||
|
58
waveshell/pkg/packet/shellstate_test.go
Normal file
58
waveshell/pkg/packet/shellstate_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package packet
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShellVersions(t *testing.T) {
|
||||
if !StateVersionsCompatible("bash v5.0.17", "bash v5.0.17") {
|
||||
t.Errorf("versions should be compatible")
|
||||
}
|
||||
if !StateVersionsCompatible("bash v5.0.17", "bash v5.0.18") {
|
||||
t.Errorf("versions should be compatible")
|
||||
}
|
||||
if !StateVersionsCompatible("bash v5.0.17", "bash v5.1.0") {
|
||||
t.Errorf("versions should be compatible")
|
||||
}
|
||||
if StateVersionsCompatible("bash v5.0.17", "bash v6.0.0") {
|
||||
t.Errorf("versions should not be compatible")
|
||||
}
|
||||
if StateVersionsCompatible("bash v5.0.17", "zsh v5.0.17") {
|
||||
t.Errorf("versions should not be compatible")
|
||||
}
|
||||
|
||||
shell, version, err := ParseShellStateVersion("bash v5.0.17")
|
||||
if err != nil {
|
||||
t.Errorf("version should be valid, got error %v", err)
|
||||
}
|
||||
if shell != ShellType_bash {
|
||||
t.Errorf("shell should be bash")
|
||||
}
|
||||
if version != "v5.0.17" {
|
||||
t.Errorf("version should be v5.0.17")
|
||||
}
|
||||
shell, version, err = ParseShellStateVersion("zsh v5.0.17")
|
||||
if err != nil {
|
||||
t.Errorf("version should be valid, got error %v", err)
|
||||
}
|
||||
if shell != ShellType_zsh {
|
||||
t.Errorf("shell should be zsh")
|
||||
}
|
||||
if version != "v5.0.17" {
|
||||
t.Errorf("version should be v5.0.17")
|
||||
}
|
||||
_, _, err = ParseShellStateVersion("fish v5.0.17")
|
||||
if err == nil {
|
||||
t.Errorf("version should be invalid")
|
||||
}
|
||||
_, _, err = ParseShellStateVersion("bash v5.0.17.1")
|
||||
if err == nil {
|
||||
t.Errorf("version should be invalid")
|
||||
}
|
||||
_, _, err = ParseShellStateVersion("bash")
|
||||
if err == nil {
|
||||
t.Errorf("version should be invalid")
|
||||
}
|
||||
_, _, err = ParseShellStateVersion("bash v5.0.17 extrastuff")
|
||||
if err == nil {
|
||||
t.Errorf("version should be invalid")
|
||||
}
|
||||
}
|
@ -20,7 +20,9 @@ import (
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
const MaxFileDataPacketSize = 16 * 1024
|
||||
@ -28,6 +30,17 @@ const WriteFileContextTimeout = 30 * time.Second
|
||||
const cleanLoopTime = 5 * time.Second
|
||||
const MaxWriteFileContextData = 100
|
||||
|
||||
type shellStateMapKey struct {
|
||||
ShellType string
|
||||
Hash string
|
||||
}
|
||||
|
||||
type ShellStateMap struct {
|
||||
Lock *sync.Mutex
|
||||
StateMap map[shellStateMapKey]*packet.ShellState // shelltype+hash -> state
|
||||
CurrentStateMap map[string]string // shelltype -> hash
|
||||
}
|
||||
|
||||
// TODO create unblockable packet-sender (backed by an array) for clientproc
|
||||
type MServer struct {
|
||||
Lock *sync.Mutex
|
||||
@ -35,9 +48,8 @@ type MServer struct {
|
||||
Sender *packet.PacketSender
|
||||
ClientMap map[base.CommandKey]*shexec.ClientProc
|
||||
Debug bool
|
||||
StateMap map[string]*packet.ShellState // sha1->state
|
||||
CurrentState string // sha1
|
||||
WriteErrorCh chan bool // closed if there is a I/O write error
|
||||
StateMap *ShellStateMap
|
||||
WriteErrorCh chan bool // closed if there is a I/O write error
|
||||
WriteErrorChOnce *sync.Once
|
||||
WriteFileContextMap map[string]*WriteFileContext
|
||||
Done bool
|
||||
@ -146,11 +158,15 @@ func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
|
||||
}
|
||||
|
||||
func runSingleCompGen(cwd string, compType string, prefix string) ([]string, bool, error) {
|
||||
sapi, err := shellapi.MakeShellApi(packet.ShellType_bash)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if !packet.IsValidCompGenType(compType) {
|
||||
return nil, false, fmt.Errorf("invalid compgen type '%s'", compType)
|
||||
}
|
||||
compGenCmdStr := fmt.Sprintf("cd %s; compgen -A %s -- %s | sort | uniq | head -n %d", shellescape.Quote(cwd), shellescape.Quote(compType), shellescape.Quote(prefix), packet.MaxCompGenValues+1)
|
||||
ecmd := exec.Command(shexec.GetLocalBashPath(), "-c", compGenCmdStr)
|
||||
ecmd := exec.Command(sapi.GetLocalShellPath(), "-c", compGenCmdStr)
|
||||
outputBytes, err := ecmd.Output()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("compgen error: %w", err)
|
||||
@ -230,26 +246,19 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) setCurrentState(state *packet.ShellState) {
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
hval, _ := state.EncodeAndHash()
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.StateMap[hval] = state
|
||||
m.CurrentState = hval
|
||||
}
|
||||
|
||||
func (m *MServer) reinit(reqId string) {
|
||||
initPk, err := shexec.MakeServerInitPacket()
|
||||
func (m *MServer) reinit(reqId string, shellType string) {
|
||||
ssPk, err := shexec.MakeShellStatePacket(shellType)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
|
||||
return
|
||||
}
|
||||
m.setCurrentState(initPk.State)
|
||||
initPk.RespId = reqId
|
||||
m.Sender.SendPacket(initPk)
|
||||
err = m.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error setting current state: %w", err))
|
||||
return
|
||||
}
|
||||
ssPk.RespId = reqId
|
||||
m.Sender.SendPacket(ssPk)
|
||||
}
|
||||
|
||||
func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
|
||||
@ -564,8 +573,8 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
|
||||
go m.runCompGen(compPk)
|
||||
return
|
||||
}
|
||||
if _, ok := pk.(*packet.ReInitPacketType); ok {
|
||||
go m.reinit(reqId)
|
||||
if reinitPk, ok := pk.(*packet.ReInitPacketType); ok {
|
||||
go m.reinit(reqId, reinitPk.ShellType)
|
||||
return
|
||||
}
|
||||
if streamPk, ok := pk.(*packet.StreamFilePacketType); ok {
|
||||
@ -581,13 +590,7 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) getCurrentState() (string, *packet.ShellState) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
return m.CurrentState, m.StateMap[m.CurrentState]
|
||||
}
|
||||
|
||||
func (m *MServer) clientPacketCallback(pk packet.PacketType) {
|
||||
func (m *MServer) clientPacketCallback(shellType string, pk packet.PacketType) {
|
||||
if pk.GetType() != packet.CmdDonePacketStr {
|
||||
return
|
||||
}
|
||||
@ -595,16 +598,25 @@ func (m *MServer) clientPacketCallback(pk packet.PacketType) {
|
||||
if donePk.FinalState == nil {
|
||||
return
|
||||
}
|
||||
stateHash, curState := m.getCurrentState()
|
||||
stateHash, curState := m.StateMap.GetCurrentState(shellType)
|
||||
if curState == nil {
|
||||
return
|
||||
}
|
||||
diff, err := shexec.MakeShellStateDiff(*curState, stateHash, *donePk.FinalState)
|
||||
sapi, err := shellapi.MakeShellApi(curState.GetShellType())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
diff, err := sapi.MakeShellStateDiff(curState, stateHash, donePk.FinalState)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
donePk.FinalState = nil
|
||||
donePk.FinalStateDiff = &diff
|
||||
donePk.FinalStateDiff = diff
|
||||
}
|
||||
|
||||
func (m *MServer) isShellInitialized(shellType string) bool {
|
||||
_, curState := m.StateMap.GetCurrentState(shellType)
|
||||
return curState != nil
|
||||
}
|
||||
|
||||
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
@ -612,7 +624,29 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
|
||||
return
|
||||
}
|
||||
ecmd, err := shexec.SSHOpts{}.MakeMShellSingleCmd(true)
|
||||
if runPacket.ShellType == "" {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require shell type"))
|
||||
return
|
||||
}
|
||||
_, curInitState := m.StateMap.GetCurrentState(runPacket.ShellType)
|
||||
if curInitState == nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("shell type %q is not initialized", runPacket.ShellType))
|
||||
return
|
||||
}
|
||||
if runPacket.State == nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require state"))
|
||||
return
|
||||
}
|
||||
_, _, err := packet.ParseShellStateVersion(runPacket.State.Version)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("invalid shellstate version: %w", err))
|
||||
return
|
||||
}
|
||||
if !packet.StateVersionsCompatible(runPacket.State.Version, curInitState.Version) {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("shellstate version %q is not compatible with current shell version %q", runPacket.State.Version, curInitState.Version))
|
||||
return
|
||||
}
|
||||
ecmd, err := shexec.MakeMShellSingleCmd()
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
|
||||
return
|
||||
@ -640,7 +674,9 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
cproc.Close()
|
||||
}()
|
||||
shexec.SendRunPacketAndRunData(context.Background(), cproc.Input, runPacket)
|
||||
cproc.ProxySingleOutput(runPacket.CK, m.Sender, m.clientPacketCallback)
|
||||
cproc.ProxySingleOutput(runPacket.CK, m.Sender, func(pk packet.PacketType) {
|
||||
m.clientPacketCallback(runPacket.ShellType, pk)
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
@ -699,7 +735,7 @@ func RunServer() (int, error) {
|
||||
server := &MServer{
|
||||
Lock: &sync.Mutex{},
|
||||
ClientMap: make(map[base.CommandKey]*shexec.ClientProc),
|
||||
StateMap: make(map[string]*packet.ShellState),
|
||||
StateMap: MakeShellStateMap(),
|
||||
Debug: debug,
|
||||
WriteErrorCh: make(chan bool),
|
||||
WriteErrorChOnce: &sync.Once{},
|
||||
@ -725,7 +761,6 @@ func RunServer() (int, error) {
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
server.setCurrentState(initPacket.State)
|
||||
server.Sender.SendPacket(initPacket)
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
@ -748,3 +783,60 @@ func RunServer() (int, error) {
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func MakeShellStateMap() *ShellStateMap {
|
||||
return &ShellStateMap{
|
||||
Lock: &sync.Mutex{},
|
||||
StateMap: make(map[shellStateMapKey]*packet.ShellState),
|
||||
CurrentStateMap: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ShellStateMap) GetCurrentState(shellType string) (string, *packet.ShellState) {
|
||||
sm.Lock.Lock()
|
||||
defer sm.Lock.Unlock()
|
||||
hval := sm.CurrentStateMap[shellType]
|
||||
return hval, sm.StateMap[shellStateMapKey{ShellType: shellType, Hash: hval}]
|
||||
}
|
||||
|
||||
func (sm *ShellStateMap) SetCurrentState(shellType string, state *packet.ShellState) error {
|
||||
if state == nil {
|
||||
return fmt.Errorf("cannot set nil state")
|
||||
}
|
||||
if shellType != state.GetShellType() {
|
||||
return fmt.Errorf("shell type mismatch: %s != %s", shellType, state.GetShellType())
|
||||
}
|
||||
sm.Lock.Lock()
|
||||
defer sm.Lock.Unlock()
|
||||
hval, _ := state.EncodeAndHash()
|
||||
key := shellStateMapKey{ShellType: shellType, Hash: hval}
|
||||
sm.StateMap[key] = state
|
||||
sm.CurrentStateMap[shellType] = hval
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShellStateMap) GetStateByHash(shellType string, hash string) *packet.ShellState {
|
||||
sm.Lock.Lock()
|
||||
defer sm.Lock.Unlock()
|
||||
return sm.StateMap[shellStateMapKey{ShellType: shellType, Hash: hash}]
|
||||
}
|
||||
|
||||
func (sm *ShellStateMap) Clear() {
|
||||
sm.Lock.Lock()
|
||||
defer sm.Lock.Unlock()
|
||||
sm.StateMap = make(map[shellStateMapKey]*packet.ShellState)
|
||||
sm.CurrentStateMap = make(map[string]string)
|
||||
}
|
||||
|
||||
func (sm *ShellStateMap) GetShells() []string {
|
||||
sm.Lock.Lock()
|
||||
defer sm.Lock.Unlock()
|
||||
return utilfn.GetMapKeys(sm.CurrentStateMap)
|
||||
}
|
||||
|
||||
func (sm *ShellStateMap) HasShell(shellType string) bool {
|
||||
sm.Lock.Lock()
|
||||
defer sm.Lock.Unlock()
|
||||
_, found := sm.CurrentStateMap[shellType]
|
||||
return found
|
||||
}
|
||||
|
260
waveshell/pkg/shellapi/bashapi.go
Normal file
260
waveshell/pkg/shellapi/bashapi.go
Normal file
@ -0,0 +1,260 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package shellapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
|
||||
)
|
||||
|
||||
const BaseBashOpts = `set +m; set +H; shopt -s extglob`
|
||||
|
||||
const BashShellVersionCmdStr = `echo bash v${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}`
|
||||
const RemoteBashPath = "bash"
|
||||
|
||||
// TODO fix bash path in these constants
|
||||
const RunBashSudoCommandFmt = `sudo -n -C %d bash /dev/fd/%d`
|
||||
const RunBashSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d bash -c "echo '[from-mshell]'; exec %d>&-; bash /dev/fd/%d < /dev/fd/%d"`
|
||||
|
||||
// do not use these directly, call GetLocalMajorVersion()
|
||||
var localBashMajorVersionOnce = &sync.Once{}
|
||||
var localBashMajorVersion = ""
|
||||
|
||||
// the "exec 2>" line also adds an extra printf at the *beginning* to strip out spurious rc file output
|
||||
var GetBashShellStateCmds = []string{
|
||||
"exec 2> /dev/null;",
|
||||
BashShellVersionCmdStr + ";",
|
||||
`pwd;`,
|
||||
`declare -p $(compgen -A variable);`,
|
||||
`alias -p;`,
|
||||
`declare -f;`,
|
||||
GetGitBranchCmdStr + ";",
|
||||
}
|
||||
|
||||
type bashShellApi struct{}
|
||||
|
||||
func (b bashShellApi) GetShellType() string {
|
||||
return packet.ShellType_bash
|
||||
}
|
||||
|
||||
func (b bashShellApi) MakeExitTrap(fdNum int) string {
|
||||
return MakeBashExitTrap(fdNum)
|
||||
}
|
||||
|
||||
func (b bashShellApi) GetLocalMajorVersion() string {
|
||||
return GetLocalBashMajorVersion()
|
||||
}
|
||||
|
||||
func (b bashShellApi) GetLocalShellPath() string {
|
||||
return GetLocalBashPath()
|
||||
}
|
||||
|
||||
func (b bashShellApi) GetRemoteShellPath() string {
|
||||
return RemoteBashPath
|
||||
}
|
||||
|
||||
func (b bashShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
|
||||
if !opts.Sudo {
|
||||
return fmt.Sprintf(RunCommandFmt, cmdStr)
|
||||
}
|
||||
if opts.SudoWithPass {
|
||||
return fmt.Sprintf(RunBashSudoPasswordCommandFmt, opts.PwFdNum, opts.MaxFdNum+1, opts.PwFdNum, opts.CommandFdNum, opts.CommandStdinFdNum)
|
||||
} else {
|
||||
return fmt.Sprintf(RunBashSudoCommandFmt, opts.MaxFdNum+1, opts.CommandFdNum)
|
||||
}
|
||||
}
|
||||
|
||||
func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
|
||||
return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
|
||||
}
|
||||
|
||||
func (b bashShellApi) GetShellState() (*packet.ShellState, error) {
|
||||
return GetBashShellState()
|
||||
}
|
||||
|
||||
func (b bashShellApi) GetBaseShellOpts() string {
|
||||
return BaseBashOpts
|
||||
}
|
||||
|
||||
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, error) {
|
||||
return parseBashShellStateOutput(output)
|
||||
}
|
||||
|
||||
func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
|
||||
var rcBuf bytes.Buffer
|
||||
rcBuf.WriteString(b.GetBaseShellOpts() + "\n")
|
||||
varDecls := shellenv.VarDeclsFromState(pk.State)
|
||||
for _, varDecl := range varDecls {
|
||||
if varDecl.IsExport() || varDecl.IsReadOnly() {
|
||||
continue
|
||||
}
|
||||
rcBuf.WriteString(BashDeclareStmt(varDecl))
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
if pk.State != nil && pk.State.Funcs != "" {
|
||||
rcBuf.WriteString(pk.State.Funcs)
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
if pk.State != nil && pk.State.Aliases != "" {
|
||||
rcBuf.WriteString(pk.State.Aliases)
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
return rcBuf.String()
|
||||
}
|
||||
|
||||
func GetBashShellStateCmd() string {
|
||||
return strings.Join(GetBashShellStateCmds, ` printf "\x00\x00";`)
|
||||
}
|
||||
|
||||
func execGetLocalBashShellVersion() string {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
ecmd := exec.CommandContext(ctx, "bash", "-c", BashShellVersionCmdStr)
|
||||
out, err := ecmd.Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
versionStr := strings.TrimSpace(string(out))
|
||||
if strings.Index(versionStr, "bash ") == -1 {
|
||||
// invalid shell version (only bash is supported)
|
||||
return ""
|
||||
}
|
||||
return versionStr
|
||||
}
|
||||
|
||||
func GetLocalBashMajorVersion() string {
|
||||
localBashMajorVersionOnce.Do(func() {
|
||||
fullVersion := execGetLocalBashShellVersion()
|
||||
localBashMajorVersion = packet.GetMajorVersion(fullVersion)
|
||||
})
|
||||
return localBashMajorVersion
|
||||
}
|
||||
|
||||
func GetBashShellState() (*packet.ShellState, error) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
cmdStr := BaseBashOpts + "; " + GetBashShellStateCmd()
|
||||
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
|
||||
outputBytes, err := RunSimpleCmdInPty(ecmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseBashShellStateOutput(outputBytes)
|
||||
}
|
||||
|
||||
func GetLocalBashPath() string {
|
||||
if runtime.GOOS == "darwin" {
|
||||
macShell := GetMacUserShell()
|
||||
if strings.Index(macShell, "bash") != -1 {
|
||||
return shellescape.Quote(macShell)
|
||||
}
|
||||
}
|
||||
return "bash"
|
||||
}
|
||||
|
||||
func GetLocalZshPath() string {
|
||||
if runtime.GOOS == "darwin" {
|
||||
macShell := GetMacUserShell()
|
||||
if strings.Index(macShell, "zsh") != -1 {
|
||||
return shellescape.Quote(macShell)
|
||||
}
|
||||
}
|
||||
return "zsh"
|
||||
}
|
||||
|
||||
func GetBashShellStateRedirectCommandStr(outputFdNum int) string {
|
||||
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetBashShellStateCmd(), outputFdNum)
|
||||
}
|
||||
|
||||
func MakeBashExitTrap(fdNum int) string {
|
||||
stateCmd := GetBashShellStateRedirectCommandStr(fdNum)
|
||||
fmtStr := `
|
||||
_waveshell_exittrap () {
|
||||
%s
|
||||
}
|
||||
trap _waveshell_exittrap EXIT
|
||||
`
|
||||
return fmt.Sprintf(fmtStr, stateCmd)
|
||||
}
|
||||
|
||||
func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
|
||||
if usePty {
|
||||
return exec.Command(GetLocalBashPath(), "--rcfile", rcFileName, "-i", "-c", cmdStr)
|
||||
} else {
|
||||
return exec.Command(GetLocalBashPath(), "--rcfile", rcFileName, "-c", cmdStr)
|
||||
}
|
||||
}
|
||||
|
||||
func (bashShellApi) MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error) {
|
||||
if oldState == nil {
|
||||
return nil, fmt.Errorf("cannot diff, oldState is nil")
|
||||
}
|
||||
if newState == nil {
|
||||
return nil, fmt.Errorf("cannot diff, newState is nil")
|
||||
}
|
||||
if !packet.StateVersionsCompatible(oldState.Version, newState.Version) {
|
||||
return nil, fmt.Errorf("cannot diff, incompatible shell versions: %q %q", oldState.Version, newState.Version)
|
||||
}
|
||||
rtn := &packet.ShellStateDiff{}
|
||||
rtn.BaseHash = oldStateHash
|
||||
rtn.Version = newState.Version // always set version in the diff
|
||||
if oldState.Cwd != newState.Cwd {
|
||||
rtn.Cwd = newState.Cwd
|
||||
}
|
||||
rtn.Error = newState.Error
|
||||
oldVars := shellenv.ShellStateVarsToMap(oldState.ShellVars)
|
||||
newVars := shellenv.ShellStateVarsToMap(newState.ShellVars)
|
||||
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
|
||||
rtn.AliasesDiff = statediff.MakeLineDiff(oldState.Aliases, newState.Aliases, oldState.GetLineDiffSplitString())
|
||||
rtn.FuncsDiff = statediff.MakeLineDiff(oldState.Funcs, newState.Funcs, oldState.GetLineDiffSplitString())
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (bashShellApi) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) {
|
||||
if oldState == nil {
|
||||
return nil, fmt.Errorf("cannot apply diff, oldState is nil")
|
||||
}
|
||||
if diff == nil {
|
||||
return oldState, nil
|
||||
}
|
||||
rtnState := &packet.ShellState{}
|
||||
var err error
|
||||
rtnState.Version = oldState.Version
|
||||
// work around a bug (before v0.6.0) where version could be invalid.
|
||||
// so only overwrite the oldversion if diff version is valid
|
||||
_, _, diffVersionErr := packet.ParseShellStateVersion(diff.Version)
|
||||
if diffVersionErr == nil {
|
||||
rtnState.Version = diff.Version
|
||||
}
|
||||
rtnState.Cwd = oldState.Cwd
|
||||
if diff.Cwd != "" {
|
||||
rtnState.Cwd = diff.Cwd
|
||||
}
|
||||
rtnState.Error = diff.Error
|
||||
oldVars := shellenv.ShellStateVarsToMap(oldState.ShellVars)
|
||||
newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("applying mapdiff 'vars': %v", err)
|
||||
}
|
||||
rtnState.ShellVars = shellenv.StrMapToShellStateVars(newVars)
|
||||
rtnState.Aliases, err = statediff.ApplyLineDiff(oldState.Aliases, diff.AliasesDiff)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("applying diff 'aliases': %v", err)
|
||||
}
|
||||
rtnState.Funcs, err = statediff.ApplyLineDiff(oldState.Funcs, diff.FuncsDiff)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("applying diff 'funcs': %v", err)
|
||||
}
|
||||
return rtnState, nil
|
||||
}
|
328
waveshell/pkg/shellapi/bashparser.go
Normal file
328
waveshell/pkg/shellapi/bashparser.go
Normal file
@ -0,0 +1,328 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package shellapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
type DeclareDeclType = shellenv.DeclareDeclType
|
||||
|
||||
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func doProcSubst(w *syntax.ProcSubst) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type bashParseEnviron struct {
|
||||
Env map[string]string
|
||||
}
|
||||
|
||||
func (e *bashParseEnviron) Get(name string) expand.Variable {
|
||||
val, ok := e.Env[name]
|
||||
if !ok {
|
||||
return expand.Variable{}
|
||||
}
|
||||
return expand.Variable{
|
||||
Exported: true,
|
||||
Kind: expand.String,
|
||||
Str: val,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *bashParseEnviron) Each(fn func(name string, vr expand.Variable) bool) {
|
||||
for key := range e.Env {
|
||||
rtn := fn(key, e.Get(key))
|
||||
if !rtn {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetParserConfig(envMap map[string]string) *expand.Config {
|
||||
cfg := &expand.Config{
|
||||
Env: &bashParseEnviron{Env: envMap},
|
||||
GlobStar: false,
|
||||
NullGlob: false,
|
||||
NoUnset: false,
|
||||
CmdSubst: func(w io.Writer, word *syntax.CmdSubst) error { return doCmdSubst("", w, word) },
|
||||
ProcSubst: doProcSubst,
|
||||
ReadDir: nil,
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// https://wiki.bash-hackers.org/syntax/shellvars
|
||||
var BashNoStoreVarNames = map[string]bool{
|
||||
"BASH": true,
|
||||
"BASHOPTS": true,
|
||||
"BASHPID": true,
|
||||
"BASH_ALIASES": true,
|
||||
"BASH_ARGC": true,
|
||||
"BASH_ARGV": true,
|
||||
"BASH_ARGV0": true,
|
||||
"BASH_CMDS": true,
|
||||
"BASH_COMMAND": true,
|
||||
"BASH_EXECUTION_STRING": true,
|
||||
"LINENO": true,
|
||||
"BASH_LINENO": true,
|
||||
"BASH_REMATCH": true,
|
||||
"BASH_SOURCE": true,
|
||||
"BASH_SUBSHELL": true,
|
||||
"COPROC": true,
|
||||
"DIRSTACK": true,
|
||||
"EPOCHREALTIME": true,
|
||||
"EPOCHSECONDS": true,
|
||||
"FUNCNAME": true,
|
||||
"HISTCMD": true,
|
||||
"OLDPWD": true,
|
||||
"PIPESTATUS": true,
|
||||
"PPID": true,
|
||||
"PWD": true,
|
||||
"RANDOM": true,
|
||||
"SECONDS": true,
|
||||
"SHLVL": true,
|
||||
"HISTFILE": true,
|
||||
"HISTFILESIZE": true,
|
||||
"HISTCONTROL": true,
|
||||
"HISTIGNORE": true,
|
||||
"HISTSIZE": true,
|
||||
"HISTTIMEFORMAT": true,
|
||||
"SRANDOM": true,
|
||||
"COLUMNS": true,
|
||||
"LINES": true,
|
||||
|
||||
// we want these in our remote state object
|
||||
// "EUID": true,
|
||||
// "SHELLOPTS": true,
|
||||
// "UID": true,
|
||||
// "BASH_VERSINFO": true,
|
||||
// "BASH_VERSION": true,
|
||||
}
|
||||
|
||||
var declareDeclArgsRe = regexp.MustCompile("^[aAxrifx]*$")
|
||||
var bashValidIdentifierRe = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
|
||||
func bashValidate(d *DeclareDeclType) error {
|
||||
if len(d.Name) == 0 || !isValidBashIdentifier(d.Name) {
|
||||
return fmt.Errorf("invalid shell variable name (invalid bash identifier)")
|
||||
}
|
||||
if strings.Index(d.Value, "\x00") >= 0 {
|
||||
return fmt.Errorf("invalid shell variable value (cannot contain 0 byte)")
|
||||
}
|
||||
if !declareDeclArgsRe.MatchString(d.Args) {
|
||||
return fmt.Errorf("invalid shell variable type %s", shellescape.Quote(d.Args))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isValidBashIdentifier(s string) bool {
|
||||
return bashValidIdentifierRe.MatchString(s)
|
||||
}
|
||||
|
||||
func bashParseDeclareStmt(stmt *syntax.Stmt, src string) (*DeclareDeclType, error) {
|
||||
cmd := stmt.Cmd
|
||||
decl, ok := cmd.(*syntax.DeclClause)
|
||||
if !ok || decl.Variant.Value != "declare" || len(decl.Args) != 2 {
|
||||
return nil, fmt.Errorf("invalid declare variant")
|
||||
}
|
||||
rtn := &DeclareDeclType{}
|
||||
declArgs := decl.Args[0]
|
||||
if !declArgs.Naked || len(declArgs.Value.Parts) != 1 {
|
||||
return nil, fmt.Errorf("wrong number of declare args parts")
|
||||
}
|
||||
declArgsLit, ok := declArgs.Value.Parts[0].(*syntax.Lit)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("declare args is not a literal")
|
||||
}
|
||||
if !strings.HasPrefix(declArgsLit.Value, "-") {
|
||||
return nil, fmt.Errorf("declare args not an argument (does not start with '-')")
|
||||
}
|
||||
if declArgsLit.Value == "--" {
|
||||
rtn.Args = ""
|
||||
} else {
|
||||
rtn.Args = declArgsLit.Value[1:]
|
||||
}
|
||||
declAssign := decl.Args[1]
|
||||
if declAssign.Name == nil {
|
||||
return nil, fmt.Errorf("declare does not have a valid name")
|
||||
}
|
||||
rtn.Name = declAssign.Name.Value
|
||||
if declAssign.Naked || declAssign.Index != nil || declAssign.Append {
|
||||
return nil, fmt.Errorf("invalid decl format")
|
||||
}
|
||||
if declAssign.Value != nil {
|
||||
rtn.Value = string(src[declAssign.Value.Pos().Offset():declAssign.Value.End().Offset()])
|
||||
} else if declAssign.Array != nil {
|
||||
rtn.Value = string(src[declAssign.Array.Pos().Offset():declAssign.Array.End().Offset()])
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid decl, not plain value or array")
|
||||
}
|
||||
err := bashNormalize(rtn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = bashValidate(rtn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarBytes []byte) error {
|
||||
declareStr := string(declareBytes)
|
||||
r := bytes.NewReader(declareBytes)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(r, "aliases")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var firstParseErr error
|
||||
declMap := make(map[string]*DeclareDeclType)
|
||||
for _, stmt := range file.Stmts {
|
||||
decl, err := bashParseDeclareStmt(stmt, declareStr)
|
||||
if err != nil {
|
||||
if firstParseErr == nil {
|
||||
firstParseErr = err
|
||||
}
|
||||
}
|
||||
if decl != nil && !BashNoStoreVarNames[decl.Name] {
|
||||
declMap[decl.Name] = decl
|
||||
}
|
||||
}
|
||||
pvarMap := parsePVarOutput(pvarBytes, false)
|
||||
utilfn.CombineMaps(declMap, pvarMap)
|
||||
state.ShellVars = shellenv.SerializeDeclMap(declMap) // this writes out the decls in a canonical order
|
||||
if firstParseErr != nil {
|
||||
state.Error = firstParseErr.Error()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
|
||||
if scbase.IsDevMode() && DebugState {
|
||||
writeStateToFile(packet.ShellType_bash, outputBytes)
|
||||
}
|
||||
// 7 fields: ignored [0], version [1], cwd [2], env/vars [3], aliases [4], funcs [5], pvars [6]
|
||||
fields := bytes.Split(outputBytes, []byte{0, 0})
|
||||
if len(fields) != 7 {
|
||||
return nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(fields))
|
||||
}
|
||||
rtn := &packet.ShellState{}
|
||||
rtn.Version = strings.TrimSpace(string(fields[1]))
|
||||
if rtn.GetShellType() != packet.ShellType_bash {
|
||||
return nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
|
||||
}
|
||||
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
|
||||
return nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
|
||||
}
|
||||
cwdStr := string(fields[2])
|
||||
if strings.HasSuffix(cwdStr, "\r\n") {
|
||||
cwdStr = cwdStr[0 : len(cwdStr)-2]
|
||||
} else if strings.HasSuffix(cwdStr, "\n") {
|
||||
cwdStr = cwdStr[0 : len(cwdStr)-1]
|
||||
}
|
||||
rtn.Cwd = string(cwdStr)
|
||||
err := bashParseDeclareOutput(rtn, fields[3], fields[6])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn.Aliases = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
|
||||
rtn.Funcs = strings.ReplaceAll(string(fields[5]), "\r\n", "\n")
|
||||
rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap")
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func bashNormalize(d *DeclareDeclType) error {
|
||||
if d.DataType() == shellenv.DeclTypeAssocArray {
|
||||
return bashNormalizeAssocArrayDecl(d)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizes order of assoc array keys so value is stable
|
||||
func bashNormalizeAssocArrayDecl(d *DeclareDeclType) error {
|
||||
if d.DataType() != shellenv.DeclTypeAssocArray {
|
||||
return fmt.Errorf("invalid decltype passed to assocArrayDeclToStr: %s", d.DataType())
|
||||
}
|
||||
varMap, err := bashAssocArrayVarToMap(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keys := make([]string, 0, len(varMap))
|
||||
for key := range varMap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('(')
|
||||
for _, key := range keys {
|
||||
buf.WriteByte('[')
|
||||
buf.WriteString(key)
|
||||
buf.WriteByte(']')
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(varMap[key])
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
d.Value = buf.String()
|
||||
return nil
|
||||
}
|
||||
|
||||
func bashAssocArrayVarToMap(d *DeclareDeclType) (map[string]string, error) {
|
||||
if d.DataType() != shellenv.DeclTypeAssocArray {
|
||||
return nil, fmt.Errorf("decl is not an assoc-array")
|
||||
}
|
||||
refStr := "X=" + d.Value
|
||||
r := strings.NewReader(refStr)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(r, "assocdecl")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(file.Stmts) != 1 {
|
||||
return nil, fmt.Errorf("invalid assoc-array parse (multiple stmts)")
|
||||
}
|
||||
stmt := file.Stmts[0]
|
||||
callExpr, ok := stmt.Cmd.(*syntax.CallExpr)
|
||||
if !ok || len(callExpr.Args) != 0 || len(callExpr.Assigns) != 1 {
|
||||
return nil, fmt.Errorf("invalid assoc-array parse (bad expr)")
|
||||
}
|
||||
assign := callExpr.Assigns[0]
|
||||
arrayExpr := assign.Array
|
||||
if arrayExpr == nil {
|
||||
return nil, fmt.Errorf("invalid assoc-array parse (no array expr)")
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
for _, elem := range arrayExpr.Elems {
|
||||
indexStr := refStr[elem.Index.Pos().Offset():elem.Index.End().Offset()]
|
||||
valStr := refStr[elem.Value.Pos().Offset():elem.Value.End().Offset()]
|
||||
rtn[indexStr] = valStr
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func BashDeclareStmt(d *DeclareDeclType) string {
|
||||
var argsStr string
|
||||
if d.Args == "" {
|
||||
argsStr = "--"
|
||||
} else {
|
||||
argsStr = "-" + d.Args
|
||||
}
|
||||
return fmt.Sprintf("declare %s %s=%s", argsStr, d.Name, d.Value)
|
||||
}
|
254
waveshell/pkg/shellapi/shellapi.go
Normal file
254
waveshell/pkg/shellapi/shellapi.go
Normal file
@ -0,0 +1,254 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package shellapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/creack/pty"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
|
||||
)
|
||||
|
||||
const GetStateTimeout = 5 * time.Second
|
||||
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
|
||||
const RunCommandFmt = `%s`
|
||||
const DebugState = false
|
||||
|
||||
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
|
||||
|
||||
var cachedMacUserShell string
|
||||
var macUserShellOnce = &sync.Once{}
|
||||
|
||||
const DefaultMacOSShell = "/bin/bash"
|
||||
|
||||
type RunCommandOpts struct {
|
||||
Sudo bool
|
||||
SudoWithPass bool
|
||||
MaxFdNum int // needed for Sudo
|
||||
CommandFdNum int // needed for Sudo
|
||||
PwFdNum int // needed for SudoWithPass
|
||||
CommandStdinFdNum int // needed for SudoWithPass
|
||||
}
|
||||
|
||||
type ShellApi interface {
|
||||
GetShellType() string
|
||||
MakeExitTrap(fdNum int) string
|
||||
GetLocalMajorVersion() string
|
||||
GetLocalShellPath() string
|
||||
GetRemoteShellPath() string
|
||||
MakeRunCommand(cmdStr string, opts RunCommandOpts) string
|
||||
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
|
||||
GetShellState() (*packet.ShellState, error)
|
||||
GetBaseShellOpts() string
|
||||
ParseShellStateOutput(output []byte) (*packet.ShellState, error)
|
||||
MakeRcFileStr(pk *packet.RunPacketType) string
|
||||
MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error)
|
||||
ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error)
|
||||
}
|
||||
|
||||
func DetectLocalShellType() string {
|
||||
shellPath := GetMacUserShell()
|
||||
if shellPath == "" {
|
||||
shellPath = os.Getenv("SHELL")
|
||||
}
|
||||
if shellPath == "" {
|
||||
return packet.ShellType_bash
|
||||
}
|
||||
_, file := filepath.Split(shellPath)
|
||||
if strings.HasPrefix(file, "zsh") {
|
||||
return packet.ShellType_zsh
|
||||
}
|
||||
return packet.ShellType_bash
|
||||
}
|
||||
|
||||
func HasShell(shellType string) bool {
|
||||
if shellType == packet.ShellType_bash {
|
||||
_, err := exec.LookPath("bash")
|
||||
return err != nil
|
||||
}
|
||||
if shellType == packet.ShellType_zsh {
|
||||
_, err := exec.LookPath("zsh")
|
||||
return err != nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func MakeShellApi(shellType string) (ShellApi, error) {
|
||||
if shellType == "" || shellType == packet.ShellType_bash {
|
||||
return &bashShellApi{}, nil
|
||||
}
|
||||
if shellType == packet.ShellType_zsh {
|
||||
return &zshShellApi{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("shell type not supported: %s", shellType)
|
||||
}
|
||||
|
||||
func GetMacUserShell() string {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return ""
|
||||
}
|
||||
macUserShellOnce.Do(func() {
|
||||
cachedMacUserShell = internalMacUserShell()
|
||||
})
|
||||
return cachedMacUserShell
|
||||
}
|
||||
|
||||
// dscl . -read /User/[username] UserShell
|
||||
// defaults to /bin/bash
|
||||
func internalMacUserShell() string {
|
||||
osUser, err := user.Current()
|
||||
if err != nil {
|
||||
return DefaultMacOSShell
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancelFn()
|
||||
userStr := "/Users/" + osUser.Username
|
||||
out, err := exec.CommandContext(ctx, "dscl", ".", "-read", userStr, "UserShell").CombinedOutput()
|
||||
if err != nil {
|
||||
return DefaultMacOSShell
|
||||
}
|
||||
outStr := strings.TrimSpace(string(out))
|
||||
m := userShellRegexp.FindStringSubmatch(outStr)
|
||||
if m == nil {
|
||||
return DefaultMacOSShell
|
||||
}
|
||||
return m[1]
|
||||
}
|
||||
|
||||
const FirstExtraFilesFdNum = 3
|
||||
|
||||
// returns output(stdout+stderr), extraFdOutput, error
|
||||
func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, error) {
|
||||
ecmd.Env = os.Environ()
|
||||
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
|
||||
cmdPty, cmdTty, err := pty.Open()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("opening new pty: %w", err)
|
||||
}
|
||||
defer cmdTty.Close()
|
||||
defer cmdPty.Close()
|
||||
pty.Setsize(cmdPty, &pty.Winsize{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols})
|
||||
ecmd.Stdin = cmdTty
|
||||
ecmd.Stdout = cmdTty
|
||||
ecmd.Stderr = cmdTty
|
||||
ecmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
ecmd.SysProcAttr.Setsid = true
|
||||
ecmd.SysProcAttr.Setctty = true
|
||||
pipeReader, pipeWriter, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not create pipe: %w", err)
|
||||
}
|
||||
defer pipeWriter.Close()
|
||||
defer pipeReader.Close()
|
||||
extraFiles := make([]*os.File, extraFdNum+1)
|
||||
extraFiles[extraFdNum] = pipeWriter
|
||||
ecmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:]
|
||||
defer pipeReader.Close()
|
||||
ecmd.Start()
|
||||
cmdTty.Close()
|
||||
pipeWriter.Close()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var outputWg sync.WaitGroup
|
||||
var outputBuf bytes.Buffer
|
||||
var extraFdOutputBuf bytes.Buffer
|
||||
outputWg.Add(2)
|
||||
go func() {
|
||||
// ignore error (/dev/ptmx has read error when process is done)
|
||||
defer outputWg.Done()
|
||||
io.Copy(&outputBuf, cmdPty)
|
||||
}()
|
||||
go func() {
|
||||
defer outputWg.Done()
|
||||
io.Copy(&extraFdOutputBuf, pipeReader)
|
||||
}()
|
||||
exitErr := ecmd.Wait()
|
||||
if exitErr != nil {
|
||||
return nil, nil, exitErr
|
||||
}
|
||||
outputWg.Wait()
|
||||
return outputBuf.Bytes(), extraFdOutputBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
|
||||
ecmd.Env = os.Environ()
|
||||
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
|
||||
cmdPty, cmdTty, err := pty.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening new pty: %w", err)
|
||||
}
|
||||
pty.Setsize(cmdPty, &pty.Winsize{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols})
|
||||
ecmd.Stdin = cmdTty
|
||||
ecmd.Stdout = cmdTty
|
||||
ecmd.Stderr = cmdTty
|
||||
ecmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
ecmd.SysProcAttr.Setsid = true
|
||||
ecmd.SysProcAttr.Setctty = true
|
||||
err = ecmd.Start()
|
||||
cmdTty.Close()
|
||||
if err != nil {
|
||||
cmdPty.Close()
|
||||
return nil, err
|
||||
}
|
||||
defer cmdPty.Close()
|
||||
ioDone := make(chan bool)
|
||||
var outputBuf bytes.Buffer
|
||||
go func() {
|
||||
// ignore error (/dev/ptmx has read error when process is done)
|
||||
io.Copy(&outputBuf, cmdPty)
|
||||
close(ioDone)
|
||||
}()
|
||||
exitErr := ecmd.Wait()
|
||||
if exitErr != nil {
|
||||
return nil, exitErr
|
||||
}
|
||||
<-ioDone
|
||||
return outputBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
func parsePVarOutput(pvarBytes []byte, isZsh bool) map[string]*DeclareDeclType {
|
||||
declMap := make(map[string]*DeclareDeclType)
|
||||
pvars := bytes.Split(pvarBytes, []byte{0})
|
||||
for _, pvarBA := range pvars {
|
||||
pvarStr := string(pvarBA)
|
||||
pvarFields := strings.SplitN(pvarStr, " ", 2)
|
||||
if len(pvarFields) != 2 {
|
||||
continue
|
||||
}
|
||||
if pvarFields[0] == "" {
|
||||
continue
|
||||
}
|
||||
decl := &DeclareDeclType{IsZshDecl: isZsh, Args: "x"}
|
||||
decl.Name = "PROMPTVAR_" + pvarFields[0]
|
||||
decl.Value = shellescape.Quote(pvarFields[1])
|
||||
declMap[decl.Name] = decl
|
||||
}
|
||||
return declMap
|
||||
}
|
||||
|
||||
// for debugging (not for production use)
|
||||
func writeStateToFile(shellType string, outputBytes []byte) error {
|
||||
msHome := base.GetMShellHomeDir()
|
||||
stateFileName := path.Join(msHome, shellType+"-state.txt")
|
||||
os.WriteFile(stateFileName, outputBytes, 0644)
|
||||
return nil
|
||||
}
|
826
waveshell/pkg/shellapi/zshapi.go
Normal file
826
waveshell/pkg/shellapi/zshapi.go
Normal file
@ -0,0 +1,826 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package shellapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/binpack"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||
)
|
||||
|
||||
const BaseZshOpts = ``
|
||||
|
||||
const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION`
|
||||
const StateOutputFdNum = 20
|
||||
|
||||
// TODO these need updating
|
||||
const RunZshSudoCommandFmt = `sudo -n -C %d zsh /dev/fd/%d`
|
||||
const RunZshSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d zsh -c "echo '[from-mshell]'; exec %d>&-; zsh /dev/fd/%d < /dev/fd/%d"`
|
||||
|
||||
var ZshIgnoreVars = map[string]bool{
|
||||
"_": true,
|
||||
"0": true,
|
||||
"terminfo": true,
|
||||
"RANDOM": true,
|
||||
"COLUMNS": true,
|
||||
"LINES": true,
|
||||
"argv": true,
|
||||
"SECONDS": true,
|
||||
"PWD": true,
|
||||
"HISTCHARS": true,
|
||||
"HISTFILE": true,
|
||||
"HISTSIZE": true,
|
||||
"SAVEHIST": true,
|
||||
"ZSH_EXECUTION_STRING": true,
|
||||
"EPOCHSECONDS": true,
|
||||
"EPOCHREALTIME": true,
|
||||
"SHLVL": true,
|
||||
"TTY": true,
|
||||
"epochtime": true,
|
||||
"langinfo": true,
|
||||
|
||||
"aliases": true,
|
||||
"dis_aliases": true,
|
||||
"saliases": true,
|
||||
"dis_saliases": true,
|
||||
"galiases": true,
|
||||
"dis_galiases": true,
|
||||
"builtins": true,
|
||||
"dis_builtins": true,
|
||||
"modules": true,
|
||||
"history": true,
|
||||
"historywords": true,
|
||||
"jobdirs": true,
|
||||
"jobstates": true,
|
||||
"jobtexts": true,
|
||||
"funcfiletrace": true,
|
||||
"funcsourcetrace": true,
|
||||
"funcstack": true,
|
||||
"functrace": true,
|
||||
"parameters": true,
|
||||
"commands": true,
|
||||
"functions": true,
|
||||
"dis_functions": true,
|
||||
"functions_source": true,
|
||||
"dis_functions_source": true,
|
||||
"_comps": true,
|
||||
"_patcomps": true,
|
||||
"_postpatcomps": true,
|
||||
}
|
||||
|
||||
var ZshUniqueArrayVars = map[string]bool{
|
||||
"path": true,
|
||||
"fpath": true,
|
||||
}
|
||||
|
||||
var ZshSpecialDecls = map[string]bool{
|
||||
"precmd_functions": true,
|
||||
"preexec_functions": true,
|
||||
}
|
||||
|
||||
var ZshUnsetVars = []string{
|
||||
"HISTFILE",
|
||||
"ZSH_EXECUTION_STRING",
|
||||
}
|
||||
|
||||
// do not use these directly, call GetLocalMajorVersion()
|
||||
var localZshMajorVersionOnce = &sync.Once{}
|
||||
var localZshMajorVersion = ""
|
||||
|
||||
// sentinel value for functions that should be autoloaded
|
||||
const ZshFnAutoLoad = "autoload"
|
||||
|
||||
type ZshParamKey struct {
|
||||
// paramtype cannot contain spaces
|
||||
// "aliases", "dis_aliases", "saliases", "dis_saliases", "galiases", "dis_galiases"
|
||||
// "functions", "dis_functions", "functions_source", "dis_functions_source"
|
||||
ParamType string
|
||||
ParamName string
|
||||
}
|
||||
|
||||
func (k ZshParamKey) String() string {
|
||||
return k.ParamType + " " + k.ParamName
|
||||
}
|
||||
|
||||
func ZshParamKeyFromString(s string) (ZshParamKey, error) {
|
||||
parts := strings.SplitN(s, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
return ZshParamKey{}, fmt.Errorf("invalid zsh param key")
|
||||
}
|
||||
return ZshParamKey{ParamType: parts[0], ParamName: parts[1]}, nil
|
||||
}
|
||||
|
||||
type ZshMap = map[ZshParamKey]string
|
||||
|
||||
type zshShellApi struct{}
|
||||
|
||||
func (z zshShellApi) GetShellType() string {
|
||||
return packet.ShellType_zsh
|
||||
}
|
||||
|
||||
func (z zshShellApi) MakeExitTrap(fdNum int) string {
|
||||
return MakeZshExitTrap(fdNum)
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetLocalMajorVersion() string {
|
||||
return GetLocalZshMajorVersion()
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetLocalShellPath() string {
|
||||
return "/bin/zsh"
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetRemoteShellPath() string {
|
||||
return "zsh"
|
||||
}
|
||||
|
||||
func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
|
||||
if !opts.Sudo {
|
||||
return cmdStr
|
||||
}
|
||||
if opts.SudoWithPass {
|
||||
return fmt.Sprintf(RunZshSudoPasswordCommandFmt, opts.PwFdNum, opts.MaxFdNum+1, opts.PwFdNum, opts.CommandFdNum, opts.CommandStdinFdNum)
|
||||
} else {
|
||||
return fmt.Sprintf(RunZshSudoCommandFmt, opts.MaxFdNum+1, opts.CommandFdNum)
|
||||
}
|
||||
}
|
||||
|
||||
func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
|
||||
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetShellState() (*packet.ShellState, error) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
cmdStr := BaseZshOpts + "; " + GetZshShellStateCmd(StateOutputFdNum)
|
||||
ecmd := exec.CommandContext(ctx, GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
||||
_, outputBytes, err := RunCommandWithExtraFd(ecmd, StateOutputFdNum)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn, err := z.ParseShellStateOutput(outputBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetBaseShellOpts() string {
|
||||
return BaseZshOpts
|
||||
}
|
||||
|
||||
func makeZshTypesetStmt(varDecl *shellenv.DeclareDeclType) string {
|
||||
if !varDecl.IsZshDecl {
|
||||
// not sure what to do here?
|
||||
return ""
|
||||
}
|
||||
var argsStr string
|
||||
if varDecl.Args == "" {
|
||||
argsStr = "--"
|
||||
} else {
|
||||
argsStr = "-" + varDecl.Args
|
||||
}
|
||||
if varDecl.IsZshScalarBound() {
|
||||
// varDecl.Value contains the extra "separator" field (if present in the original typeset def)
|
||||
return fmt.Sprintf("typeset %s %s %s=%s", argsStr, varDecl.ZshBoundScalar, varDecl.Name, varDecl.Value)
|
||||
} else {
|
||||
return fmt.Sprintf("typeset %s %s=%s", argsStr, varDecl.Name, varDecl.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
|
||||
var rcBuf bytes.Buffer
|
||||
rcBuf.WriteString(z.GetBaseShellOpts() + "\n")
|
||||
rcBuf.WriteString("unsetopt GLOBAL_RCS\n")
|
||||
rcBuf.WriteString("unset KSH_ARRAYS\n")
|
||||
rcBuf.WriteString("zmodload zsh/parameter\n")
|
||||
varDecls := shellenv.VarDeclsFromState(pk.State)
|
||||
var postDecls []*shellenv.DeclareDeclType
|
||||
for _, varDecl := range varDecls {
|
||||
if ZshIgnoreVars[varDecl.Name] {
|
||||
continue
|
||||
}
|
||||
if ZshUniqueArrayVars[varDecl.Name] && !varDecl.IsUniqueArray() {
|
||||
varDecl.AddFlag("U")
|
||||
}
|
||||
if ZshSpecialDecls[varDecl.Name] {
|
||||
postDecls = append(postDecls, varDecl)
|
||||
continue
|
||||
}
|
||||
stmt := makeZshTypesetStmt(varDecl)
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
rcBuf.WriteString(makeZshTypesetStmt(varDecl))
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
if shellenv.FindVarDecl(varDecls, "ZDOTDIR") == nil {
|
||||
rcBuf.WriteString("unset ZDOTDIR\n")
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
for _, varName := range ZshUnsetVars {
|
||||
rcBuf.WriteString("unset " + shellescape.Quote(varName) + "\n")
|
||||
}
|
||||
|
||||
// aliases
|
||||
aliasMap, err := DecodeZshMap([]byte(pk.State.Aliases))
|
||||
if err != nil {
|
||||
base.Logf("error decoding zsh aliases: %v\n", err)
|
||||
rcBuf.WriteString("# error decoding zsh aliases\n")
|
||||
} else {
|
||||
for aliasKey, aliasValue := range aliasMap {
|
||||
// tricky here, don't quote AliasName (it gets implicit quotes, and quoting doesn't work as expected)
|
||||
aliasStr := fmt.Sprintf("%s[%s]=%s\n", aliasKey.ParamType, aliasKey.ParamName, shellescape.Quote(aliasValue))
|
||||
rcBuf.WriteString(aliasStr)
|
||||
}
|
||||
}
|
||||
|
||||
// functions
|
||||
fnMap, err := DecodeZshMap([]byte(pk.State.Funcs))
|
||||
if err != nil {
|
||||
base.Logf("error decoding zsh functions: %v\n", err)
|
||||
rcBuf.WriteString("# error decoding zsh functions\n")
|
||||
} else {
|
||||
for fnKey, fnValue := range fnMap {
|
||||
if fnValue == ZshFnAutoLoad {
|
||||
rcBuf.WriteString(fmt.Sprintf("autoload %s\n", shellescape.Quote(fnKey.ParamName)))
|
||||
} else {
|
||||
// careful, no whitespace (except newlines)
|
||||
rcBuf.WriteString(fmt.Sprintf("function %s () {\n%s\n}\n", shellescape.Quote(fnKey.ParamName), fnValue))
|
||||
if fnKey.ParamType == "dis_functions" {
|
||||
rcBuf.WriteString(fmt.Sprintf("disable -f %s\n", shellescape.Quote(fnKey.ParamName)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write postdecls
|
||||
for _, varDecl := range postDecls {
|
||||
rcBuf.WriteString(makeZshTypesetStmt(varDecl))
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
return rcBuf.String()
|
||||
}
|
||||
|
||||
func writeZshId(buf *bytes.Buffer, idStr string) {
|
||||
buf.WriteString(shellescape.Quote(idStr))
|
||||
}
|
||||
|
||||
const numRandomBytes = 4
|
||||
|
||||
// returns (cmd-string)
|
||||
func GetZshShellStateCmd(fdNum int) string {
|
||||
var sectionSeparator []byte
|
||||
// adding this extra "\n" helps with debuging and readability of output
|
||||
sectionSeparator = append(sectionSeparator, byte('\n'))
|
||||
for len(sectionSeparator) < numRandomBytes {
|
||||
// any character *except* null (0)
|
||||
rn := rand.Intn(256)
|
||||
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
|
||||
sectionSeparator = append(sectionSeparator, byte(rn))
|
||||
}
|
||||
}
|
||||
sectionSeparator = append(sectionSeparator, 0, 0)
|
||||
// we have to use these crazy separators because zsh allows basically anything in
|
||||
// variable names and values (including nulls).
|
||||
// note that we don't need crazy separators for "env" or "typeset".
|
||||
// environment variables *cannot* contain nulls by definition, and "typeset" already escapes nulls.
|
||||
// the raw aliases and functions though need to be handled more carefully
|
||||
// output redirection is necessary to prevent cooked tty options from screwing up the output (funcs especially)
|
||||
// note we do not need the "extra" separator that bashapi uses because we are reading from OUTPUTFD (which already excludes any spurious stdout/stderr data)
|
||||
cmd := `
|
||||
exec > [%OUTPUTFD%]
|
||||
unsetopt SH_WORD_SPLIT;
|
||||
zmodload zsh/parameter;
|
||||
[%ZSHVERSION%];
|
||||
printf "\x00[%SECTIONSEP%]";
|
||||
pwd;
|
||||
printf "[%SECTIONSEP%]";
|
||||
env -0;
|
||||
printf "[%SECTIONSEP%]";
|
||||
typeset -p +H -m '*';
|
||||
printf "[%SECTIONSEP%]";
|
||||
for var in "${(@k)aliases}"; do
|
||||
printf "aliases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${aliases[$var]}
|
||||
done
|
||||
for var in "${(@k)dis_aliases}"; do
|
||||
printf "dis_aliases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_aliases[$var]}
|
||||
done
|
||||
for var in "${(@k)saliases}"; do
|
||||
printf "saliases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${saliases[$var]}
|
||||
done
|
||||
for var in "${(@k)dis_saliases}"; do
|
||||
printf "dis_saliases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_saliases[$var]}
|
||||
done
|
||||
for var in "${(@k)galiases}"; do
|
||||
printf "galiases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${galiases[$var]}
|
||||
done
|
||||
for var in "${(@k)dis_galiases}"; do
|
||||
printf "dis_galiases %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_galiases[$var]}
|
||||
done
|
||||
printf "[%SECTIONSEP%]";
|
||||
echo $FPATH;
|
||||
printf "[%SECTIONSEP%]";
|
||||
for var in "${(@k)functions}"; do
|
||||
printf "functions %s[%PARTSEP%]%s[%PARTSEP%]" $var ${functions[$var]}
|
||||
done
|
||||
for var in "${(@k)dis_functions}"; do
|
||||
printf "dis_functions %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_functions[$var]}
|
||||
done
|
||||
for var in "${(@k)functions_source}"; do
|
||||
printf "functions_source %s[%PARTSEP%]%s[%PARTSEP%]" $var ${functions_source[$var]}
|
||||
done
|
||||
for var in "${(@k)dis_functions_source}"; do
|
||||
printf "dis_functions_source %s[%PARTSEP%]%s[%PARTSEP%]" $var ${dis_functions_source[$var]}
|
||||
done
|
||||
printf "[%SECTIONSEP%]";
|
||||
[%GITBRANCH%]
|
||||
`
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr)
|
||||
cmd = strings.ReplaceAll(cmd, "[%GITBRANCH%]", GetGitBranchCmdStr)
|
||||
cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1])))
|
||||
cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator)))
|
||||
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
|
||||
return cmd
|
||||
}
|
||||
|
||||
func MakeZshExitTrap(fdNum int) string {
|
||||
stateCmd := GetZshShellStateCmd(fdNum)
|
||||
fmtStr := `
|
||||
zshexit () {
|
||||
%s
|
||||
}
|
||||
`
|
||||
return fmt.Sprintf(fmtStr, stateCmd)
|
||||
}
|
||||
|
||||
func execGetLocalZshShellVersion() string {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
ecmd := exec.CommandContext(ctx, "zsh", "-c", ZshShellVersionCmdStr)
|
||||
out, err := ecmd.Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
versionStr := strings.TrimSpace(string(out))
|
||||
if strings.Index(versionStr, "zsh ") == -1 {
|
||||
return ""
|
||||
}
|
||||
return versionStr
|
||||
}
|
||||
|
||||
func GetLocalZshMajorVersion() string {
|
||||
localZshMajorVersionOnce.Do(func() {
|
||||
fullVersion := execGetLocalZshShellVersion()
|
||||
localZshMajorVersion = packet.GetMajorVersion(fullVersion)
|
||||
})
|
||||
return localZshMajorVersion
|
||||
}
|
||||
|
||||
func EncodeZshMap(m ZshMap) []byte {
|
||||
var buf bytes.Buffer
|
||||
binpack.PackUInt(&buf, uint64(len(m)))
|
||||
orderedKeys := utilfn.GetOrderedStringerMapKeys(m)
|
||||
for _, key := range orderedKeys {
|
||||
value := m[key]
|
||||
binpack.PackValue(&buf, []byte(key.String()))
|
||||
binpack.PackValue(&buf, []byte(value))
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func EncodeZshMapForApply(m map[string][]byte) string {
|
||||
var buf bytes.Buffer
|
||||
binpack.PackUInt(&buf, uint64(len(m)))
|
||||
orderedKeys := utilfn.GetOrderedMapKeys(m)
|
||||
for _, key := range orderedKeys {
|
||||
value := m[key]
|
||||
binpack.PackValue(&buf, []byte(key))
|
||||
binpack.PackValue(&buf, value)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func DecodeZshMapForDiff(barr []byte) (map[string][]byte, error) {
|
||||
rtn := make(map[string][]byte)
|
||||
buf := bytes.NewBuffer(barr)
|
||||
u := binpack.MakeUnpacker(buf)
|
||||
numEntries := u.UnpackUInt("numEntries")
|
||||
for idx := 0; idx < numEntries; idx++ {
|
||||
key := string(u.UnpackValue("key"))
|
||||
value := u.UnpackValue("value")
|
||||
rtn[key] = value
|
||||
}
|
||||
if u.Error() != nil {
|
||||
return nil, u.Error()
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func DecodeZshMap(barr []byte) (ZshMap, error) {
|
||||
rtn := make(ZshMap)
|
||||
buf := bytes.NewBuffer(barr)
|
||||
u := binpack.MakeUnpacker(buf)
|
||||
numEntries := u.UnpackUInt("numEntries")
|
||||
for idx := 0; idx < numEntries; idx++ {
|
||||
key := string(u.UnpackValue("key"))
|
||||
value := string(u.UnpackValue("value"))
|
||||
zshKey, err := ZshParamKeyFromString(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn[zshKey] = value
|
||||
}
|
||||
if u.Error() != nil {
|
||||
return nil, u.Error()
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func parseZshAliasStateOutput(aliasBytes []byte, partSeparator []byte) map[ZshParamKey]string {
|
||||
aliasParts := bytes.Split(aliasBytes, partSeparator)
|
||||
rtn := make(map[ZshParamKey]string)
|
||||
for aliasPartIdx := 0; aliasPartIdx < len(aliasParts)-1; aliasPartIdx += 2 {
|
||||
aliasNameAndType := string(aliasParts[aliasPartIdx])
|
||||
aliasNameAndTypeParts := strings.SplitN(aliasNameAndType, " ", 2)
|
||||
if len(aliasNameAndTypeParts) != 2 {
|
||||
continue
|
||||
}
|
||||
aliasKey := ZshParamKey{ParamType: aliasNameAndTypeParts[0], ParamName: aliasNameAndTypeParts[1]}
|
||||
aliasValue := string(aliasParts[aliasPartIdx+1])
|
||||
rtn[aliasKey] = aliasValue
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func isSourceFileInFpath(fpathArr []string, sourceFile string) bool {
|
||||
for _, fpath := range fpathArr {
|
||||
if fpath == "" || fpath == "." {
|
||||
continue
|
||||
}
|
||||
firstChar := fpath[0]
|
||||
if firstChar != '/' && firstChar != '~' {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(sourceFile, fpath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ParseZshFunctions(fpathArr []string, fnBytes []byte, partSeparator []byte) map[ZshParamKey]string {
|
||||
fnBody := make(map[ZshParamKey]string)
|
||||
fnSource := make(map[string]string)
|
||||
fnParts := bytes.Split(fnBytes, partSeparator)
|
||||
for fnPartIdx := 0; fnPartIdx < len(fnParts)-1; fnPartIdx += 2 {
|
||||
fnTypeAndName := string(fnParts[fnPartIdx])
|
||||
fnValue := string(fnParts[fnPartIdx+1])
|
||||
fnTypeAndNameParts := strings.SplitN(fnTypeAndName, " ", 2)
|
||||
if len(fnTypeAndNameParts) != 2 {
|
||||
continue
|
||||
}
|
||||
fnType := fnTypeAndNameParts[0]
|
||||
fnName := fnTypeAndNameParts[1]
|
||||
if fnName == "zshexit" {
|
||||
continue
|
||||
}
|
||||
if fnType == "functions" || fnType == "dis_functions" {
|
||||
fnBody[ZshParamKey{ParamType: fnType, ParamName: fnName}] = fnValue
|
||||
}
|
||||
if fnType == "functions_source" || fnType == "dis_functions_source" {
|
||||
fnSource[fnName] = fnValue
|
||||
}
|
||||
}
|
||||
// ok, so the trick here is that we want to only include functions that are *not* autoloaded
|
||||
// the ones that are pending autoloading or come from a source file in fpath, can just be set to autoload
|
||||
for fnKey := range fnBody {
|
||||
source := fnSource[fnKey.ParamName]
|
||||
if isSourceFileInFpath(fpathArr, source) {
|
||||
fnBody[fnKey] = ZshFnAutoLoad
|
||||
}
|
||||
}
|
||||
return fnBody
|
||||
}
|
||||
|
||||
func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string {
|
||||
var buf bytes.Buffer
|
||||
for fnKey, fnValue := range fnMap {
|
||||
buf.WriteString(fmt.Sprintf("%s %s %s\x00", fnKey.ParamType, fnKey.ParamName, fnValue))
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
|
||||
if scbase.IsDevMode() && DebugState {
|
||||
writeStateToFile(packet.ShellType_zsh, outputBytes)
|
||||
}
|
||||
firstZeroIdx := bytes.Index(outputBytes, []byte{0})
|
||||
firstDZeroIdx := bytes.Index(outputBytes, []byte{0, 0})
|
||||
if firstZeroIdx == -1 || firstDZeroIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes")
|
||||
}
|
||||
versionStr := string(outputBytes[0:firstZeroIdx])
|
||||
sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2]
|
||||
partSeparator := sectionSeparator[0 : len(sectionSeparator)-1]
|
||||
// 8 fields: version [0], cwd [1], env [2], vars [3], aliases [4], fpath [5], functions [6], pvars [7]
|
||||
fields := bytes.Split(outputBytes, sectionSeparator)
|
||||
if len(fields) != 8 {
|
||||
base.Logf("invalid -- numfields\n")
|
||||
return nil, fmt.Errorf("invalid zsh shell state output, wrong number of fields, fields=%d", len(fields))
|
||||
}
|
||||
rtn := &packet.ShellState{}
|
||||
rtn.Version = strings.TrimSpace(versionStr)
|
||||
if rtn.GetShellType() != packet.ShellType_zsh {
|
||||
return nil, fmt.Errorf("invalid zsh shell state output, wrong shell type")
|
||||
}
|
||||
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
|
||||
return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
|
||||
}
|
||||
cwdStr := stripNewLineChars(string(fields[1]))
|
||||
rtn.Cwd = cwdStr
|
||||
zshEnv := parseZshEnv(fields[2])
|
||||
zshDecls, err := parseZshDecls(fields[3])
|
||||
if err != nil {
|
||||
base.Logf("invalid - parsedecls %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
for _, decl := range zshDecls {
|
||||
if decl.IsZshScalarBound() {
|
||||
decl.ZshEnvValue = zshEnv[decl.ZshBoundScalar]
|
||||
}
|
||||
}
|
||||
aliasMap := parseZshAliasStateOutput(fields[4], partSeparator)
|
||||
rtn.Aliases = string(EncodeZshMap(aliasMap))
|
||||
fpathStr := stripNewLineChars(string(string(fields[5])))
|
||||
fpathArr := strings.Split(fpathStr, ":")
|
||||
zshFuncs := ParseZshFunctions(fpathArr, fields[6], partSeparator)
|
||||
rtn.Funcs = string(EncodeZshMap(zshFuncs))
|
||||
pvarMap := parsePVarOutput(fields[7], true)
|
||||
utilfn.CombineMaps(zshDecls, pvarMap)
|
||||
rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls)
|
||||
base.Logf("parse shellstate done\n")
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func parseZshEnv(output []byte) map[string]string {
|
||||
outputStr := string(output)
|
||||
lines := strings.Split(outputStr, "\x00")
|
||||
rtn := make(map[string]string)
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
eqIdx := strings.Index(line, "=")
|
||||
if eqIdx == -1 {
|
||||
continue
|
||||
}
|
||||
name := line[0:eqIdx]
|
||||
if ZshIgnoreVars[name] {
|
||||
continue
|
||||
}
|
||||
val := line[eqIdx+1:]
|
||||
rtn[name] = val
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func parseZshScalarBoundAssignment(declStr string, decl *DeclareDeclType) error {
|
||||
declStr = strings.TrimLeft(declStr, " ")
|
||||
spaceIdx := strings.Index(declStr, " ")
|
||||
if spaceIdx == -1 {
|
||||
return fmt.Errorf("invalid zsh decl (scalar bound): %q", declStr)
|
||||
}
|
||||
decl.ZshBoundScalar = declStr[0:spaceIdx]
|
||||
standardDecl := declStr[spaceIdx+1:]
|
||||
return parseStandardZshAssignment(standardDecl, decl)
|
||||
}
|
||||
|
||||
func parseStandardZshAssignment(declStr string, decl *DeclareDeclType) error {
|
||||
declStr = strings.TrimLeft(declStr, " ")
|
||||
eqIdx := strings.Index(declStr, "=")
|
||||
if eqIdx == -1 {
|
||||
return fmt.Errorf("invalid zsh decl: %q", declStr)
|
||||
}
|
||||
decl.Name = declStr[0:eqIdx]
|
||||
decl.Value = declStr[eqIdx+1:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseZshDeclAssignment(declStr string, decl *DeclareDeclType) error {
|
||||
if decl.IsZshScalarBound() {
|
||||
return parseZshScalarBoundAssignment(declStr, decl)
|
||||
}
|
||||
return parseStandardZshAssignment(declStr, decl)
|
||||
}
|
||||
|
||||
// returns (newDeclStr, argsStr, err)
|
||||
func parseZshDeclArgs(declStr string, isExport bool) (string, string, error) {
|
||||
origDeclStr := declStr
|
||||
var argsStr string
|
||||
if isExport {
|
||||
argsStr = "x"
|
||||
}
|
||||
declStr = strings.TrimLeft(declStr, " ")
|
||||
for strings.HasPrefix(declStr, "-") {
|
||||
spaceIdx := strings.Index(declStr, " ")
|
||||
if spaceIdx == -1 {
|
||||
return "", "", fmt.Errorf("invalid zsh export line: %q", origDeclStr)
|
||||
}
|
||||
newArgsStr := strings.TrimSpace(declStr[1:spaceIdx])
|
||||
argsStr = argsStr + newArgsStr
|
||||
declStr = declStr[spaceIdx+1:]
|
||||
declStr = strings.TrimLeft(declStr, " ")
|
||||
}
|
||||
return declStr, argsStr, nil
|
||||
}
|
||||
|
||||
func stripNewLineChars(s string) string {
|
||||
for {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
lastChar := s[len(s)-1]
|
||||
if lastChar == '\n' || lastChar == '\r' {
|
||||
s = s[0 : len(s)-1]
|
||||
} else {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseZshDeclLine(line string) (*DeclareDeclType, error) {
|
||||
line = stripNewLineChars(line)
|
||||
if strings.HasPrefix(line, "export ") {
|
||||
exportLine := line[7:]
|
||||
assignLine, exportArgs, err := parseZshDeclArgs(exportLine, true)
|
||||
rtn := &DeclareDeclType{IsZshDecl: true, Args: exportArgs}
|
||||
err = parseZshDeclAssignment(assignLine, rtn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
} else if strings.HasPrefix(line, "typeset ") {
|
||||
typesetLine := line[8:]
|
||||
assignLine, typesetArgs, err := parseZshDeclArgs(typesetLine, false)
|
||||
rtn := &DeclareDeclType{IsZshDecl: true, Args: typesetArgs}
|
||||
err = parseZshDeclAssignment(assignLine, rtn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid zsh decl line: %q", line)
|
||||
}
|
||||
}
|
||||
|
||||
// combine decl2 INTO decl1
|
||||
func combineTiedZshDecls(decl1 *DeclareDeclType, decl2 *DeclareDeclType) {
|
||||
if decl2.IsExport() {
|
||||
decl1.AddFlag("x")
|
||||
}
|
||||
if decl2.IsArray() {
|
||||
decl1.AddFlag("a")
|
||||
}
|
||||
}
|
||||
|
||||
func parseZshDecls(output []byte) (map[string]*DeclareDeclType, error) {
|
||||
// NOTES:
|
||||
// - we get extra \r characters in the output (trimmed in parseZshDeclLine) (we get \r\n)
|
||||
// - tied variables (-T) are printed twice! this is especially confusing for exported vars:
|
||||
// (1) `export -T PATH path=( ... )`
|
||||
// (2) `typeset -aT PATH path=( ... )`
|
||||
// we have to "combine" these two lines into one decl.
|
||||
outputStr := string(output)
|
||||
lines := strings.Split(outputStr, "\n")
|
||||
rtn := make(map[string]*DeclareDeclType)
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
decl, err := parseZshDeclLine(line)
|
||||
if err != nil {
|
||||
base.Logf("error parsing zsh decl line: %v", err)
|
||||
continue
|
||||
}
|
||||
if decl == nil {
|
||||
continue
|
||||
}
|
||||
if ZshIgnoreVars[decl.Name] {
|
||||
continue
|
||||
}
|
||||
if rtn[decl.Name] != nil && decl.IsZshScalarBound() {
|
||||
combineTiedZshDecls(rtn[decl.Name], decl)
|
||||
continue
|
||||
}
|
||||
rtn[decl.Name] = decl
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func makeZshMapDiff(oldMap string, newMap string) ([]byte, error) {
|
||||
oldMapMap, err := DecodeZshMapForDiff([]byte(oldMap))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error zshMapDiff decoding old-zsh map: %v", err)
|
||||
}
|
||||
newMapMap, err := DecodeZshMapForDiff([]byte(newMap))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error zshMapDiff decoding new-zsh map: %v", err)
|
||||
}
|
||||
return statediff.MakeMapDiff(oldMapMap, newMapMap), nil
|
||||
}
|
||||
|
||||
func applyZshMapDiff(oldMap string, diff []byte) (string, error) {
|
||||
oldMapMap, err := DecodeZshMapForDiff([]byte(oldMap))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error zshMapDiff decoding old-zsh map: %v", err)
|
||||
}
|
||||
newMapMap, err := statediff.ApplyMapDiff(oldMapMap, diff)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error zshMapDiff applying diff: %v", err)
|
||||
}
|
||||
return EncodeZshMapForApply(newMapMap), nil
|
||||
}
|
||||
|
||||
func (zshShellApi) MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error) {
|
||||
if oldState == nil {
|
||||
return nil, fmt.Errorf("cannot diff, oldState is nil")
|
||||
}
|
||||
if newState == nil {
|
||||
return nil, fmt.Errorf("cannot diff, newState is nil")
|
||||
}
|
||||
if oldState.Version != newState.Version {
|
||||
return nil, fmt.Errorf("cannot diff, states have different versions")
|
||||
}
|
||||
rtn := &packet.ShellStateDiff{}
|
||||
rtn.BaseHash = oldStateHash
|
||||
rtn.Version = newState.Version // always set version
|
||||
if oldState.Cwd != newState.Cwd {
|
||||
rtn.Cwd = newState.Cwd
|
||||
}
|
||||
rtn.Error = newState.Error
|
||||
oldVars := shellenv.ShellStateVarsToMap(oldState.ShellVars)
|
||||
newVars := shellenv.ShellStateVarsToMap(newState.ShellVars)
|
||||
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
|
||||
var err error
|
||||
rtn.AliasesDiff, err = makeZshMapDiff(oldState.Aliases, newState.Aliases)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn.FuncsDiff, err = makeZshMapDiff(oldState.Funcs, newState.Funcs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (zshShellApi) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) {
|
||||
if oldState == nil {
|
||||
return nil, fmt.Errorf("cannot apply diff, oldState is nil")
|
||||
}
|
||||
if diff == nil {
|
||||
return oldState, nil
|
||||
}
|
||||
rtnState := &packet.ShellState{}
|
||||
var err error
|
||||
rtnState.Version = oldState.Version
|
||||
if diff.Version != rtnState.Version {
|
||||
rtnState.Version = diff.Version
|
||||
}
|
||||
rtnState.Cwd = oldState.Cwd
|
||||
if diff.Cwd != "" {
|
||||
rtnState.Cwd = diff.Cwd
|
||||
}
|
||||
rtnState.Error = diff.Error
|
||||
oldVars := shellenv.ShellStateVarsToMap(oldState.ShellVars)
|
||||
newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("applying mapdiff 'vars': %v", err)
|
||||
}
|
||||
rtnState.ShellVars = shellenv.StrMapToShellStateVars(newVars)
|
||||
rtnState.Aliases, err = applyZshMapDiff(oldState.Aliases, diff.AliasesDiff)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("applying diff 'aliases': %v", err)
|
||||
}
|
||||
rtnState.Funcs, err = applyZshMapDiff(oldState.Funcs, diff.FuncsDiff)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("applying diff 'funcs': %v", err)
|
||||
}
|
||||
return rtnState, nil
|
||||
}
|
29
waveshell/pkg/shellapi/zshapi_test.go
Normal file
29
waveshell/pkg/shellapi/zshapi_test.go
Normal file
@ -0,0 +1,29 @@
|
||||
package shellapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testSingleDecl(declStr string) {
|
||||
decl, err := parseZshDeclLine(declStr)
|
||||
if err != nil {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
}
|
||||
fmt.Printf("decl %#v\n", decl)
|
||||
}
|
||||
|
||||
func TestParseZshDecl(t *testing.T) {
|
||||
declStr := `export -T PATH path=( /usr/local/bin /usr/bin /bin /usr/sbin /sbin )`
|
||||
testSingleDecl(declStr)
|
||||
declStr = `typeset -i10 SAVEHIST=1000`
|
||||
testSingleDecl(declStr)
|
||||
declStr = `typeset -a signals=( EXIT HUP INT QUIT ILL TRAP ABRT EMT FPE KILL BUS SEGV )`
|
||||
testSingleDecl(declStr)
|
||||
declStr = `typeset -aT RC rc=(80 25) 'x'`
|
||||
testSingleDecl(declStr)
|
||||
declStr = `typeset -g -A foo=( [bar]=baz [quux]=quuux )`
|
||||
testSingleDecl(declStr)
|
||||
declStr = `typeset -x -g -aT FOO foo=( 1 2 3 )`
|
||||
testSingleDecl(declStr)
|
||||
}
|
362
waveshell/pkg/shellenv/shellenv.go
Normal file
362
waveshell/pkg/shellenv/shellenv.go
Normal file
@ -0,0 +1,362 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package shellenv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
const (
|
||||
DeclTypeArray = "array"
|
||||
DeclTypeAssocArray = "assoc"
|
||||
DeclTypeInt = "int"
|
||||
DeclTypeNormal = "normal"
|
||||
)
|
||||
|
||||
type DeclareDeclType struct {
|
||||
IsZshDecl bool
|
||||
|
||||
Args string
|
||||
Name string
|
||||
|
||||
// this holds the raw quoted value suitable for bash. this is *not* the real expanded variable value
|
||||
Value string
|
||||
|
||||
// special fields for zsh "-T" output.
|
||||
// for bound scalars, "Value" hold everything after the "=" (including the separator character)
|
||||
ZshBoundScalar string // the name of the "scalar" env variable
|
||||
ZshEnvValue string // unlike Value this *is* the expanded value of scalar env variable
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsExport() bool {
|
||||
return strings.Index(d.Args, "x") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsReadOnly() bool {
|
||||
return strings.Index(d.Args, "r") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsZshScalarBound() bool {
|
||||
return strings.Index(d.Args, "T") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsArray() bool {
|
||||
return strings.Index(d.Args, "a") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsAssocArray() bool {
|
||||
return strings.Index(d.Args, "A") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsUniqueArray() bool {
|
||||
return d.IsArray() && strings.Index(d.Args, "U") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) AddFlag(flag string) {
|
||||
if strings.Index(d.Args, flag) >= 0 {
|
||||
return
|
||||
}
|
||||
d.Args += flag
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) SortZshFlags() {
|
||||
// x is always first (or g)
|
||||
// T is always last
|
||||
// the 'i' flags are tricky (they shouldn't be sorted, because the order matters, e.g. i10)
|
||||
var hasX, hasG, hasT bool
|
||||
var newArgs []rune
|
||||
for _, r := range d.Args {
|
||||
if r == 'x' {
|
||||
hasX = true
|
||||
continue
|
||||
}
|
||||
if r == 'g' {
|
||||
hasG = true
|
||||
continue
|
||||
}
|
||||
if r == 'T' {
|
||||
hasT = true
|
||||
continue
|
||||
}
|
||||
newArgs = append(newArgs, r)
|
||||
}
|
||||
newArgsStr := string(newArgs)
|
||||
if hasG {
|
||||
newArgsStr = "g" + newArgsStr
|
||||
}
|
||||
if hasX {
|
||||
newArgsStr = "x" + newArgsStr
|
||||
}
|
||||
if hasT {
|
||||
newArgsStr += "T"
|
||||
}
|
||||
d.Args = newArgsStr
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) DataType() string {
|
||||
if strings.Index(d.Args, "a") >= 0 {
|
||||
return DeclTypeArray
|
||||
}
|
||||
if strings.Index(d.Args, "A") >= 0 {
|
||||
return DeclTypeAssocArray
|
||||
}
|
||||
if strings.Index(d.Args, "i") >= 0 {
|
||||
return DeclTypeInt
|
||||
}
|
||||
return DeclTypeNormal
|
||||
}
|
||||
|
||||
func FindVarDecl(decls []*DeclareDeclType, name string) *DeclareDeclType {
|
||||
for _, decl := range decls {
|
||||
if decl.Name == name {
|
||||
return decl
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NOTE Serialize no longer writes the final null byte
|
||||
func (d *DeclareDeclType) Serialize() []byte {
|
||||
if d.IsZshDecl {
|
||||
d.SortZshFlags()
|
||||
parts := []string{
|
||||
"z1",
|
||||
d.Args,
|
||||
d.Name,
|
||||
d.Value,
|
||||
d.ZshBoundScalar,
|
||||
d.ZshEnvValue,
|
||||
}
|
||||
return utilfn.EncodeStringArray(parts)
|
||||
} else {
|
||||
parts := []string{
|
||||
"b1",
|
||||
d.Args,
|
||||
d.Name,
|
||||
d.Value,
|
||||
}
|
||||
return utilfn.EncodeStringArray(parts)
|
||||
}
|
||||
// this is the v0 encoding (keeping here for reference since we still need to decode this)
|
||||
// rtn := fmt.Sprintf("%s|%s=%s\x00", d.Args, d.Name, d.Value)
|
||||
// return []byte(rtn)
|
||||
}
|
||||
|
||||
func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool {
|
||||
if d1.IsExport() != d2.IsExport() {
|
||||
return false
|
||||
}
|
||||
if d1.DataType() != d2.DataType() {
|
||||
return false
|
||||
}
|
||||
if compareName && d1.Name != d2.Name {
|
||||
return false
|
||||
}
|
||||
return d1.Value == d2.Value // this works even for assoc arrays because we normalize them when parsing
|
||||
}
|
||||
|
||||
// envline should be valid
|
||||
func parseDeclLine(envLineBytes []byte) *DeclareDeclType {
|
||||
if utilfn.EncodedStringArrayHasFirstKey(envLineBytes, "z1") {
|
||||
parts, err := utilfn.DecodeStringArray(envLineBytes)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(parts) != 6 {
|
||||
return nil
|
||||
}
|
||||
return &DeclareDeclType{
|
||||
IsZshDecl: true,
|
||||
Args: parts[1],
|
||||
Name: parts[2],
|
||||
Value: parts[3],
|
||||
ZshBoundScalar: parts[4],
|
||||
ZshEnvValue: parts[5],
|
||||
}
|
||||
} else if utilfn.EncodedStringArrayHasFirstKey(envLineBytes, "b1") {
|
||||
parts, err := utilfn.DecodeStringArray(envLineBytes)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(parts) != 4 {
|
||||
return nil
|
||||
}
|
||||
return &DeclareDeclType{
|
||||
Args: parts[1],
|
||||
Name: parts[2],
|
||||
Value: parts[3],
|
||||
}
|
||||
}
|
||||
// legacy decoding (v0)
|
||||
envLine := string(envLineBytes)
|
||||
eqIdx := strings.Index(envLine, "=")
|
||||
if eqIdx == -1 {
|
||||
return nil
|
||||
}
|
||||
namePart := envLine[0:eqIdx]
|
||||
valPart := envLine[eqIdx+1:]
|
||||
pipeIdx := strings.Index(namePart, "|")
|
||||
if pipeIdx == -1 {
|
||||
return nil
|
||||
}
|
||||
return &DeclareDeclType{
|
||||
Args: namePart[0:pipeIdx],
|
||||
Name: namePart[pipeIdx+1:],
|
||||
Value: valPart,
|
||||
}
|
||||
}
|
||||
|
||||
// returns name => full-line
|
||||
func parseDeclLineToKV(envLine []byte) (string, []byte) {
|
||||
decl := parseDeclLine(envLine)
|
||||
if decl == nil {
|
||||
return "", nil
|
||||
}
|
||||
return decl.Name, envLine
|
||||
}
|
||||
|
||||
func ShellStateVarsToMap(shellVars []byte) map[string][]byte {
|
||||
if len(shellVars) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string][]byte)
|
||||
vars := bytes.Split(shellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
name, val := parseDeclLineToKV(varLine)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
rtn[name] = val
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func StrMapToShellStateVars(varMap map[string][]byte) []byte {
|
||||
var buf bytes.Buffer
|
||||
orderedKeys := utilfn.GetOrderedMapKeys(varMap)
|
||||
for _, key := range orderedKeys {
|
||||
val := varMap[key]
|
||||
buf.Write(val)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func DeclMapFromState(state *packet.ShellState) map[string]*DeclareDeclType {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]*DeclareDeclType)
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := parseDeclLine(varLine)
|
||||
if decl != nil {
|
||||
rtn[decl.Name] = decl
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func SerializeDeclMap(declMap map[string]*DeclareDeclType) []byte {
|
||||
var rtn bytes.Buffer
|
||||
orderedKeys := utilfn.GetOrderedMapKeys(declMap)
|
||||
for _, key := range orderedKeys {
|
||||
decl := declMap[key]
|
||||
rtn.Write(decl.Serialize())
|
||||
rtn.WriteByte(0)
|
||||
}
|
||||
return rtn.Bytes()
|
||||
}
|
||||
|
||||
func EnvMapFromState(state *packet.ShellState) map[string]string {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
ectx := simpleexpand.SimpleExpandContext{}
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := parseDeclLine(varLine)
|
||||
if decl != nil && decl.IsExport() {
|
||||
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func ShellVarMapFromState(state *packet.ShellState) map[string]string {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
ectx := simpleexpand.SimpleExpandContext{}
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := parseDeclLine(varLine)
|
||||
if decl != nil {
|
||||
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func DumpVarMapFromState(state *packet.ShellState) {
|
||||
fmt.Printf("DUMP-STATE-VARS:\n")
|
||||
if state == nil {
|
||||
fmt.Printf(" nil\n")
|
||||
return
|
||||
}
|
||||
decls := VarDeclsFromState(state)
|
||||
for _, decl := range decls {
|
||||
fmt.Printf(" %s %#v\n", decl.Name, decl)
|
||||
}
|
||||
envMap := EnvMapFromState(state)
|
||||
fmt.Printf("DUMP-STATE-ENV:\n")
|
||||
for k, v := range envMap {
|
||||
fmt.Printf(" %s=%s\n", k, v)
|
||||
}
|
||||
fmt.Printf("\n\n")
|
||||
}
|
||||
|
||||
func VarDeclsFromState(state *packet.ShellState) []*DeclareDeclType {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
var rtn []*DeclareDeclType
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := parseDeclLine(varLine)
|
||||
if decl != nil {
|
||||
rtn = append(rtn, decl)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func RemoveFunc(funcs string, toRemove string) string {
|
||||
lines := strings.Split(funcs, "\n")
|
||||
var newLines []string
|
||||
removeLine := fmt.Sprintf("%s ()", toRemove)
|
||||
doingRemove := false
|
||||
for _, line := range lines {
|
||||
if line == removeLine {
|
||||
doingRemove = true
|
||||
continue
|
||||
}
|
||||
if doingRemove {
|
||||
if line == "}" {
|
||||
doingRemove = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
return strings.Join(newLines, "\n")
|
||||
}
|
62
waveshell/pkg/shellutil/shellutil.go
Normal file
62
waveshell/pkg/shellutil/shellutil.go
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package shellutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
)
|
||||
|
||||
const DefaultTermType = "xterm-256color"
|
||||
const DefaultTermRows = 24
|
||||
const DefaultTermCols = 80
|
||||
|
||||
func MShellEnvVars(termType string) map[string]string {
|
||||
rtn := make(map[string]string)
|
||||
if termType != "" {
|
||||
rtn["TERM"] = termType
|
||||
}
|
||||
rtn["WAVESHELL"], _ = os.Executable()
|
||||
rtn["WAVESHELL_VERSION"] = base.MShellVersion
|
||||
return rtn
|
||||
}
|
||||
|
||||
func UpdateCmdEnv(cmd *exec.Cmd, envVars map[string]string) {
|
||||
if len(envVars) == 0 {
|
||||
return
|
||||
}
|
||||
found := make(map[string]bool)
|
||||
var newEnv []string
|
||||
for _, envStr := range cmd.Env {
|
||||
envKey := GetEnvStrKey(envStr)
|
||||
newEnvVal, ok := envVars[envKey]
|
||||
if ok {
|
||||
if newEnvVal == "" {
|
||||
continue
|
||||
}
|
||||
newEnv = append(newEnv, envKey+"="+newEnvVal)
|
||||
found[envKey] = true
|
||||
} else {
|
||||
newEnv = append(newEnv, envStr)
|
||||
}
|
||||
}
|
||||
for envKey, envVal := range envVars {
|
||||
if found[envKey] {
|
||||
continue
|
||||
}
|
||||
newEnv = append(newEnv, envKey+"="+envVal)
|
||||
}
|
||||
cmd.Env = newEnv
|
||||
}
|
||||
|
||||
func GetEnvStrKey(envStr string) string {
|
||||
eqIdx := strings.Index(envStr, "=")
|
||||
if eqIdx == -1 {
|
||||
return envStr
|
||||
}
|
||||
return envStr[0:eqIdx]
|
||||
}
|
@ -1,641 +0,0 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package shexec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
const (
|
||||
DeclTypeArray = "array"
|
||||
DeclTypeAssocArray = "assoc"
|
||||
DeclTypeInt = "int"
|
||||
DeclTypeNormal = "normal"
|
||||
)
|
||||
|
||||
type ParseEnviron struct {
|
||||
Env map[string]string
|
||||
}
|
||||
|
||||
func (e *ParseEnviron) Get(name string) expand.Variable {
|
||||
val, ok := e.Env[name]
|
||||
if !ok {
|
||||
return expand.Variable{}
|
||||
}
|
||||
return expand.Variable{
|
||||
Exported: true,
|
||||
Kind: expand.String,
|
||||
Str: val,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ParseEnviron) Each(fn func(name string, vr expand.Variable) bool) {
|
||||
for key := range e.Env {
|
||||
rtn := fn(key, e.Get(key))
|
||||
if !rtn {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func doProcSubst(w *syntax.ProcSubst) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func GetParserConfig(envMap map[string]string) *expand.Config {
|
||||
cfg := &expand.Config{
|
||||
Env: &ParseEnviron{Env: envMap},
|
||||
GlobStar: false,
|
||||
NullGlob: false,
|
||||
NoUnset: false,
|
||||
CmdSubst: func(w io.Writer, word *syntax.CmdSubst) error { return doCmdSubst("", w, word) },
|
||||
ProcSubst: doProcSubst,
|
||||
ReadDir: nil,
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func writeIndent(buf *bytes.Buffer, num int) {
|
||||
for i := 0; i < num; i++ {
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
func makeSpaceStr(num int) string {
|
||||
barr := make([]byte, num)
|
||||
for i := 0; i < num; i++ {
|
||||
barr[i] = ' '
|
||||
}
|
||||
return string(barr)
|
||||
}
|
||||
|
||||
// https://wiki.bash-hackers.org/syntax/shellvars
|
||||
var NoStoreVarNames = map[string]bool{
|
||||
"BASH": true,
|
||||
"BASHOPTS": true,
|
||||
"BASHPID": true,
|
||||
"BASH_ALIASES": true,
|
||||
"BASH_ARGC": true,
|
||||
"BASH_ARGV": true,
|
||||
"BASH_ARGV0": true,
|
||||
"BASH_CMDS": true,
|
||||
"BASH_COMMAND": true,
|
||||
"BASH_EXECUTION_STRING": true,
|
||||
"LINENO": true,
|
||||
"BASH_LINENO": true,
|
||||
"BASH_REMATCH": true,
|
||||
"BASH_SOURCE": true,
|
||||
"BASH_SUBSHELL": true,
|
||||
"COPROC": true,
|
||||
"DIRSTACK": true,
|
||||
"EPOCHREALTIME": true,
|
||||
"EPOCHSECONDS": true,
|
||||
"FUNCNAME": true,
|
||||
"HISTCMD": true,
|
||||
"OLDPWD": true,
|
||||
"PIPESTATUS": true,
|
||||
"PPID": true,
|
||||
"PWD": true,
|
||||
"RANDOM": true,
|
||||
"SECONDS": true,
|
||||
"SHLVL": true,
|
||||
"HISTFILE": true,
|
||||
"HISTFILESIZE": true,
|
||||
"HISTCONTROL": true,
|
||||
"HISTIGNORE": true,
|
||||
"HISTSIZE": true,
|
||||
"HISTTIMEFORMAT": true,
|
||||
"SRANDOM": true,
|
||||
"COLUMNS": true,
|
||||
"LINES": true,
|
||||
|
||||
// we want these in our remote state object
|
||||
// "EUID": true,
|
||||
// "SHELLOPTS": true,
|
||||
// "UID": true,
|
||||
// "BASH_VERSINFO": true,
|
||||
// "BASH_VERSION": true,
|
||||
}
|
||||
|
||||
type DeclareDeclType struct {
|
||||
Args string
|
||||
Name string
|
||||
|
||||
// this holds the raw quoted value suitable for bash. this is *not* the real expanded variable value
|
||||
Value string
|
||||
}
|
||||
|
||||
var declareDeclArgsRe = regexp.MustCompile("^[aAxrifx]*$")
|
||||
var bashValidIdentifierRe = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
|
||||
func (d *DeclareDeclType) Validate() error {
|
||||
if len(d.Name) == 0 || !IsValidBashIdentifier(d.Name) {
|
||||
return fmt.Errorf("invalid shell variable name (invalid bash identifier)")
|
||||
}
|
||||
if strings.Index(d.Value, "\x00") >= 0 {
|
||||
return fmt.Errorf("invalid shell variable value (cannot contain 0 byte)")
|
||||
}
|
||||
if !declareDeclArgsRe.MatchString(d.Args) {
|
||||
return fmt.Errorf("invalid shell variable type %s", shellescape.Quote(d.Args))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) Serialize() string {
|
||||
return fmt.Sprintf("%s|%s=%s\x00", d.Args, d.Name, d.Value)
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) DeclareStmt() string {
|
||||
var argsStr string
|
||||
if d.Args == "" {
|
||||
argsStr = "--"
|
||||
} else {
|
||||
argsStr = "-" + d.Args
|
||||
}
|
||||
return fmt.Sprintf("declare %s %s=%s", argsStr, d.Name, d.Value)
|
||||
}
|
||||
|
||||
// envline should be valid
|
||||
func ParseDeclLine(envLine string) *DeclareDeclType {
|
||||
eqIdx := strings.Index(envLine, "=")
|
||||
if eqIdx == -1 {
|
||||
return nil
|
||||
}
|
||||
namePart := envLine[0:eqIdx]
|
||||
valPart := envLine[eqIdx+1:]
|
||||
pipeIdx := strings.Index(namePart, "|")
|
||||
if pipeIdx == -1 {
|
||||
return nil
|
||||
}
|
||||
return &DeclareDeclType{
|
||||
Args: namePart[0:pipeIdx],
|
||||
Name: namePart[pipeIdx+1:],
|
||||
Value: valPart,
|
||||
}
|
||||
}
|
||||
|
||||
// returns name => full-line
|
||||
func parseDeclLineToKV(envLine string) (string, string) {
|
||||
decl := ParseDeclLine(envLine)
|
||||
if decl == nil {
|
||||
return "", ""
|
||||
}
|
||||
return decl.Name, envLine
|
||||
}
|
||||
|
||||
func shellStateVarsToMap(shellVars []byte) map[string]string {
|
||||
if len(shellVars) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
vars := bytes.Split(shellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
name, val := parseDeclLineToKV(string(varLine))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
rtn[name] = val
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func strMapToShellStateVars(varMap map[string]string) []byte {
|
||||
var buf bytes.Buffer
|
||||
orderedKeys := getOrderedKeysStrMap(varMap)
|
||||
for _, key := range orderedKeys {
|
||||
val := varMap[key]
|
||||
buf.WriteString(val)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func getOrderedKeysStrMap(m map[string]string) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for key := range m {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func getOrderedKeysDeclMap(m map[string]*DeclareDeclType) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for key := range m {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func DeclMapFromState(state *packet.ShellState) map[string]*DeclareDeclType {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]*DeclareDeclType)
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := ParseDeclLine(string(varLine))
|
||||
if decl != nil {
|
||||
rtn[decl.Name] = decl
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func SerializeDeclMap(declMap map[string]*DeclareDeclType) []byte {
|
||||
var rtn bytes.Buffer
|
||||
orderedKeys := getOrderedKeysDeclMap(declMap)
|
||||
for _, key := range orderedKeys {
|
||||
decl := declMap[key]
|
||||
rtn.WriteString(decl.Serialize())
|
||||
}
|
||||
return rtn.Bytes()
|
||||
}
|
||||
|
||||
func EnvMapFromState(state *packet.ShellState) map[string]string {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
ectx := simpleexpand.SimpleExpandContext{}
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := ParseDeclLine(string(varLine))
|
||||
if decl != nil && decl.IsExport() {
|
||||
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func ShellVarMapFromState(state *packet.ShellState) map[string]string {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
ectx := simpleexpand.SimpleExpandContext{}
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := ParseDeclLine(string(varLine))
|
||||
if decl != nil {
|
||||
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func DumpVarMapFromState(state *packet.ShellState) {
|
||||
fmt.Printf("DUMP-STATE-VARS:\n")
|
||||
if state == nil {
|
||||
fmt.Printf(" nil\n")
|
||||
return
|
||||
}
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
fmt.Printf(" %s\n", varLine)
|
||||
}
|
||||
}
|
||||
|
||||
func VarDeclsFromState(state *packet.ShellState) []*DeclareDeclType {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
var rtn []*DeclareDeclType
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := ParseDeclLine(string(varLine))
|
||||
if decl != nil {
|
||||
rtn = append(rtn, decl)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func IsValidBashIdentifier(s string) bool {
|
||||
return bashValidIdentifierRe.MatchString(s)
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsExport() bool {
|
||||
return strings.Index(d.Args, "x") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsReadOnly() bool {
|
||||
return strings.Index(d.Args, "r") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) DataType() string {
|
||||
if strings.Index(d.Args, "a") >= 0 {
|
||||
return DeclTypeArray
|
||||
}
|
||||
if strings.Index(d.Args, "A") >= 0 {
|
||||
return DeclTypeAssocArray
|
||||
}
|
||||
if strings.Index(d.Args, "i") >= 0 {
|
||||
return DeclTypeInt
|
||||
}
|
||||
return DeclTypeNormal
|
||||
}
|
||||
|
||||
func parseDeclareStmt(stmt *syntax.Stmt, src string) (*DeclareDeclType, error) {
|
||||
cmd := stmt.Cmd
|
||||
decl, ok := cmd.(*syntax.DeclClause)
|
||||
if !ok || decl.Variant.Value != "declare" || len(decl.Args) != 2 {
|
||||
return nil, fmt.Errorf("invalid declare variant")
|
||||
}
|
||||
rtn := &DeclareDeclType{}
|
||||
declArgs := decl.Args[0]
|
||||
if !declArgs.Naked || len(declArgs.Value.Parts) != 1 {
|
||||
return nil, fmt.Errorf("wrong number of declare args parts")
|
||||
}
|
||||
declArgsLit, ok := declArgs.Value.Parts[0].(*syntax.Lit)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("declare args is not a literal")
|
||||
}
|
||||
if !strings.HasPrefix(declArgsLit.Value, "-") {
|
||||
return nil, fmt.Errorf("declare args not an argument (does not start with '-')")
|
||||
}
|
||||
if declArgsLit.Value == "--" {
|
||||
rtn.Args = ""
|
||||
} else {
|
||||
rtn.Args = declArgsLit.Value[1:]
|
||||
}
|
||||
declAssign := decl.Args[1]
|
||||
if declAssign.Name == nil {
|
||||
return nil, fmt.Errorf("declare does not have a valid name")
|
||||
}
|
||||
rtn.Name = declAssign.Name.Value
|
||||
if declAssign.Naked || declAssign.Index != nil || declAssign.Append {
|
||||
return nil, fmt.Errorf("invalid decl format")
|
||||
}
|
||||
if declAssign.Value != nil {
|
||||
rtn.Value = string(src[declAssign.Value.Pos().Offset():declAssign.Value.End().Offset()])
|
||||
} else if declAssign.Array != nil {
|
||||
rtn.Value = string(src[declAssign.Array.Pos().Offset():declAssign.Array.End().Offset()])
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid decl, not plain value or array")
|
||||
}
|
||||
err := rtn.normalize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = rtn.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func parseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarBytes []byte) error {
|
||||
declareStr := string(declareBytes)
|
||||
r := bytes.NewReader(declareBytes)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(r, "aliases")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var firstParseErr error
|
||||
declMap := make(map[string]*DeclareDeclType)
|
||||
for _, stmt := range file.Stmts {
|
||||
decl, err := parseDeclareStmt(stmt, declareStr)
|
||||
if err != nil {
|
||||
if firstParseErr == nil {
|
||||
firstParseErr = err
|
||||
}
|
||||
}
|
||||
if decl != nil && !NoStoreVarNames[decl.Name] {
|
||||
declMap[decl.Name] = decl
|
||||
}
|
||||
}
|
||||
pvars := bytes.Split(pvarBytes, []byte{0})
|
||||
for _, pvarBA := range pvars {
|
||||
pvarStr := string(pvarBA)
|
||||
pvarFields := strings.SplitN(pvarStr, " ", 2)
|
||||
if len(pvarFields) != 2 {
|
||||
continue
|
||||
}
|
||||
if pvarFields[0] == "" {
|
||||
continue
|
||||
}
|
||||
decl := &DeclareDeclType{Args: "x"}
|
||||
decl.Name = "PROMPTVAR_" + pvarFields[0]
|
||||
decl.Value = shellescape.Quote(pvarFields[1])
|
||||
declMap[decl.Name] = decl
|
||||
}
|
||||
state.ShellVars = SerializeDeclMap(declMap) // this writes out the decls in a canonical order
|
||||
if firstParseErr != nil {
|
||||
state.Error = firstParseErr.Error()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
|
||||
// 5 fields: version, cwd, env/vars, aliases, funcs
|
||||
fields := bytes.Split(outputBytes, []byte{0, 0})
|
||||
if len(fields) != 6 {
|
||||
return nil, fmt.Errorf("invalid shell state output, wrong number of fields, fields=%d", len(fields))
|
||||
}
|
||||
rtn := &packet.ShellState{}
|
||||
rtn.Version = strings.TrimSpace(string(fields[0]))
|
||||
if strings.Index(rtn.Version, "bash") == -1 {
|
||||
return nil, fmt.Errorf("invalid shell state output, only bash is supported")
|
||||
}
|
||||
cwdStr := string(fields[1])
|
||||
if strings.HasSuffix(cwdStr, "\r\n") {
|
||||
cwdStr = cwdStr[0 : len(cwdStr)-2]
|
||||
} else if strings.HasSuffix(cwdStr, "\n") {
|
||||
cwdStr = cwdStr[0 : len(cwdStr)-1]
|
||||
}
|
||||
rtn.Cwd = string(cwdStr)
|
||||
err := parseDeclareOutput(rtn, fields[2], fields[5])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n")
|
||||
rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
|
||||
rtn.Funcs = removeFunc(rtn.Funcs, "_mshell_exittrap")
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func removeFunc(funcs string, toRemove string) string {
|
||||
lines := strings.Split(funcs, "\n")
|
||||
var newLines []string
|
||||
removeLine := fmt.Sprintf("%s ()", toRemove)
|
||||
doingRemove := false
|
||||
for _, line := range lines {
|
||||
if line == removeLine {
|
||||
doingRemove = true
|
||||
continue
|
||||
}
|
||||
if doingRemove {
|
||||
if line == "}" {
|
||||
doingRemove = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
return strings.Join(newLines, "\n")
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) normalize() error {
|
||||
if d.DataType() == DeclTypeAssocArray {
|
||||
return d.normalizeAssocArrayDecl()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizes order of assoc array keys so value is stable
|
||||
func (d *DeclareDeclType) normalizeAssocArrayDecl() error {
|
||||
if d.DataType() != DeclTypeAssocArray {
|
||||
return fmt.Errorf("invalid decltype passed to assocArrayDeclToStr: %s", d.DataType())
|
||||
}
|
||||
varMap, err := assocArrayVarToMap(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keys := make([]string, 0, len(varMap))
|
||||
for key := range varMap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('(')
|
||||
for _, key := range keys {
|
||||
buf.WriteByte('[')
|
||||
buf.WriteString(key)
|
||||
buf.WriteByte(']')
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(varMap[key])
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
d.Value = buf.String()
|
||||
return nil
|
||||
}
|
||||
|
||||
func assocArrayVarToMap(d *DeclareDeclType) (map[string]string, error) {
|
||||
if d.DataType() != DeclTypeAssocArray {
|
||||
return nil, fmt.Errorf("decl is not an assoc-array")
|
||||
}
|
||||
refStr := "X=" + d.Value
|
||||
r := strings.NewReader(refStr)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(r, "assocdecl")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(file.Stmts) != 1 {
|
||||
return nil, fmt.Errorf("invalid assoc-array parse (multiple stmts)")
|
||||
}
|
||||
stmt := file.Stmts[0]
|
||||
callExpr, ok := stmt.Cmd.(*syntax.CallExpr)
|
||||
if !ok || len(callExpr.Args) != 0 || len(callExpr.Assigns) != 1 {
|
||||
return nil, fmt.Errorf("invalid assoc-array parse (bad expr)")
|
||||
}
|
||||
assign := callExpr.Assigns[0]
|
||||
arrayExpr := assign.Array
|
||||
if arrayExpr == nil {
|
||||
return nil, fmt.Errorf("invalid assoc-array parse (no array expr)")
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
for _, elem := range arrayExpr.Elems {
|
||||
indexStr := refStr[elem.Index.Pos().Offset():elem.Index.End().Offset()]
|
||||
valStr := refStr[elem.Value.Pos().Offset():elem.Value.End().Offset()]
|
||||
rtn[indexStr] = valStr
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func strMapsEqual(m1 map[string]string, m2 map[string]string) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for key, val1 := range m1 {
|
||||
val2, found := m2[key]
|
||||
if !found || val1 != val2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for key := range m2 {
|
||||
_, found := m1[key]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool {
|
||||
if d1.IsExport() != d2.IsExport() {
|
||||
return false
|
||||
}
|
||||
if d1.DataType() != d2.DataType() {
|
||||
return false
|
||||
}
|
||||
if compareName && d1.Name != d2.Name {
|
||||
return false
|
||||
}
|
||||
return d1.Value == d2.Value // this works even for assoc arrays because we normalize them when parsing
|
||||
}
|
||||
|
||||
func MakeShellStateDiff(oldState packet.ShellState, oldStateHash string, newState packet.ShellState) (packet.ShellStateDiff, error) {
|
||||
var rtn packet.ShellStateDiff
|
||||
rtn.BaseHash = oldStateHash
|
||||
if oldState.Version != newState.Version {
|
||||
return rtn, fmt.Errorf("cannot diff, states have different versions")
|
||||
}
|
||||
rtn.Version = newState.Version
|
||||
if oldState.Cwd != newState.Cwd {
|
||||
rtn.Cwd = newState.Cwd
|
||||
}
|
||||
rtn.Error = newState.Error
|
||||
oldVars := shellStateVarsToMap(oldState.ShellVars)
|
||||
newVars := shellStateVarsToMap(newState.ShellVars)
|
||||
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
|
||||
rtn.AliasesDiff = statediff.MakeLineDiff(oldState.Aliases, newState.Aliases)
|
||||
rtn.FuncsDiff = statediff.MakeLineDiff(oldState.Funcs, newState.Funcs)
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func ApplyShellStateDiff(oldState packet.ShellState, diff packet.ShellStateDiff) (packet.ShellState, error) {
|
||||
var rtnState packet.ShellState
|
||||
var err error
|
||||
rtnState.Version = oldState.Version
|
||||
rtnState.Cwd = oldState.Cwd
|
||||
if diff.Cwd != "" {
|
||||
rtnState.Cwd = diff.Cwd
|
||||
}
|
||||
rtnState.Error = diff.Error
|
||||
oldVars := shellStateVarsToMap(oldState.ShellVars)
|
||||
newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff)
|
||||
if err != nil {
|
||||
return rtnState, fmt.Errorf("applying mapdiff 'vars': %v", err)
|
||||
}
|
||||
rtnState.ShellVars = strMapToShellStateVars(newVars)
|
||||
rtnState.Aliases, err = statediff.ApplyLineDiff(oldState.Aliases, diff.AliasesDiff)
|
||||
if err != nil {
|
||||
return rtnState, fmt.Errorf("applying diff 'aliases': %v", err)
|
||||
}
|
||||
rtnState.Funcs, err = statediff.ApplyLineDiff(oldState.Funcs, diff.FuncsDiff)
|
||||
if err != nil {
|
||||
return rtnState, fmt.Errorf("applying diff 'funcs': %v", err)
|
||||
}
|
||||
return rtnState, nil
|
||||
}
|
@ -14,7 +14,6 @@ import (
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"path"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -29,19 +28,19 @@ import (
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/cirfile"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/mpio"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const DefaultTermRows = 24
|
||||
const DefaultTermCols = 80
|
||||
const MinTermRows = 2
|
||||
const MinTermCols = 10
|
||||
const MaxTermRows = 1024
|
||||
const MaxTermCols = 1024
|
||||
const MaxFdNum = 1023
|
||||
const FirstExtraFilesFdNum = 3
|
||||
const DefaultTermType = "xterm-256color"
|
||||
const DefaultMaxPtySize = 1024 * 1024
|
||||
const MinMaxPtySize = 16 * 1024
|
||||
const MaxMaxPtySize = 100 * 1024 * 1024
|
||||
@ -52,27 +51,6 @@ const SigKillWaitTime = 2 * time.Second
|
||||
const RtnStateFdNum = 20
|
||||
const ReturnStateReadWaitTime = 2 * time.Second
|
||||
|
||||
const GetStateTimeout = 5 * time.Second
|
||||
const RemoteBashPath = "bash"
|
||||
const DefaultMacOSShell = "/bin/bash"
|
||||
|
||||
const BaseBashOpts = `set +m; set +H; shopt -s extglob`
|
||||
|
||||
const ShellVersionCmdStr = `echo bash v${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}`
|
||||
|
||||
// do not use these directly, call GetLocalBashMajorVersion()
|
||||
var LocalBashMajorVersionOnce = &sync.Once{}
|
||||
var LocalBashMajorVersion = ""
|
||||
|
||||
var GetShellStateCmds = []string{
|
||||
ShellVersionCmdStr + ";",
|
||||
`pwd;`,
|
||||
`declare -p $(compgen -A variable);`,
|
||||
`alias -p;`,
|
||||
`declare -f;`,
|
||||
`printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`,
|
||||
}
|
||||
|
||||
const ClientCommandFmt = `
|
||||
PATH=$PATH:~/.mshell;
|
||||
which mshell > /dev/null;
|
||||
@ -104,11 +82,6 @@ func MakeInstallCommandStr() string {
|
||||
return strings.ReplaceAll(InstallCommandFmt, "[%VERSION%]", semver.MajorMinor(base.MShellVersion))
|
||||
}
|
||||
|
||||
// TODO fix bash path in these constants
|
||||
const RunCommandFmt = `%s`
|
||||
const RunSudoCommandFmt = `sudo -n -C %d bash /dev/fd/%d`
|
||||
const RunSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d bash -c "echo '[from-mshell]'; exec %d>&-; bash /dev/fd/%d < /dev/fd/%d"`
|
||||
|
||||
type MShellBinaryReaderFn func(version string, goos string, goarch string) (io.ReadCloser, error)
|
||||
|
||||
type ReturnStateBuf struct {
|
||||
@ -139,8 +112,9 @@ type ShExecType struct {
|
||||
RunnerOutFd *os.File
|
||||
MsgSender *packet.PacketSender // where to send out-of-band messages back to calling proceess
|
||||
ReturnState *ReturnStateBuf
|
||||
Exited bool // locked via Lock
|
||||
TmpRcFileName string
|
||||
Exited bool // locked via Lock
|
||||
TmpRcFileName string // file *or* directory holding temporary rc file(s)
|
||||
SAPI shellapi.ShellApi
|
||||
}
|
||||
|
||||
type StdContext struct{}
|
||||
@ -183,20 +157,6 @@ type ShExecUPR struct {
|
||||
UPR packet.UnknownPacketReporter
|
||||
}
|
||||
|
||||
func GetLocalBashPath() string {
|
||||
if runtime.GOOS == "darwin" {
|
||||
macShell := GetMacUserShell()
|
||||
if strings.Index(macShell, "bash") != -1 {
|
||||
return shellescape.Quote(macShell)
|
||||
}
|
||||
}
|
||||
return "bash"
|
||||
}
|
||||
|
||||
func GetShellStateCmd() string {
|
||||
return strings.Join(GetShellStateCmds, ` printf "\x00\x00";`)
|
||||
}
|
||||
|
||||
func (s *ShExecType) processSpecialInputPacket(pk *packet.SpecialInputPacketType) error {
|
||||
base.Logf("processSpecialInputPacket: %#v\n", pk)
|
||||
if pk.WinSize != nil {
|
||||
@ -242,12 +202,13 @@ func (s ShExecUPR) UnknownPacket(pk packet.PacketType) {
|
||||
}
|
||||
}
|
||||
|
||||
func MakeShExec(ck base.CommandKey, upr packet.UnknownPacketReporter) *ShExecType {
|
||||
func MakeShExec(ck base.CommandKey, upr packet.UnknownPacketReporter, sapi shellapi.ShellApi) *ShExecType {
|
||||
return &ShExecType{
|
||||
Lock: &sync.Mutex{},
|
||||
StartTs: time.Now(),
|
||||
CK: ck,
|
||||
Multiplexer: mpio.MakeMultiplexer(ck, upr),
|
||||
SAPI: sapi,
|
||||
}
|
||||
}
|
||||
|
||||
@ -267,7 +228,8 @@ func (c *ShExecType) Close() {
|
||||
c.ReturnState.Reader.Close()
|
||||
}
|
||||
if c.TmpRcFileName != "" {
|
||||
os.Remove(c.TmpRcFileName)
|
||||
// TmpRcFileName can be a file or a directory
|
||||
os.RemoveAll(c.TmpRcFileName)
|
||||
}
|
||||
}
|
||||
|
||||
@ -280,42 +242,6 @@ func (c *ShExecType) MakeCmdStartPacket(reqId string) *packet.CmdStartPacketType
|
||||
return startPacket
|
||||
}
|
||||
|
||||
func getEnvStrKey(envStr string) string {
|
||||
eqIdx := strings.Index(envStr, "=")
|
||||
if eqIdx == -1 {
|
||||
return envStr
|
||||
}
|
||||
return envStr[0:eqIdx]
|
||||
}
|
||||
|
||||
func UpdateCmdEnv(cmd *exec.Cmd, envVars map[string]string) {
|
||||
if len(envVars) == 0 {
|
||||
return
|
||||
}
|
||||
found := make(map[string]bool)
|
||||
var newEnv []string
|
||||
for _, envStr := range cmd.Env {
|
||||
envKey := getEnvStrKey(envStr)
|
||||
newEnvVal, ok := envVars[envKey]
|
||||
if ok {
|
||||
if newEnvVal == "" {
|
||||
continue
|
||||
}
|
||||
newEnv = append(newEnv, envKey+"="+newEnvVal)
|
||||
found[envKey] = true
|
||||
} else {
|
||||
newEnv = append(newEnv, envStr)
|
||||
}
|
||||
}
|
||||
for envKey, envVal := range envVars {
|
||||
if found[envKey] {
|
||||
continue
|
||||
}
|
||||
newEnv = append(newEnv, envKey+"="+envVal)
|
||||
}
|
||||
cmd.Env = newEnv
|
||||
}
|
||||
|
||||
// returns (pr, err)
|
||||
func MakeSimpleStaticWriterPipe(data []byte) (*os.File, error) {
|
||||
pr, pw, err := os.Pipe()
|
||||
@ -329,17 +255,30 @@ func MakeSimpleStaticWriterPipe(data []byte) (*os.File, error) {
|
||||
return pr, err
|
||||
}
|
||||
|
||||
func MakeRunnerExec(ck base.CommandKey) (*exec.Cmd, error) {
|
||||
msPath, err := base.GetMShellPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ecmd := exec.Command(msPath, string(ck))
|
||||
return ecmd, nil
|
||||
}
|
||||
|
||||
func MakeDetachedExecCmd(pk *packet.RunPacketType, cmdTty *os.File) (*exec.Cmd, error) {
|
||||
sapi, err := shellapi.MakeShellApi(pk.ShellType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state := pk.State
|
||||
if state == nil {
|
||||
state = &packet.ShellState{}
|
||||
}
|
||||
ecmd := exec.Command(GetLocalBashPath(), "-c", pk.Command)
|
||||
ecmd := exec.Command(sapi.GetLocalShellPath(), "-c", pk.Command)
|
||||
if !pk.StateComplete {
|
||||
ecmd.Env = os.Environ()
|
||||
}
|
||||
UpdateCmdEnv(ecmd, EnvMapFromState(state))
|
||||
UpdateCmdEnv(ecmd, MShellEnvVars(getTermType(pk)))
|
||||
shellutil.UpdateCmdEnv(ecmd, shellenv.EnvMapFromState(state))
|
||||
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(getTermType(pk)))
|
||||
if state.Cwd != "" {
|
||||
ecmd.Dir = base.ExpandHomeDir(state.Cwd)
|
||||
}
|
||||
@ -373,15 +312,6 @@ func MakeDetachedExecCmd(pk *packet.RunPacketType, cmdTty *os.File) (*exec.Cmd,
|
||||
return ecmd, nil
|
||||
}
|
||||
|
||||
func MakeRunnerExec(ck base.CommandKey) (*exec.Cmd, error) {
|
||||
msPath, err := base.GetMShellPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ecmd := exec.Command(msPath, string(ck))
|
||||
return ecmd, nil
|
||||
}
|
||||
|
||||
// this will never return (unless there is an error creating/opening the file), as fifoFile will never EOF
|
||||
func MakeAndCopyStdinFifo(dst *os.File, fifoName string) error {
|
||||
os.Remove(fifoName)
|
||||
@ -457,8 +387,8 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
|
||||
}
|
||||
|
||||
func GetWinsize(p *packet.RunPacketType) *pty.Winsize {
|
||||
rows := DefaultTermRows
|
||||
cols := DefaultTermCols
|
||||
rows := shellutil.DefaultTermRows
|
||||
cols := shellutil.DefaultTermCols
|
||||
if p.TermOpts != nil {
|
||||
rows = base.BoundInt(p.TermOpts.Rows, MinTermRows, MaxTermRows)
|
||||
cols = base.BoundInt(p.TermOpts.Cols, MinTermCols, MaxTermCols)
|
||||
@ -497,49 +427,23 @@ type ClientOpts struct {
|
||||
UsePty bool
|
||||
}
|
||||
|
||||
func (opts SSHOpts) MakeSSHInstallCmd() (*exec.Cmd, error) {
|
||||
if opts.SSHHost == "" {
|
||||
return nil, fmt.Errorf("no ssh host provided, can only install to a remote host")
|
||||
}
|
||||
cmdStr := MakeInstallCommandStr()
|
||||
return opts.MakeSSHExecCmd(cmdStr), nil
|
||||
}
|
||||
|
||||
func (opts SSHOpts) MakeMShellServerCmd() (*exec.Cmd, error) {
|
||||
msPath, err := base.GetMShellPath()
|
||||
func MakeMShellSingleCmd() (*exec.Cmd, error) {
|
||||
execFile, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("cannot find local mshell executable: %w", err)
|
||||
}
|
||||
ecmd := exec.Command(msPath, "--server")
|
||||
ecmd := exec.Command(execFile, "--single-from-server")
|
||||
return ecmd, nil
|
||||
}
|
||||
|
||||
func (opts SSHOpts) MakeMShellSingleCmd(fromServer bool) (*exec.Cmd, error) {
|
||||
if opts.SSHHost == "" {
|
||||
execFile, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find local mshell executable: %w", err)
|
||||
}
|
||||
var ecmd *exec.Cmd
|
||||
if fromServer {
|
||||
ecmd = exec.Command(execFile, "--single-from-server")
|
||||
} else {
|
||||
ecmd = exec.Command(execFile, "--single")
|
||||
}
|
||||
return ecmd, nil
|
||||
}
|
||||
cmdStr := MakeClientCommandStr()
|
||||
return opts.MakeSSHExecCmd(cmdStr), nil
|
||||
}
|
||||
|
||||
func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string) *exec.Cmd {
|
||||
func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string, sapi shellapi.ShellApi) *exec.Cmd {
|
||||
remoteCommand = strings.TrimSpace(remoteCommand)
|
||||
if opts.SSHHost == "" {
|
||||
homeDir, _ := os.UserHomeDir() // ignore error
|
||||
if homeDir == "" {
|
||||
homeDir = "/"
|
||||
}
|
||||
ecmd := exec.Command(GetLocalBashPath(), "-c", remoteCommand)
|
||||
ecmd := exec.Command(sapi.GetLocalShellPath(), "-c", remoteCommand)
|
||||
ecmd.Dir = homeDir
|
||||
return ecmd
|
||||
} else {
|
||||
@ -566,7 +470,7 @@ func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string) *exec.Cmd {
|
||||
}
|
||||
// note that SSHOptsStr is *not* escaped
|
||||
sshCmd := fmt.Sprintf("ssh %s %s %s %s", strings.Join(moreSSHOpts, " "), opts.SSHOptsStr, shellescape.Quote(opts.SSHHost), shellescape.Quote(remoteCommand))
|
||||
ecmd := exec.Command(RemoteBashPath, "-c", sshCmd)
|
||||
ecmd := exec.Command(sapi.GetRemoteShellPath(), "-c", sshCmd)
|
||||
return ecmd
|
||||
}
|
||||
}
|
||||
@ -605,59 +509,6 @@ func GetTerminalSize() (int, int, error) {
|
||||
return pty.Getsize(fd)
|
||||
}
|
||||
|
||||
func (opts *ClientOpts) MakeRunPacket() (*packet.RunPacketType, error) {
|
||||
runPacket := packet.MakeRunPacket()
|
||||
runPacket.Detached = opts.Detach
|
||||
runPacket.State = &packet.ShellState{}
|
||||
runPacket.State.Cwd = opts.Cwd
|
||||
runPacket.Fds = opts.Fds
|
||||
if opts.UsePty {
|
||||
runPacket.UsePty = true
|
||||
runPacket.TermOpts = &packet.TermOpts{}
|
||||
rows, cols, err := GetTerminalSize()
|
||||
if err == nil {
|
||||
runPacket.TermOpts.Rows = rows
|
||||
runPacket.TermOpts.Cols = cols
|
||||
}
|
||||
term := os.Getenv("TERM")
|
||||
if term != "" {
|
||||
runPacket.TermOpts.Term = term
|
||||
}
|
||||
}
|
||||
if !opts.Sudo {
|
||||
// normal, non-sudo command
|
||||
runPacket.Command = fmt.Sprintf(RunCommandFmt, opts.Command)
|
||||
return runPacket, nil
|
||||
}
|
||||
if opts.SudoWithPass {
|
||||
pwFdNum, err := AddRunData(runPacket, opts.SudoPw, "sudo pw")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
commandFdNum, err := AddRunData(runPacket, opts.Command, "command")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
commandStdinFdNum, err := NextFreeFdNum(runPacket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
commandStdinRfd := packet.RemoteFd{FdNum: commandStdinFdNum, Read: true, DupStdin: true}
|
||||
runPacket.Fds = append(runPacket.Fds, commandStdinRfd)
|
||||
maxFdNum := MaxFdNumInPacket(runPacket)
|
||||
runPacket.Command = fmt.Sprintf(RunSudoPasswordCommandFmt, pwFdNum, maxFdNum+1, pwFdNum, commandFdNum, commandStdinFdNum)
|
||||
return runPacket, nil
|
||||
} else {
|
||||
commandFdNum, err := AddRunData(runPacket, opts.Command, "command")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
maxFdNum := MaxFdNumInPacket(runPacket)
|
||||
runPacket.Command = fmt.Sprintf(RunSudoCommandFmt, maxFdNum+1, commandFdNum)
|
||||
return runPacket, nil
|
||||
}
|
||||
}
|
||||
|
||||
func AddRunData(pk *packet.RunPacketType, data string, dataType string) (int, error) {
|
||||
if len(data) > MaxRunDataSize {
|
||||
return 0, fmt.Errorf("%s too large, exceeds read buffer size size:%d", dataType, len(data))
|
||||
@ -810,31 +661,6 @@ func RunInstallFromCmd(ctx context.Context, ecmd *exec.Cmd, tryDetect bool, mshe
|
||||
}
|
||||
}
|
||||
|
||||
func RunInstallFromOpts(opts *InstallOpts) error {
|
||||
ecmd, err := opts.SSHOpts.MakeSSHInstallCmd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msgFn := func(str string) {
|
||||
fmt.Printf("%s", str)
|
||||
}
|
||||
var mshellStream *os.File
|
||||
if opts.OptName != "" {
|
||||
mshellStream, err = os.Open(opts.OptName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open mshell binary %q: %v", opts.OptName, err)
|
||||
}
|
||||
defer mshellStream.Close()
|
||||
}
|
||||
err = RunInstallFromCmd(context.Background(), ecmd, opts.Detect, mshellStream, base.MShellBinaryFromOptDir, msgFn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mmVersion := semver.MajorMinor(base.MShellVersion)
|
||||
fmt.Printf("mshell installed successfully at %s:~/.mshell/mshell%s\n", opts.SSHOpts.SSHHost, mmVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func HasDupStdin(fds []packet.RemoteFd) bool {
|
||||
for _, rfd := range fds {
|
||||
if rfd.Read && rfd.DupStdin {
|
||||
@ -844,101 +670,6 @@ func HasDupStdin(fds []packet.RemoteFd) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, upr packet.UnknownPacketReporter, debug bool) (*packet.CmdDonePacketType, error) {
|
||||
cmd := MakeShExec(runPacket.CK, upr)
|
||||
ecmd, err := sshOpts.MakeMShellSingleCmd(false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd.Cmd = ecmd
|
||||
inputWriter, err := ecmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating stdin pipe: %v", err)
|
||||
}
|
||||
stdoutReader, err := ecmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating stdout pipe: %v", err)
|
||||
}
|
||||
stderrReader, err := ecmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating stderr pipe: %v", err)
|
||||
}
|
||||
if !HasDupStdin(runPacket.Fds) {
|
||||
cmd.Multiplexer.MakeRawFdReader(0, fdContext.GetReader(0), false, false)
|
||||
}
|
||||
cmd.Multiplexer.MakeRawFdWriter(1, fdContext.GetWriter(1), false, "client")
|
||||
cmd.Multiplexer.MakeRawFdWriter(2, fdContext.GetWriter(2), false, "client")
|
||||
for _, rfd := range runPacket.Fds {
|
||||
if rfd.Read && rfd.DupStdin {
|
||||
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fdContext.GetReader(0), false, false)
|
||||
continue
|
||||
}
|
||||
if rfd.Read {
|
||||
fd := fdContext.GetReader(rfd.FdNum)
|
||||
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, false, false)
|
||||
} else if rfd.Write {
|
||||
fd := fdContext.GetWriter(rfd.FdNum)
|
||||
cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true, "client")
|
||||
}
|
||||
}
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("running ssh command: %w", err)
|
||||
}
|
||||
defer cmd.Close()
|
||||
stdoutPacketParser := packet.MakePacketParser(stdoutReader, nil)
|
||||
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
|
||||
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, false)
|
||||
sender := packet.MakePacketSender(inputWriter, nil)
|
||||
versionOk := false
|
||||
for pk := range packetParser.MainCh {
|
||||
if pk.GetType() == packet.RawPacketStr {
|
||||
rawPk := pk.(*packet.RawPacketType)
|
||||
fmt.Printf("%s\n", rawPk.Data)
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.InitPacketStr {
|
||||
initPk := pk.(*packet.InitPacketType)
|
||||
mmVersion := semver.MajorMinor(base.MShellVersion)
|
||||
if initPk.NotFound {
|
||||
if sshOpts.SSHHost == "" {
|
||||
return nil, fmt.Errorf("mshell-%s command not found on local server", mmVersion)
|
||||
}
|
||||
if initPk.UName == "" {
|
||||
return nil, fmt.Errorf("mshell-%s command not found on remote server, no uname detected", mmVersion)
|
||||
}
|
||||
goos, goarch, err := DetectGoArch(initPk.UName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mshell-%s command not found on remote server, architecture cannot be detected (might be incompatible with mshell): %w", mmVersion, err)
|
||||
}
|
||||
sshOptsStr := sshOpts.MakeMShellSSHOpts()
|
||||
return nil, fmt.Errorf("mshell-%s command not found on remote server, can install with 'mshell --install %s %s.%s'", mmVersion, sshOptsStr, goos, goarch)
|
||||
}
|
||||
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
|
||||
return nil, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion))
|
||||
}
|
||||
versionOk = true
|
||||
if debug {
|
||||
fmt.Printf("VERSION> %s\n", initPk.Version)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !versionOk {
|
||||
return nil, fmt.Errorf("did not receive version from remote mshell")
|
||||
}
|
||||
SendRunPacketAndRunData(context.Background(), sender, runPacket)
|
||||
if debug {
|
||||
cmd.Multiplexer.Debug = true
|
||||
}
|
||||
remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetParser, sender, false, true, true)
|
||||
donePacket := cmd.WaitForCommand()
|
||||
if remoteDonePacket != nil {
|
||||
donePacket = remoteDonePacket
|
||||
}
|
||||
return donePacket, nil
|
||||
}
|
||||
|
||||
func min(v1 int, v2 int) int {
|
||||
if v1 <= v2 {
|
||||
return v1
|
||||
@ -1015,46 +746,13 @@ func (cmd *ShExecType) RunRemoteIOAndWait(packetParser *packet.PacketParser, sen
|
||||
}
|
||||
|
||||
func getTermType(pk *packet.RunPacketType) string {
|
||||
termType := DefaultTermType
|
||||
termType := shellutil.DefaultTermType
|
||||
if pk.TermOpts != nil && pk.TermOpts.Term != "" {
|
||||
termType = pk.TermOpts.Term
|
||||
}
|
||||
return termType
|
||||
}
|
||||
|
||||
func makeRcFileStr(pk *packet.RunPacketType) string {
|
||||
var rcBuf bytes.Buffer
|
||||
rcBuf.WriteString(BaseBashOpts + "\n")
|
||||
varDecls := VarDeclsFromState(pk.State)
|
||||
for _, varDecl := range varDecls {
|
||||
if varDecl.IsExport() || varDecl.IsReadOnly() {
|
||||
continue
|
||||
}
|
||||
rcBuf.WriteString(varDecl.DeclareStmt())
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
if pk.State != nil && pk.State.Funcs != "" {
|
||||
rcBuf.WriteString(pk.State.Funcs)
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
if pk.State != nil && pk.State.Aliases != "" {
|
||||
rcBuf.WriteString(pk.State.Aliases)
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
return rcBuf.String()
|
||||
}
|
||||
|
||||
func makeExitTrap(fdNum int) string {
|
||||
stateCmd := GetShellStateRedirectCommandStr(fdNum)
|
||||
fmtStr := `
|
||||
_mshell_exittrap () {
|
||||
%s
|
||||
}
|
||||
trap _mshell_exittrap EXIT
|
||||
`
|
||||
return fmt.Sprintf(fmtStr, stateCmd)
|
||||
}
|
||||
|
||||
func (s *ShExecType) SendSignal(sig syscall.Signal) {
|
||||
base.Logf("signal start %v\n", sig)
|
||||
if sig == syscall.SIGKILL {
|
||||
@ -1086,11 +784,15 @@ func (s *ShExecType) SendSignal(sig syscall.Signal) {
|
||||
}
|
||||
|
||||
func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fromServer bool) (rtnShExec *ShExecType, rtnErr error) {
|
||||
sapi, err := shellapi.MakeShellApi(pk.ShellType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state := pk.State
|
||||
if state == nil {
|
||||
state = &packet.ShellState{}
|
||||
return nil, fmt.Errorf("invalid run packet, no state")
|
||||
}
|
||||
cmd := MakeShExec(pk.CK, nil)
|
||||
cmd := MakeShExec(pk.CK, nil, sapi)
|
||||
defer func() {
|
||||
// on error, call cmd.Close()
|
||||
if rtnErr != nil {
|
||||
@ -1104,7 +806,7 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
||||
cmd.MsgSender = sender
|
||||
}
|
||||
var rtnStateWriter *os.File
|
||||
rcFileStr := makeRcFileStr(pk)
|
||||
rcFileStr := sapi.MakeRcFileStr(pk)
|
||||
if pk.ReturnState {
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
@ -1115,10 +817,10 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
||||
cmd.ReturnState.FdNum = RtnStateFdNum
|
||||
rtnStateWriter = pw
|
||||
defer pw.Close()
|
||||
trapCmdStr := makeExitTrap(cmd.ReturnState.FdNum)
|
||||
trapCmdStr := sapi.MakeExitTrap(cmd.ReturnState.FdNum)
|
||||
rcFileStr += trapCmdStr
|
||||
}
|
||||
shellVarMap := ShellVarMapFromState(state)
|
||||
shellVarMap := shellenv.ShellVarMapFromState(state)
|
||||
if base.HasDebugFlag(shellVarMap, base.DebugFlag_LogRcFile) {
|
||||
debugRcFileName := base.GetDebugRcFileName()
|
||||
err := os.WriteFile(debugRcFileName, []byte(rcFileStr), 0600)
|
||||
@ -1126,9 +828,13 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
||||
base.Logf("error writing %s: %v\n", debugRcFileName, err)
|
||||
}
|
||||
}
|
||||
bashVersion := GetLocalBashMajorVersion()
|
||||
isOldBashVersion := (semver.Compare(bashVersion, "v4") < 0)
|
||||
var isOldBashVersion bool
|
||||
if sapi.GetShellType() == packet.ShellType_bash {
|
||||
bashVersion := sapi.GetLocalMajorVersion()
|
||||
isOldBashVersion = (semver.Compare(bashVersion, "v4") < 0)
|
||||
}
|
||||
var rcFileName string
|
||||
var zdotdir string
|
||||
if isOldBashVersion {
|
||||
rcFileDir, err := base.EnsureRcFilesDir()
|
||||
if err != nil {
|
||||
@ -1140,12 +846,19 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
||||
return nil, fmt.Errorf("could not write temp rcfile: %w", err)
|
||||
}
|
||||
cmd.TmpRcFileName = rcFileName
|
||||
go func() {
|
||||
// cmd.Close() will also remove rcFileName
|
||||
// adding this to also try to proactively clean up after 1-second.
|
||||
time.Sleep(1 * time.Second)
|
||||
os.Remove(rcFileName)
|
||||
}()
|
||||
} else if sapi.GetShellType() == packet.ShellType_zsh {
|
||||
rcFileDir, err := base.EnsureRcFilesDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
zdotdir = path.Join(rcFileDir, uuid.New().String())
|
||||
os.Mkdir(zdotdir, 0700)
|
||||
rcFileName = path.Join(zdotdir, ".zshenv")
|
||||
err = os.WriteFile(rcFileName, []byte(rcFileStr), 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not write temp rcfile: %w", err)
|
||||
}
|
||||
cmd.TmpRcFileName = zdotdir
|
||||
} else {
|
||||
rcFileFdNum, err := AddRunData(pk, rcFileStr, "rcfile")
|
||||
if err != nil {
|
||||
@ -1153,19 +866,26 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
||||
}
|
||||
rcFileName = fmt.Sprintf("/dev/fd/%d", rcFileFdNum)
|
||||
}
|
||||
if pk.UsePty {
|
||||
cmd.Cmd = exec.Command(GetLocalBashPath(), "--rcfile", rcFileName, "-i", "-c", pk.Command)
|
||||
} else {
|
||||
cmd.Cmd = exec.Command(GetLocalBashPath(), "--rcfile", rcFileName, "-c", pk.Command)
|
||||
if cmd.TmpRcFileName != "" {
|
||||
go func() {
|
||||
// cmd.Close() will also remove rcFileName
|
||||
// adding this to also try to proactively clean up after 2-seconds.
|
||||
time.Sleep(2 * time.Second)
|
||||
os.Remove(cmd.TmpRcFileName)
|
||||
}()
|
||||
}
|
||||
cmd.Cmd = sapi.MakeShExecCommand(pk.Command, rcFileName, pk.UsePty)
|
||||
if !pk.StateComplete {
|
||||
cmd.Cmd.Env = os.Environ()
|
||||
}
|
||||
UpdateCmdEnv(cmd.Cmd, EnvMapFromState(state))
|
||||
shellutil.UpdateCmdEnv(cmd.Cmd, shellenv.EnvMapFromState(state))
|
||||
if sapi.GetShellType() == packet.ShellType_zsh {
|
||||
shellutil.UpdateCmdEnv(cmd.Cmd, map[string]string{"ZDOTDIR": zdotdir})
|
||||
}
|
||||
if state.Cwd != "" {
|
||||
cmd.Cmd.Dir = base.ExpandHomeDir(state.Cwd)
|
||||
}
|
||||
err := ValidateRemoteFds(pk.Fds)
|
||||
err = ValidateRemoteFds(pk.Fds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1181,7 +901,7 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
||||
cmdTty.Close()
|
||||
}()
|
||||
cmd.CmdPty = cmdPty
|
||||
UpdateCmdEnv(cmd.Cmd, MShellEnvVars(getTermType(pk)))
|
||||
shellutil.UpdateCmdEnv(cmd.Cmd, shellutil.MShellEnvVars(getTermType(pk)))
|
||||
}
|
||||
if cmdTty != nil {
|
||||
cmd.Cmd.Stdin = cmdTty
|
||||
@ -1374,6 +1094,10 @@ func (cmd *ShExecType) DetachedWait(startPacket *packet.CmdStartPacketType) {
|
||||
}
|
||||
|
||||
func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, *packet.CmdStartPacketType, error) {
|
||||
sapi, err := shellapi.MakeShellApi(pk.ShellType)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
fileNames, err := base.GetCommandFileNames(pk.CK)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@ -1393,7 +1117,7 @@ func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
|
||||
defer func() {
|
||||
cmdTty.Close()
|
||||
}()
|
||||
cmd := MakeShExec(pk.CK, nil)
|
||||
cmd := MakeShExec(pk.CK, nil, sapi)
|
||||
cmd.FileNames = fileNames
|
||||
cmd.CmdPty = cmdPty
|
||||
cmd.Detached = true
|
||||
@ -1462,7 +1186,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
|
||||
c.ReturnState.Reader.Close()
|
||||
}()
|
||||
<-c.ReturnState.DoneCh
|
||||
state, _ := ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
|
||||
state, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
|
||||
donePacket.FinalState = state
|
||||
}
|
||||
endTs := time.Now()
|
||||
@ -1487,18 +1211,27 @@ func MakeInitPacket() *packet.InitPacketType {
|
||||
}
|
||||
initPacket.HostName, _ = os.Hostname()
|
||||
initPacket.UName = fmt.Sprintf("%s|%s", runtime.GOOS, runtime.GOARCH)
|
||||
initPacket.Shell = shellapi.DetectLocalShellType()
|
||||
return initPacket
|
||||
}
|
||||
|
||||
func MakeShellStatePacket(shellType string) (*packet.ShellStatePacketType, error) {
|
||||
sapi, err := shellapi.MakeShellApi(shellType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shellState, err := sapi.GetShellState()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn := packet.MakeShellStatePacket()
|
||||
rtn.State = shellState
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func MakeServerInitPacket() (*packet.InitPacketType, error) {
|
||||
var err error
|
||||
initPacket := MakeInitPacket()
|
||||
shellState, err := GetShellState()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initPacket.State = shellState
|
||||
initPacket.Shell = os.Getenv(ShellVarName)
|
||||
initPacket.RemoteId, err = base.GetRemoteId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1549,129 +1282,3 @@ func getStderr(err error) string {
|
||||
}
|
||||
return lines[0]
|
||||
}
|
||||
|
||||
func runSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
|
||||
ecmd.Env = os.Environ()
|
||||
UpdateCmdEnv(ecmd, MShellEnvVars(DefaultTermType))
|
||||
cmdPty, cmdTty, err := pty.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening new pty: %w", err)
|
||||
}
|
||||
pty.Setsize(cmdPty, &pty.Winsize{Rows: DefaultTermRows, Cols: DefaultTermCols})
|
||||
ecmd.Stdin = cmdTty
|
||||
ecmd.Stdout = cmdTty
|
||||
ecmd.Stderr = cmdTty
|
||||
ecmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
ecmd.SysProcAttr.Setsid = true
|
||||
ecmd.SysProcAttr.Setctty = true
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
cmdTty.Close()
|
||||
cmdPty.Close()
|
||||
return nil, err
|
||||
}
|
||||
cmdTty.Close()
|
||||
defer cmdPty.Close()
|
||||
ioDone := make(chan bool)
|
||||
var outputBuf bytes.Buffer
|
||||
go func() {
|
||||
// ignore error (/dev/ptmx has read error when process is done)
|
||||
io.Copy(&outputBuf, cmdPty)
|
||||
close(ioDone)
|
||||
}()
|
||||
exitErr := ecmd.Wait()
|
||||
if exitErr != nil {
|
||||
return nil, exitErr
|
||||
}
|
||||
<-ioDone
|
||||
return outputBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
func GetShellStateRedirectCommandStr(outputFdNum int) string {
|
||||
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetShellStateCmd(), outputFdNum)
|
||||
}
|
||||
|
||||
func GetShellState() (*packet.ShellState, error) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
cmdStr := BaseBashOpts + "; " + GetShellStateCmd()
|
||||
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
|
||||
outputBytes, err := runSimpleCmdInPty(ecmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ParseShellStateOutput(outputBytes)
|
||||
}
|
||||
|
||||
func MShellEnvVars(termType string) map[string]string {
|
||||
rtn := make(map[string]string)
|
||||
if termType != "" {
|
||||
rtn["TERM"] = termType
|
||||
}
|
||||
rtn["MSHELL"], _ = os.Executable()
|
||||
rtn["MSHELL_VERSION"] = base.MShellVersion
|
||||
rtn["WAVESHELL"], _ = os.Executable()
|
||||
rtn["WAVESHELL_VERSION"] = base.MShellVersion
|
||||
return rtn
|
||||
}
|
||||
|
||||
func ExecGetLocalShellVersion() string {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
ecmd := exec.CommandContext(ctx, "bash", "-c", ShellVersionCmdStr)
|
||||
out, err := ecmd.Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
versionStr := strings.TrimSpace(string(out))
|
||||
if strings.Index(versionStr, "bash ") == -1 {
|
||||
// invalid shell version (only bash is supported)
|
||||
return ""
|
||||
}
|
||||
return versionStr
|
||||
}
|
||||
|
||||
func GetLocalBashMajorVersion() string {
|
||||
LocalBashMajorVersionOnce.Do(func() {
|
||||
fullVersion := ExecGetLocalShellVersion()
|
||||
LocalBashMajorVersion = packet.GetBashMajorVersion(fullVersion)
|
||||
})
|
||||
return LocalBashMajorVersion
|
||||
}
|
||||
|
||||
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
|
||||
|
||||
var cachedMacUserShell string
|
||||
var macUserShellOnce = &sync.Once{}
|
||||
|
||||
func GetMacUserShell() string {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return ""
|
||||
}
|
||||
macUserShellOnce.Do(func() {
|
||||
cachedMacUserShell = internalMacUserShell()
|
||||
})
|
||||
return cachedMacUserShell
|
||||
}
|
||||
|
||||
// dscl . -read /User/[username] UserShell
|
||||
// defaults to /bin/bash
|
||||
func internalMacUserShell() string {
|
||||
osUser, err := user.Current()
|
||||
if err != nil {
|
||||
return DefaultMacOSShell
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancelFn()
|
||||
userStr := "/Users/" + osUser.Username
|
||||
out, err := exec.CommandContext(ctx, "dscl", ".", "-read", userStr, "UserShell").CombinedOutput()
|
||||
if err != nil {
|
||||
return DefaultMacOSShell
|
||||
}
|
||||
outStr := strings.TrimSpace(string(out))
|
||||
m := userShellRegexp.FindStringSubmatch(outStr)
|
||||
if m == nil {
|
||||
return DefaultMacOSShell
|
||||
}
|
||||
return m[1]
|
||||
}
|
||||
|
@ -10,7 +10,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const LineDiffVersion = 0
|
||||
const LineDiffVersion_0 = 0
|
||||
const LineDiffVersion = 1
|
||||
|
||||
type SingleLineEntry struct {
|
||||
LineVal int
|
||||
@ -18,12 +19,21 @@ type SingleLineEntry struct {
|
||||
}
|
||||
|
||||
type LineDiffType struct {
|
||||
Lines []SingleLineEntry
|
||||
NewData []string
|
||||
Version int
|
||||
SplitString string // added in version 1
|
||||
Lines []SingleLineEntry
|
||||
NewData []string
|
||||
}
|
||||
|
||||
func (diff *LineDiffType) Clear() {
|
||||
diff.Version = LineDiffVersion
|
||||
diff.SplitString = ""
|
||||
diff.Lines = nil
|
||||
diff.NewData = nil
|
||||
}
|
||||
|
||||
func (diff LineDiffType) Dump() {
|
||||
fmt.Printf("DIFF:\n")
|
||||
fmt.Printf("DIFF: v%d\n", diff.Version)
|
||||
pos := 1
|
||||
for _, entry := range diff.Lines {
|
||||
fmt.Printf(" %d-%d: %d\n", pos, pos+entry.Run, entry.LineVal)
|
||||
@ -68,13 +78,37 @@ func putUVarint(buf *bytes.Buffer, viBuf []byte, ival int) {
|
||||
buf.Write(viBuf[0:l])
|
||||
}
|
||||
|
||||
// run length encoding, writes a uvarint for length, and then that many bytes of data
|
||||
func putEncodedString(buf *bytes.Buffer, viBuf []byte, str string) {
|
||||
l := binary.PutUvarint(viBuf, uint64(len(str)))
|
||||
buf.Write(viBuf[0:l])
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
func readEncodedString(buf *bytes.Buffer) (string, error) {
|
||||
strLen64, err := binary.ReadUvarint(buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid diff, cannot read string length: %v", err)
|
||||
}
|
||||
strLen := int(strLen64)
|
||||
if strLen == 0 {
|
||||
return "", nil
|
||||
}
|
||||
strBytes := buf.Next(strLen)
|
||||
if len(strBytes) != strLen {
|
||||
return "", fmt.Errorf("invalid diff, partial read, expected %d, got %d", strLen, len(strBytes))
|
||||
}
|
||||
return string(strBytes), nil
|
||||
}
|
||||
|
||||
// version 0 is no longer used, but kept here as a reference for decoding
|
||||
// simple encoding
|
||||
// write varints. first version, then len, then len-number-of-varints, then fill the rest with newdata
|
||||
// write varints. first version, then then len, then len-number-of-varints, then fill the rest with newdata
|
||||
// [version] [len-varint] [varint]xlen... newdata (bytes)
|
||||
func (diff LineDiffType) Encode() []byte {
|
||||
func (diff LineDiffType) Encode_v0() []byte {
|
||||
var buf bytes.Buffer
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
putUVarint(&buf, viBuf, LineDiffVersion)
|
||||
putUVarint(&buf, viBuf, LineDiffVersion_0)
|
||||
putUVarint(&buf, viBuf, len(diff.Lines))
|
||||
for _, entry := range diff.Lines {
|
||||
putUVarint(&buf, viBuf, entry.LineVal)
|
||||
@ -83,37 +117,117 @@ func (diff LineDiffType) Encode() []byte {
|
||||
for idx, str := range diff.NewData {
|
||||
buf.WriteString(str)
|
||||
if idx != len(diff.NewData)-1 {
|
||||
buf.WriteByte('\n')
|
||||
buf.WriteString(diff.SplitString)
|
||||
}
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (rtn *LineDiffType) Decode(diffBytes []byte) error {
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read version: %v", err)
|
||||
// version 1 updates the diff to include the split-string
|
||||
// it also encodes all the strings with run-length encoding
|
||||
func (diff LineDiffType) Encode() []byte {
|
||||
var buf bytes.Buffer
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
putUVarint(&buf, viBuf, LineDiffVersion)
|
||||
putEncodedString(&buf, viBuf, diff.SplitString)
|
||||
putUVarint(&buf, viBuf, len(diff.Lines))
|
||||
for _, entry := range diff.Lines {
|
||||
putUVarint(&buf, viBuf, entry.LineVal)
|
||||
putUVarint(&buf, viBuf, entry.Run)
|
||||
}
|
||||
if version != LineDiffVersion {
|
||||
return fmt.Errorf("invalid diff, bad version: %d", version)
|
||||
}
|
||||
linesLen64, err := binary.ReadUvarint(r)
|
||||
writeEncodedStringArray(&buf, viBuf, diff.NewData)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (rtn *LineDiffType) readEncodedLines(buf *bytes.Buffer) error {
|
||||
linesLen64, err := binary.ReadUvarint(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read lines length: %v", err)
|
||||
}
|
||||
linesLen := int(linesLen64)
|
||||
rtn.Lines = make([]SingleLineEntry, linesLen)
|
||||
for idx := 0; idx < linesLen; idx++ {
|
||||
lineVal, err := binary.ReadUvarint(r)
|
||||
lineVal64, err := binary.ReadUvarint(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read line %d: %v", idx, err)
|
||||
}
|
||||
lineRun, err := binary.ReadUvarint(r)
|
||||
lineRun64, err := binary.ReadUvarint(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read line-run %d: %v", idx, err)
|
||||
}
|
||||
rtn.Lines[idx] = SingleLineEntry{LineVal: int(lineVal), Run: int(lineRun)}
|
||||
rtn.Lines[idx] = SingleLineEntry{LineVal: int(lineVal64), Run: int(lineRun64)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeEncodedStringArray(buf *bytes.Buffer, viBuf []byte, strArr []string) {
|
||||
putUVarint(buf, viBuf, len(strArr))
|
||||
for _, str := range strArr {
|
||||
putEncodedString(buf, viBuf, str)
|
||||
}
|
||||
}
|
||||
|
||||
func readEncodedStringArray(buf *bytes.Buffer) ([]string, error) {
|
||||
strArrLen64, err := binary.ReadUvarint(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid diff, cannot read string-array length: %v", err)
|
||||
}
|
||||
strArrLen := int(strArrLen64)
|
||||
rtn := make([]string, strArrLen)
|
||||
for idx := 0; idx < strArrLen; idx++ {
|
||||
str, err := readEncodedString(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn[idx] = str
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (rtn *LineDiffType) Decode(diffBytes []byte) error {
|
||||
rtn.Clear()
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read version: %v", err)
|
||||
}
|
||||
if version == LineDiffVersion_0 {
|
||||
return rtn.Decode_v0(diffBytes)
|
||||
}
|
||||
if version != LineDiffVersion {
|
||||
return fmt.Errorf("invalid diff, bad version: %d", version)
|
||||
}
|
||||
rtn.Version = int(version)
|
||||
splitString, err := readEncodedString(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read split-string: %v", err)
|
||||
}
|
||||
rtn.SplitString = splitString
|
||||
err = rtn.readEncodedLines(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rtn.NewData, err = readEncodedStringArray(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rtn *LineDiffType) Decode_v0(diffBytes []byte) error {
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read version: %v", err)
|
||||
}
|
||||
if version != LineDiffVersion_0 {
|
||||
return fmt.Errorf("invalid diff, bad version: %d", version)
|
||||
}
|
||||
rtn.Version = int(version)
|
||||
rtn.SplitString = "\n" // added when we added version 1
|
||||
err = rtn.readEncodedLines(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
restOfInput := string(r.Bytes())
|
||||
if len(restOfInput) > 0 {
|
||||
@ -122,8 +236,10 @@ func (rtn *LineDiffType) Decode(diffBytes []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeLineDiff(oldData []string, newData []string) LineDiffType {
|
||||
func makeLineDiff(oldData []string, newData []string, splitString string) LineDiffType {
|
||||
var rtn LineDiffType
|
||||
rtn.Version = LineDiffVersion
|
||||
rtn.SplitString = splitString
|
||||
oldDataMap := make(map[string]int) // 1-indexed
|
||||
for idx, str := range oldData {
|
||||
if _, found := oldDataMap[str]; found {
|
||||
@ -163,13 +279,13 @@ func makeLineDiff(oldData []string, newData []string) LineDiffType {
|
||||
return rtn
|
||||
}
|
||||
|
||||
func MakeLineDiff(str1 string, str2 string) []byte {
|
||||
func MakeLineDiff(str1 string, str2 string, splitString string) []byte {
|
||||
if str1 == str2 {
|
||||
return nil
|
||||
}
|
||||
str1Arr := strings.Split(str1, "\n")
|
||||
str2Arr := strings.Split(str2, "\n")
|
||||
diff := makeLineDiff(str1Arr, str2Arr)
|
||||
str1Arr := strings.Split(str1, splitString)
|
||||
str2Arr := strings.Split(str2, splitString)
|
||||
diff := makeLineDiff(str1Arr, str2Arr, splitString)
|
||||
return diff.Encode()
|
||||
}
|
||||
|
||||
@ -182,10 +298,10 @@ func ApplyLineDiff(str1 string, diffBytes []byte) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
str1Arr := strings.Split(str1, "\n")
|
||||
str1Arr := strings.Split(str1, diff.SplitString)
|
||||
str2Arr, err := diff.applyDiff(str1Arr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.Join(str2Arr, "\n"), nil
|
||||
return strings.Join(str2Arr, diff.SplitString), nil
|
||||
}
|
||||
|
@ -7,17 +7,27 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/binpack"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
const MapDiffVersion = 0
|
||||
const MapDiffVersion_0 = 0
|
||||
const MapDiffVersion = 1
|
||||
|
||||
// 0-bytes are not allowed in entries or keys (same as bash)
|
||||
|
||||
type MapDiffType struct {
|
||||
ToAdd map[string]string
|
||||
ToAdd map[string][]byte
|
||||
ToRemove []string
|
||||
}
|
||||
|
||||
func (diff *MapDiffType) Clear() {
|
||||
diff.ToAdd = nil
|
||||
diff.ToRemove = nil
|
||||
}
|
||||
|
||||
func (diff MapDiffType) Dump() {
|
||||
fmt.Printf("VAR-DIFF\n")
|
||||
for name, val := range diff.ToAdd {
|
||||
@ -28,12 +38,12 @@ func (diff MapDiffType) Dump() {
|
||||
}
|
||||
}
|
||||
|
||||
func makeMapDiff(oldMap map[string]string, newMap map[string]string) MapDiffType {
|
||||
func makeMapDiff(oldMap map[string][]byte, newMap map[string][]byte) MapDiffType {
|
||||
var rtn MapDiffType
|
||||
rtn.ToAdd = make(map[string]string)
|
||||
rtn.ToAdd = make(map[string][]byte)
|
||||
for name, newVal := range newMap {
|
||||
oldVal, found := oldMap[name]
|
||||
if !found || oldVal != newVal {
|
||||
if !found || !bytes.Equal(oldVal, newVal) {
|
||||
rtn.ToAdd[name] = newVal
|
||||
continue
|
||||
}
|
||||
@ -47,8 +57,8 @@ func makeMapDiff(oldMap map[string]string, newMap map[string]string) MapDiffType
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (diff MapDiffType) apply(oldMap map[string]string) map[string]string {
|
||||
rtn := make(map[string]string)
|
||||
func (diff MapDiffType) apply(oldMap map[string][]byte) map[string][]byte {
|
||||
rtn := make(map[string][]byte)
|
||||
for name, val := range oldMap {
|
||||
rtn[name] = val
|
||||
}
|
||||
@ -61,15 +71,16 @@ func (diff MapDiffType) apply(oldMap map[string]string) map[string]string {
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (diff MapDiffType) Encode() []byte {
|
||||
// this is kept for reference
|
||||
func (diff MapDiffType) Encode_v0() []byte {
|
||||
var buf bytes.Buffer
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
putUVarint(&buf, viBuf, MapDiffVersion)
|
||||
putUVarint(&buf, viBuf, MapDiffVersion_0)
|
||||
putUVarint(&buf, viBuf, len(diff.ToAdd))
|
||||
for key, val := range diff.ToAdd {
|
||||
buf.WriteString(key)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(val)
|
||||
buf.Write(val)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
for _, val := range diff.ToRemove {
|
||||
@ -79,13 +90,75 @@ func (diff MapDiffType) Encode() []byte {
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// we sort map keys and remove values to make the diff deterministic
|
||||
func (diff MapDiffType) Encode() []byte {
|
||||
var buf bytes.Buffer
|
||||
binpack.PackUInt(&buf, MapDiffVersion)
|
||||
binpack.PackUInt(&buf, uint64(len(diff.ToAdd)))
|
||||
addKeys := utilfn.GetOrderedMapKeys(diff.ToAdd)
|
||||
for _, key := range addKeys {
|
||||
val := diff.ToAdd[key]
|
||||
binpack.PackValue(&buf, []byte(key))
|
||||
binpack.PackValue(&buf, val)
|
||||
}
|
||||
slices.Sort(diff.ToRemove)
|
||||
binpack.PackUInt(&buf, uint64(len(diff.ToRemove)))
|
||||
for _, val := range diff.ToRemove {
|
||||
binpack.PackValue(&buf, []byte(val))
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (diff *MapDiffType) Decode(diffBytes []byte) error {
|
||||
diff.Clear()
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binpack.UnpackUInt(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read version: %v", err)
|
||||
}
|
||||
if version == MapDiffVersion_0 {
|
||||
return diff.Decode_v0(diffBytes)
|
||||
}
|
||||
if version != MapDiffVersion {
|
||||
return fmt.Errorf("invalid diff, bad version: %d", version)
|
||||
}
|
||||
addLen, err := binpack.UnpackUIntAsInt(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read add length: %v", err)
|
||||
}
|
||||
diff.ToAdd = make(map[string][]byte)
|
||||
for i := 0; i < addLen; i++ {
|
||||
key, err := binpack.UnpackValue(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read add key %d: %v", i, err)
|
||||
}
|
||||
val, err := binpack.UnpackValue(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read add val %d: %v", i, err)
|
||||
}
|
||||
diff.ToAdd[string(key)] = val
|
||||
}
|
||||
removeLen, err := binpack.UnpackUIntAsInt(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read remove length: %v", err)
|
||||
}
|
||||
for i := 0; i < removeLen; i++ {
|
||||
val, err := binpack.UnpackValue(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read remove val %d: %v", i, err)
|
||||
}
|
||||
diff.ToRemove = append(diff.ToRemove, string(val))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (diff *MapDiffType) Decode_v0(diffBytes []byte) error {
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read version: %v", err)
|
||||
}
|
||||
if version != MapDiffVersion {
|
||||
if version != MapDiffVersion_0 {
|
||||
return fmt.Errorf("invalid diff, bad version: %d", version)
|
||||
}
|
||||
mapLen64, err := binary.ReadUvarint(r)
|
||||
@ -99,9 +172,9 @@ func (diff *MapDiffType) Decode(diffBytes []byte) error {
|
||||
}
|
||||
mapFields := fields[0 : 2*mapLen]
|
||||
removeFields := fields[2*mapLen:]
|
||||
diff.ToAdd = make(map[string]string)
|
||||
diff.ToAdd = make(map[string][]byte)
|
||||
for i := 0; i < len(mapFields); i += 2 {
|
||||
diff.ToAdd[string(mapFields[i])] = string(mapFields[i+1])
|
||||
diff.ToAdd[string(mapFields[i])] = mapFields[i+1]
|
||||
}
|
||||
for _, removeVal := range removeFields {
|
||||
if len(removeVal) == 0 {
|
||||
@ -112,7 +185,7 @@ func (diff *MapDiffType) Decode(diffBytes []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func MakeMapDiff(m1 map[string]string, m2 map[string]string) []byte {
|
||||
func MakeMapDiff(m1 map[string][]byte, m2 map[string][]byte) []byte {
|
||||
diff := makeMapDiff(m1, m2)
|
||||
if len(diff.ToAdd) == 0 && len(diff.ToRemove) == 0 {
|
||||
return nil
|
||||
@ -120,7 +193,7 @@ func MakeMapDiff(m1 map[string]string, m2 map[string]string) []byte {
|
||||
return diff.Encode()
|
||||
}
|
||||
|
||||
func ApplyMapDiff(oldMap map[string]string, diffBytes []byte) (map[string]string, error) {
|
||||
func ApplyMapDiff(oldMap map[string][]byte, diffBytes []byte) (map[string][]byte, error) {
|
||||
if len(diffBytes) == 0 {
|
||||
return oldMap, nil
|
||||
}
|
||||
|
@ -4,8 +4,12 @@
|
||||
package statediff
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
const Str1 = `
|
||||
@ -38,8 +42,8 @@ banana2
|
||||
coconut
|
||||
`
|
||||
|
||||
func testLineDiff(t *testing.T, str1 string, str2 string) {
|
||||
diffBytes := MakeLineDiff(str1, str2)
|
||||
func testLineDiff(t *testing.T, str1 string, str2 string, splitString string) {
|
||||
diffBytes := MakeLineDiff(str1, str2, splitString)
|
||||
fmt.Printf("diff-len: %d\n", len(diffBytes))
|
||||
out, err := ApplyLineDiff(str1, diffBytes)
|
||||
if err != nil {
|
||||
@ -57,35 +61,86 @@ func testLineDiff(t *testing.T, str1 string, str2 string) {
|
||||
}
|
||||
|
||||
func TestLineDiff(t *testing.T) {
|
||||
testLineDiff(t, Str1, Str2)
|
||||
testLineDiff(t, Str2, Str3)
|
||||
testLineDiff(t, Str1, Str3)
|
||||
testLineDiff(t, Str3, Str1)
|
||||
testLineDiff(t, Str3, Str4)
|
||||
testLineDiff(t, Str1, Str2, "\n")
|
||||
testLineDiff(t, Str2, Str3, "\n")
|
||||
testLineDiff(t, Str1, Str3, "\n")
|
||||
testLineDiff(t, Str3, Str1, "\n")
|
||||
testLineDiff(t, Str3, Str4, "\n")
|
||||
}
|
||||
|
||||
func strMapsEqual(m1 map[string]string, m2 map[string]string) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
func TestLineDiff0(t *testing.T) {
|
||||
var str1Arr []string = []string{"a", "b", "c", "d", "e"}
|
||||
var str2Arr []string = []string{"a", "e"}
|
||||
str1 := strings.Join(str1Arr, "\x00")
|
||||
str2 := strings.Join(str2Arr, "\x00")
|
||||
diffBytes := MakeLineDiff(str1, str2, "\x00")
|
||||
fmt.Printf("diff-len: %d\n", len(diffBytes))
|
||||
out, err := ApplyLineDiff(str1, diffBytes)
|
||||
if err != nil {
|
||||
t.Errorf("error in diff: %v", err)
|
||||
return
|
||||
}
|
||||
for key, val := range m1 {
|
||||
val2, ok := m2[key]
|
||||
if !ok || val != val2 {
|
||||
return false
|
||||
}
|
||||
if out != str2 {
|
||||
t.Errorf("bad diff output")
|
||||
}
|
||||
for key, val := range m2 {
|
||||
val2, ok := m1[key]
|
||||
if !ok || val != val2 {
|
||||
return false
|
||||
}
|
||||
diffBytes = MakeLineDiff(str2, str1, "\x00")
|
||||
fmt.Printf("diff-len: %d\n", len(diffBytes))
|
||||
out, err = ApplyLineDiff(str2, diffBytes)
|
||||
if err != nil {
|
||||
t.Errorf("error in diff: %v", err)
|
||||
return
|
||||
}
|
||||
if out != str1 {
|
||||
t.Errorf("bad diff output")
|
||||
}
|
||||
|
||||
diffBytes = MakeLineDiff(str1, str1, "\x00")
|
||||
if len(diffBytes) != 0 {
|
||||
t.Errorf("bad diff output (len should be 0)")
|
||||
}
|
||||
var diffVar LineDiffType
|
||||
diffVar.Decode(diffBytes)
|
||||
if len(diffVar.Lines) != 0 || len(diffVar.NewData) != 0 || diffVar.Version != LineDiffVersion {
|
||||
t.Errorf("bad diff output (for decoding nil)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLineDiffVersion0(t *testing.T) {
|
||||
var str1Arr []string = []string{"a", "b", "c", "d", "e"}
|
||||
var str2Arr []string = []string{"a", "e"}
|
||||
str1 := strings.Join(str1Arr, "\n")
|
||||
str2 := strings.Join(str2Arr, "\n")
|
||||
|
||||
var diff LineDiffType
|
||||
diff.Version = 0
|
||||
diff.SplitString = "\n"
|
||||
diff.Lines = []SingleLineEntry{{LineVal: 1, Run: 1}, {LineVal: 5, Run: 1}}
|
||||
encDiff0 := diff.Encode_v0()
|
||||
|
||||
var decDiff LineDiffType
|
||||
err := decDiff.Decode(encDiff0)
|
||||
if err != nil {
|
||||
t.Errorf("error decoding diff: %v\n", err)
|
||||
}
|
||||
if decDiff.Version != 0 {
|
||||
t.Errorf("bad version")
|
||||
}
|
||||
if decDiff.SplitString != "\n" {
|
||||
t.Errorf("bad split string")
|
||||
}
|
||||
out, err := ApplyLineDiff(str1, encDiff0)
|
||||
if err != nil {
|
||||
t.Errorf("error in diff: %v", err)
|
||||
return
|
||||
}
|
||||
if out != str2 {
|
||||
t.Errorf("bad diff output")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMapDiff(t *testing.T) {
|
||||
m1 := map[string]string{"a": "5", "b": "hello", "c": "mike"}
|
||||
m2 := map[string]string{"a": "5", "b": "goodbye", "d": "more"}
|
||||
m1 := map[string][]byte{"a": []byte("5"), "b": []byte("hello"), "c": []byte("mike")}
|
||||
m2 := map[string][]byte{"a": []byte("5"), "b": []byte("goodbye"), "d": []byte("more")}
|
||||
diffBytes := MakeMapDiff(m1, m2)
|
||||
fmt.Printf("mapdifflen: %d\n", len(diffBytes))
|
||||
var diff MapDiffType
|
||||
@ -95,8 +150,38 @@ func TestMapDiff(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("error applying map diff: %v", err)
|
||||
}
|
||||
if !strMapsEqual(m2, mcheck) {
|
||||
if !utilfn.ByteMapsEqual(m2, mcheck) {
|
||||
t.Errorf("maps not equal")
|
||||
}
|
||||
|
||||
// try v0
|
||||
mdiff := makeMapDiff(m1, m2)
|
||||
diffBytes = mdiff.Encode_v0()
|
||||
mcheck, err = ApplyMapDiff(m1, diffBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("error applying map diff: %v", err)
|
||||
}
|
||||
if !utilfn.ByteMapsEqual(m2, mcheck) {
|
||||
t.Errorf("maps not equal")
|
||||
}
|
||||
|
||||
diffBytes = MakeMapDiff(m1, m1)
|
||||
if len(diffBytes) != 0 {
|
||||
t.Errorf("bad diff output (len should be 0)")
|
||||
}
|
||||
mcheck, err = ApplyMapDiff(m1, diffBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("error applying map diff: %v", err)
|
||||
}
|
||||
if !utilfn.ByteMapsEqual(m1, mcheck) {
|
||||
t.Errorf("maps not equal")
|
||||
}
|
||||
fmt.Printf("%v\n", mcheck)
|
||||
}
|
||||
|
||||
func TestVarint(t *testing.T) {
|
||||
viBuf := make([]byte, 10)
|
||||
viLen := binary.PutVarint(viBuf, 1)
|
||||
fmt.Printf("%#v\n", viBuf[0:viLen])
|
||||
viLen = binary.PutUvarint(viBuf, 1)
|
||||
fmt.Printf("%#v\n", viBuf[0:viLen])
|
||||
}
|
||||
|
522
waveshell/pkg/utilfn/utilfn.go
Normal file
522
waveshell/pkg/utilfn/utilfn.go
Normal file
@ -0,0 +1,522 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package utilfn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var HexDigits = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}
|
||||
|
||||
func GetStrArr(v interface{}, field string) []string {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
m, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
fieldVal := m[field]
|
||||
if fieldVal == nil {
|
||||
return nil
|
||||
}
|
||||
iarr, ok := fieldVal.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var sarr []string
|
||||
for _, iv := range iarr {
|
||||
if sv, ok := iv.(string); ok {
|
||||
sarr = append(sarr, sv)
|
||||
}
|
||||
}
|
||||
return sarr
|
||||
}
|
||||
|
||||
func GetBool(v interface{}, field string) bool {
|
||||
if v == nil {
|
||||
return false
|
||||
}
|
||||
m, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
fieldVal := m[field]
|
||||
if fieldVal == nil {
|
||||
return false
|
||||
}
|
||||
bval, ok := fieldVal.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return bval
|
||||
}
|
||||
|
||||
var needsQuoteRe = regexp.MustCompile(`[^\w@%:,./=+-]`)
|
||||
|
||||
// minimum maxlen=6
|
||||
func ShellQuote(val string, forceQuote bool, maxLen int) string {
|
||||
if maxLen < 6 {
|
||||
maxLen = 6
|
||||
}
|
||||
rtn := val
|
||||
if needsQuoteRe.MatchString(val) {
|
||||
rtn = "'" + strings.ReplaceAll(val, "'", `'"'"'`) + "'"
|
||||
}
|
||||
if strings.HasPrefix(rtn, "\"") || strings.HasPrefix(rtn, "'") {
|
||||
if len(rtn) > maxLen {
|
||||
return rtn[0:maxLen-4] + "..." + rtn[0:1]
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
if forceQuote {
|
||||
if len(rtn) > maxLen-2 {
|
||||
return "\"" + rtn[0:maxLen-5] + "...\""
|
||||
}
|
||||
return "\"" + rtn + "\""
|
||||
} else {
|
||||
if len(rtn) > maxLen {
|
||||
return rtn[0:maxLen-3] + "..."
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
}
|
||||
|
||||
func EllipsisStr(s string, maxLen int) string {
|
||||
if maxLen < 4 {
|
||||
maxLen = 4
|
||||
}
|
||||
if len(s) > maxLen {
|
||||
return s[0:maxLen-3] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func LongestPrefix(root string, strs []string) string {
|
||||
if len(strs) == 0 {
|
||||
return root
|
||||
}
|
||||
if len(strs) == 1 {
|
||||
comp := strs[0]
|
||||
if len(comp) >= len(root) && strings.HasPrefix(comp, root) {
|
||||
if strings.HasSuffix(comp, "/") {
|
||||
return strs[0]
|
||||
}
|
||||
return strs[0]
|
||||
}
|
||||
}
|
||||
lcp := strs[0]
|
||||
for i := 1; i < len(strs); i++ {
|
||||
s := strs[i]
|
||||
for j := 0; j < len(lcp); j++ {
|
||||
if j >= len(s) || lcp[j] != s[j] {
|
||||
lcp = lcp[0:j]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(lcp) < len(root) || !strings.HasPrefix(lcp, root) {
|
||||
return root
|
||||
}
|
||||
return lcp
|
||||
}
|
||||
|
||||
func ContainsStr(strs []string, test string) bool {
|
||||
for _, s := range strs {
|
||||
if s == test {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsPrefix(strs []string, test string) bool {
|
||||
for _, s := range strs {
|
||||
if len(s) > len(test) && strings.HasPrefix(s, test) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sentinel value for StrWithPos.Pos to indicate no position
|
||||
const NoStrPos = -1
|
||||
|
||||
type StrWithPos struct {
|
||||
Str string `json:"str"`
|
||||
Pos int `json:"pos"` // this is a 'rune' position (not a byte position)
|
||||
}
|
||||
|
||||
func (sp StrWithPos) String() string {
|
||||
return strWithCursor(sp.Str, sp.Pos)
|
||||
}
|
||||
|
||||
func ParseToSP(s string) StrWithPos {
|
||||
idx := strings.Index(s, "[*]")
|
||||
if idx == -1 {
|
||||
return StrWithPos{Str: s, Pos: NoStrPos}
|
||||
}
|
||||
return StrWithPos{Str: s[0:idx] + s[idx+3:], Pos: utf8.RuneCountInString(s[0:idx])}
|
||||
}
|
||||
|
||||
func strWithCursor(str string, pos int) string {
|
||||
if pos == NoStrPos {
|
||||
return str
|
||||
}
|
||||
if pos < 0 {
|
||||
// invalid position
|
||||
return "[*]_" + str
|
||||
}
|
||||
if pos > len(str) {
|
||||
// invalid position
|
||||
return str + "_[*]"
|
||||
}
|
||||
if pos == len(str) {
|
||||
return str + "[*]"
|
||||
}
|
||||
var rtn []rune
|
||||
for _, ch := range str {
|
||||
if len(rtn) == pos {
|
||||
rtn = append(rtn, '[', '*', ']')
|
||||
}
|
||||
rtn = append(rtn, ch)
|
||||
}
|
||||
return string(rtn)
|
||||
}
|
||||
|
||||
func (sp StrWithPos) Prepend(str string) StrWithPos {
|
||||
return StrWithPos{Str: str + sp.Str, Pos: utf8.RuneCountInString(str) + sp.Pos}
|
||||
}
|
||||
|
||||
func (sp StrWithPos) Append(str string) StrWithPos {
|
||||
return StrWithPos{Str: sp.Str + str, Pos: sp.Pos}
|
||||
}
|
||||
|
||||
// returns base64 hash of data
|
||||
func Sha1Hash(data []byte) string {
|
||||
hvalRaw := sha1.Sum(data)
|
||||
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
|
||||
return hval
|
||||
}
|
||||
|
||||
func ChunkSlice[T any](s []T, chunkSize int) [][]T {
|
||||
var rtn [][]T
|
||||
for len(rtn) > 0 {
|
||||
if len(s) <= chunkSize {
|
||||
rtn = append(rtn, s)
|
||||
break
|
||||
}
|
||||
rtn = append(rtn, s[:chunkSize])
|
||||
s = s[chunkSize:]
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
var ErrOverflow = errors.New("integer overflow")
|
||||
|
||||
// Add two int values, returning an error if the result overflows.
|
||||
func AddInt(left, right int) (int, error) {
|
||||
if right > 0 {
|
||||
if left > math.MaxInt-right {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
} else {
|
||||
if left < math.MinInt-right {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
}
|
||||
return left + right, nil
|
||||
}
|
||||
|
||||
// Add a slice of ints, returning an error if the result overflows.
|
||||
func AddIntSlice(vals ...int) (int, error) {
|
||||
var rtn int
|
||||
for _, v := range vals {
|
||||
var err error
|
||||
rtn, err = AddInt(rtn, v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func StrsEqual(s1arr []string, s2arr []string) bool {
|
||||
if len(s1arr) != len(s2arr) {
|
||||
return false
|
||||
}
|
||||
for i, s1 := range s1arr {
|
||||
s2 := s2arr[i]
|
||||
if s1 != s2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func StrMapsEqual(m1 map[string]string, m2 map[string]string) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for key, val1 := range m1 {
|
||||
val2, found := m2[key]
|
||||
if !found || val1 != val2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for key := range m2 {
|
||||
_, found := m1[key]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ByteMapsEqual(m1 map[string][]byte, m2 map[string][]byte) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for key, val1 := range m1 {
|
||||
val2, found := m2[key]
|
||||
if !found || !bytes.Equal(val1, val2) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for key := range m2 {
|
||||
_, found := m1[key]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func GetOrderedStringerMapKeys[K interface {
|
||||
comparable
|
||||
fmt.Stringer
|
||||
}, V any](m map[K]V) []K {
|
||||
keyStrMap := make(map[K]string)
|
||||
keys := make([]K, 0, len(m))
|
||||
for key := range m {
|
||||
keys = append(keys, key)
|
||||
keyStrMap[key] = key.String()
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return keyStrMap[keys[i]] < keyStrMap[keys[j]]
|
||||
})
|
||||
return keys
|
||||
}
|
||||
|
||||
func GetOrderedMapKeys[V any](m map[string]V) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for key := range m {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
const (
|
||||
nullEncodeEscByte = '\\'
|
||||
nullEncodeSepByte = '|'
|
||||
nullEncodeEqByte = '='
|
||||
nullEncodeZeroByteEsc = '0'
|
||||
nullEncodeEscByteEsc = '\\'
|
||||
nullEncodeSepByteEsc = 's'
|
||||
nullEncodeEqByteEsc = 'e'
|
||||
)
|
||||
|
||||
func EncodeStringMap(m map[string]string) []byte {
|
||||
var buf bytes.Buffer
|
||||
for idx, key := range GetOrderedMapKeys(m) {
|
||||
val := m[key]
|
||||
buf.Write(NullEncodeStr(key))
|
||||
buf.WriteByte(nullEncodeEqByte)
|
||||
buf.Write(NullEncodeStr(val))
|
||||
if idx < len(m)-1 {
|
||||
buf.WriteByte(nullEncodeSepByte)
|
||||
}
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func DecodeStringMap(barr []byte) (map[string]string, error) {
|
||||
if len(barr) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var rtn = make(map[string]string)
|
||||
for _, b := range bytes.Split(barr, []byte{nullEncodeSepByte}) {
|
||||
keyVal := bytes.SplitN(b, []byte{nullEncodeEqByte}, 2)
|
||||
if len(keyVal) != 2 {
|
||||
return nil, fmt.Errorf("invalid null encoding: %s", string(b))
|
||||
}
|
||||
key, err := NullDecodeStr(keyVal[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
val, err := NullDecodeStr(keyVal[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn[key] = val
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func EncodeStringArray(arr []string) []byte {
|
||||
var buf bytes.Buffer
|
||||
for idx, s := range arr {
|
||||
buf.Write(NullEncodeStr(s))
|
||||
if idx < len(arr)-1 {
|
||||
buf.WriteByte(nullEncodeSepByte)
|
||||
}
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func DecodeStringArray(barr []byte) ([]string, error) {
|
||||
if len(barr) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var rtn []string
|
||||
for _, b := range bytes.Split(barr, []byte{nullEncodeSepByte}) {
|
||||
s, err := NullDecodeStr(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn = append(rtn, s)
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func EncodedStringArrayHasFirstKey(encoded []byte, firstKey string) bool {
|
||||
firstKeyBytes := NullEncodeStr(firstKey)
|
||||
if !bytes.HasPrefix(encoded, firstKeyBytes) {
|
||||
return false
|
||||
}
|
||||
if len(encoded) == len(firstKeyBytes) || encoded[len(firstKeyBytes)] == nullEncodeSepByte {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// encodes a string, removing null/zero bytes (and separators '|')
|
||||
// a zero byte is encoded as "\0", a '\' is encoded as "\\", sep is encoded as "\s"
|
||||
// allows for easy double splitting (first on \x00, and next on "|")
|
||||
func NullEncodeStr(s string) []byte {
|
||||
strBytes := []byte(s)
|
||||
if bytes.IndexByte(strBytes, 0) == -1 &&
|
||||
bytes.IndexByte(strBytes, nullEncodeEscByte) == -1 &&
|
||||
bytes.IndexByte(strBytes, nullEncodeSepByte) == -1 &&
|
||||
bytes.IndexByte(strBytes, nullEncodeEqByte) == -1 {
|
||||
return strBytes
|
||||
}
|
||||
var rtn []byte
|
||||
for _, b := range strBytes {
|
||||
if b == 0 {
|
||||
rtn = append(rtn, nullEncodeEscByte, nullEncodeZeroByteEsc)
|
||||
} else if b == nullEncodeEscByte {
|
||||
rtn = append(rtn, nullEncodeEscByte, nullEncodeEscByteEsc)
|
||||
} else if b == nullEncodeSepByte {
|
||||
rtn = append(rtn, nullEncodeEscByte, nullEncodeSepByteEsc)
|
||||
} else if b == nullEncodeEqByte {
|
||||
rtn = append(rtn, nullEncodeEscByte, nullEncodeEqByteEsc)
|
||||
} else {
|
||||
rtn = append(rtn, b)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func NullDecodeStr(barr []byte) (string, error) {
|
||||
if bytes.IndexByte(barr, nullEncodeEscByte) == -1 {
|
||||
return string(barr), nil
|
||||
}
|
||||
var rtn []byte
|
||||
for i := 0; i < len(barr); i++ {
|
||||
curByte := barr[i]
|
||||
if curByte == nullEncodeEscByte {
|
||||
i++
|
||||
nextByte := barr[i]
|
||||
if nextByte == nullEncodeZeroByteEsc {
|
||||
rtn = append(rtn, 0)
|
||||
} else if nextByte == nullEncodeEscByteEsc {
|
||||
rtn = append(rtn, nullEncodeEscByte)
|
||||
} else if nextByte == nullEncodeSepByteEsc {
|
||||
rtn = append(rtn, nullEncodeSepByte)
|
||||
} else if nextByte == nullEncodeEqByteEsc {
|
||||
rtn = append(rtn, nullEncodeEqByte)
|
||||
} else {
|
||||
// invalid encoding
|
||||
return "", fmt.Errorf("invalid null encoding: %d", nextByte)
|
||||
}
|
||||
} else {
|
||||
rtn = append(rtn, curByte)
|
||||
}
|
||||
}
|
||||
return string(rtn), nil
|
||||
}
|
||||
|
||||
func SortStringRunes(s string) string {
|
||||
runes := []rune(s)
|
||||
sort.Slice(runes, func(i, j int) bool {
|
||||
return runes[i] < runes[j]
|
||||
})
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
// will overwrite m1 with m2's values
|
||||
func CombineMaps[V any](m1 map[string]V, m2 map[string]V) {
|
||||
for key, val := range m2 {
|
||||
m1[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
// returns hex escaped string (\xNN for each byte)
|
||||
func ShellHexEscape(s string) string {
|
||||
var rtn []byte
|
||||
for _, ch := range []byte(s) {
|
||||
rtn = append(rtn, []byte(fmt.Sprintf("\\x%02x", ch))...)
|
||||
}
|
||||
return string(rtn)
|
||||
}
|
||||
|
||||
func GetMapKeys[K comparable, V any](m map[K]V) []K {
|
||||
var rtn []K
|
||||
for key := range m {
|
||||
rtn = append(rtn, key)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
// combines string arrays and removes duplicates (returns a new array)
|
||||
func CombineStrArrays(sarr1 []string, sarr2 []string) []string {
|
||||
var rtn []string
|
||||
m := make(map[string]struct{})
|
||||
for _, s := range sarr1 {
|
||||
if _, found := m[s]; found {
|
||||
continue
|
||||
}
|
||||
m[s] = struct{}{}
|
||||
rtn = append(rtn, s)
|
||||
}
|
||||
for _, s := range sarr2 {
|
||||
if _, found := m[s]; found {
|
||||
continue
|
||||
}
|
||||
m[s] = struct{}{}
|
||||
rtn = append(rtn, s)
|
||||
}
|
||||
return rtn
|
||||
}
|
175
waveshell/pkg/utilfn/utilfn_test.go
Normal file
175
waveshell/pkg/utilfn/utilfn_test.go
Normal file
@ -0,0 +1,175 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package utilfn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const Str1 = `
|
||||
hello
|
||||
line #2
|
||||
more
|
||||
stuff
|
||||
apple
|
||||
`
|
||||
|
||||
const Str2 = `
|
||||
line #2
|
||||
apple
|
||||
grapes
|
||||
banana
|
||||
`
|
||||
|
||||
const Str3 = `
|
||||
more
|
||||
stuff
|
||||
banana
|
||||
coconut
|
||||
`
|
||||
|
||||
func testDiff(t *testing.T, str1 string, str2 string) {
|
||||
diffBytes := MakeDiff(str1, str2)
|
||||
fmt.Printf("diff-len: %d\n", len(diffBytes))
|
||||
out, err := ApplyDiff(str1, diffBytes)
|
||||
if err != nil {
|
||||
t.Errorf("error in diff: %v", err)
|
||||
return
|
||||
}
|
||||
if out != str2 {
|
||||
t.Errorf("bad diff output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiff(t *testing.T) {
|
||||
testDiff(t, Str1, Str2)
|
||||
testDiff(t, Str2, Str3)
|
||||
testDiff(t, Str1, Str3)
|
||||
testDiff(t, Str3, Str1)
|
||||
}
|
||||
|
||||
func testArithmetic(t *testing.T, fn func() (int, error), shouldError bool, expected int) {
|
||||
retVal, err := fn()
|
||||
if err != nil {
|
||||
if !shouldError {
|
||||
t.Errorf("unexpected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if shouldError {
|
||||
t.Errorf("expected error")
|
||||
return
|
||||
}
|
||||
if retVal != expected {
|
||||
t.Errorf("wrong return value")
|
||||
}
|
||||
}
|
||||
|
||||
func testAddInt(t *testing.T, shouldError bool, expected int, a int, b int) {
|
||||
testArithmetic(t, func() (int, error) { return AddInt(a, b) }, shouldError, expected)
|
||||
}
|
||||
|
||||
func TestAddInt(t *testing.T) {
|
||||
testAddInt(t, false, 3, 1, 2)
|
||||
testAddInt(t, true, 0, 1, math.MaxInt)
|
||||
testAddInt(t, true, 0, math.MinInt, -1)
|
||||
testAddInt(t, false, math.MaxInt-1, math.MaxInt, -1)
|
||||
testAddInt(t, false, math.MinInt+1, math.MinInt, 1)
|
||||
testAddInt(t, false, math.MaxInt, math.MaxInt, 0)
|
||||
testAddInt(t, true, 0, math.MinInt, -1)
|
||||
}
|
||||
|
||||
func testAddIntSlice(t *testing.T, shouldError bool, expected int, vals ...int) {
|
||||
testArithmetic(t, func() (int, error) { return AddIntSlice(vals...) }, shouldError, expected)
|
||||
}
|
||||
|
||||
func TestAddIntSlice(t *testing.T) {
|
||||
testAddIntSlice(t, false, 0)
|
||||
testAddIntSlice(t, false, 1, 1)
|
||||
testAddIntSlice(t, false, 3, 1, 2)
|
||||
testAddIntSlice(t, false, 6, 1, 2, 3)
|
||||
testAddIntSlice(t, true, 0, 1, math.MaxInt)
|
||||
testAddIntSlice(t, true, 0, 1, 2, math.MaxInt)
|
||||
testAddIntSlice(t, true, 0, math.MaxInt, 2, 1)
|
||||
testAddIntSlice(t, false, math.MaxInt, 0, 0, math.MaxInt)
|
||||
testAddIntSlice(t, true, 0, math.MinInt, -1)
|
||||
testAddIntSlice(t, false, math.MaxInt, math.MaxInt-3, 1, 2)
|
||||
testAddIntSlice(t, true, 0, math.MaxInt-2, 1, 2)
|
||||
testAddIntSlice(t, false, math.MinInt, math.MinInt+3, -1, -2)
|
||||
testAddIntSlice(t, true, 0, math.MinInt+2, -1, -2)
|
||||
}
|
||||
|
||||
func testNullEncodeStr(t *testing.T, str string, expected string) {
|
||||
encoded := NullEncodeStr(str)
|
||||
decoded, err := NullDecodeStr(encoded)
|
||||
if err != nil {
|
||||
t.Errorf("error in null encoding: %v", err)
|
||||
} else if decoded != str {
|
||||
t.Errorf("bad null encoding")
|
||||
}
|
||||
if string(encoded) != expected {
|
||||
t.Errorf("bad null encoding, %q != %q", str, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullEncodeStr(t *testing.T) {
|
||||
testNullEncodeStr(t, "", "")
|
||||
testNullEncodeStr(t, "hello", "hello")
|
||||
testNullEncodeStr(t, "hello\x00", "hello\\0")
|
||||
testNullEncodeStr(t, "abc|def", "abc\\sdef")
|
||||
testNullEncodeStr(t, "a|b\x00c\\d", "a\\sb\\0c\\\\d")
|
||||
testNullEncodeStr(t, "v==v", "v\\e\\ev")
|
||||
}
|
||||
|
||||
func testEncodeStringArray(t *testing.T, strs []string) {
|
||||
encoded := EncodeStringArray(strs)
|
||||
decoded, err := DecodeStringArray(encoded)
|
||||
if err != nil {
|
||||
t.Errorf("error in string array encoding: %v", err)
|
||||
} else if !StrsEqual(strs, decoded) {
|
||||
t.Errorf("bad string array encoding: %#v != %#v", strs, decoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeStringArray(t *testing.T) {
|
||||
testEncodeStringArray(t, nil)
|
||||
testEncodeStringArray(t, []string{})
|
||||
testEncodeStringArray(t, []string{"hello"})
|
||||
testEncodeStringArray(t, []string{"hello", "world=bar"})
|
||||
testEncodeStringArray(t, []string{"hello", "wor\x00ld", "fo|\\o", "N\\\x00|||ul==l"})
|
||||
}
|
||||
|
||||
func testEncodeStringMap(t *testing.T, m map[string]string) {
|
||||
encoded := EncodeStringMap(m)
|
||||
decoded, err := DecodeStringMap(encoded)
|
||||
if err != nil {
|
||||
t.Errorf("error in string map encoding: %v", err)
|
||||
} else if !StrMapsEqual(m, decoded) {
|
||||
t.Errorf("bad string map encoding: %#v != %#v", m, decoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeStringMap(t *testing.T) {
|
||||
testEncodeStringMap(t, nil)
|
||||
testEncodeStringMap(t, map[string]string{})
|
||||
testEncodeStringMap(t, map[string]string{"hello": "world"})
|
||||
testEncodeStringMap(t, map[string]string{"hello": "world", "foo": "bar"})
|
||||
testEncodeStringMap(t, map[string]string{"hello": "world", "fo=o": "b=ar", "a|b": "c\\d"})
|
||||
testEncodeStringMap(t, map[string]string{"hello\x00|": "w\x00orld", "foo": "bar", "a|b": "c\\d", "v==v": "v\\e\\ev"})
|
||||
}
|
||||
|
||||
func testShellHexEscape(t *testing.T, s string, expected string) {
|
||||
encoded := ShellHexEscape(s)
|
||||
if encoded != expected {
|
||||
t.Errorf("bad shell hex encoding, %q != %q", encoded, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellHexEscape(t *testing.T) {
|
||||
testShellHexEscape(t, "", "")
|
||||
testShellHexEscape(t, "a", `\x61`)
|
||||
testShellHexEscape(t, "\x00\x01abc\x00", `\x00\x01\x61\x62\x63\x00`)
|
||||
}
|
@ -30,6 +30,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/server"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/cmdrunner"
|
||||
@ -802,6 +803,7 @@ func doShutdown(reason string) {
|
||||
|
||||
func main() {
|
||||
scbase.BuildTime = BuildTime
|
||||
base.ProcessType = base.ProcessType_WaveSrv
|
||||
|
||||
if len(os.Args) >= 2 && os.Args[1] == "--test" {
|
||||
log.Printf("running test fn\n")
|
||||
|
2
wavesrv/db/migrations/000030_zsh_support.down.sql
Normal file
2
wavesrv/db/migrations/000030_zsh_support.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE remote_instance DROP COLUMN shelltype;
|
||||
ALTER TABLE remote DROP COLUMN shellpref;
|
2
wavesrv/db/migrations/000030_zsh_support.up.sql
Normal file
2
wavesrv/db/migrations/000030_zsh_support.up.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE remote_instance ADD COLUMN shelltype varchar(20) NOT NULL DEFAULT 'bash';
|
||||
ALTER TABLE remote ADD COLUMN shellpref varchar(20) NOT NULL DEFAULT 'detect';
|
@ -27,7 +27,10 @@ import (
|
||||
"github.com/kevinburke/ssh_config"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/comp"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/dbutil"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/pcloud"
|
||||
@ -37,7 +40,6 @@ import (
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
@ -243,6 +245,7 @@ func init() {
|
||||
registerCmdFn("chat", OpenAICommand)
|
||||
|
||||
registerCmdFn("_killserver", KillServerCommand)
|
||||
registerCmdFn("_dumpstate", DumpStateCommand)
|
||||
|
||||
registerCmdFn("set", SetCommand)
|
||||
|
||||
@ -1200,9 +1203,8 @@ type RemoteEditArgs struct {
|
||||
ConnectMode string
|
||||
Alias string
|
||||
AutoInstall bool
|
||||
SSHPassword string
|
||||
SSHKeyFile string
|
||||
Color string
|
||||
ShellPref string
|
||||
EditMap map[string]interface{}
|
||||
}
|
||||
|
||||
@ -1286,6 +1288,16 @@ func parseRemoteEditArgs(isNew bool, pk *scpacket.FeCommandPacketType, isLocal b
|
||||
return nil, fmt.Errorf("invalid alias format")
|
||||
}
|
||||
}
|
||||
var shellPref string
|
||||
if isNew {
|
||||
shellPref = sstore.ShellTypePref_Detect
|
||||
}
|
||||
if pk.Kwargs["shellpref"] != "" {
|
||||
shellPref = pk.Kwargs["shellpref"]
|
||||
}
|
||||
if shellPref != "" && shellPref != packet.ShellType_bash && shellPref != packet.ShellType_zsh && shellPref != sstore.ShellTypePref_Detect {
|
||||
return nil, fmt.Errorf("invalid shellpref %q, must be %s", shellPref, formatStrs([]string{packet.ShellType_bash, packet.ShellType_zsh, sstore.ShellTypePref_Detect}, "or", false))
|
||||
}
|
||||
var connectMode string
|
||||
if isNew {
|
||||
connectMode = sstore.ConnectModeAuto
|
||||
@ -1340,6 +1352,9 @@ func parseRemoteEditArgs(isNew bool, pk *scpacket.FeCommandPacketType, isLocal b
|
||||
}
|
||||
editMap[sstore.RemoteField_SSHPassword] = sshPassword
|
||||
}
|
||||
if _, found := pk.Kwargs["shellpref"]; found {
|
||||
editMap[sstore.RemoteField_ShellPref] = shellPref
|
||||
}
|
||||
|
||||
return &RemoteEditArgs{
|
||||
SSHOpts: sshOpts,
|
||||
@ -1347,10 +1362,9 @@ func parseRemoteEditArgs(isNew bool, pk *scpacket.FeCommandPacketType, isLocal b
|
||||
Alias: alias,
|
||||
AutoInstall: true,
|
||||
CanonicalName: canonicalName,
|
||||
SSHKeyFile: keyFile,
|
||||
SSHPassword: sshPassword,
|
||||
Color: color,
|
||||
EditMap: editMap,
|
||||
ShellPref: shellPref,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -1375,6 +1389,7 @@ func RemoteNewCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ss
|
||||
AutoInstall: editArgs.AutoInstall,
|
||||
SSHOpts: editArgs.SSHOpts,
|
||||
SSHConfigSrc: sstore.SSHConfigSrcTypeManual,
|
||||
ShellPref: editArgs.ShellPref,
|
||||
}
|
||||
if editArgs.Color != "" {
|
||||
r.RemoteOpts = &sstore.RemoteOptsType{Color: editArgs.Color}
|
||||
@ -1881,7 +1896,7 @@ func crShowCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, ids re
|
||||
if riBaseMap[remoteId] {
|
||||
continue
|
||||
}
|
||||
feState := msh.GetDefaultFeState()
|
||||
feState := msh.GetDefaultFeState(msh.GetShellPref())
|
||||
if feState == nil {
|
||||
continue
|
||||
}
|
||||
@ -2380,7 +2395,7 @@ func makeStaticCmd(ctx context.Context, metaCmd string, ids resolvedIds, cmdStr
|
||||
CmdStr: cmdStr,
|
||||
RawCmdStr: cmdStr,
|
||||
Remote: ids.Remote.RemotePtr,
|
||||
TermOpts: sstore.TermOpts{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize},
|
||||
TermOpts: sstore.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize},
|
||||
Status: sstore.CmdStatusDone,
|
||||
RunOut: nil,
|
||||
}
|
||||
@ -2977,23 +2992,31 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (ssto
|
||||
}
|
||||
|
||||
func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
|
||||
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen)
|
||||
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initPk, err := ids.Remote.MShell.ReInit(ctx)
|
||||
shellType := ids.Remote.ShellType
|
||||
if pk.Kwargs["shell"] != "" {
|
||||
shellArg := pk.Kwargs["shell"]
|
||||
if shellArg != packet.ShellType_bash && shellArg != packet.ShellType_zsh {
|
||||
return nil, fmt.Errorf("/reset invalid shell type %q", shellArg)
|
||||
}
|
||||
shellType = shellArg
|
||||
}
|
||||
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if initPk == nil || initPk.State == nil {
|
||||
if ssPk == nil || ssPk.State == nil {
|
||||
return nil, fmt.Errorf("invalid initpk received from remote (no remote state)")
|
||||
}
|
||||
feState := sstore.FeStateFromShellState(initPk.State)
|
||||
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, initPk.State, nil)
|
||||
feState := sstore.FeStateFromShellState(ssPk.State)
|
||||
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputStr := "reset remote state"
|
||||
outputStr := fmt.Sprintf("reset remote state (shell:%s)", ssPk.State.GetShellType())
|
||||
cmd, err := makeStaticCmd(ctx, "reset", ids, pk.GetRawStr(), []byte(outputStr))
|
||||
if err != nil {
|
||||
// TODO tricky error since the command was a success, but we can't show the output
|
||||
@ -4199,6 +4222,20 @@ func KillServerCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (s
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func DumpStateCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
|
||||
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currentState, err := sstore.GetFullState(ctx, *ids.Remote.StatePtr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting state: %v", err)
|
||||
}
|
||||
feState := sstore.FeStateFromShellState(currentState)
|
||||
shellenv.DumpVarMapFromState(currentState)
|
||||
return sstore.InfoMsgUpdate("current connection state sent to log. festate: %s", dbutil.QuickJson(feState)), nil
|
||||
}
|
||||
|
||||
func ClientCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
|
||||
return nil, fmt.Errorf("/client requires a subcommand: %s", formatStrs([]string{"show", "set"}, "or", false))
|
||||
}
|
||||
|
@ -36,6 +36,7 @@ type ResolvedRemote struct {
|
||||
MShell *remote.MShellProc
|
||||
RState remote.RemoteRuntimeState
|
||||
RemoteCopy *sstore.RemoteType
|
||||
ShellType string
|
||||
StatePtr *sstore.ShellStatePtr
|
||||
FeState map[string]string
|
||||
}
|
||||
@ -472,6 +473,7 @@ func ResolveRemoteFromPtr(ctx context.Context, rptr *sstore.RemotePtrType, sessi
|
||||
RemoteCopy: &rcopy,
|
||||
StatePtr: nil,
|
||||
FeState: nil,
|
||||
ShellType: "",
|
||||
}
|
||||
if sessionId != "" && screenId != "" {
|
||||
ri, err := sstore.GetRemoteInstance(ctx, sessionId, screenId, *rptr)
|
||||
@ -480,11 +482,13 @@ func ResolveRemoteFromPtr(ctx context.Context, rptr *sstore.RemotePtrType, sessi
|
||||
// continue with state set to nil
|
||||
} else {
|
||||
if ri == nil {
|
||||
rtn.StatePtr = msh.GetDefaultStatePtr()
|
||||
rtn.FeState = msh.GetDefaultFeState()
|
||||
rtn.ShellType = msh.GetShellPref()
|
||||
rtn.StatePtr = msh.GetDefaultStatePtr(rtn.ShellType)
|
||||
rtn.FeState = msh.GetDefaultFeState(rtn.ShellType)
|
||||
} else {
|
||||
rtn.StatePtr = &sstore.ShellStatePtr{BaseHash: ri.StateBaseHash, DiffHashArr: ri.StateDiffHashArr}
|
||||
rtn.FeState = ri.FeState
|
||||
rtn.ShellType = ri.ShellType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -9,10 +9,10 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
@ -141,6 +141,12 @@ func onlyRawArgs(metaCmd string, metaSubCmd string) bool {
|
||||
return CmdParseOverrides[metaCmd] == CmdParseTypeRaw
|
||||
}
|
||||
|
||||
var waveValidIdentifierRe = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
|
||||
func isValidWaveParamName(name string) bool {
|
||||
return waveValidIdentifierRe.MatchString(name)
|
||||
}
|
||||
|
||||
func setBracketArgs(argMap map[string]string, bracketStr string) error {
|
||||
bracketStr = strings.TrimSpace(bracketStr)
|
||||
if bracketStr == "" {
|
||||
@ -160,7 +166,7 @@ func setBracketArgs(argMap map[string]string, bracketStr string) error {
|
||||
varName = litStr[0:eqIdx]
|
||||
varVal = litStr[eqIdx+1:]
|
||||
}
|
||||
if !shexec.IsValidBashIdentifier(varName) {
|
||||
if !isValidWaveParamName(varName) {
|
||||
wordErr = fmt.Errorf("invalid identifier %s in bracket args", utilfn.ShellQuote(varName, true, 20))
|
||||
return false
|
||||
}
|
||||
@ -183,14 +189,29 @@ var literalRtnStateCommands = []string{
|
||||
".",
|
||||
"source",
|
||||
"unset",
|
||||
"unsetopt",
|
||||
"cd",
|
||||
"alias",
|
||||
"unalias",
|
||||
"deactivate",
|
||||
"eval",
|
||||
"asdf",
|
||||
"sdk",
|
||||
"nvm",
|
||||
"virtualenv",
|
||||
"builtin",
|
||||
"typeset",
|
||||
"declare",
|
||||
"float",
|
||||
"functions",
|
||||
"integer",
|
||||
"local",
|
||||
"readonly",
|
||||
"unfunction",
|
||||
"shopt",
|
||||
"enable",
|
||||
"disable",
|
||||
"function",
|
||||
}
|
||||
|
||||
func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string {
|
||||
@ -235,14 +256,13 @@ func isRtnStateCmd(cmd syntax.Command) bool {
|
||||
if arg0 != "" && utilfn.ContainsStr(literalRtnStateCommands, arg0) {
|
||||
return true
|
||||
}
|
||||
arg1 := getCallExprLitArg(callExpr, 1)
|
||||
if arg0 == "git" {
|
||||
arg1 := getCallExprLitArg(callExpr, 1)
|
||||
if arg1 == "checkout" || arg1 == "switch" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if arg0 == "conda" {
|
||||
arg1 := getCallExprLitArg(callExpr, 1)
|
||||
if arg1 == "activate" || arg1 == "deactivate" {
|
||||
return true
|
||||
}
|
||||
@ -253,12 +273,30 @@ func isRtnStateCmd(cmd syntax.Command) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func checkSimpleRtnStateCmd(cmdStr string) bool {
|
||||
cmdStr = strings.TrimSpace(cmdStr)
|
||||
if strings.HasPrefix(cmdStr, "function ") {
|
||||
return true
|
||||
}
|
||||
firstSpace := strings.Index(cmdStr, " ")
|
||||
if firstSpace != -1 {
|
||||
firstWord := strings.TrimSpace(cmdStr[:firstSpace])
|
||||
if strings.HasSuffix(firstWord, "()") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detects: export, declare, ., source, X=1, unset
|
||||
func IsReturnStateCommand(cmdStr string) bool {
|
||||
cmdReader := strings.NewReader(cmdStr)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(cmdReader, "cmd")
|
||||
if err != nil {
|
||||
if checkSimpleRtnStateCmd(cmdStr) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
for _, stmt := range file.Stmts {
|
||||
@ -353,7 +391,7 @@ func EvalMetaCommand(ctx context.Context, origPk *scpacket.FeCommandPacketType)
|
||||
return nil, fmt.Errorf("parsing metacmd, position %v", err)
|
||||
}
|
||||
envMap := make(map[string]string) // later we can add vars like session, screen, remote, and user
|
||||
cfg := shexec.GetParserConfig(envMap)
|
||||
cfg := shellapi.GetParserConfig(envMap)
|
||||
// process arguments
|
||||
for idx, w := range words {
|
||||
literalVal, err := expand.Literal(cfg, w)
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
@ -87,15 +88,15 @@ func GetUITermOpts(winSize *packet.WinSize, ptermStr string) (*packet.TermOpts,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
termOpts := &packet.TermOpts{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols, Term: remote.DefaultTerm, MaxPtySize: shexec.DefaultMaxPtySize}
|
||||
termOpts := &packet.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, Term: remote.DefaultTerm, MaxPtySize: shexec.DefaultMaxPtySize}
|
||||
if winSize == nil {
|
||||
winSize = &packet.WinSize{Rows: shexec.DefaultTermRows, Cols: shexec.DefaultTermCols}
|
||||
winSize = &packet.WinSize{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols}
|
||||
}
|
||||
if winSize.Rows == 0 {
|
||||
winSize.Rows = shexec.DefaultTermRows
|
||||
winSize.Rows = shellutil.DefaultTermRows
|
||||
}
|
||||
if winSize.Cols == 0 {
|
||||
winSize.Cols = shexec.DefaultTermCols
|
||||
winSize.Cols = shellutil.DefaultTermCols
|
||||
}
|
||||
if opts.Rows == PTermMax {
|
||||
termOpts.Rows = winSize.Rows
|
||||
|
@ -15,9 +15,9 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/shparse"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
@ -482,8 +482,8 @@ func splitCompWord(p *CompPoint) {
|
||||
|
||||
w1 := ParsedWord{Offset: w.Offset, Prefix: w.Prefix[:prefixPos]}
|
||||
w2 := ParsedWord{Offset: w.Offset + prefixPos, Prefix: w.Prefix[prefixPos:], Word: w.Word, PartialWord: w.PartialWord}
|
||||
p.CompWord = p.CompWord // the same (w1)
|
||||
p.CompWordPos = 0 // will be at 0 since w1 has a word length of 0
|
||||
// p.CompWord = p.CompWord // the same (w1)
|
||||
p.CompWordPos = 0 // will be at 0 since w1 has a word length of 0
|
||||
var newWords []ParsedWord
|
||||
if p.CompWord > 0 {
|
||||
newWords = append(newWords, p.Words[0:p.CompWord]...)
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
func parseToSP(s string) utilfn.StrWithPos {
|
||||
|
@ -10,8 +10,8 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
)
|
||||
|
||||
var globalLock = &sync.Mutex{}
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
ccp "golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
|
@ -27,8 +27,12 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/server"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
@ -44,6 +48,18 @@ const RemoteTermCols = 80
|
||||
const PtyReadBufSize = 100
|
||||
const RemoteConnectTimeout = 15 * time.Second
|
||||
|
||||
var envVarsToStrip map[string]bool = map[string]bool{
|
||||
"PROMPT": true,
|
||||
"PROMPT_VERSION": true,
|
||||
"MSHELL": true,
|
||||
"MSHELL_VERSION": true,
|
||||
"WAVETERM": true,
|
||||
"WAVETERM_VERSION": true,
|
||||
"TERM_PROGRAM": true,
|
||||
"TERM_PROGRAM_VERSION": true,
|
||||
"TERM_SESSION_ID": true,
|
||||
}
|
||||
|
||||
// we add this ping packet to the MShellServer Commands in order to deal with spurious SSH output
|
||||
// basically we guarantee the parser will see a valid packet (either an init error or a ping)
|
||||
// so we can pass ignoreUntilValid to PacketParser
|
||||
@ -120,9 +136,9 @@ type MShellProc struct {
|
||||
PtyBuffer *circbuf.Buffer
|
||||
MakeClientCancelFn context.CancelFunc
|
||||
MakeClientDeadline *time.Time
|
||||
StateMap map[string]*packet.ShellState // sha1->state
|
||||
CurrentState string // sha1
|
||||
StateMap *server.ShellStateMap
|
||||
NumTryConnect int
|
||||
InitPkShellType string
|
||||
|
||||
// install
|
||||
InstallStatus string
|
||||
@ -159,32 +175,26 @@ func (msh *MShellProc) GetStatus() string {
|
||||
return msh.Status
|
||||
}
|
||||
|
||||
func (msh *MShellProc) GetDefaultState() *packet.ShellState {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
return msh.StateMap[msh.CurrentState]
|
||||
func (msh *MShellProc) GetDefaultState(shellType string) *packet.ShellState {
|
||||
_, state := msh.StateMap.GetCurrentState(shellType)
|
||||
return state
|
||||
}
|
||||
|
||||
func (msh *MShellProc) GetDefaultStatePtr() *sstore.ShellStatePtr {
|
||||
func (msh *MShellProc) GetDefaultStatePtr(shellType string) *sstore.ShellStatePtr {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
if msh.CurrentState == "" {
|
||||
hash, _ := msh.StateMap.GetCurrentState(shellType)
|
||||
if hash == "" {
|
||||
return nil
|
||||
}
|
||||
return &sstore.ShellStatePtr{BaseHash: msh.CurrentState}
|
||||
return &sstore.ShellStatePtr{BaseHash: hash}
|
||||
}
|
||||
|
||||
func (msh *MShellProc) GetDefaultFeState() map[string]string {
|
||||
state := msh.GetDefaultState()
|
||||
func (msh *MShellProc) GetDefaultFeState(shellType string) map[string]string {
|
||||
state := msh.GetDefaultState(shellType)
|
||||
return sstore.FeStateFromShellState(state)
|
||||
}
|
||||
|
||||
func (msh *MShellProc) GetStateByHash(hval string) *packet.ShellState {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
return msh.StateMap[hval]
|
||||
}
|
||||
|
||||
func (msh *MShellProc) GetRemoteId() string {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
@ -494,7 +504,19 @@ func (msh *MShellProc) tryAutoInstall() {
|
||||
go msh.RunInstall()
|
||||
}
|
||||
|
||||
// if msh.IsConnected() then GetShellPref() should return a valid shell
|
||||
// if msh is not connected, then InitPkShellType might be empty
|
||||
func (msh *MShellProc) GetShellPref() string {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
if msh.Remote.ShellPref == sstore.ShellTypePref_Detect {
|
||||
return msh.InitPkShellType
|
||||
}
|
||||
return msh.Remote.ShellPref
|
||||
}
|
||||
|
||||
func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
|
||||
shellPref := msh.GetShellPref()
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
state := RemoteRuntimeState{
|
||||
@ -514,6 +536,8 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
|
||||
Local: msh.Remote.Local,
|
||||
NoInitPk: msh.ErrNoInitPk,
|
||||
AuthType: sstore.RemoteAuthTypeNone,
|
||||
ShellPref: msh.Remote.ShellPref,
|
||||
DefaultShellType: shellPref,
|
||||
}
|
||||
if msh.Remote.SSHOpts != nil {
|
||||
state.AuthType = msh.Remote.SSHOpts.GetAuthType()
|
||||
@ -580,7 +604,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
|
||||
vars["besthost"] = vars["remotehost"]
|
||||
vars["bestshorthost"] = vars["remoteshorthost"]
|
||||
}
|
||||
curState := msh.StateMap[msh.CurrentState]
|
||||
_, curState := msh.StateMap.GetCurrentState(shellPref)
|
||||
if curState != nil {
|
||||
state.DefaultFeState = sstore.FeStateFromShellState(curState)
|
||||
vars["cwd"] = curState.Cwd
|
||||
@ -601,6 +625,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState {
|
||||
vars["isroot"] = "1"
|
||||
}
|
||||
state.RemoteVars = vars
|
||||
state.ActiveShells = msh.StateMap.GetShells()
|
||||
return state
|
||||
}
|
||||
|
||||
@ -622,21 +647,6 @@ func GetAllRemoteRuntimeState() []RemoteRuntimeState {
|
||||
return rtn
|
||||
}
|
||||
|
||||
func GetDefaultRemoteStateById(remoteId string) (*packet.ShellState, error) {
|
||||
remote := GetRemoteById(remoteId)
|
||||
if remote == nil {
|
||||
return nil, fmt.Errorf("remote not found")
|
||||
}
|
||||
if !remote.IsConnected() {
|
||||
return nil, fmt.Errorf("remote not connected")
|
||||
}
|
||||
state := remote.GetDefaultState()
|
||||
if state == nil {
|
||||
return nil, fmt.Errorf("could not get default remote state")
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func MakeMShell(r *sstore.RemoteType) *MShellProc {
|
||||
buf, err := circbuf.NewBuffer(CircBufSize)
|
||||
if err != nil {
|
||||
@ -651,7 +661,7 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
|
||||
InstallStatus: StatusDisconnected,
|
||||
RunningCmds: make(map[base.CommandKey]RunCmdType),
|
||||
PendingStateCmds: make(map[pendingStateKey]base.CommandKey),
|
||||
StateMap: make(map[string]*packet.ShellState),
|
||||
StateMap: server.MakeShellStateMap(),
|
||||
}
|
||||
rtn.WriteToPtyBuffer("console for connection [%s]\n", r.GetName())
|
||||
return rtn
|
||||
@ -999,11 +1009,16 @@ func (msh *MShellProc) RunInstall() {
|
||||
msh.WriteToPtyBuffer("*error: cannot install on remote that is already trying to install, cancel current install to try again\n")
|
||||
return
|
||||
}
|
||||
sapi, err := shellapi.MakeShellApi(packet.ShellType_bash)
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error: %v\n", err)
|
||||
return
|
||||
}
|
||||
msh.WriteToPtyBuffer("installing mshell %s to %s...\n", scbase.MShellVersion, remoteCopy.RemoteCanonicalName)
|
||||
sshOpts := convertSSHOpts(remoteCopy.SSHOpts)
|
||||
sshOpts.SSHErrorsToTty = true
|
||||
cmdStr := shexec.MakeInstallCommandStr()
|
||||
ecmd := sshOpts.MakeSSHExecCmd(cmdStr)
|
||||
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi)
|
||||
cmdPty, err := msh.addControllingTty(ecmd)
|
||||
if err != nil {
|
||||
statusErr := fmt.Errorf("cannot attach controlling tty to mshell install command: %w", err)
|
||||
@ -1084,12 +1099,20 @@ func getStateVarsFromInitPk(initPk *packet.InitPacketType) map[string]string {
|
||||
rtn["remoteuser"] = initPk.User
|
||||
rtn["remotehost"] = initPk.HostName
|
||||
rtn["remoteuname"] = initPk.UName
|
||||
rtn["shelltype"] = initPk.Shell
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (msh *MShellProc) ReInit(ctx context.Context) (*packet.InitPacketType, error) {
|
||||
func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.ShellStatePacketType, error) {
|
||||
if !msh.IsConnected() {
|
||||
return nil, fmt.Errorf("cannot reinit, remote is not connected")
|
||||
}
|
||||
if shellType != packet.ShellType_bash && shellType != packet.ShellType_zsh {
|
||||
return nil, fmt.Errorf("invalid shell type %q", shellType)
|
||||
}
|
||||
reinitPk := packet.MakeReInitPacket()
|
||||
reinitPk.ReqId = uuid.New().String()
|
||||
reinitPk.ShellType = shellType
|
||||
resp, err := msh.PacketRpcRaw(ctx, reinitPk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1097,21 +1120,26 @@ func (msh *MShellProc) ReInit(ctx context.Context) (*packet.InitPacketType, erro
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("no response")
|
||||
}
|
||||
initPk, ok := resp.(*packet.InitPacketType)
|
||||
ssPk, ok := resp.(*packet.ShellStatePacketType)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid reinit response (not an initpacket): %T", resp)
|
||||
if respPk, ok := resp.(*packet.ResponsePacketType); ok && respPk.Error != "" {
|
||||
return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid reinit response (not an shellstate packet): %T", resp)
|
||||
}
|
||||
if initPk.State == nil {
|
||||
return nil, fmt.Errorf("invalid reinit response initpk does not contain remote state")
|
||||
if ssPk.State == nil {
|
||||
return nil, fmt.Errorf("invalid reinit response shellstate packet does not contain remote state")
|
||||
}
|
||||
hval := initPk.State.GetHashVal(false)
|
||||
sstore.StoreStateBase(ctx, initPk.State)
|
||||
msh.WithLock(func() {
|
||||
msh.CurrentState = hval
|
||||
msh.StateMap[hval] = initPk.State
|
||||
})
|
||||
msh.updateRemoteStateVars(ctx, msh.RemoteId, initPk)
|
||||
return initPk, nil
|
||||
// TODO: maybe we don't need to save statebase here. should be possible to save it on demand
|
||||
// when it is actually used. complication from other functions that try to get the statebase
|
||||
// from the DB. probably need to route those through MShellProc.
|
||||
err = sstore.StoreStateBase(ctx, ssPk.State)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error storing remote state: %w", err)
|
||||
}
|
||||
msh.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
|
||||
msh.WriteToPtyBuffer("initialized shell:%s state:%s\n", shellType, ssPk.State.GetHashVal(false))
|
||||
return ssPk, nil
|
||||
}
|
||||
|
||||
func (msh *MShellProc) StreamFile(ctx context.Context, streamPk *packet.StreamFilePacketType) (*packet.RpcResponseIter, error) {
|
||||
@ -1123,13 +1151,15 @@ func addScVarsToState(state *packet.ShellState) *packet.ShellState {
|
||||
return nil
|
||||
}
|
||||
rtn := *state
|
||||
envMap := shexec.DeclMapFromState(&rtn)
|
||||
envMap["PROMPT"] = &shexec.DeclareDeclType{Name: "PROMPT", Value: "1", Args: "x"}
|
||||
envMap["PROMPT_VERSION"] = &shexec.DeclareDeclType{Name: "PROMPT_VERSION", Value: scbase.WaveVersion, Args: "x"}
|
||||
envMap := shellenv.DeclMapFromState(&rtn)
|
||||
envMap["WAVETERM"] = &shellenv.DeclareDeclType{Name: "WAVETERM", Value: "1", Args: "x"}
|
||||
envMap["WAVETERM_VERSION"] = &shellenv.DeclareDeclType{Name: "WAVETERM_VERSION", Value: scbase.WaveVersion, Args: "x"}
|
||||
envMap["TERM_PROGRAM"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM", Value: "waveterm", Args: "x"}
|
||||
envMap["TERM_PROGRAM_VERSION"] = &shellenv.DeclareDeclType{Name: "TERM_PROGRAM_VERSION", Value: scbase.WaveVersion, Args: "x"}
|
||||
if _, exists := envMap["LANG"]; !exists {
|
||||
envMap["LANG"] = &shexec.DeclareDeclType{Name: "LANG", Value: scbase.DetermineLang(), Args: "x"}
|
||||
envMap["LANG"] = &shellenv.DeclareDeclType{Name: "LANG", Value: scbase.DetermineLang(), Args: "x"}
|
||||
}
|
||||
rtn.ShellVars = shexec.SerializeDeclMap(envMap)
|
||||
rtn.ShellVars = shellenv.SerializeDeclMap(envMap)
|
||||
return &rtn
|
||||
}
|
||||
|
||||
@ -1139,10 +1169,11 @@ func stripScVarsFromState(state *packet.ShellState) *packet.ShellState {
|
||||
}
|
||||
rtn := *state
|
||||
rtn.HashVal = ""
|
||||
envMap := shexec.DeclMapFromState(&rtn)
|
||||
delete(envMap, "PROMPT")
|
||||
delete(envMap, "PROMPT_VERSION")
|
||||
rtn.ShellVars = shexec.SerializeDeclMap(envMap)
|
||||
envMap := shellenv.DeclMapFromState(&rtn)
|
||||
for key := range envVarsToStrip {
|
||||
delete(envMap, key)
|
||||
}
|
||||
rtn.ShellVars = shellenv.SerializeDeclMap(envMap)
|
||||
return &rtn
|
||||
}
|
||||
|
||||
@ -1158,12 +1189,23 @@ func stripScVarsFromStateDiff(stateDiff *packet.ShellStateDiff) *packet.ShellSta
|
||||
log.Printf("error decoding statediff in stripScVarsFromStateDiff: %v\n", err)
|
||||
return stateDiff
|
||||
}
|
||||
delete(mapDiff.ToAdd, "PROMPT")
|
||||
delete(mapDiff.ToAdd, "PROMPT_VERSION")
|
||||
for key := range envVarsToStrip {
|
||||
delete(mapDiff.ToAdd, key)
|
||||
}
|
||||
rtn.VarsDiff = mapDiff.Encode()
|
||||
return &rtn
|
||||
}
|
||||
|
||||
func (msh *MShellProc) getActiveShellTypes(ctx context.Context) ([]string, error) {
|
||||
shellPref := msh.GetShellPref()
|
||||
rtn := []string{shellPref}
|
||||
activeShells, err := sstore.GetRemoteActiveShells(ctx, msh.RemoteId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return utilfn.CombineStrArrays(rtn, activeShells), nil
|
||||
}
|
||||
|
||||
func (msh *MShellProc) Launch(interactive bool) {
|
||||
remoteCopy := msh.GetRemoteCopy()
|
||||
if remoteCopy.Archived {
|
||||
@ -1179,6 +1221,11 @@ func (msh *MShellProc) Launch(interactive bool) {
|
||||
msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n")
|
||||
return
|
||||
}
|
||||
sapi, err := shellapi.MakeShellApi(msh.GetShellType())
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error, %v\n", err)
|
||||
return
|
||||
}
|
||||
istatus := msh.GetInstallStatus()
|
||||
if istatus == StatusConnecting {
|
||||
msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n")
|
||||
@ -1205,7 +1252,7 @@ func (msh *MShellProc) Launch(interactive bool) {
|
||||
} else {
|
||||
cmdStr = MakeServerCommandStr()
|
||||
}
|
||||
ecmd := sshOpts.MakeSSHExecCmd(cmdStr)
|
||||
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi)
|
||||
cmdPty, err := msh.addControllingTty(ecmd)
|
||||
if err != nil {
|
||||
statusErr := fmt.Errorf("cannot attach controlling tty to mshell command: %w", err)
|
||||
@ -1237,7 +1284,6 @@ func (msh *MShellProc) Launch(interactive bool) {
|
||||
cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, ecmd)
|
||||
// TODO check if initPk.State is not nil
|
||||
var mshellVersion string
|
||||
var stateBaseHash string
|
||||
var hitDeadline bool
|
||||
msh.WithLock(func() {
|
||||
msh.MakeClientCancelFn = nil
|
||||
@ -1255,16 +1301,9 @@ func (msh *MShellProc) Launch(interactive bool) {
|
||||
// only set NeedsMShellUpgrade if we got an InitPk
|
||||
msh.NeedsMShellUpgrade = true
|
||||
}
|
||||
msh.InitPkShellType = initPk.Shell
|
||||
}
|
||||
if initPk != nil && initPk.State != nil {
|
||||
hval := initPk.State.GetHashVal(false)
|
||||
msh.CurrentState = hval
|
||||
msh.StateMap[hval] = initPk.State
|
||||
sstore.StoreStateBase(context.Background(), initPk.State)
|
||||
stateBaseHash = hval
|
||||
} else {
|
||||
msh.CurrentState = ""
|
||||
}
|
||||
msh.StateMap.Clear()
|
||||
// no notify here, because we'll call notify in either case below
|
||||
})
|
||||
if err == context.Canceled {
|
||||
@ -1290,11 +1329,9 @@ func (msh *MShellProc) Launch(interactive bool) {
|
||||
return
|
||||
}
|
||||
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk)
|
||||
msh.WriteToPtyBuffer("connected state:%s\n", stateBaseHash)
|
||||
msh.WithLock(func() {
|
||||
msh.ServerProc = cproc
|
||||
msh.Status = StatusConnected
|
||||
go msh.NotifyRemoteUpdate()
|
||||
})
|
||||
go func() {
|
||||
exitErr := cproc.Cmd.Wait()
|
||||
@ -1308,15 +1345,40 @@ func (msh *MShellProc) Launch(interactive bool) {
|
||||
msh.WriteToPtyBuffer("*disconnected exitcode=%d\n", exitCode)
|
||||
}()
|
||||
go msh.ProcessPackets()
|
||||
msh.initActiveShells()
|
||||
go msh.NotifyRemoteUpdate()
|
||||
return
|
||||
}
|
||||
|
||||
func (msh *MShellProc) initActiveShells() {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
activeShells, err := msh.getActiveShellTypes(ctx)
|
||||
if err != nil {
|
||||
// we're not going to fail the connect for this error (it will be unusable, but technically connected)
|
||||
msh.WriteToPtyBuffer("*error getting active shells: %v\n", err)
|
||||
return
|
||||
}
|
||||
for _, shellType := range activeShells {
|
||||
_, err = msh.ReInit(ctx, shellType)
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (msh *MShellProc) IsConnected() bool {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
return msh.Status == StatusConnected
|
||||
}
|
||||
|
||||
func (msh *MShellProc) GetShellType() string {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
return msh.InitPkShellType
|
||||
}
|
||||
|
||||
func replaceHomePath(pathStr string, homeDir string) string {
|
||||
if homeDir == "" {
|
||||
return pathStr
|
||||
@ -1451,13 +1513,13 @@ func RunCommand(ctx context.Context, sessionId string, screenId string, remotePt
|
||||
// get current remote-instance state
|
||||
statePtr, err := sstore.GetRemoteStatePtr(ctx, sessionId, screenId, remotePtr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("cannot get current remote stateptr: %w", err)
|
||||
return nil, nil, fmt.Errorf("cannot get current connection stateptr: %w", err)
|
||||
}
|
||||
if statePtr == nil {
|
||||
statePtr = msh.GetDefaultStatePtr()
|
||||
statePtr = msh.GetDefaultStatePtr(msh.GetShellPref())
|
||||
}
|
||||
if statePtr == nil {
|
||||
return nil, nil, fmt.Errorf("cannot run command, no valid remote stateptr")
|
||||
return nil, nil, fmt.Errorf("cannot run command, no valid connection stateptr")
|
||||
}
|
||||
currentState, err := sstore.GetFullState(ctx, *statePtr)
|
||||
if err != nil || currentState == nil {
|
||||
@ -1465,6 +1527,15 @@ func RunCommand(ctx context.Context, sessionId string, screenId string, remotePt
|
||||
}
|
||||
runPacket.State = addScVarsToState(currentState)
|
||||
runPacket.StateComplete = true
|
||||
runPacket.ShellType = currentState.GetShellType()
|
||||
// check to see if shellType is initialized
|
||||
if !msh.StateMap.HasShell(runPacket.ShellType) {
|
||||
// try to reinit the shell
|
||||
_, err := msh.ReInit(ctx, runPacket.ShellType)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error trying to initialize shell %q: %v", runPacket.ShellType, err)
|
||||
}
|
||||
}
|
||||
msh.ServerProc.Output.RegisterRpc(runPacket.ReqId)
|
||||
err = shexec.SendRunPacketAndRunData(ctx, msh.ServerProc.Input, runPacket)
|
||||
if err != nil {
|
||||
@ -1653,6 +1724,8 @@ func (msh *MShellProc) notifyHangups_nolock() {
|
||||
}
|
||||
|
||||
func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
// this will remove from RunningCmds and from PendingStateCmds
|
||||
defer msh.RemoveRunningCmd(donePk.CK)
|
||||
if donePk.FinalState != nil {
|
||||
@ -1661,12 +1734,12 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
|
||||
if donePk.FinalStateDiff != nil {
|
||||
donePk.FinalStateDiff = stripScVarsFromStateDiff(donePk.FinalStateDiff)
|
||||
}
|
||||
update, err := sstore.UpdateCmdDoneInfo(context.Background(), donePk.CK, donePk, sstore.CmdStatusDone)
|
||||
update, err := sstore.UpdateCmdDoneInfo(ctx, donePk.CK, donePk, sstore.CmdStatusDone)
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error updating cmddone: %v\n", err)
|
||||
return
|
||||
}
|
||||
screen, err := sstore.UpdateScreenFocusForDoneCmd(context.Background(), donePk.CK.GetGroupId(), donePk.CK.GetCmdId())
|
||||
screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, donePk.CK.GetGroupId(), donePk.CK.GetCmdId())
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error trying to update screen focus type: %v\n", err)
|
||||
// fall-through (nothing to do)
|
||||
@ -1678,7 +1751,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
|
||||
var statePtr *sstore.ShellStatePtr
|
||||
if donePk.FinalState != nil && rct != nil {
|
||||
feState := sstore.FeStateFromShellState(donePk.FinalState)
|
||||
remoteInst, err := sstore.UpdateRemoteState(context.Background(), rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, donePk.FinalState, nil)
|
||||
remoteInst, err := sstore.UpdateRemoteState(ctx, rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, donePk.FinalState, nil)
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error trying to update remotestate: %v\n", err)
|
||||
// fall-through (nothing to do)
|
||||
@ -1693,7 +1766,12 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
|
||||
msh.WriteToPtyBuffer("*error trying to update remotestate: %v\n", err)
|
||||
// fall-through (nothing to do)
|
||||
} else {
|
||||
remoteInst, err := sstore.UpdateRemoteState(context.Background(), rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, nil, donePk.FinalStateDiff)
|
||||
stateDiff := donePk.FinalStateDiff
|
||||
fullState := msh.StateMap.GetStateByHash(stateDiff.GetShellType(), stateDiff.BaseHash)
|
||||
if fullState != nil {
|
||||
sstore.StoreStateBase(ctx, fullState)
|
||||
}
|
||||
remoteInst, err := sstore.UpdateRemoteState(ctx, rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, nil, stateDiff)
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error trying to update remotestate: %v\n", err)
|
||||
// fall-through (nothing to do)
|
||||
@ -1707,7 +1785,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
|
||||
}
|
||||
}
|
||||
if statePtr != nil {
|
||||
err = sstore.UpdateCmdRtnState(context.Background(), donePk.CK, *statePtr)
|
||||
err = sstore.UpdateCmdRtnState(ctx, donePk.CK, *statePtr)
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error trying to update cmd rtnstate: %v\n", err)
|
||||
// fall-through (nothing to do)
|
||||
@ -1951,7 +2029,7 @@ func evalPromptEsc(escCode string, vars map[string]string, state *packet.ShellSt
|
||||
return ""
|
||||
}
|
||||
varName := escCode[2 : len(escCode)-1]
|
||||
varMap := shexec.ShellVarMapFromState(state)
|
||||
varMap := shellenv.ShellVarMapFromState(state)
|
||||
return varMap[varName]
|
||||
}
|
||||
if escCode == "h" {
|
||||
@ -2024,43 +2102,56 @@ func evalPromptEsc(escCode string, vars map[string]string, state *packet.ShellSt
|
||||
return "(" + escCode + ")"
|
||||
}
|
||||
|
||||
func (msh *MShellProc) getFullState(stateDiff *packet.ShellStateDiff) (*packet.ShellState, error) {
|
||||
baseState := msh.GetStateByHash(stateDiff.BaseHash)
|
||||
func (msh *MShellProc) getFullState(shellType string, stateDiff *packet.ShellStateDiff) (*packet.ShellState, error) {
|
||||
baseState := msh.StateMap.GetStateByHash(shellType, stateDiff.BaseHash)
|
||||
if baseState != nil && len(stateDiff.DiffHashArr) == 0 {
|
||||
newState, err := shexec.ApplyShellStateDiff(*baseState, *stateDiff)
|
||||
sapi, err := shellapi.MakeShellApi(baseState.GetShellType())
|
||||
newState, err := sapi.ApplyShellStateDiff(baseState, stateDiff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &newState, nil
|
||||
return newState, nil
|
||||
} else {
|
||||
fullState, err := sstore.GetFullState(context.Background(), sstore.ShellStatePtr{BaseHash: stateDiff.BaseHash, DiffHashArr: stateDiff.DiffHashArr})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newState, err := shexec.ApplyShellStateDiff(*fullState, *stateDiff)
|
||||
return &newState, nil
|
||||
sapi, err := shellapi.MakeShellApi(fullState.GetShellType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newState, err := sapi.ApplyShellStateDiff(fullState, stateDiff)
|
||||
return newState, nil
|
||||
}
|
||||
}
|
||||
|
||||
// internal func, first tries the StateMap, otherwise will fallback on sstore.GetFullState
|
||||
func (msh *MShellProc) getFeStateFromDiff(stateDiff *packet.ShellStateDiff) (map[string]string, error) {
|
||||
baseState := msh.GetStateByHash(stateDiff.BaseHash)
|
||||
baseState := msh.StateMap.GetStateByHash(stateDiff.GetShellType(), stateDiff.BaseHash)
|
||||
if baseState != nil && len(stateDiff.DiffHashArr) == 0 {
|
||||
newState, err := shexec.ApplyShellStateDiff(*baseState, *stateDiff)
|
||||
sapi, err := shellapi.MakeShellApi(baseState.GetShellType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sstore.FeStateFromShellState(&newState), nil
|
||||
newState, err := sapi.ApplyShellStateDiff(baseState, stateDiff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sstore.FeStateFromShellState(newState), nil
|
||||
} else {
|
||||
fullState, err := sstore.GetFullState(context.Background(), sstore.ShellStatePtr{BaseHash: stateDiff.BaseHash, DiffHashArr: stateDiff.DiffHashArr})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newState, err := shexec.ApplyShellStateDiff(*fullState, *stateDiff)
|
||||
sapi, err := shellapi.MakeShellApi(fullState.GetShellType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sstore.FeStateFromShellState(&newState), nil
|
||||
newState, err := sapi.ApplyShellStateDiff(fullState, stateDiff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sstore.FeStateFromShellState(newState), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11,10 +11,11 @@ import (
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/simpleexpand"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
@ -106,27 +107,95 @@ const MaxDiffKeyLen = 40
|
||||
const MaxDiffValLen = 50
|
||||
|
||||
var IgnoreVars = map[string]bool{
|
||||
"PROMPT": true,
|
||||
"PROMPT_VERSION": true,
|
||||
"MSHELL": true,
|
||||
"MSHELL_VERSION": true,
|
||||
"WAVESHELL": true,
|
||||
"WAVESHELL_VERSION": true,
|
||||
"PROMPT": true,
|
||||
"PROMPT_VERSION": true,
|
||||
"MSHELL": true,
|
||||
"MSHELL_VERSION": true,
|
||||
"WAVESHELL": true,
|
||||
"WAVESHELL_VERSION": true,
|
||||
"WAVETERM": true,
|
||||
"WAVETERM_VERSION": true,
|
||||
"TERM_PROGRAM": true,
|
||||
"TERM_PROGRAM_VERSION": true,
|
||||
"TERM_SESSION_ID": true,
|
||||
}
|
||||
|
||||
func displayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newState packet.ShellState) {
|
||||
func makeBashAliasesDiff(buf *bytes.Buffer, oldAliases string, newAliases string) {
|
||||
newAliasMap, _ := ParseAliases(newAliases)
|
||||
oldAliasMap, _ := ParseAliases(oldAliases)
|
||||
for aliasName, newAliasVal := range newAliasMap {
|
||||
oldAliasVal, found := oldAliasMap[aliasName]
|
||||
if !found || newAliasVal != oldAliasVal {
|
||||
buf.WriteString(fmt.Sprintf("alias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
for aliasName := range oldAliasMap {
|
||||
_, found := newAliasMap[aliasName]
|
||||
if !found {
|
||||
buf.WriteString(fmt.Sprintf("unalias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeZshAlisesDiff(buf *bytes.Buffer, oldAliases string, newAliases string) {
|
||||
newAliasMap, err := shellapi.DecodeZshMap([]byte(newAliases))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
oldAliasMap, err := shellapi.DecodeZshMap([]byte(oldAliases))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for aliasKey, newAliasVal := range newAliasMap {
|
||||
oldAliasVal, found := oldAliasMap[aliasKey]
|
||||
if !found || newAliasVal != oldAliasVal {
|
||||
buf.WriteString(fmt.Sprintf("%s %s=%s\n", aliasKey.ParamType, aliasKey.ParamName, utilfn.EllipsisStr(shellescape.Quote(newAliasVal), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
for aliasKey := range oldAliasMap {
|
||||
_, found := newAliasMap[aliasKey]
|
||||
if !found {
|
||||
buf.WriteString(fmt.Sprintf("remove %s %s\n", aliasKey.ParamType, aliasKey.ParamName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeZshFuncsDiff(buf *bytes.Buffer, oldFuncs string, newFuncs string) {
|
||||
newFuncMap, err := shellapi.DecodeZshMap([]byte(newFuncs))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
oldFuncMap, err := shellapi.DecodeZshMap([]byte(oldFuncs))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for funcKey, newFuncVal := range newFuncMap {
|
||||
oldFuncVal, found := oldFuncMap[funcKey]
|
||||
if !found || newFuncVal != oldFuncVal {
|
||||
buf.WriteString(fmt.Sprintf("%s %s\n", funcKey.ParamType, funcKey.ParamName))
|
||||
}
|
||||
}
|
||||
for funcKey := range oldFuncMap {
|
||||
_, found := newFuncMap[funcKey]
|
||||
if !found {
|
||||
buf.WriteString(fmt.Sprintf("remove %s %s\n", funcKey.ParamType, funcKey.ParamName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DisplayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newState packet.ShellState) {
|
||||
if newState.Cwd != oldState.Cwd {
|
||||
buf.WriteString(fmt.Sprintf("cwd %s\n", newState.Cwd))
|
||||
}
|
||||
if !bytes.Equal(newState.ShellVars, oldState.ShellVars) {
|
||||
newEnvMap := shexec.DeclMapFromState(&newState)
|
||||
oldEnvMap := shexec.DeclMapFromState(&oldState)
|
||||
newEnvMap := shellenv.DeclMapFromState(&newState)
|
||||
oldEnvMap := shellenv.DeclMapFromState(&oldState)
|
||||
for key, newVal := range newEnvMap {
|
||||
if IgnoreVars[key] {
|
||||
continue
|
||||
}
|
||||
oldVal, found := oldEnvMap[key]
|
||||
if !found || !shexec.DeclsEqual(false, oldVal, newVal) {
|
||||
if !found || !shellenv.DeclsEqual(false, oldVal, newVal) {
|
||||
var exportStr string
|
||||
if newVal.IsExport() {
|
||||
exportStr = "export "
|
||||
@ -134,7 +203,7 @@ func displayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newSt
|
||||
buf.WriteString(fmt.Sprintf("%s%s=%s\n", exportStr, utilfn.EllipsisStr(key, MaxDiffKeyLen), utilfn.EllipsisStr(newVal.Value, MaxDiffValLen)))
|
||||
}
|
||||
}
|
||||
for key, _ := range oldEnvMap {
|
||||
for key := range oldEnvMap {
|
||||
if IgnoreVars[key] {
|
||||
continue
|
||||
}
|
||||
@ -144,36 +213,31 @@ func displayStateUpdateDiff(buf *bytes.Buffer, oldState packet.ShellState, newSt
|
||||
}
|
||||
}
|
||||
}
|
||||
if newState.Aliases != oldState.Aliases {
|
||||
newAliasMap, _ := ParseAliases(newState.Aliases)
|
||||
oldAliasMap, _ := ParseAliases(oldState.Aliases)
|
||||
for aliasName, newAliasVal := range newAliasMap {
|
||||
oldAliasVal, found := oldAliasMap[aliasName]
|
||||
if !found || newAliasVal != oldAliasVal {
|
||||
buf.WriteString(fmt.Sprintf("alias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
for aliasName, _ := range oldAliasMap {
|
||||
_, found := newAliasMap[aliasName]
|
||||
if !found {
|
||||
buf.WriteString(fmt.Sprintf("unalias %s\n", utilfn.EllipsisStr(shellescape.Quote(aliasName), MaxDiffKeyLen)))
|
||||
}
|
||||
if newState.GetShellType() == packet.ShellType_zsh {
|
||||
makeZshAlisesDiff(buf, oldState.Aliases, newState.Aliases)
|
||||
makeZshFuncsDiff(buf, oldState.Funcs, newState.Funcs)
|
||||
} else {
|
||||
makeBashAliasesDiff(buf, oldState.Aliases, newState.Aliases)
|
||||
makeBashFuncsDiff(newState, oldState, buf)
|
||||
}
|
||||
}
|
||||
|
||||
func makeBashFuncsDiff(newState packet.ShellState, oldState packet.ShellState, buf *bytes.Buffer) {
|
||||
if newState.Funcs == oldState.Funcs {
|
||||
return
|
||||
}
|
||||
newFuncMap, _ := ParseFuncs(newState.Funcs)
|
||||
oldFuncMap, _ := ParseFuncs(oldState.Funcs)
|
||||
for funcName, newFuncVal := range newFuncMap {
|
||||
oldFuncVal, found := oldFuncMap[funcName]
|
||||
if !found || newFuncVal != oldFuncVal {
|
||||
buf.WriteString(fmt.Sprintf("function %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
if newState.Funcs != oldState.Funcs {
|
||||
newFuncMap, _ := ParseFuncs(newState.Funcs)
|
||||
oldFuncMap, _ := ParseFuncs(oldState.Funcs)
|
||||
for funcName, newFuncVal := range newFuncMap {
|
||||
oldFuncVal, found := oldFuncMap[funcName]
|
||||
if !found || newFuncVal != oldFuncVal {
|
||||
buf.WriteString(fmt.Sprintf("function %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
for funcName, _ := range oldFuncMap {
|
||||
_, found := newFuncMap[funcName]
|
||||
if !found {
|
||||
buf.WriteString(fmt.Sprintf("unset -f %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
|
||||
}
|
||||
for funcName := range oldFuncMap {
|
||||
_, found := newFuncMap[funcName]
|
||||
if !found {
|
||||
buf.WriteString(fmt.Sprintf("unset -f %s\n", utilfn.EllipsisStr(shellescape.Quote(funcName), MaxDiffKeyLen)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -201,6 +265,6 @@ func GetRtnStateDiff(ctx context.Context, screenId string, lineId string) ([]byt
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting rtn full state: %v", err)
|
||||
}
|
||||
displayStateUpdateDiff(&outputBytes, *initialState, *rtnState)
|
||||
DisplayStateUpdateDiff(&outputBytes, *initialState, *rtnState)
|
||||
return outputBytes.Bytes(), nil
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ const WaveDevDirName = ".waveterm-dev" // must match emain.ts
|
||||
const WaveAppPathVarName = "WAVETERM_APP_PATH"
|
||||
const WaveVersion = "v0.5.3"
|
||||
const WaveAuthKeyFileName = "waveterm.authkey"
|
||||
const MShellVersion = "v0.3.0"
|
||||
const MShellVersion = "v0.4.0"
|
||||
|
||||
var SessionDirCache = make(map[string]string)
|
||||
var ScreenDirCache = make(map[string]string)
|
||||
|
@ -11,8 +11,8 @@ import (
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
)
|
||||
|
||||
const FeCommandPacketStr = "fecmd"
|
||||
|
@ -6,7 +6,7 @@ package shparse
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
var noEscChars []bool
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
//
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
// $(ls f[*]); ./x
|
||||
|
@ -18,7 +18,8 @@ import (
|
||||
"github.com/sawka/txwrap"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/dbutil"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||
)
|
||||
@ -217,8 +218,8 @@ func UpsertRemote(ctx context.Context, r *RemoteType) error {
|
||||
maxRemoteIdx := tx.GetInt(query)
|
||||
r.RemoteIdx = int64(maxRemoteIdx + 1)
|
||||
query = `INSERT INTO remote
|
||||
( remoteid, remotetype, remotealias, remotecanonicalname, remoteuser, remotehost, connectmode, autoinstall, sshopts, remoteopts, lastconnectts, archived, remoteidx, local, statevars, sshconfigsrc, openaiopts) VALUES
|
||||
(:remoteid,:remotetype,:remotealias,:remotecanonicalname,:remoteuser,:remotehost,:connectmode,:autoinstall,:sshopts,:remoteopts,:lastconnectts,:archived,:remoteidx,:local,:statevars,:sshconfigsrc,:openaiopts)`
|
||||
( remoteid, remotetype, remotealias, remotecanonicalname, remoteuser, remotehost, connectmode, autoinstall, sshopts, remoteopts, lastconnectts, archived, remoteidx, local, statevars, sshconfigsrc, openaiopts, shellpref) VALUES
|
||||
(:remoteid,:remotetype,:remotealias,:remotecanonicalname,:remoteuser,:remotehost,:connectmode,:autoinstall,:sshopts,:remoteopts,:lastconnectts,:archived,:remoteidx,:local,:statevars,:sshconfigsrc,:openaiopts,:shellpref)`
|
||||
tx.NamedExec(query, r.ToMap())
|
||||
return nil
|
||||
})
|
||||
@ -1288,11 +1289,12 @@ func GetRemoteInstance(ctx context.Context, sessionId string, screenId string, r
|
||||
return ri, nil
|
||||
}
|
||||
|
||||
// internal function for UpdateRemoteState
|
||||
// internal function for UpdateRemoteState (sets StateBaseHash, StateDiffHashArr, and ShellType)
|
||||
func updateRIWithState(ctx context.Context, ri *RemoteInstance, stateBase *packet.ShellState, stateDiff *packet.ShellStateDiff) error {
|
||||
if stateBase != nil {
|
||||
ri.StateBaseHash = stateBase.GetHashVal(false)
|
||||
ri.StateDiffHashArr = nil
|
||||
ri.ShellType = stateBase.GetShellType()
|
||||
err := StoreStateBase(ctx, stateBase)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1300,6 +1302,7 @@ func updateRIWithState(ctx context.Context, ri *RemoteInstance, stateBase *packe
|
||||
} else if stateDiff != nil {
|
||||
ri.StateBaseHash = stateDiff.BaseHash
|
||||
ri.StateDiffHashArr = append(stateDiff.DiffHashArr, stateDiff.GetHashVal(false))
|
||||
ri.ShellType = stateDiff.GetShellType()
|
||||
err := StoreStateDiff(ctx, stateDiff)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1340,18 +1343,18 @@ func UpdateRemoteState(ctx context.Context, sessionId string, screenId string, r
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query = `INSERT INTO remote_instance ( riid, name, sessionid, screenid, remoteownerid, remoteid, festate, statebasehash, statediffhasharr)
|
||||
VALUES (:riid,:name,:sessionid,:screenid,:remoteownerid,:remoteid,:festate,:statebasehash,:statediffhasharr)`
|
||||
query = `INSERT INTO remote_instance ( riid, name, sessionid, screenid, remoteownerid, remoteid, festate, statebasehash, statediffhasharr, shelltype)
|
||||
VALUES (:riid,:name,:sessionid,:screenid,:remoteownerid,:remoteid,:festate,:statebasehash,:statediffhasharr,:shelltype)`
|
||||
tx.NamedExec(query, ri.ToMap())
|
||||
return nil
|
||||
} else {
|
||||
query = `UPDATE remote_instance SET festate = ?, statebasehash = ?, statediffhasharr = ? WHERE riid = ?`
|
||||
query = `UPDATE remote_instance SET festate = ?, statebasehash = ?, statediffhasharr = ?, shelltype = ? WHERE riid = ?`
|
||||
ri.FeState = feState
|
||||
err = updateRIWithState(tx.Context(), ri, stateBase, stateDiff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx.Exec(query, quickJson(ri.FeState), ri.StateBaseHash, quickJsonArr(ri.StateDiffHashArr), ri.RIId)
|
||||
tx.Exec(query, quickJson(ri.FeState), ri.StateBaseHash, quickJsonArr(ri.StateDiffHashArr), ri.ShellType, ri.RIId)
|
||||
return nil
|
||||
}
|
||||
})
|
||||
@ -1730,9 +1733,11 @@ const (
|
||||
RemoteField_SSHKey = "sshkey" // string
|
||||
RemoteField_SSHPassword = "sshpassword" // string
|
||||
RemoteField_Color = "color" // string
|
||||
RemoteField_ShellPref = "shellpref" // string
|
||||
)
|
||||
|
||||
// editMap: alias, connectmode, autoinstall, sshkey, color, sshpassword (from constants)
|
||||
// note that all validation should have already happened outside of this function
|
||||
func UpdateRemote(ctx context.Context, remoteId string, editMap map[string]interface{}) (*RemoteType, error) {
|
||||
var rtn *RemoteType
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
@ -1760,6 +1765,10 @@ func UpdateRemote(ctx context.Context, remoteId string, editMap map[string]inter
|
||||
query = `UPDATE remote SET sshopts = json_set(sshopts, '$.sshpassword', ?) WHERE remoteid = ?`
|
||||
tx.Exec(query, sshPassword, remoteId)
|
||||
}
|
||||
if shellPref, found := editMap[RemoteField_ShellPref]; found {
|
||||
query = `UPDATE remote SET shellpref = ? WHERE remoteid = ?`
|
||||
tx.Exec(query, shellPref, remoteId)
|
||||
}
|
||||
if color, found := editMap[RemoteField_Color]; found {
|
||||
query = `UPDATE remote SET remoteopts = json_set(remoteopts, '$.color', ?) WHERE remoteid = ?`
|
||||
tx.Exec(query, color, remoteId)
|
||||
@ -1958,22 +1967,26 @@ func GetFullState(ctx context.Context, ssPtr ShellStatePtr) (*packet.ShellState,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sapi, err := shellapi.MakeShellApi(state.GetShellType())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for idx, diffHash := range ssPtr.DiffHashArr {
|
||||
query = `SELECT * FROM state_diff WHERE diffhash = ?`
|
||||
stateDiff := dbutil.GetMapGen[*StateDiff](tx, query, diffHash)
|
||||
if stateDiff == nil {
|
||||
return fmt.Errorf("ShellStateDiff %s not found", diffHash)
|
||||
}
|
||||
var ssDiff packet.ShellStateDiff
|
||||
ssDiff := &packet.ShellStateDiff{}
|
||||
err = ssDiff.DecodeShellStateDiff(stateDiff.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newState, err := shexec.ApplyShellStateDiff(*state, ssDiff)
|
||||
newState, err := sapi.ApplyShellStateDiff(state, ssDiff)
|
||||
if err != nil {
|
||||
return fmt.Errorf("GetFullState, diff[%d]:%s: %v", idx, diffHash, err)
|
||||
}
|
||||
state = &newState
|
||||
state = newState
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@ -2750,3 +2763,15 @@ func SetWebPtyPos(ctx context.Context, screenId string, lineId string, ptyPos in
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func GetRemoteActiveShells(ctx context.Context, remoteId string) ([]string, error) {
|
||||
return WithTxRtn(ctx, func(tx *TxWrap) ([]string, error) {
|
||||
query := `SELECT * FROM remote_instance WHERE remoteid = ?`
|
||||
riArr := dbutil.SelectMapsGen[*RemoteInstance](tx, query, remoteId)
|
||||
shellTypeMap := make(map[string]bool)
|
||||
for _, ri := range riArr {
|
||||
shellTypeMap[ri.ShellType] = true
|
||||
}
|
||||
return utilfn.GetMapKeys(shellTypeMap), nil
|
||||
})
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
// global lock for all memory operations
|
||||
|
@ -22,10 +22,11 @@ import (
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
)
|
||||
|
||||
const MaxMigration = 29
|
||||
const MaxMigration = 30
|
||||
const MigratePrimaryScreenVersion = 9
|
||||
const CmdScreenSpecialMigration = 13
|
||||
const CmdLineSpecialMigration = 20
|
||||
const RISpecialMigration = 30
|
||||
|
||||
func MakeMigrate() (*migrate.Migrate, error) {
|
||||
fsVar, err := iofs.New(sh2db.MigrationFS, "migrations")
|
||||
@ -84,6 +85,12 @@ func MigrateUpStep(m *migrate.Migrate, newVersion uint) error {
|
||||
return fmt.Errorf("migrating to v%d: %w", newVersion, mErr)
|
||||
}
|
||||
}
|
||||
if newVersion == RISpecialMigration {
|
||||
mErr := RunMigration30()
|
||||
if mErr != nil {
|
||||
return fmt.Errorf("migrating to v%d: %w", newVersion, mErr)
|
||||
}
|
||||
}
|
||||
log.Printf("[db] migration v%d, elapsed %v\n", newVersion, time.Since(startTime))
|
||||
return nil
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ import (
|
||||
"github.com/sawka/txwrap"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shexec"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/dbutil"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||
|
||||
@ -47,6 +47,10 @@ const LocalRemoteAlias = "local"
|
||||
const DefaultCwd = "~"
|
||||
const APITokenSentinel = "--apitoken--"
|
||||
|
||||
// defined here and not in packet.go since this value should never
|
||||
// be passed to waveshell (it should always get resolved prior to sending a run packet)
|
||||
const ShellTypePref_Detect = "detect"
|
||||
|
||||
const (
|
||||
LineTypeCmd = "cmd"
|
||||
LineTypeText = "text"
|
||||
@ -694,6 +698,7 @@ type RemoteInstance struct {
|
||||
RemoteOwnerId string `json:"remoteownerid"`
|
||||
RemoteId string `json:"remoteid"`
|
||||
FeState map[string]string `json:"festate"`
|
||||
ShellType string `json:"shelltype"`
|
||||
StateBaseHash string `json:"-"`
|
||||
StateDiffHashArr []string `json:"-"`
|
||||
|
||||
@ -741,15 +746,19 @@ func FeStateFromShellState(state *packet.ShellState) map[string]string {
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
rtn["cwd"] = state.Cwd
|
||||
envMap := shexec.EnvMapFromState(state)
|
||||
envMap := shellenv.EnvMapFromState(state)
|
||||
if envMap["VIRTUAL_ENV"] != "" {
|
||||
rtn["VIRTUAL_ENV"] = envMap["VIRTUAL_ENV"]
|
||||
}
|
||||
for key, val := range envMap {
|
||||
if strings.HasPrefix(key, "PROMPTVAR_") && rtn[key] != "" {
|
||||
if strings.HasPrefix(key, "PROMPTVAR_") && envMap[key] != "" {
|
||||
rtn[key] = val
|
||||
}
|
||||
}
|
||||
_, _, err := packet.ParseShellStateVersion(state.Version)
|
||||
if err != nil {
|
||||
rtn["invalidstate"] = "1"
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
@ -763,6 +772,7 @@ func (ri *RemoteInstance) FromMap(m map[string]interface{}) bool {
|
||||
quickSetJson(&ri.FeState, m, "festate")
|
||||
quickSetStr(&ri.StateBaseHash, m, "statebasehash")
|
||||
quickSetJsonArr(&ri.StateDiffHashArr, m, "statediffhasharr")
|
||||
quickSetStr(&ri.ShellType, m, "shelltype")
|
||||
return true
|
||||
}
|
||||
|
||||
@ -777,6 +787,7 @@ func (ri *RemoteInstance) ToMap() map[string]interface{} {
|
||||
rtn["festate"] = quickJson(ri.FeState)
|
||||
rtn["statebasehash"] = ri.StateBaseHash
|
||||
rtn["statediffhasharr"] = quickJsonArr(ri.StateDiffHashArr)
|
||||
rtn["shelltype"] = ri.ShellType
|
||||
return rtn
|
||||
}
|
||||
|
||||
@ -1014,6 +1025,9 @@ type RemoteRuntimeState struct {
|
||||
Local bool `json:"local,omitempty"`
|
||||
RemoteOpts *RemoteOptsType `json:"remoteopts,omitempty"`
|
||||
CanComplete bool `json:"cancomplete,omitempty"`
|
||||
ActiveShells []string `json:"activeshells,omitempty"`
|
||||
ShellPref string `json:"shellpref,omitempty"`
|
||||
DefaultShellType string `json:"defaultshelltype,omitempty"`
|
||||
}
|
||||
|
||||
func (state RemoteRuntimeState) IsConnected() bool {
|
||||
@ -1068,8 +1082,9 @@ type RemoteType struct {
|
||||
SSHOpts *SSHOpts `json:"sshopts"`
|
||||
StateVars map[string]string `json:"statevars"`
|
||||
SSHConfigSrc string `json:"sshconfigsrc"`
|
||||
ShellPref string `json:"shellpref"` // bash, zsh, or detect
|
||||
|
||||
// OpenAI fields
|
||||
// OpenAI fields (unused)
|
||||
OpenAIOpts *OpenAIOptsType `json:"openaiopts,omitempty"`
|
||||
}
|
||||
|
||||
@ -1125,6 +1140,7 @@ func (r *RemoteType) ToMap() map[string]interface{} {
|
||||
rtn["statevars"] = quickJson(r.StateVars)
|
||||
rtn["sshconfigsrc"] = r.SSHConfigSrc
|
||||
rtn["openaiopts"] = quickJson(r.OpenAIOpts)
|
||||
rtn["shellpref"] = r.ShellPref
|
||||
return rtn
|
||||
}
|
||||
|
||||
@ -1146,6 +1162,7 @@ func (r *RemoteType) FromMap(m map[string]interface{}) bool {
|
||||
quickSetJson(&r.StateVars, m, "statevars")
|
||||
quickSetStr(&r.SSHConfigSrc, m, "sshconfigsrc")
|
||||
quickSetJson(&r.OpenAIOpts, m, "openaiopts")
|
||||
quickSetStr(&r.ShellPref, m, "shellpref")
|
||||
return true
|
||||
}
|
||||
|
||||
@ -1308,6 +1325,7 @@ func EnsureLocalRemote(ctx context.Context) error {
|
||||
SSHOpts: &SSHOpts{Local: true},
|
||||
Local: true,
|
||||
SSHConfigSrc: SSHConfigSrcTypeManual,
|
||||
ShellPref: ShellTypePref_Detect,
|
||||
}
|
||||
err = UpsertRemote(ctx, localRemote)
|
||||
if err != nil {
|
||||
@ -1327,6 +1345,7 @@ func EnsureLocalRemote(ctx context.Context) error {
|
||||
RemoteOpts: &RemoteOptsType{Color: "red"},
|
||||
Local: true,
|
||||
SSHConfigSrc: SSHConfigSrcTypeManual,
|
||||
ShellPref: ShellTypePref_Detect,
|
||||
}
|
||||
err = UpsertRemote(ctx, sudoRemote)
|
||||
if err != nil {
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||
)
|
||||
|
||||
@ -34,6 +35,36 @@ func getSliceChunk[T any](slice []T, chunkSize int) ([]T, []T) {
|
||||
return slice[0:chunkSize], slice[chunkSize:]
|
||||
}
|
||||
|
||||
// we're going to mark any invalid basestate versions as "invalid"
|
||||
// so we can give a better error message for the FE and prompt a reset
|
||||
func RunMigration30() error {
|
||||
ctx := context.Background()
|
||||
startTime := time.Now()
|
||||
updateCount := 0
|
||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||
query := `SELECT riid FROM remote_instance`
|
||||
riidArr := tx.SelectStrings(query)
|
||||
for _, riid := range riidArr {
|
||||
query = `SELECT version FROM state_base WHERE basehash = (SELECT statebasehash FROM remote_instance WHERE riid = ?)`
|
||||
version := tx.GetString(query, riid)
|
||||
_, _, err := packet.ParseShellStateVersion(version)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
// deal with bad versions by marking festate with an invalidshellstate flag
|
||||
query = `UPDATE remote_instance SET festate = json_set(festate, '$.invalidshellstate', '1') WHERE riid = ?`
|
||||
tx.Exec(query, riid)
|
||||
updateCount++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return fmt.Errorf("error running remote-instance v30 migration: %w", txErr)
|
||||
}
|
||||
log.Printf("[db] remote-instance v30 migration done: %v (%d bad versions)\n", time.Since(startTime), updateCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunMigration20() error {
|
||||
ctx := context.Background()
|
||||
startTime := time.Now()
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
var MainBus *UpdateBus = MakeUpdateBus()
|
||||
|
@ -1,249 +0,0 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package utilfn
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"math"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var HexDigits = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}
|
||||
|
||||
func GetStrArr(v interface{}, field string) []string {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
m, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
fieldVal := m[field]
|
||||
if fieldVal == nil {
|
||||
return nil
|
||||
}
|
||||
iarr, ok := fieldVal.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var sarr []string
|
||||
for _, iv := range iarr {
|
||||
if sv, ok := iv.(string); ok {
|
||||
sarr = append(sarr, sv)
|
||||
}
|
||||
}
|
||||
return sarr
|
||||
}
|
||||
|
||||
func GetBool(v interface{}, field string) bool {
|
||||
if v == nil {
|
||||
return false
|
||||
}
|
||||
m, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
fieldVal := m[field]
|
||||
if fieldVal == nil {
|
||||
return false
|
||||
}
|
||||
bval, ok := fieldVal.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return bval
|
||||
}
|
||||
|
||||
var needsQuoteRe = regexp.MustCompile(`[^\w@%:,./=+-]`)
|
||||
|
||||
// minimum maxlen=6
|
||||
func ShellQuote(val string, forceQuote bool, maxLen int) string {
|
||||
if maxLen < 6 {
|
||||
maxLen = 6
|
||||
}
|
||||
rtn := val
|
||||
if needsQuoteRe.MatchString(val) {
|
||||
rtn = "'" + strings.ReplaceAll(val, "'", `'"'"'`) + "'"
|
||||
}
|
||||
if strings.HasPrefix(rtn, "\"") || strings.HasPrefix(rtn, "'") {
|
||||
if len(rtn) > maxLen {
|
||||
return rtn[0:maxLen-4] + "..." + rtn[0:1]
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
if forceQuote {
|
||||
if len(rtn) > maxLen-2 {
|
||||
return "\"" + rtn[0:maxLen-5] + "...\""
|
||||
}
|
||||
return "\"" + rtn + "\""
|
||||
} else {
|
||||
if len(rtn) > maxLen {
|
||||
return rtn[0:maxLen-3] + "..."
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
}
|
||||
|
||||
func EllipsisStr(s string, maxLen int) string {
|
||||
if maxLen < 4 {
|
||||
maxLen = 4
|
||||
}
|
||||
if len(s) > maxLen {
|
||||
return s[0:maxLen-3] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func LongestPrefix(root string, strs []string) string {
|
||||
if len(strs) == 0 {
|
||||
return root
|
||||
}
|
||||
if len(strs) == 1 {
|
||||
comp := strs[0]
|
||||
if len(comp) >= len(root) && strings.HasPrefix(comp, root) {
|
||||
if strings.HasSuffix(comp, "/") {
|
||||
return strs[0]
|
||||
}
|
||||
return strs[0]
|
||||
}
|
||||
}
|
||||
lcp := strs[0]
|
||||
for i := 1; i < len(strs); i++ {
|
||||
s := strs[i]
|
||||
for j := 0; j < len(lcp); j++ {
|
||||
if j >= len(s) || lcp[j] != s[j] {
|
||||
lcp = lcp[0:j]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(lcp) < len(root) || !strings.HasPrefix(lcp, root) {
|
||||
return root
|
||||
}
|
||||
return lcp
|
||||
}
|
||||
|
||||
func ContainsStr(strs []string, test string) bool {
|
||||
for _, s := range strs {
|
||||
if s == test {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsPrefix(strs []string, test string) bool {
|
||||
for _, s := range strs {
|
||||
if len(s) > len(test) && strings.HasPrefix(s, test) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sentinel value for StrWithPos.Pos to indicate no position
|
||||
const NoStrPos = -1
|
||||
|
||||
type StrWithPos struct {
|
||||
Str string `json:"str"`
|
||||
Pos int `json:"pos"` // this is a 'rune' position (not a byte position)
|
||||
}
|
||||
|
||||
func (sp StrWithPos) String() string {
|
||||
return strWithCursor(sp.Str, sp.Pos)
|
||||
}
|
||||
|
||||
func ParseToSP(s string) StrWithPos {
|
||||
idx := strings.Index(s, "[*]")
|
||||
if idx == -1 {
|
||||
return StrWithPos{Str: s, Pos: NoStrPos}
|
||||
}
|
||||
return StrWithPos{Str: s[0:idx] + s[idx+3:], Pos: utf8.RuneCountInString(s[0:idx])}
|
||||
}
|
||||
|
||||
func strWithCursor(str string, pos int) string {
|
||||
if pos == NoStrPos {
|
||||
return str
|
||||
}
|
||||
if pos < 0 {
|
||||
// invalid position
|
||||
return "[*]_" + str
|
||||
}
|
||||
if pos > len(str) {
|
||||
// invalid position
|
||||
return str + "_[*]"
|
||||
}
|
||||
if pos == len(str) {
|
||||
return str + "[*]"
|
||||
}
|
||||
var rtn []rune
|
||||
for _, ch := range str {
|
||||
if len(rtn) == pos {
|
||||
rtn = append(rtn, '[', '*', ']')
|
||||
}
|
||||
rtn = append(rtn, ch)
|
||||
}
|
||||
return string(rtn)
|
||||
}
|
||||
|
||||
func (sp StrWithPos) Prepend(str string) StrWithPos {
|
||||
return StrWithPos{Str: str + sp.Str, Pos: utf8.RuneCountInString(str) + sp.Pos}
|
||||
}
|
||||
|
||||
func (sp StrWithPos) Append(str string) StrWithPos {
|
||||
return StrWithPos{Str: sp.Str + str, Pos: sp.Pos}
|
||||
}
|
||||
|
||||
// returns base64 hash of data
|
||||
func Sha1Hash(data []byte) string {
|
||||
hvalRaw := sha1.Sum(data)
|
||||
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
|
||||
return hval
|
||||
}
|
||||
|
||||
func ChunkSlice[T any](s []T, chunkSize int) [][]T {
|
||||
var rtn [][]T
|
||||
for len(rtn) > 0 {
|
||||
if len(s) <= chunkSize {
|
||||
rtn = append(rtn, s)
|
||||
break
|
||||
}
|
||||
rtn = append(rtn, s[:chunkSize])
|
||||
s = s[chunkSize:]
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
var ErrOverflow = errors.New("integer overflow")
|
||||
|
||||
// Add two int values, returning an error if the result overflows.
|
||||
func AddInt(left, right int) (int, error) {
|
||||
if right > 0 {
|
||||
if left > math.MaxInt-right {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
} else {
|
||||
if left < math.MinInt-right {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
}
|
||||
return left + right, nil
|
||||
}
|
||||
|
||||
// Add a slice of ints, returning an error if the result overflows.
|
||||
func AddIntSlice(vals ...int) (int, error) {
|
||||
var rtn int
|
||||
for _, v := range vals {
|
||||
var err error
|
||||
rtn, err = AddInt(rtn, v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
@ -1,103 +0,0 @@
|
||||
// Copyright 2023, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package utilfn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const Str1 = `
|
||||
hello
|
||||
line #2
|
||||
more
|
||||
stuff
|
||||
apple
|
||||
`
|
||||
|
||||
const Str2 = `
|
||||
line #2
|
||||
apple
|
||||
grapes
|
||||
banana
|
||||
`
|
||||
|
||||
const Str3 = `
|
||||
more
|
||||
stuff
|
||||
banana
|
||||
coconut
|
||||
`
|
||||
|
||||
func testDiff(t *testing.T, str1 string, str2 string) {
|
||||
diffBytes := MakeDiff(str1, str2)
|
||||
fmt.Printf("diff-len: %d\n", len(diffBytes))
|
||||
out, err := ApplyDiff(str1, diffBytes)
|
||||
if err != nil {
|
||||
t.Errorf("error in diff: %v", err)
|
||||
return
|
||||
}
|
||||
if out != str2 {
|
||||
t.Errorf("bad diff output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiff(t *testing.T) {
|
||||
testDiff(t, Str1, Str2)
|
||||
testDiff(t, Str2, Str3)
|
||||
testDiff(t, Str1, Str3)
|
||||
testDiff(t, Str3, Str1)
|
||||
}
|
||||
|
||||
func testArithmetic(t *testing.T, fn func() (int, error), shouldError bool, expected int) {
|
||||
retVal, err := fn()
|
||||
if err != nil {
|
||||
if !shouldError {
|
||||
t.Errorf("unexpected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if shouldError {
|
||||
t.Errorf("expected error")
|
||||
return
|
||||
}
|
||||
if retVal != expected {
|
||||
t.Errorf("wrong return value")
|
||||
}
|
||||
}
|
||||
|
||||
func testAddInt(t *testing.T, shouldError bool, expected int, a int, b int) {
|
||||
testArithmetic(t, func() (int, error) { return AddInt(a, b) }, shouldError, expected)
|
||||
}
|
||||
|
||||
func TestAddInt(t *testing.T) {
|
||||
testAddInt(t, false, 3, 1, 2)
|
||||
testAddInt(t, true, 0, 1, math.MaxInt)
|
||||
testAddInt(t, true, 0, math.MinInt, -1)
|
||||
testAddInt(t, false, math.MaxInt-1, math.MaxInt, -1)
|
||||
testAddInt(t, false, math.MinInt+1, math.MinInt, 1)
|
||||
testAddInt(t, false, math.MaxInt, math.MaxInt, 0)
|
||||
testAddInt(t, true, 0, math.MinInt, -1)
|
||||
}
|
||||
|
||||
func testAddIntSlice(t *testing.T, shouldError bool, expected int, vals ...int) {
|
||||
testArithmetic(t, func() (int, error) { return AddIntSlice(vals...) }, shouldError, expected)
|
||||
}
|
||||
|
||||
func TestAddIntSlice(t *testing.T) {
|
||||
testAddIntSlice(t, false, 0)
|
||||
testAddIntSlice(t, false, 1, 1)
|
||||
testAddIntSlice(t, false, 3, 1, 2)
|
||||
testAddIntSlice(t, false, 6, 1, 2, 3)
|
||||
testAddIntSlice(t, true, 0, 1, math.MaxInt)
|
||||
testAddIntSlice(t, true, 0, 1, 2, math.MaxInt)
|
||||
testAddIntSlice(t, true, 0, math.MaxInt, 2, 1)
|
||||
testAddIntSlice(t, false, math.MaxInt, 0, 0, math.MaxInt)
|
||||
testAddIntSlice(t, true, 0, math.MinInt, -1)
|
||||
testAddIntSlice(t, false, math.MaxInt, math.MaxInt-3, 1, 2)
|
||||
testAddIntSlice(t, true, 0, math.MaxInt-2, 1, 2)
|
||||
testAddIntSlice(t, false, math.MinInt, math.MinInt+3, -1, -2)
|
||||
testAddIntSlice(t, true, 0, math.MinInt+2, -1, -2)
|
||||
}
|
Loading…
Reference in New Issue
Block a user