diff --git a/waveshell/pkg/packet/packet.go b/waveshell/pkg/packet/packet.go index 6419a1d1e..d204075eb 100644 --- a/waveshell/pkg/packet/packet.go +++ b/waveshell/pkg/packet/packet.go @@ -598,11 +598,12 @@ func MakeLogPacket(entry wlog.LogEntry) *LogPacketType { } type ShellStatePacketType struct { - Type string `json:"type"` - ShellType string `json:"shelltype"` - RespId string `json:"respid,omitempty"` - State *ShellState `json:"state"` - Error string `json:"error,omitempty"` + Type string `json:"type"` + ShellType string `json:"shelltype"` + RespId string `json:"respid,omitempty"` + State *ShellState `json:"state"` + Stats *ShellStateStats `json:"stats"` + Error string `json:"error,omitempty"` } func (*ShellStatePacketType) GetType() string { diff --git a/waveshell/pkg/packet/shellstate.go b/waveshell/pkg/packet/shellstate.go index ff3ef76fe..af4b04446 100644 --- a/waveshell/pkg/packet/shellstate.go +++ b/waveshell/pkg/packet/shellstate.go @@ -19,6 +19,17 @@ import ( const ShellStatePackVersion = 0 const ShellStateDiffPackVersion = 0 +type ShellStateStats struct { + Version string `json:"version"` + AliasCount int `json:"aliascount"` + EnvCount int `json:"envcount"` + VarCount int `json:"varcount"` + FuncCount int `json:"funccount"` + HashVal string `json:"hashval"` + OutputSize int64 `json:"outputsize"` + StateSize int64 `json:"statesize"` +} + type ShellState struct { Version string `json:"version"` // [type] [semver] Cwd string `json:"cwd,omitempty"` @@ -29,6 +40,10 @@ type ShellState struct { HashVal string `json:"-"` } +func (state ShellState) ApproximateSize() int64 { + return int64(len(state.Version) + len(state.Cwd) + len(state.ShellVars) + len(state.Aliases) + len(state.Funcs) + len(state.Error)) +} + type ShellStateDiff struct { Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base) BaseHash string `json:"basehash"` diff --git a/waveshell/pkg/server/server.go b/waveshell/pkg/server/server.go index 9fc098ce0..6c956e32a 100644 --- a/waveshell/pkg/server/server.go +++ b/waveshell/pkg/server/server.go @@ -244,11 +244,10 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) { appendSlashes(comps) } m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore}) - return } func (m *MServer) reinit(reqId string, shellType string) { - ssPk, err := shexec.MakeShellStatePacket(shellType) + ssPk, err := m.MakeShellStatePacket(reqId, shellType) if err != nil { m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err)) return @@ -262,6 +261,32 @@ func (m *MServer) reinit(reqId string, shellType string) { m.Sender.SendPacket(ssPk) } +func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) { + sapi, err := shellapi.MakeShellApi(shellType) + if err != nil { + return nil, err + } + rtnCh := make(chan shellapi.ShellStateOutput, 1) + go sapi.GetShellState(rtnCh) + for ssOutput := range rtnCh { + if ssOutput.Error != "" { + return nil, errors.New(ssOutput.Error) + } + if ssOutput.ShellState != nil { + rtn := packet.MakeShellStatePacket() + rtn.State = ssOutput.ShellState + rtn.Stats = ssOutput.Stats + return rtn, nil + } + if ssOutput.Output != nil { + dataPk := packet.MakeFileDataPacket(reqId) + dataPk.Data = ssOutput.Output + m.Sender.SendPacket(dataPk) + } + } + return nil, nil +} + func makeTemp(path string, mode fs.FileMode) (*os.File, error) { dirName := filepath.Dir(path) baseName := filepath.Base(path) diff --git a/waveshell/pkg/shellapi/bashapi.go b/waveshell/pkg/shellapi/bashapi.go index efb53f977..567d3fb79 100644 --- a/waveshell/pkg/shellapi/bashapi.go +++ b/waveshell/pkg/shellapi/bashapi.go @@ -16,6 +16,7 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/waveshell/pkg/shellenv" "github.com/wavetermdev/waveterm/waveshell/pkg/statediff" + "github.com/wavetermdev/waveterm/waveshell/pkg/utilfn" ) const BaseBashOpts = `set +m; set +H; shopt -s extglob` @@ -48,7 +49,7 @@ func (b bashShellApi) GetShellType() string { return packet.ShellType_bash } -func (b bashShellApi) MakeExitTrap(fdNum int) string { +func (b bashShellApi) MakeExitTrap(fdNum int) (string, []byte) { return MakeBashExitTrap(fdNum) } @@ -79,29 +80,15 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty return MakeBashShExecCommand(cmdStr, rcFileName, usePty) } -func (b bashShellApi) GetShellState() chan ShellStateOutput { - ch := make(chan ShellStateOutput, 1) - defer close(ch) - ssPk, err := GetBashShellState() - if err != nil { - ch <- ShellStateOutput{ - Status: ShellStateOutputStatus_Done, - Error: err.Error(), - } - return ch - } - ch <- ShellStateOutput{ - Status: ShellStateOutputStatus_Done, - ShellState: ssPk, - } - return ch +func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) { + GetBashShellState(outCh) } func (b bashShellApi) GetBaseShellOpts() string { return BaseBashOpts } -func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, error) { +func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error) { return parseBashShellStateOutput(output) } @@ -130,8 +117,32 @@ func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string { return rcBuf.String() } -func GetBashShellStateCmd() string { - return strings.Join(GetBashShellStateCmds, ` printf "\x00\x00";`) +func GetBashShellStateCmd(fdNum int) (string, []byte) { + endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes) + endBytes = append(endBytes, '\n') + cmdStr := strings.TrimSpace(` +exec 2> /dev/null; +exec > [%OUTPUTFD%]; +printf "\x00\x00"; +[%BASHVERSIONCMD%]; +printf "\x00\x00"; +pwd; +printf "\x00\x00"; +declare -p $(compgen -A variable); +printf "\x00\x00"; +alias -p; +printf "\x00\x00"; +declare -f; +printf "\x00\x00"; +[%GITBRANCHCMD%]; +printf "\x00\x00"; +printf "[%ENDBYTES%]"; +`) + cmdStr = strings.ReplaceAll(cmdStr, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum)) + cmdStr = strings.ReplaceAll(cmdStr, "[%BASHVERSIONCMD%]", BashShellVersionCmdStr) + cmdStr = strings.ReplaceAll(cmdStr, "[%GITBRANCHCMD%]", GetGitBranchCmdStr) + cmdStr = strings.ReplaceAll(cmdStr, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes))) + return cmdStr, endBytes } func execGetLocalBashShellVersion() string { @@ -158,16 +169,34 @@ func GetLocalBashMajorVersion() string { return localBashMajorVersion } -func GetBashShellState() (*packet.ShellState, error) { +func GetBashShellState(outCh chan ShellStateOutput) { ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout) defer cancelFn() - cmdStr := BaseBashOpts + "; " + GetBashShellStateCmd() + defer close(outCh) + stateCmd, endBytes := GetBashShellStateCmd(StateOutputFdNum) + cmdStr := BaseBashOpts + "; " + stateCmd ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr) - outputBytes, err := RunSimpleCmdInPty(ecmd) + outputCh := make(chan []byte, 10) + var outputWg sync.WaitGroup + outputWg.Add(1) + go func() { + defer outputWg.Done() + for outputBytes := range outputCh { + outCh <- ShellStateOutput{Output: outputBytes} + } + }() + outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes) + outputWg.Wait() if err != nil { - return nil, err + outCh <- ShellStateOutput{Error: err.Error()} + return } - return parseBashShellStateOutput(outputBytes) + rtn, stats, err := parseBashShellStateOutput(outputBytes) + if err != nil { + outCh <- ShellStateOutput{Error: err.Error()} + return + } + outCh <- ShellStateOutput{ShellState: rtn, Stats: stats} } func GetLocalBashPath() string { @@ -190,19 +219,20 @@ func GetLocalZshPath() string { return "zsh" } -func GetBashShellStateRedirectCommandStr(outputFdNum int) string { - return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetBashShellStateCmd(), outputFdNum) +func GetBashShellStateRedirectCommandStr(outputFdNum int) (string, []byte) { + cmdStr, endBytes := GetBashShellStateCmd(outputFdNum) + return cmdStr, endBytes } -func MakeBashExitTrap(fdNum int) string { - stateCmd := GetBashShellStateRedirectCommandStr(fdNum) +func MakeBashExitTrap(fdNum int) (string, []byte) { + stateCmd, endBytes := GetBashShellStateRedirectCommandStr(fdNum) fmtStr := ` _waveshell_exittrap () { %s } trap _waveshell_exittrap EXIT ` - return fmt.Sprintf(fmtStr, stateCmd) + return fmt.Sprintf(fmtStr, stateCmd), endBytes } func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd { diff --git a/waveshell/pkg/shellapi/bashparser.go b/waveshell/pkg/shellapi/bashparser.go index 2143e3290..f56b33b37 100644 --- a/waveshell/pkg/shellapi/bashparser.go +++ b/waveshell/pkg/shellapi/bashparser.go @@ -20,6 +20,19 @@ import ( "mvdan.cc/sh/v3/syntax" ) +const ( + BashSection_Ignored = iota + BashSection_Version + BashSection_Cwd + BashSection_Vars + BashSection_Aliases + BashSection_Funcs + BashSection_PVars + BashSection_EndBytes + + BashSection_Count // must be last +) + type DeclareDeclType = shellenv.DeclareDeclType func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error { @@ -214,38 +227,37 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB return nil } -func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, error) { +func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) { if scbase.IsDevMode() && DebugState { writeStateToFile(packet.ShellType_bash, outputBytes) } - // 7 fields: ignored [0], version [1], cwd [2], env/vars [3], aliases [4], funcs [5], pvars [6] - fields := bytes.Split(outputBytes, []byte{0, 0}) - if len(fields) != 7 { - return nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(fields)) + sections := bytes.Split(outputBytes, []byte{0, 0}) + if len(sections) != BashSection_Count { + return nil, nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(sections)) } rtn := &packet.ShellState{} - rtn.Version = strings.TrimSpace(string(fields[1])) + rtn.Version = strings.TrimSpace(string(sections[BashSection_Version])) if rtn.GetShellType() != packet.ShellType_bash { - return nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version) + return nil, nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version) } if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil { - return nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err) + return nil, nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err) } - cwdStr := string(fields[2]) + cwdStr := string(sections[BashSection_Cwd]) if strings.HasSuffix(cwdStr, "\r\n") { cwdStr = cwdStr[0 : len(cwdStr)-2] - } else if strings.HasSuffix(cwdStr, "\n") { - cwdStr = cwdStr[0 : len(cwdStr)-1] + } else { + cwdStr = strings.TrimSuffix(cwdStr, "\n") } rtn.Cwd = string(cwdStr) - err := bashParseDeclareOutput(rtn, fields[3], fields[6]) + err := bashParseDeclareOutput(rtn, sections[BashSection_Vars], sections[BashSection_PVars]) if err != nil { - return nil, err + return nil, nil, err } - rtn.Aliases = strings.ReplaceAll(string(fields[4]), "\r\n", "\n") - rtn.Funcs = strings.ReplaceAll(string(fields[5]), "\r\n", "\n") + rtn.Aliases = strings.ReplaceAll(string(sections[BashSection_Aliases]), "\r\n", "\n") + rtn.Funcs = strings.ReplaceAll(string(sections[BashSection_Funcs]), "\r\n", "\n") rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap") - return rtn, nil + return rtn, nil, nil } func bashNormalize(d *DeclareDeclType) error { diff --git a/waveshell/pkg/shellapi/shellapi.go b/waveshell/pkg/shellapi/shellapi.go index 1388f6cf7..c9e7fc018 100644 --- a/waveshell/pkg/shellapi/shellapi.go +++ b/waveshell/pkg/shellapi/shellapi.go @@ -7,7 +7,6 @@ import ( "bytes" "context" "fmt" - "io" "os" "os/exec" "os/user" @@ -28,12 +27,14 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/utilfn" ) -const GetStateTimeout = 5 * time.Second +const GetStateTimeout = 15 * time.Second const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"` const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"` const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"` const RunCommandFmt = `%s` const DebugState = false +const StateOutputFdNum = 20 +const NumRandomEndBytes = 8 var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`) @@ -56,23 +57,23 @@ const ( ) type ShellStateOutput struct { - Status string - StderrOutput []byte - ShellState *packet.ShellState - Error string + Output []byte + ShellState *packet.ShellState + Stats *packet.ShellStateStats + Error string } type ShellApi interface { GetShellType() string - MakeExitTrap(fdNum int) string + MakeExitTrap(fdNum int) (string, []byte) GetLocalMajorVersion() string GetLocalShellPath() string GetRemoteShellPath() string MakeRunCommand(cmdStr string, opts RunCommandOpts) string MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd - GetShellState() chan ShellStateOutput + GetShellState(chan ShellStateOutput) GetBaseShellOpts() string - ParseShellStateOutput(output []byte) (*packet.ShellState, error) + ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error) MakeRcFileStr(pk *packet.RunPacketType) string MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) @@ -153,12 +154,13 @@ func internalMacUserShell() string { const FirstExtraFilesFdNum = 3 // returns output(stdout+stderr), extraFdOutput, error -func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, error) { +func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte) ([]byte, error) { + defer close(outputCh) ecmd.Env = os.Environ() shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType)) cmdPty, cmdTty, err := pty.Open() if err != nil { - return nil, nil, fmt.Errorf("opening new pty: %w", err) + return nil, fmt.Errorf("opening new pty: %w", err) } defer cmdTty.Close() defer cmdPty.Close() @@ -171,42 +173,44 @@ func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, erro ecmd.SysProcAttr.Setctty = true pipeReader, pipeWriter, err := os.Pipe() if err != nil { - return nil, nil, fmt.Errorf("could not create pipe: %w", err) + return nil, fmt.Errorf("could not create pipe: %w", err) } defer pipeWriter.Close() defer pipeReader.Close() extraFiles := make([]*os.File, extraFdNum+1) extraFiles[extraFdNum] = pipeWriter ecmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:] - defer pipeReader.Close() - ecmd.Start() + err = ecmd.Start() cmdTty.Close() pipeWriter.Close() if err != nil { - return nil, nil, err + return nil, err } var outputWg sync.WaitGroup - var outputBuf bytes.Buffer var extraFdOutputBuf bytes.Buffer outputWg.Add(2) go func() { // ignore error (/dev/ptmx has read error when process is done) defer outputWg.Done() - io.Copy(&outputBuf, cmdPty) + err := utilfn.CopyToChannel(outputCh, cmdPty) + if err != nil { + errStr := fmt.Sprintf("\r\nerror reading from pty: %v\r\n", err) + outputCh <- []byte(errStr) + } }() go func() { defer outputWg.Done() - io.Copy(&extraFdOutputBuf, pipeReader) + utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes) }() exitErr := ecmd.Wait() if exitErr != nil { - return nil, nil, exitErr + return nil, exitErr } outputWg.Wait() - return outputBuf.Bytes(), extraFdOutputBuf.Bytes(), nil + return extraFdOutputBuf.Bytes(), nil } -func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) { +func RunSimpleCmdInPty(ecmd *exec.Cmd, endBytes []byte) ([]byte, error) { ecmd.Env = os.Environ() shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType)) cmdPty, cmdTty, err := pty.Open() @@ -231,8 +235,8 @@ func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) { var outputBuf bytes.Buffer go func() { // ignore error (/dev/ptmx has read error when process is done) - io.Copy(&outputBuf, cmdPty) - close(ioDone) + defer close(ioDone) + utilfn.CopyWithEndBytes(&outputBuf, cmdPty, endBytes) }() exitErr := ecmd.Wait() if exitErr != nil { diff --git a/waveshell/pkg/shellapi/zshapi.go b/waveshell/pkg/shellapi/zshapi.go index b4b3fa6af..aa6ad1c70 100644 --- a/waveshell/pkg/shellapi/zshapi.go +++ b/waveshell/pkg/shellapi/zshapi.go @@ -8,7 +8,6 @@ import ( "context" "errors" "fmt" - "math/rand" "os/exec" "strings" "sync" @@ -28,7 +27,6 @@ import ( const BaseZshOpts = `` const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION` -const StateOutputFdNum = 20 const ( ZshSection_Version = iota @@ -41,6 +39,7 @@ const ( ZshSection_Funcs ZshSection_PVars ZshSection_Prompt + ZshSection_EndBytes ZshSection_NumFieldsExpected // must be last ) @@ -118,6 +117,11 @@ var ZshIgnoreVars = map[string]bool{ "zcurses_windows": true, // not listed, but we also exclude all ZFTP_* variables + + // powerlevel10k + "_GITSTATUS_CLIENT_PID_POWERLEVEL9K": true, + "GITSTATUS_DAEMON_PID_POWERLEVEL9K": true, + "_GITSTATUS_FILE_PREFIX_POWERLEVEL9K": true, } var ZshIgnoreFuncs = map[string]bool{ @@ -211,7 +215,7 @@ func (z zshShellApi) GetShellType() string { return packet.ShellType_zsh } -func (z zshShellApi) MakeExitTrap(fdNum int) string { +func (z zshShellApi) MakeExitTrap(fdNum int) (string, []byte) { return MakeZshExitTrap(fdNum) } @@ -242,25 +246,34 @@ func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr) } -func (z zshShellApi) GetShellState() chan ShellStateOutput { +func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) { ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout) defer cancelFn() - rtnCh := make(chan ShellStateOutput, 1) - defer close(rtnCh) - cmdStr := BaseZshOpts + "; " + GetZshShellStateCmd(StateOutputFdNum) + defer close(outCh) + stateCmd, endBytes := GetZshShellStateCmd(StateOutputFdNum) + cmdStr := BaseZshOpts + "; " + stateCmd ecmd := exec.CommandContext(ctx, GetLocalZshPath(), "-l", "-i", "-c", cmdStr) - _, outputBytes, err := RunCommandWithExtraFd(ecmd, StateOutputFdNum) + outputCh := make(chan []byte, 10) + var outputWg sync.WaitGroup + outputWg.Add(1) + go func() { + defer outputWg.Done() + for outputBytes := range outputCh { + outCh <- ShellStateOutput{Output: outputBytes} + } + }() + outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes) + outputWg.Wait() if err != nil { - rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()} - return rtnCh + outCh <- ShellStateOutput{Error: err.Error()} + return } - rtn, err := z.ParseShellStateOutput(outputBytes) + rtn, stats, err := z.ParseShellStateOutput(outputBytes) if err != nil { - rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()} - return rtnCh + outCh <- ShellStateOutput{Error: err.Error()} + return } - rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, ShellState: rtn} - return rtnCh + outCh <- ShellStateOutput{ShellState: rtn, Stats: stats} } func (z zshShellApi) GetBaseShellOpts() string { @@ -437,19 +450,15 @@ func writeZshId(buf *bytes.Buffer, idStr string) { const numRandomBytes = 4 -// returns (cmd-string) -func GetZshShellStateCmd(fdNum int) string { +// returns (cmd-string, endbytes) +func GetZshShellStateCmd(fdNum int) (string, []byte) { var sectionSeparator []byte // adding this extra "\n" helps with debuging and readability of output sectionSeparator = append(sectionSeparator, byte('\n')) - for len(sectionSeparator) < numRandomBytes { - // any character *except* null (0) - rn := rand.Intn(256) - if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here - sectionSeparator = append(sectionSeparator, byte(rn)) - } - } + sectionSeparator = utilfn.AppendNonZeroRandomBytes(sectionSeparator, numRandomBytes) sectionSeparator = append(sectionSeparator, 0, 0) + endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes) + endBytes = append(endBytes, byte('\n')) // we have to use these crazy separators because zsh allows basically anything in // variable names and values (including nulls). // note that we don't need crazy separators for "env" or "typeset". @@ -511,6 +520,8 @@ printf "[%SECTIONSEP%]"; [%K8SNAMESPACE%] printf "[%SECTIONSEP%]"; print -P "$PS1" +printf "[%SECTIONSEP%]"; +printf "[%ENDBYTES%]" ` cmd = strings.TrimSpace(cmd) cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr) @@ -520,17 +531,19 @@ print -P "$PS1" cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1]))) cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator))) cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum)) - return cmd + cmd = strings.ReplaceAll(cmd, "[%OUTPUTFDNUM%]", fmt.Sprintf("%d", fdNum)) + cmd = strings.ReplaceAll(cmd, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes))) + return cmd, endBytes } -func MakeZshExitTrap(fdNum int) string { - stateCmd := GetZshShellStateCmd(fdNum) +func MakeZshExitTrap(fdNum int) (string, []byte) { + stateCmd, endBytes := GetZshShellStateCmd(fdNum) fmtStr := ` zshexit () { %s } ` - return fmt.Sprintf(fmtStr, stateCmd) + return fmt.Sprintf(fmtStr, stateCmd), endBytes } func execGetLocalZshShellVersion() string { @@ -698,14 +711,14 @@ func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string { return buf.String() } -func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) { +func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) { if scbase.IsDevMode() && DebugState { writeStateToFile(packet.ShellType_zsh, outputBytes) } firstZeroIdx := bytes.Index(outputBytes, []byte{0}) firstDZeroIdx := bytes.Index(outputBytes, []byte{0, 0}) if firstZeroIdx == -1 || firstDZeroIdx == -1 { - return nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes") + return nil, nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes") } versionStr := string(outputBytes[0:firstZeroIdx]) sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2] @@ -714,15 +727,15 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta sections := bytes.Split(outputBytes, sectionSeparator) if len(sections) != ZshSection_NumFieldsExpected { base.Logf("invalid -- numfields\n") - return nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections)) + return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections)) } rtn := &packet.ShellState{} rtn.Version = strings.TrimSpace(versionStr) if rtn.GetShellType() != packet.ShellType_zsh { - return nil, fmt.Errorf("invalid zsh shell state output, wrong shell type") + return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong shell type") } if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil { - return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err) + return nil, nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err) } cwdStr := stripNewLineChars(string(sections[ZshSection_Cwd])) rtn.Cwd = cwdStr @@ -730,7 +743,7 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta zshDecls, err := parseZshDecls(sections[ZshSection_Vars]) if err != nil { base.Logf("invalid - parsedecls %v\n", err) - return nil, err + return nil, nil, err } for _, decl := range zshDecls { if decl.IsZshScalarBound() { @@ -746,7 +759,17 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta pvarMap := parseExtVarOutput(sections[ZshSection_PVars], string(sections[ZshSection_Prompt]), string(sections[ZshSection_Mods])) utilfn.CombineMaps(zshDecls, pvarMap) rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls) - return rtn, nil + stats := &packet.ShellStateStats{ + Version: rtn.Version, + AliasCount: int(len(aliasMap)), + FuncCount: int(len(zshFuncs)), + VarCount: int(len(zshDecls)), + EnvCount: int(len(zshEnv)), + HashVal: rtn.GetHashVal(false), + OutputSize: int64(len(outputBytes)), + StateSize: rtn.ApproximateSize(), + } + return rtn, stats, nil } func parseZshEnv(output []byte) map[string]string { diff --git a/waveshell/pkg/shexec/shexec.go b/waveshell/pkg/shexec/shexec.go index 72a291a14..9ff6211da 100644 --- a/waveshell/pkg/shexec/shexec.go +++ b/waveshell/pkg/shexec/shexec.go @@ -7,7 +7,6 @@ import ( "bytes" "context" "encoding/base64" - "errors" "fmt" "io" "os" @@ -89,13 +88,14 @@ func MakeInstallCommandStr() string { type MShellBinaryReaderFn func(version string, goos string, goarch string) (io.ReadCloser, error) type ReturnStateBuf struct { - Lock *sync.Mutex - Buf []byte - Done bool - Err error - Reader *os.File - FdNum int - DoneCh chan bool + Lock *sync.Mutex + Buf []byte + Done bool + Err error + Reader *os.File + FdNum int + EndBytes []byte + DoneCh chan bool } func MakeReturnStateBuf() *ReturnStateBuf { @@ -835,7 +835,8 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro cmd.ReturnState.FdNum = RtnStateFdNum rtnStateWriter = pw defer pw.Close() - trapCmdStr := sapi.MakeExitTrap(cmd.ReturnState.FdNum) + trapCmdStr, endBytes := sapi.MakeExitTrap(cmd.ReturnState.FdNum) + cmd.ReturnState.EndBytes = endBytes rcFileStr += trapCmdStr } shellVarMap := shellenv.ShellVarMapFromState(state) @@ -1021,6 +1022,11 @@ func (rs *ReturnStateBuf) Run() { } rs.Lock.Lock() rs.Buf = append(rs.Buf, buf[0:n]...) + if bytes.HasSuffix(rs.Buf, rs.EndBytes) { + rs.Buf = rs.Buf[:len(rs.Buf)-len(rs.EndBytes)] + rs.Lock.Unlock() + break + } rs.Lock.Unlock() } } @@ -1127,7 +1133,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType { wlog.Logf("debug returnstate file %q\n", base.GetDebugReturnStateFileName()) os.WriteFile(base.GetDebugReturnStateFileName(), c.ReturnState.Buf, 0666) } - state, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error? + state, _, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error? donePacket.FinalState = state } endTs := time.Now() @@ -1156,21 +1162,6 @@ func MakeInitPacket() *packet.InitPacketType { return initPacket } -func MakeShellStatePacket(shellType string) (*packet.ShellStatePacketType, error) { - sapi, err := shellapi.MakeShellApi(shellType) - if err != nil { - return nil, err - } - rtnCh := sapi.GetShellState() - ssOutput := <-rtnCh - if ssOutput.Error != "" { - return nil, errors.New(ssOutput.Error) - } - rtn := packet.MakeShellStatePacket() - rtn.State = ssOutput.ShellState - return rtn, nil -} - func MakeServerInitPacket() (*packet.InitPacketType, error) { var err error initPacket := MakeInitPacket() diff --git a/waveshell/pkg/utilfn/utilfn.go b/waveshell/pkg/utilfn/utilfn.go index 643192307..8e8638e36 100644 --- a/waveshell/pkg/utilfn/utilfn.go +++ b/waveshell/pkg/utilfn/utilfn.go @@ -10,7 +10,9 @@ import ( "encoding/json" "errors" "fmt" + "io" "math" + mathrand "math/rand" "regexp" "sort" "strings" @@ -552,3 +554,60 @@ func StrArrayToMap(sarr []string) map[string]bool { } return m } + +func AppendNonZeroRandomBytes(b []byte, randLen int) []byte { + if randLen <= 0 { + return b + } + numAdded := 0 + for numAdded < randLen { + rn := mathrand.Intn(256) + if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here + b = append(b, byte(rn)) + numAdded++ + } + } + return b +} + +// returns (isEOF, error) +func CopyWithEndBytes(outputBuf *bytes.Buffer, reader io.Reader, endBytes []byte) (bool, error) { + buf := make([]byte, 4096) + for { + n, err := reader.Read(buf) + if n > 0 { + outputBuf.Write(buf[:n]) + obytes := outputBuf.Bytes() + if bytes.HasSuffix(obytes, endBytes) { + outputBuf.Truncate(len(obytes) - len(endBytes)) + return (err == io.EOF), nil + } + } + if err == io.EOF { + return true, nil + } + if err != nil { + return false, err + } + } +} + +// does *not* close outputCh on EOF or error +func CopyToChannel(outputCh chan<- []byte, reader io.Reader) error { + buf := make([]byte, 4096) + for { + n, err := reader.Read(buf) + if n > 0 { + // copy so client can use []byte without it being overwritten + bufCopy := make([]byte, n) + copy(bufCopy, buf[:n]) + outputCh <- bufCopy + } + if err == io.EOF { + return nil + } + if err != nil { + return err + } + } +} diff --git a/wavesrv/pkg/cmdrunner/cmdrunner.go b/wavesrv/pkg/cmdrunner/cmdrunner.go index 58a1f05cf..5b45c2166 100644 --- a/wavesrv/pkg/cmdrunner/cmdrunner.go +++ b/wavesrv/pkg/cmdrunner/cmdrunner.go @@ -41,6 +41,7 @@ import ( "github.com/wavetermdev/waveterm/wavesrv/pkg/releasechecker" "github.com/wavetermdev/waveterm/wavesrv/pkg/remote" "github.com/wavetermdev/waveterm/wavesrv/pkg/remote/openai" + "github.com/wavetermdev/waveterm/wavesrv/pkg/rtnstate" "github.com/wavetermdev/waveterm/wavesrv/pkg/scbase" "github.com/wavetermdev/waveterm/wavesrv/pkg/scbus" "github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket" @@ -1648,8 +1649,12 @@ func CopyFileCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scb } var outputPos int64 outputStr := fmt.Sprintf("Copying [%v]:%v to [%v]:%v\r\n", sourceRemoteId.DisplayName, sourceFullPath, destRemoteId.DisplayName, destFullPath) - termopts := sstore.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize} - cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), termopts) + termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM) + if err != nil { + return nil, fmt.Errorf("cannot make termopts: %w", err) + } + pkTermOpts := convertTermOpts(termOpts) + cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), *pkTermOpts) writeStringToPty(ctx, cmd, outputStr, &outputPos) if err != nil { // TODO tricky error since the command was a success, but we can't show the output @@ -3655,11 +3660,14 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbu return update, nil } -func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) { +func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (rtnUpdate scbus.UpdatePacket, rtnErr error) { ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote) if err != nil { return nil, err } + if !ids.Remote.MShell.IsConnected() { + return nil, fmt.Errorf("cannot reinit, remote is not connected") + } shellType := ids.Remote.ShellType if pk.Kwargs["shell"] != "" { shellArg := pk.Kwargs["shell"] @@ -3668,33 +3676,76 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) ( } shellType = shellArg } - ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType) + verbose := resolveBool(pk.Kwargs["verbose"], false) + termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM) if err != nil { - return nil, err + return nil, fmt.Errorf("cannot make termopts: %w", err) } - if ssPk == nil || ssPk.State == nil { - return nil, fmt.Errorf("invalid initpk received from remote (no remote state)") - } - feState := sstore.FeStateFromShellState(ssPk.State) - remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil) + pkTermOpts := convertTermOpts(termOpts) + cmd, err := makeDynCmd(ctx, "reset", ids, pk.GetRawStr(), *pkTermOpts) if err != nil { return nil, err } - outputStr := fmt.Sprintf("reset remote state (shell:%s)", ssPk.State.GetShellType()) - cmd, err := makeStaticCmd(ctx, "reset", ids, pk.GetRawStr(), []byte(outputStr)) - if err != nil { - // TODO tricky error since the command was a success, but we can't show the output - return nil, err - } update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil) if err != nil { - // TODO tricky error since the command was a success, but we can't show the output return nil, err } - update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst), sstore.InteractiveUpdate(pk.Interactive)) + go doResetCommand(ids, shellType, cmd, verbose) return update, nil } +func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) { + ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFn() + startTime := time.Now() + var outputPos int64 + var rtnErr error + exitSuccess := true + defer func() { + if rtnErr != nil { + exitSuccess = false + writeStringToPty(ctx, cmd, fmt.Sprintf("\r\nerror: %v", rtnErr), &outputPos) + } + deferWriteCmdStatus(ctx, cmd, startTime, exitSuccess, outputPos) + }() + dataFn := func(data []byte) { + writeStringToPty(ctx, cmd, string(data), &outputPos) + } + origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType) + ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType, dataFn, verbose) + if err != nil { + rtnErr = err + return + } + if ssPk == nil || ssPk.State == nil { + rtnErr = fmt.Errorf("invalid initpk received from remote (no remote state)") + return + } + feState := sstore.FeStateFromShellState(ssPk.State) + remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil) + if err != nil { + rtnErr = err + return + } + newStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType) + if verbose && origStatePtr != nil && newStatePtr != nil { + statePtrDiff := fmt.Sprintf("oldstate: %v, newstate: %v\r\n", origStatePtr.BaseHash, newStatePtr.BaseHash) + writeStringToPty(ctx, cmd, statePtrDiff, &outputPos) + origFullState, _ := sstore.GetFullState(ctx, *origStatePtr) + newFullState, _ := sstore.GetFullState(ctx, *newStatePtr) + if origFullState != nil && newFullState != nil { + var diffBuf bytes.Buffer + rtnstate.DisplayStateUpdateDiff(&diffBuf, *origFullState, *newFullState) + diffStr := diffBuf.String() + diffStr = strings.ReplaceAll(diffStr, "\n", "\r\n") + writeStringToPty(ctx, cmd, diffStr, &outputPos) + } + } + update := scbus.MakeUpdatePacket() + update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst)) + scbus.MainUpdateBus.DoUpdate(update) +} + func ResetCwdCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) { ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote) if err != nil { diff --git a/wavesrv/pkg/remote/remote.go b/wavesrv/pkg/remote/remote.go index 97c249d11..a1c658c7f 100644 --- a/wavesrv/pkg/remote/remote.go +++ b/wavesrv/pkg/remote/remote.go @@ -196,7 +196,7 @@ func (msh *MShellProc) EnsureShellType(ctx context.Context, shellType string) er return nil } // try to reinit the shell - _, err := msh.ReInit(ctx, shellType) + _, err := msh.ReInit(ctx, shellType, nil, false) if err != nil { return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err) } @@ -1401,33 +1401,60 @@ func makeReinitErrorUpdate(shellType string) sstore.ActivityUpdate { return rtn } -func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.ShellStatePacketType, error) { +func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) { if !msh.IsConnected() { return nil, fmt.Errorf("cannot reinit, remote is not connected") } if shellType != packet.ShellType_bash && shellType != packet.ShellType_zsh { return nil, fmt.Errorf("invalid shell type %q", shellType) } + if dataFn == nil { + dataFn = func([]byte) {} + } + defer func() { + if rtnErr != nil { + sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror") + } + }() + startTs := time.Now() reinitPk := packet.MakeReInitPacket() reinitPk.ReqId = uuid.New().String() reinitPk.ShellType = shellType - resp, err := msh.PacketRpcRaw(ctx, reinitPk) + rpcIter, err := msh.PacketRpcIter(ctx, reinitPk) if err != nil { return nil, err } - if resp == nil { - return nil, fmt.Errorf("no response") - } - ssPk, ok := resp.(*packet.ShellStatePacketType) - if !ok { - sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror") - if respPk, ok := resp.(*packet.ResponsePacketType); ok && respPk.Error != "" { - return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error) + defer rpcIter.Close() + var ssPk *packet.ShellStatePacketType + for { + resp, err := rpcIter.Next(ctx) + if err != nil { + return nil, err } - return nil, fmt.Errorf("invalid reinit response (not an shellstate packet): %T", resp) + if resp == nil { + return nil, fmt.Errorf("channel closed with no response") + } + var ok bool + ssPk, ok = resp.(*packet.ShellStatePacketType) + if ok { + break + } + respPk, ok := resp.(*packet.ResponsePacketType) + if ok { + if respPk.Error != "" { + return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error) + } + return nil, fmt.Errorf("invalid response from waveshell") + } + dataPk, ok := resp.(*packet.FileDataPacketType) + if ok { + dataFn(dataPk.Data) + continue + } + invalidPkStr := fmt.Sprintf("\r\ninvalid packettype from waveshell: %s\r\n", resp.GetType()) + dataFn([]byte(invalidPkStr)) } - if ssPk.State == nil { - sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror") + if ssPk == nil || ssPk.State == nil { return nil, fmt.Errorf("invalid reinit response shellstate packet does not contain remote state") } // TODO: maybe we don't need to save statebase here. should be possible to save it on demand @@ -1438,10 +1465,29 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.Sh return nil, fmt.Errorf("error storing remote state: %w", err) } msh.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State) - msh.WriteToPtyBuffer("initialized shell:%s state:%s\n", shellType, ssPk.State.GetHashVal(false)) + timeDur := time.Since(startTs) + dataFn([]byte(makeShellInitOutputMsg(verbose, ssPk.State, ssPk.Stats, timeDur, false))) + msh.WriteToPtyBuffer("%s", makeShellInitOutputMsg(false, ssPk.State, ssPk.Stats, timeDur, true)) return ssPk, nil } +func makeShellInitOutputMsg(verbose bool, state *packet.ShellState, stats *packet.ShellStateStats, dur time.Duration, ptyMsg bool) string { + if !verbose || ptyMsg { + if ptyMsg { + return fmt.Sprintf("initialized state shell:%s statehash:%s %dms\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds()) + } else { + return fmt.Sprintf("initialized connection state (shell:%s)\r\n", state.GetShellType()) + } + } + var buf bytes.Buffer + buf.WriteString("-----\r\n") + buf.WriteString(fmt.Sprintf("initialized connection shell:%s statehash:%s %dms\r\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds())) + if stats != nil { + buf.WriteString(fmt.Sprintf(" outsize:%s size:%s env:%d, vars:%d, aliases:%d, funcs:%d\r\n", scbase.NumFormatDec(stats.OutputSize), scbase.NumFormatDec(stats.StateSize), stats.EnvCount, stats.VarCount, stats.AliasCount, stats.FuncCount)) + } + return buf.String() +} + func (msh *MShellProc) WriteFile(ctx context.Context, writePk *packet.WriteFilePacketType) (*packet.RpcResponseIter, error) { return msh.PacketRpcIter(ctx, writePk) } @@ -1690,7 +1736,7 @@ func (msh *MShellProc) initActiveShells() { return } for _, shellType := range activeShells { - _, err = msh.ReInit(ctx, shellType) + _, err = msh.ReInit(ctx, shellType, nil, false) if err != nil { msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err) }