zsh reinit fixes (#477)

* reset command now initiates and completes async so there is feedback that something is happening when it takes a long time

* switch from standard rpc to rpciter

* checkpoint on reinit -- stream output, stats packet, logging to cmd pty, new endBytes for EOF

* make generic versions of endbytes scanner and channel output funcs

* update bash to use more modern state parsing (tricks learned from zsh)

* verbose mode, fix stats output message

* add a diff when verbose mode is on
This commit is contained in:
Mike Sawka 2024-03-19 16:38:38 -07:00 committed by GitHub
parent accb74ae0f
commit 5616c9abbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 427 additions and 170 deletions

View File

@ -602,6 +602,7 @@ type ShellStatePacketType struct {
ShellType string `json:"shelltype"`
RespId string `json:"respid,omitempty"`
State *ShellState `json:"state"`
Stats *ShellStateStats `json:"stats"`
Error string `json:"error,omitempty"`
}

View File

@ -19,6 +19,17 @@ import (
const ShellStatePackVersion = 0
const ShellStateDiffPackVersion = 0
type ShellStateStats struct {
Version string `json:"version"`
AliasCount int `json:"aliascount"`
EnvCount int `json:"envcount"`
VarCount int `json:"varcount"`
FuncCount int `json:"funccount"`
HashVal string `json:"hashval"`
OutputSize int64 `json:"outputsize"`
StateSize int64 `json:"statesize"`
}
type ShellState struct {
Version string `json:"version"` // [type] [semver]
Cwd string `json:"cwd,omitempty"`
@ -29,6 +40,10 @@ type ShellState struct {
HashVal string `json:"-"`
}
func (state ShellState) ApproximateSize() int64 {
return int64(len(state.Version) + len(state.Cwd) + len(state.ShellVars) + len(state.Aliases) + len(state.Funcs) + len(state.Error))
}
type ShellStateDiff struct {
Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base)
BaseHash string `json:"basehash"`

View File

@ -244,11 +244,10 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
appendSlashes(comps)
}
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
return
}
func (m *MServer) reinit(reqId string, shellType string) {
ssPk, err := shexec.MakeShellStatePacket(shellType)
ssPk, err := m.MakeShellStatePacket(reqId, shellType)
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
return
@ -262,6 +261,32 @@ func (m *MServer) reinit(reqId string, shellType string) {
m.Sender.SendPacket(ssPk)
}
func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) {
sapi, err := shellapi.MakeShellApi(shellType)
if err != nil {
return nil, err
}
rtnCh := make(chan shellapi.ShellStateOutput, 1)
go sapi.GetShellState(rtnCh)
for ssOutput := range rtnCh {
if ssOutput.Error != "" {
return nil, errors.New(ssOutput.Error)
}
if ssOutput.ShellState != nil {
rtn := packet.MakeShellStatePacket()
rtn.State = ssOutput.ShellState
rtn.Stats = ssOutput.Stats
return rtn, nil
}
if ssOutput.Output != nil {
dataPk := packet.MakeFileDataPacket(reqId)
dataPk.Data = ssOutput.Output
m.Sender.SendPacket(dataPk)
}
}
return nil, nil
}
func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
dirName := filepath.Dir(path)
baseName := filepath.Base(path)

View File

@ -16,6 +16,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
const BaseBashOpts = `set +m; set +H; shopt -s extglob`
@ -48,7 +49,7 @@ func (b bashShellApi) GetShellType() string {
return packet.ShellType_bash
}
func (b bashShellApi) MakeExitTrap(fdNum int) string {
func (b bashShellApi) MakeExitTrap(fdNum int) (string, []byte) {
return MakeBashExitTrap(fdNum)
}
@ -79,29 +80,15 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
}
func (b bashShellApi) GetShellState() chan ShellStateOutput {
ch := make(chan ShellStateOutput, 1)
defer close(ch)
ssPk, err := GetBashShellState()
if err != nil {
ch <- ShellStateOutput{
Status: ShellStateOutputStatus_Done,
Error: err.Error(),
}
return ch
}
ch <- ShellStateOutput{
Status: ShellStateOutputStatus_Done,
ShellState: ssPk,
}
return ch
func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) {
GetBashShellState(outCh)
}
func (b bashShellApi) GetBaseShellOpts() string {
return BaseBashOpts
}
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, error) {
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
return parseBashShellStateOutput(output)
}
@ -130,8 +117,32 @@ func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
return rcBuf.String()
}
func GetBashShellStateCmd() string {
return strings.Join(GetBashShellStateCmds, ` printf "\x00\x00";`)
func GetBashShellStateCmd(fdNum int) (string, []byte) {
endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes)
endBytes = append(endBytes, '\n')
cmdStr := strings.TrimSpace(`
exec 2> /dev/null;
exec > [%OUTPUTFD%];
printf "\x00\x00";
[%BASHVERSIONCMD%];
printf "\x00\x00";
pwd;
printf "\x00\x00";
declare -p $(compgen -A variable);
printf "\x00\x00";
alias -p;
printf "\x00\x00";
declare -f;
printf "\x00\x00";
[%GITBRANCHCMD%];
printf "\x00\x00";
printf "[%ENDBYTES%]";
`)
cmdStr = strings.ReplaceAll(cmdStr, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
cmdStr = strings.ReplaceAll(cmdStr, "[%BASHVERSIONCMD%]", BashShellVersionCmdStr)
cmdStr = strings.ReplaceAll(cmdStr, "[%GITBRANCHCMD%]", GetGitBranchCmdStr)
cmdStr = strings.ReplaceAll(cmdStr, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes)))
return cmdStr, endBytes
}
func execGetLocalBashShellVersion() string {
@ -158,16 +169,34 @@ func GetLocalBashMajorVersion() string {
return localBashMajorVersion
}
func GetBashShellState() (*packet.ShellState, error) {
func GetBashShellState(outCh chan ShellStateOutput) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
cmdStr := BaseBashOpts + "; " + GetBashShellStateCmd()
defer close(outCh)
stateCmd, endBytes := GetBashShellStateCmd(StateOutputFdNum)
cmdStr := BaseBashOpts + "; " + stateCmd
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
outputBytes, err := RunSimpleCmdInPty(ecmd)
if err != nil {
return nil, err
outputCh := make(chan []byte, 10)
var outputWg sync.WaitGroup
outputWg.Add(1)
go func() {
defer outputWg.Done()
for outputBytes := range outputCh {
outCh <- ShellStateOutput{Output: outputBytes}
}
return parseBashShellStateOutput(outputBytes)
}()
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
outputWg.Wait()
if err != nil {
outCh <- ShellStateOutput{Error: err.Error()}
return
}
rtn, stats, err := parseBashShellStateOutput(outputBytes)
if err != nil {
outCh <- ShellStateOutput{Error: err.Error()}
return
}
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
}
func GetLocalBashPath() string {
@ -190,19 +219,20 @@ func GetLocalZshPath() string {
return "zsh"
}
func GetBashShellStateRedirectCommandStr(outputFdNum int) string {
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetBashShellStateCmd(), outputFdNum)
func GetBashShellStateRedirectCommandStr(outputFdNum int) (string, []byte) {
cmdStr, endBytes := GetBashShellStateCmd(outputFdNum)
return cmdStr, endBytes
}
func MakeBashExitTrap(fdNum int) string {
stateCmd := GetBashShellStateRedirectCommandStr(fdNum)
func MakeBashExitTrap(fdNum int) (string, []byte) {
stateCmd, endBytes := GetBashShellStateRedirectCommandStr(fdNum)
fmtStr := `
_waveshell_exittrap () {
%s
}
trap _waveshell_exittrap EXIT
`
return fmt.Sprintf(fmtStr, stateCmd)
return fmt.Sprintf(fmtStr, stateCmd), endBytes
}
func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {

View File

@ -20,6 +20,19 @@ import (
"mvdan.cc/sh/v3/syntax"
)
const (
BashSection_Ignored = iota
BashSection_Version
BashSection_Cwd
BashSection_Vars
BashSection_Aliases
BashSection_Funcs
BashSection_PVars
BashSection_EndBytes
BashSection_Count // must be last
)
type DeclareDeclType = shellenv.DeclareDeclType
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
@ -214,38 +227,37 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
return nil
}
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
if scbase.IsDevMode() && DebugState {
writeStateToFile(packet.ShellType_bash, outputBytes)
}
// 7 fields: ignored [0], version [1], cwd [2], env/vars [3], aliases [4], funcs [5], pvars [6]
fields := bytes.Split(outputBytes, []byte{0, 0})
if len(fields) != 7 {
return nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(fields))
sections := bytes.Split(outputBytes, []byte{0, 0})
if len(sections) != BashSection_Count {
return nil, nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(sections))
}
rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(string(fields[1]))
rtn.Version = strings.TrimSpace(string(sections[BashSection_Version]))
if rtn.GetShellType() != packet.ShellType_bash {
return nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
return nil, nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
}
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
return nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
return nil, nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
}
cwdStr := string(fields[2])
cwdStr := string(sections[BashSection_Cwd])
if strings.HasSuffix(cwdStr, "\r\n") {
cwdStr = cwdStr[0 : len(cwdStr)-2]
} else if strings.HasSuffix(cwdStr, "\n") {
cwdStr = cwdStr[0 : len(cwdStr)-1]
} else {
cwdStr = strings.TrimSuffix(cwdStr, "\n")
}
rtn.Cwd = string(cwdStr)
err := bashParseDeclareOutput(rtn, fields[3], fields[6])
err := bashParseDeclareOutput(rtn, sections[BashSection_Vars], sections[BashSection_PVars])
if err != nil {
return nil, err
return nil, nil, err
}
rtn.Aliases = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(fields[5]), "\r\n", "\n")
rtn.Aliases = strings.ReplaceAll(string(sections[BashSection_Aliases]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(sections[BashSection_Funcs]), "\r\n", "\n")
rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap")
return rtn, nil
return rtn, nil, nil
}
func bashNormalize(d *DeclareDeclType) error {

View File

@ -7,7 +7,6 @@ import (
"bytes"
"context"
"fmt"
"io"
"os"
"os/exec"
"os/user"
@ -28,12 +27,14 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
const GetStateTimeout = 5 * time.Second
const GetStateTimeout = 15 * time.Second
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
const RunCommandFmt = `%s`
const DebugState = false
const StateOutputFdNum = 20
const NumRandomEndBytes = 8
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
@ -56,23 +57,23 @@ const (
)
type ShellStateOutput struct {
Status string
StderrOutput []byte
Output []byte
ShellState *packet.ShellState
Stats *packet.ShellStateStats
Error string
}
type ShellApi interface {
GetShellType() string
MakeExitTrap(fdNum int) string
MakeExitTrap(fdNum int) (string, []byte)
GetLocalMajorVersion() string
GetLocalShellPath() string
GetRemoteShellPath() string
MakeRunCommand(cmdStr string, opts RunCommandOpts) string
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
GetShellState() chan ShellStateOutput
GetShellState(chan ShellStateOutput)
GetBaseShellOpts() string
ParseShellStateOutput(output []byte) (*packet.ShellState, error)
ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error)
MakeRcFileStr(pk *packet.RunPacketType) string
MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error)
ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error)
@ -153,12 +154,13 @@ func internalMacUserShell() string {
const FirstExtraFilesFdNum = 3
// returns output(stdout+stderr), extraFdOutput, error
func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, error) {
func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte) ([]byte, error) {
defer close(outputCh)
ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
cmdPty, cmdTty, err := pty.Open()
if err != nil {
return nil, nil, fmt.Errorf("opening new pty: %w", err)
return nil, fmt.Errorf("opening new pty: %w", err)
}
defer cmdTty.Close()
defer cmdPty.Close()
@ -171,42 +173,44 @@ func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, erro
ecmd.SysProcAttr.Setctty = true
pipeReader, pipeWriter, err := os.Pipe()
if err != nil {
return nil, nil, fmt.Errorf("could not create pipe: %w", err)
return nil, fmt.Errorf("could not create pipe: %w", err)
}
defer pipeWriter.Close()
defer pipeReader.Close()
extraFiles := make([]*os.File, extraFdNum+1)
extraFiles[extraFdNum] = pipeWriter
ecmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:]
defer pipeReader.Close()
ecmd.Start()
err = ecmd.Start()
cmdTty.Close()
pipeWriter.Close()
if err != nil {
return nil, nil, err
return nil, err
}
var outputWg sync.WaitGroup
var outputBuf bytes.Buffer
var extraFdOutputBuf bytes.Buffer
outputWg.Add(2)
go func() {
// ignore error (/dev/ptmx has read error when process is done)
defer outputWg.Done()
io.Copy(&outputBuf, cmdPty)
err := utilfn.CopyToChannel(outputCh, cmdPty)
if err != nil {
errStr := fmt.Sprintf("\r\nerror reading from pty: %v\r\n", err)
outputCh <- []byte(errStr)
}
}()
go func() {
defer outputWg.Done()
io.Copy(&extraFdOutputBuf, pipeReader)
utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes)
}()
exitErr := ecmd.Wait()
if exitErr != nil {
return nil, nil, exitErr
return nil, exitErr
}
outputWg.Wait()
return outputBuf.Bytes(), extraFdOutputBuf.Bytes(), nil
return extraFdOutputBuf.Bytes(), nil
}
func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
func RunSimpleCmdInPty(ecmd *exec.Cmd, endBytes []byte) ([]byte, error) {
ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
cmdPty, cmdTty, err := pty.Open()
@ -231,8 +235,8 @@ func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
var outputBuf bytes.Buffer
go func() {
// ignore error (/dev/ptmx has read error when process is done)
io.Copy(&outputBuf, cmdPty)
close(ioDone)
defer close(ioDone)
utilfn.CopyWithEndBytes(&outputBuf, cmdPty, endBytes)
}()
exitErr := ecmd.Wait()
if exitErr != nil {

View File

@ -8,7 +8,6 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"os/exec"
"strings"
"sync"
@ -28,7 +27,6 @@ import (
const BaseZshOpts = ``
const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION`
const StateOutputFdNum = 20
const (
ZshSection_Version = iota
@ -41,6 +39,7 @@ const (
ZshSection_Funcs
ZshSection_PVars
ZshSection_Prompt
ZshSection_EndBytes
ZshSection_NumFieldsExpected // must be last
)
@ -118,6 +117,11 @@ var ZshIgnoreVars = map[string]bool{
"zcurses_windows": true,
// not listed, but we also exclude all ZFTP_* variables
// powerlevel10k
"_GITSTATUS_CLIENT_PID_POWERLEVEL9K": true,
"GITSTATUS_DAEMON_PID_POWERLEVEL9K": true,
"_GITSTATUS_FILE_PREFIX_POWERLEVEL9K": true,
}
var ZshIgnoreFuncs = map[string]bool{
@ -211,7 +215,7 @@ func (z zshShellApi) GetShellType() string {
return packet.ShellType_zsh
}
func (z zshShellApi) MakeExitTrap(fdNum int) string {
func (z zshShellApi) MakeExitTrap(fdNum int) (string, []byte) {
return MakeZshExitTrap(fdNum)
}
@ -242,25 +246,34 @@ func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
}
func (z zshShellApi) GetShellState() chan ShellStateOutput {
func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
rtnCh := make(chan ShellStateOutput, 1)
defer close(rtnCh)
cmdStr := BaseZshOpts + "; " + GetZshShellStateCmd(StateOutputFdNum)
defer close(outCh)
stateCmd, endBytes := GetZshShellStateCmd(StateOutputFdNum)
cmdStr := BaseZshOpts + "; " + stateCmd
ecmd := exec.CommandContext(ctx, GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
_, outputBytes, err := RunCommandWithExtraFd(ecmd, StateOutputFdNum)
if err != nil {
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()}
return rtnCh
outputCh := make(chan []byte, 10)
var outputWg sync.WaitGroup
outputWg.Add(1)
go func() {
defer outputWg.Done()
for outputBytes := range outputCh {
outCh <- ShellStateOutput{Output: outputBytes}
}
rtn, err := z.ParseShellStateOutput(outputBytes)
}()
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
outputWg.Wait()
if err != nil {
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()}
return rtnCh
outCh <- ShellStateOutput{Error: err.Error()}
return
}
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, ShellState: rtn}
return rtnCh
rtn, stats, err := z.ParseShellStateOutput(outputBytes)
if err != nil {
outCh <- ShellStateOutput{Error: err.Error()}
return
}
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
}
func (z zshShellApi) GetBaseShellOpts() string {
@ -437,19 +450,15 @@ func writeZshId(buf *bytes.Buffer, idStr string) {
const numRandomBytes = 4
// returns (cmd-string)
func GetZshShellStateCmd(fdNum int) string {
// returns (cmd-string, endbytes)
func GetZshShellStateCmd(fdNum int) (string, []byte) {
var sectionSeparator []byte
// adding this extra "\n" helps with debuging and readability of output
sectionSeparator = append(sectionSeparator, byte('\n'))
for len(sectionSeparator) < numRandomBytes {
// any character *except* null (0)
rn := rand.Intn(256)
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
sectionSeparator = append(sectionSeparator, byte(rn))
}
}
sectionSeparator = utilfn.AppendNonZeroRandomBytes(sectionSeparator, numRandomBytes)
sectionSeparator = append(sectionSeparator, 0, 0)
endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes)
endBytes = append(endBytes, byte('\n'))
// we have to use these crazy separators because zsh allows basically anything in
// variable names and values (including nulls).
// note that we don't need crazy separators for "env" or "typeset".
@ -511,6 +520,8 @@ printf "[%SECTIONSEP%]";
[%K8SNAMESPACE%]
printf "[%SECTIONSEP%]";
print -P "$PS1"
printf "[%SECTIONSEP%]";
printf "[%ENDBYTES%]"
`
cmd = strings.TrimSpace(cmd)
cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr)
@ -520,17 +531,19 @@ print -P "$PS1"
cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1])))
cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator)))
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
return cmd
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFDNUM%]", fmt.Sprintf("%d", fdNum))
cmd = strings.ReplaceAll(cmd, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes)))
return cmd, endBytes
}
func MakeZshExitTrap(fdNum int) string {
stateCmd := GetZshShellStateCmd(fdNum)
func MakeZshExitTrap(fdNum int) (string, []byte) {
stateCmd, endBytes := GetZshShellStateCmd(fdNum)
fmtStr := `
zshexit () {
%s
}
`
return fmt.Sprintf(fmtStr, stateCmd)
return fmt.Sprintf(fmtStr, stateCmd), endBytes
}
func execGetLocalZshShellVersion() string {
@ -698,14 +711,14 @@ func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string {
return buf.String()
}
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
if scbase.IsDevMode() && DebugState {
writeStateToFile(packet.ShellType_zsh, outputBytes)
}
firstZeroIdx := bytes.Index(outputBytes, []byte{0})
firstDZeroIdx := bytes.Index(outputBytes, []byte{0, 0})
if firstZeroIdx == -1 || firstDZeroIdx == -1 {
return nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes")
return nil, nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes")
}
versionStr := string(outputBytes[0:firstZeroIdx])
sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2]
@ -714,15 +727,15 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
sections := bytes.Split(outputBytes, sectionSeparator)
if len(sections) != ZshSection_NumFieldsExpected {
base.Logf("invalid -- numfields\n")
return nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
}
rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(versionStr)
if rtn.GetShellType() != packet.ShellType_zsh {
return nil, fmt.Errorf("invalid zsh shell state output, wrong shell type")
return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong shell type")
}
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
return nil, nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
}
cwdStr := stripNewLineChars(string(sections[ZshSection_Cwd]))
rtn.Cwd = cwdStr
@ -730,7 +743,7 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
zshDecls, err := parseZshDecls(sections[ZshSection_Vars])
if err != nil {
base.Logf("invalid - parsedecls %v\n", err)
return nil, err
return nil, nil, err
}
for _, decl := range zshDecls {
if decl.IsZshScalarBound() {
@ -746,7 +759,17 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
pvarMap := parseExtVarOutput(sections[ZshSection_PVars], string(sections[ZshSection_Prompt]), string(sections[ZshSection_Mods]))
utilfn.CombineMaps(zshDecls, pvarMap)
rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls)
return rtn, nil
stats := &packet.ShellStateStats{
Version: rtn.Version,
AliasCount: int(len(aliasMap)),
FuncCount: int(len(zshFuncs)),
VarCount: int(len(zshDecls)),
EnvCount: int(len(zshEnv)),
HashVal: rtn.GetHashVal(false),
OutputSize: int64(len(outputBytes)),
StateSize: rtn.ApproximateSize(),
}
return rtn, stats, nil
}
func parseZshEnv(output []byte) map[string]string {

View File

@ -7,7 +7,6 @@ import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"os"
@ -95,6 +94,7 @@ type ReturnStateBuf struct {
Err error
Reader *os.File
FdNum int
EndBytes []byte
DoneCh chan bool
}
@ -835,7 +835,8 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
cmd.ReturnState.FdNum = RtnStateFdNum
rtnStateWriter = pw
defer pw.Close()
trapCmdStr := sapi.MakeExitTrap(cmd.ReturnState.FdNum)
trapCmdStr, endBytes := sapi.MakeExitTrap(cmd.ReturnState.FdNum)
cmd.ReturnState.EndBytes = endBytes
rcFileStr += trapCmdStr
}
shellVarMap := shellenv.ShellVarMapFromState(state)
@ -1021,6 +1022,11 @@ func (rs *ReturnStateBuf) Run() {
}
rs.Lock.Lock()
rs.Buf = append(rs.Buf, buf[0:n]...)
if bytes.HasSuffix(rs.Buf, rs.EndBytes) {
rs.Buf = rs.Buf[:len(rs.Buf)-len(rs.EndBytes)]
rs.Lock.Unlock()
break
}
rs.Lock.Unlock()
}
}
@ -1127,7 +1133,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
wlog.Logf("debug returnstate file %q\n", base.GetDebugReturnStateFileName())
os.WriteFile(base.GetDebugReturnStateFileName(), c.ReturnState.Buf, 0666)
}
state, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
state, _, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
donePacket.FinalState = state
}
endTs := time.Now()
@ -1156,21 +1162,6 @@ func MakeInitPacket() *packet.InitPacketType {
return initPacket
}
func MakeShellStatePacket(shellType string) (*packet.ShellStatePacketType, error) {
sapi, err := shellapi.MakeShellApi(shellType)
if err != nil {
return nil, err
}
rtnCh := sapi.GetShellState()
ssOutput := <-rtnCh
if ssOutput.Error != "" {
return nil, errors.New(ssOutput.Error)
}
rtn := packet.MakeShellStatePacket()
rtn.State = ssOutput.ShellState
return rtn, nil
}
func MakeServerInitPacket() (*packet.InitPacketType, error) {
var err error
initPacket := MakeInitPacket()

View File

@ -10,7 +10,9 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"math"
mathrand "math/rand"
"regexp"
"sort"
"strings"
@ -552,3 +554,60 @@ func StrArrayToMap(sarr []string) map[string]bool {
}
return m
}
func AppendNonZeroRandomBytes(b []byte, randLen int) []byte {
if randLen <= 0 {
return b
}
numAdded := 0
for numAdded < randLen {
rn := mathrand.Intn(256)
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
b = append(b, byte(rn))
numAdded++
}
}
return b
}
// returns (isEOF, error)
func CopyWithEndBytes(outputBuf *bytes.Buffer, reader io.Reader, endBytes []byte) (bool, error) {
buf := make([]byte, 4096)
for {
n, err := reader.Read(buf)
if n > 0 {
outputBuf.Write(buf[:n])
obytes := outputBuf.Bytes()
if bytes.HasSuffix(obytes, endBytes) {
outputBuf.Truncate(len(obytes) - len(endBytes))
return (err == io.EOF), nil
}
}
if err == io.EOF {
return true, nil
}
if err != nil {
return false, err
}
}
}
// does *not* close outputCh on EOF or error
func CopyToChannel(outputCh chan<- []byte, reader io.Reader) error {
buf := make([]byte, 4096)
for {
n, err := reader.Read(buf)
if n > 0 {
// copy so client can use []byte without it being overwritten
bufCopy := make([]byte, n)
copy(bufCopy, buf[:n])
outputCh <- bufCopy
}
if err == io.EOF {
return nil
}
if err != nil {
return err
}
}
}

View File

@ -41,6 +41,7 @@ import (
"github.com/wavetermdev/waveterm/wavesrv/pkg/releasechecker"
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote/openai"
"github.com/wavetermdev/waveterm/wavesrv/pkg/rtnstate"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
@ -1648,8 +1649,12 @@ func CopyFileCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scb
}
var outputPos int64
outputStr := fmt.Sprintf("Copying [%v]:%v to [%v]:%v\r\n", sourceRemoteId.DisplayName, sourceFullPath, destRemoteId.DisplayName, destFullPath)
termopts := sstore.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize}
cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), termopts)
termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM)
if err != nil {
return nil, fmt.Errorf("cannot make termopts: %w", err)
}
pkTermOpts := convertTermOpts(termOpts)
cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), *pkTermOpts)
writeStringToPty(ctx, cmd, outputStr, &outputPos)
if err != nil {
// TODO tricky error since the command was a success, but we can't show the output
@ -3655,11 +3660,14 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbu
return update, nil
}
func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) {
func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (rtnUpdate scbus.UpdatePacket, rtnErr error) {
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
if err != nil {
return nil, err
}
if !ids.Remote.MShell.IsConnected() {
return nil, fmt.Errorf("cannot reinit, remote is not connected")
}
shellType := ids.Remote.ShellType
if pk.Kwargs["shell"] != "" {
shellArg := pk.Kwargs["shell"]
@ -3668,33 +3676,76 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
}
shellType = shellArg
}
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType)
verbose := resolveBool(pk.Kwargs["verbose"], false)
termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM)
if err != nil {
return nil, err
return nil, fmt.Errorf("cannot make termopts: %w", err)
}
if ssPk == nil || ssPk.State == nil {
return nil, fmt.Errorf("invalid initpk received from remote (no remote state)")
}
feState := sstore.FeStateFromShellState(ssPk.State)
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
pkTermOpts := convertTermOpts(termOpts)
cmd, err := makeDynCmd(ctx, "reset", ids, pk.GetRawStr(), *pkTermOpts)
if err != nil {
return nil, err
}
outputStr := fmt.Sprintf("reset remote state (shell:%s)", ssPk.State.GetShellType())
cmd, err := makeStaticCmd(ctx, "reset", ids, pk.GetRawStr(), []byte(outputStr))
if err != nil {
// TODO tricky error since the command was a success, but we can't show the output
return nil, err
}
update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil)
if err != nil {
// TODO tricky error since the command was a success, but we can't show the output
return nil, err
}
update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst), sstore.InteractiveUpdate(pk.Interactive))
go doResetCommand(ids, shellType, cmd, verbose)
return update, nil
}
func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) {
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFn()
startTime := time.Now()
var outputPos int64
var rtnErr error
exitSuccess := true
defer func() {
if rtnErr != nil {
exitSuccess = false
writeStringToPty(ctx, cmd, fmt.Sprintf("\r\nerror: %v", rtnErr), &outputPos)
}
deferWriteCmdStatus(ctx, cmd, startTime, exitSuccess, outputPos)
}()
dataFn := func(data []byte) {
writeStringToPty(ctx, cmd, string(data), &outputPos)
}
origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType, dataFn, verbose)
if err != nil {
rtnErr = err
return
}
if ssPk == nil || ssPk.State == nil {
rtnErr = fmt.Errorf("invalid initpk received from remote (no remote state)")
return
}
feState := sstore.FeStateFromShellState(ssPk.State)
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
if err != nil {
rtnErr = err
return
}
newStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
if verbose && origStatePtr != nil && newStatePtr != nil {
statePtrDiff := fmt.Sprintf("oldstate: %v, newstate: %v\r\n", origStatePtr.BaseHash, newStatePtr.BaseHash)
writeStringToPty(ctx, cmd, statePtrDiff, &outputPos)
origFullState, _ := sstore.GetFullState(ctx, *origStatePtr)
newFullState, _ := sstore.GetFullState(ctx, *newStatePtr)
if origFullState != nil && newFullState != nil {
var diffBuf bytes.Buffer
rtnstate.DisplayStateUpdateDiff(&diffBuf, *origFullState, *newFullState)
diffStr := diffBuf.String()
diffStr = strings.ReplaceAll(diffStr, "\n", "\r\n")
writeStringToPty(ctx, cmd, diffStr, &outputPos)
}
}
update := scbus.MakeUpdatePacket()
update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst))
scbus.MainUpdateBus.DoUpdate(update)
}
func ResetCwdCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) {
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
if err != nil {

View File

@ -196,7 +196,7 @@ func (msh *MShellProc) EnsureShellType(ctx context.Context, shellType string) er
return nil
}
// try to reinit the shell
_, err := msh.ReInit(ctx, shellType)
_, err := msh.ReInit(ctx, shellType, nil, false)
if err != nil {
return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err)
}
@ -1401,33 +1401,60 @@ func makeReinitErrorUpdate(shellType string) sstore.ActivityUpdate {
return rtn
}
func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.ShellStatePacketType, error) {
func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) {
if !msh.IsConnected() {
return nil, fmt.Errorf("cannot reinit, remote is not connected")
}
if shellType != packet.ShellType_bash && shellType != packet.ShellType_zsh {
return nil, fmt.Errorf("invalid shell type %q", shellType)
}
if dataFn == nil {
dataFn = func([]byte) {}
}
defer func() {
if rtnErr != nil {
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
}
}()
startTs := time.Now()
reinitPk := packet.MakeReInitPacket()
reinitPk.ReqId = uuid.New().String()
reinitPk.ShellType = shellType
resp, err := msh.PacketRpcRaw(ctx, reinitPk)
rpcIter, err := msh.PacketRpcIter(ctx, reinitPk)
if err != nil {
return nil, err
}
defer rpcIter.Close()
var ssPk *packet.ShellStatePacketType
for {
resp, err := rpcIter.Next(ctx)
if err != nil {
return nil, err
}
if resp == nil {
return nil, fmt.Errorf("no response")
return nil, fmt.Errorf("channel closed with no response")
}
ssPk, ok := resp.(*packet.ShellStatePacketType)
if !ok {
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
if respPk, ok := resp.(*packet.ResponsePacketType); ok && respPk.Error != "" {
var ok bool
ssPk, ok = resp.(*packet.ShellStatePacketType)
if ok {
break
}
respPk, ok := resp.(*packet.ResponsePacketType)
if ok {
if respPk.Error != "" {
return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error)
}
return nil, fmt.Errorf("invalid reinit response (not an shellstate packet): %T", resp)
return nil, fmt.Errorf("invalid response from waveshell")
}
if ssPk.State == nil {
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
dataPk, ok := resp.(*packet.FileDataPacketType)
if ok {
dataFn(dataPk.Data)
continue
}
invalidPkStr := fmt.Sprintf("\r\ninvalid packettype from waveshell: %s\r\n", resp.GetType())
dataFn([]byte(invalidPkStr))
}
if ssPk == nil || ssPk.State == nil {
return nil, fmt.Errorf("invalid reinit response shellstate packet does not contain remote state")
}
// TODO: maybe we don't need to save statebase here. should be possible to save it on demand
@ -1438,10 +1465,29 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.Sh
return nil, fmt.Errorf("error storing remote state: %w", err)
}
msh.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
msh.WriteToPtyBuffer("initialized shell:%s state:%s\n", shellType, ssPk.State.GetHashVal(false))
timeDur := time.Since(startTs)
dataFn([]byte(makeShellInitOutputMsg(verbose, ssPk.State, ssPk.Stats, timeDur, false)))
msh.WriteToPtyBuffer("%s", makeShellInitOutputMsg(false, ssPk.State, ssPk.Stats, timeDur, true))
return ssPk, nil
}
func makeShellInitOutputMsg(verbose bool, state *packet.ShellState, stats *packet.ShellStateStats, dur time.Duration, ptyMsg bool) string {
if !verbose || ptyMsg {
if ptyMsg {
return fmt.Sprintf("initialized state shell:%s statehash:%s %dms\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds())
} else {
return fmt.Sprintf("initialized connection state (shell:%s)\r\n", state.GetShellType())
}
}
var buf bytes.Buffer
buf.WriteString("-----\r\n")
buf.WriteString(fmt.Sprintf("initialized connection shell:%s statehash:%s %dms\r\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds()))
if stats != nil {
buf.WriteString(fmt.Sprintf(" outsize:%s size:%s env:%d, vars:%d, aliases:%d, funcs:%d\r\n", scbase.NumFormatDec(stats.OutputSize), scbase.NumFormatDec(stats.StateSize), stats.EnvCount, stats.VarCount, stats.AliasCount, stats.FuncCount))
}
return buf.String()
}
func (msh *MShellProc) WriteFile(ctx context.Context, writePk *packet.WriteFilePacketType) (*packet.RpcResponseIter, error) {
return msh.PacketRpcIter(ctx, writePk)
}
@ -1690,7 +1736,7 @@ func (msh *MShellProc) initActiveShells() {
return
}
for _, shellType := range activeShells {
_, err = msh.ReInit(ctx, shellType)
_, err = msh.ReInit(ctx, shellType, nil, false)
if err != nil {
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
}