zsh reinit fixes (#477)

* reset command now initiates and completes async so there is feedback that something is happening when it takes a long time

* switch from standard rpc to rpciter

* checkpoint on reinit -- stream output, stats packet, logging to cmd pty, new endBytes for EOF

* make generic versions of endbytes scanner and channel output funcs

* update bash to use more modern state parsing (tricks learned from zsh)

* verbose mode, fix stats output message

* add a diff when verbose mode is on
This commit is contained in:
Mike Sawka 2024-03-19 16:38:38 -07:00 committed by GitHub
parent accb74ae0f
commit 5616c9abbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 427 additions and 170 deletions

View File

@ -602,6 +602,7 @@ type ShellStatePacketType struct {
ShellType string `json:"shelltype"` ShellType string `json:"shelltype"`
RespId string `json:"respid,omitempty"` RespId string `json:"respid,omitempty"`
State *ShellState `json:"state"` State *ShellState `json:"state"`
Stats *ShellStateStats `json:"stats"`
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
} }

View File

@ -19,6 +19,17 @@ import (
const ShellStatePackVersion = 0 const ShellStatePackVersion = 0
const ShellStateDiffPackVersion = 0 const ShellStateDiffPackVersion = 0
type ShellStateStats struct {
Version string `json:"version"`
AliasCount int `json:"aliascount"`
EnvCount int `json:"envcount"`
VarCount int `json:"varcount"`
FuncCount int `json:"funccount"`
HashVal string `json:"hashval"`
OutputSize int64 `json:"outputsize"`
StateSize int64 `json:"statesize"`
}
type ShellState struct { type ShellState struct {
Version string `json:"version"` // [type] [semver] Version string `json:"version"` // [type] [semver]
Cwd string `json:"cwd,omitempty"` Cwd string `json:"cwd,omitempty"`
@ -29,6 +40,10 @@ type ShellState struct {
HashVal string `json:"-"` HashVal string `json:"-"`
} }
func (state ShellState) ApproximateSize() int64 {
return int64(len(state.Version) + len(state.Cwd) + len(state.ShellVars) + len(state.Aliases) + len(state.Funcs) + len(state.Error))
}
type ShellStateDiff struct { type ShellStateDiff struct {
Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base) Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base)
BaseHash string `json:"basehash"` BaseHash string `json:"basehash"`

View File

@ -244,11 +244,10 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
appendSlashes(comps) appendSlashes(comps)
} }
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore}) m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
return
} }
func (m *MServer) reinit(reqId string, shellType string) { func (m *MServer) reinit(reqId string, shellType string) {
ssPk, err := shexec.MakeShellStatePacket(shellType) ssPk, err := m.MakeShellStatePacket(reqId, shellType)
if err != nil { if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err)) m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
return return
@ -262,6 +261,32 @@ func (m *MServer) reinit(reqId string, shellType string) {
m.Sender.SendPacket(ssPk) m.Sender.SendPacket(ssPk)
} }
func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) {
sapi, err := shellapi.MakeShellApi(shellType)
if err != nil {
return nil, err
}
rtnCh := make(chan shellapi.ShellStateOutput, 1)
go sapi.GetShellState(rtnCh)
for ssOutput := range rtnCh {
if ssOutput.Error != "" {
return nil, errors.New(ssOutput.Error)
}
if ssOutput.ShellState != nil {
rtn := packet.MakeShellStatePacket()
rtn.State = ssOutput.ShellState
rtn.Stats = ssOutput.Stats
return rtn, nil
}
if ssOutput.Output != nil {
dataPk := packet.MakeFileDataPacket(reqId)
dataPk.Data = ssOutput.Output
m.Sender.SendPacket(dataPk)
}
}
return nil, nil
}
func makeTemp(path string, mode fs.FileMode) (*os.File, error) { func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
dirName := filepath.Dir(path) dirName := filepath.Dir(path)
baseName := filepath.Base(path) baseName := filepath.Base(path)

View File

@ -16,6 +16,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv" "github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff" "github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
) )
const BaseBashOpts = `set +m; set +H; shopt -s extglob` const BaseBashOpts = `set +m; set +H; shopt -s extglob`
@ -48,7 +49,7 @@ func (b bashShellApi) GetShellType() string {
return packet.ShellType_bash return packet.ShellType_bash
} }
func (b bashShellApi) MakeExitTrap(fdNum int) string { func (b bashShellApi) MakeExitTrap(fdNum int) (string, []byte) {
return MakeBashExitTrap(fdNum) return MakeBashExitTrap(fdNum)
} }
@ -79,29 +80,15 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
return MakeBashShExecCommand(cmdStr, rcFileName, usePty) return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
} }
func (b bashShellApi) GetShellState() chan ShellStateOutput { func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) {
ch := make(chan ShellStateOutput, 1) GetBashShellState(outCh)
defer close(ch)
ssPk, err := GetBashShellState()
if err != nil {
ch <- ShellStateOutput{
Status: ShellStateOutputStatus_Done,
Error: err.Error(),
}
return ch
}
ch <- ShellStateOutput{
Status: ShellStateOutputStatus_Done,
ShellState: ssPk,
}
return ch
} }
func (b bashShellApi) GetBaseShellOpts() string { func (b bashShellApi) GetBaseShellOpts() string {
return BaseBashOpts return BaseBashOpts
} }
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, error) { func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
return parseBashShellStateOutput(output) return parseBashShellStateOutput(output)
} }
@ -130,8 +117,32 @@ func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
return rcBuf.String() return rcBuf.String()
} }
func GetBashShellStateCmd() string { func GetBashShellStateCmd(fdNum int) (string, []byte) {
return strings.Join(GetBashShellStateCmds, ` printf "\x00\x00";`) endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes)
endBytes = append(endBytes, '\n')
cmdStr := strings.TrimSpace(`
exec 2> /dev/null;
exec > [%OUTPUTFD%];
printf "\x00\x00";
[%BASHVERSIONCMD%];
printf "\x00\x00";
pwd;
printf "\x00\x00";
declare -p $(compgen -A variable);
printf "\x00\x00";
alias -p;
printf "\x00\x00";
declare -f;
printf "\x00\x00";
[%GITBRANCHCMD%];
printf "\x00\x00";
printf "[%ENDBYTES%]";
`)
cmdStr = strings.ReplaceAll(cmdStr, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
cmdStr = strings.ReplaceAll(cmdStr, "[%BASHVERSIONCMD%]", BashShellVersionCmdStr)
cmdStr = strings.ReplaceAll(cmdStr, "[%GITBRANCHCMD%]", GetGitBranchCmdStr)
cmdStr = strings.ReplaceAll(cmdStr, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes)))
return cmdStr, endBytes
} }
func execGetLocalBashShellVersion() string { func execGetLocalBashShellVersion() string {
@ -158,16 +169,34 @@ func GetLocalBashMajorVersion() string {
return localBashMajorVersion return localBashMajorVersion
} }
func GetBashShellState() (*packet.ShellState, error) { func GetBashShellState(outCh chan ShellStateOutput) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout) ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn() defer cancelFn()
cmdStr := BaseBashOpts + "; " + GetBashShellStateCmd() defer close(outCh)
stateCmd, endBytes := GetBashShellStateCmd(StateOutputFdNum)
cmdStr := BaseBashOpts + "; " + stateCmd
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr) ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
outputBytes, err := RunSimpleCmdInPty(ecmd) outputCh := make(chan []byte, 10)
if err != nil { var outputWg sync.WaitGroup
return nil, err outputWg.Add(1)
go func() {
defer outputWg.Done()
for outputBytes := range outputCh {
outCh <- ShellStateOutput{Output: outputBytes}
} }
return parseBashShellStateOutput(outputBytes) }()
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
outputWg.Wait()
if err != nil {
outCh <- ShellStateOutput{Error: err.Error()}
return
}
rtn, stats, err := parseBashShellStateOutput(outputBytes)
if err != nil {
outCh <- ShellStateOutput{Error: err.Error()}
return
}
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
} }
func GetLocalBashPath() string { func GetLocalBashPath() string {
@ -190,19 +219,20 @@ func GetLocalZshPath() string {
return "zsh" return "zsh"
} }
func GetBashShellStateRedirectCommandStr(outputFdNum int) string { func GetBashShellStateRedirectCommandStr(outputFdNum int) (string, []byte) {
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetBashShellStateCmd(), outputFdNum) cmdStr, endBytes := GetBashShellStateCmd(outputFdNum)
return cmdStr, endBytes
} }
func MakeBashExitTrap(fdNum int) string { func MakeBashExitTrap(fdNum int) (string, []byte) {
stateCmd := GetBashShellStateRedirectCommandStr(fdNum) stateCmd, endBytes := GetBashShellStateRedirectCommandStr(fdNum)
fmtStr := ` fmtStr := `
_waveshell_exittrap () { _waveshell_exittrap () {
%s %s
} }
trap _waveshell_exittrap EXIT trap _waveshell_exittrap EXIT
` `
return fmt.Sprintf(fmtStr, stateCmd) return fmt.Sprintf(fmtStr, stateCmd), endBytes
} }
func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd { func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {

View File

@ -20,6 +20,19 @@ import (
"mvdan.cc/sh/v3/syntax" "mvdan.cc/sh/v3/syntax"
) )
const (
BashSection_Ignored = iota
BashSection_Version
BashSection_Cwd
BashSection_Vars
BashSection_Aliases
BashSection_Funcs
BashSection_PVars
BashSection_EndBytes
BashSection_Count // must be last
)
type DeclareDeclType = shellenv.DeclareDeclType type DeclareDeclType = shellenv.DeclareDeclType
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error { func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
@ -214,38 +227,37 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
return nil return nil
} }
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, error) { func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
if scbase.IsDevMode() && DebugState { if scbase.IsDevMode() && DebugState {
writeStateToFile(packet.ShellType_bash, outputBytes) writeStateToFile(packet.ShellType_bash, outputBytes)
} }
// 7 fields: ignored [0], version [1], cwd [2], env/vars [3], aliases [4], funcs [5], pvars [6] sections := bytes.Split(outputBytes, []byte{0, 0})
fields := bytes.Split(outputBytes, []byte{0, 0}) if len(sections) != BashSection_Count {
if len(fields) != 7 { return nil, nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(sections))
return nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(fields))
} }
rtn := &packet.ShellState{} rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(string(fields[1])) rtn.Version = strings.TrimSpace(string(sections[BashSection_Version]))
if rtn.GetShellType() != packet.ShellType_bash { if rtn.GetShellType() != packet.ShellType_bash {
return nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version) return nil, nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
} }
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil { if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
return nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err) return nil, nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
} }
cwdStr := string(fields[2]) cwdStr := string(sections[BashSection_Cwd])
if strings.HasSuffix(cwdStr, "\r\n") { if strings.HasSuffix(cwdStr, "\r\n") {
cwdStr = cwdStr[0 : len(cwdStr)-2] cwdStr = cwdStr[0 : len(cwdStr)-2]
} else if strings.HasSuffix(cwdStr, "\n") { } else {
cwdStr = cwdStr[0 : len(cwdStr)-1] cwdStr = strings.TrimSuffix(cwdStr, "\n")
} }
rtn.Cwd = string(cwdStr) rtn.Cwd = string(cwdStr)
err := bashParseDeclareOutput(rtn, fields[3], fields[6]) err := bashParseDeclareOutput(rtn, sections[BashSection_Vars], sections[BashSection_PVars])
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
rtn.Aliases = strings.ReplaceAll(string(fields[4]), "\r\n", "\n") rtn.Aliases = strings.ReplaceAll(string(sections[BashSection_Aliases]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(fields[5]), "\r\n", "\n") rtn.Funcs = strings.ReplaceAll(string(sections[BashSection_Funcs]), "\r\n", "\n")
rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap") rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap")
return rtn, nil return rtn, nil, nil
} }
func bashNormalize(d *DeclareDeclType) error { func bashNormalize(d *DeclareDeclType) error {

View File

@ -7,7 +7,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"os" "os"
"os/exec" "os/exec"
"os/user" "os/user"
@ -28,12 +27,14 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn" "github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
) )
const GetStateTimeout = 5 * time.Second const GetStateTimeout = 15 * time.Second
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"` const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"` const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"` const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
const RunCommandFmt = `%s` const RunCommandFmt = `%s`
const DebugState = false const DebugState = false
const StateOutputFdNum = 20
const NumRandomEndBytes = 8
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`) var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
@ -56,23 +57,23 @@ const (
) )
type ShellStateOutput struct { type ShellStateOutput struct {
Status string Output []byte
StderrOutput []byte
ShellState *packet.ShellState ShellState *packet.ShellState
Stats *packet.ShellStateStats
Error string Error string
} }
type ShellApi interface { type ShellApi interface {
GetShellType() string GetShellType() string
MakeExitTrap(fdNum int) string MakeExitTrap(fdNum int) (string, []byte)
GetLocalMajorVersion() string GetLocalMajorVersion() string
GetLocalShellPath() string GetLocalShellPath() string
GetRemoteShellPath() string GetRemoteShellPath() string
MakeRunCommand(cmdStr string, opts RunCommandOpts) string MakeRunCommand(cmdStr string, opts RunCommandOpts) string
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
GetShellState() chan ShellStateOutput GetShellState(chan ShellStateOutput)
GetBaseShellOpts() string GetBaseShellOpts() string
ParseShellStateOutput(output []byte) (*packet.ShellState, error) ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error)
MakeRcFileStr(pk *packet.RunPacketType) string MakeRcFileStr(pk *packet.RunPacketType) string
MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error) MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error)
ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error)
@ -153,12 +154,13 @@ func internalMacUserShell() string {
const FirstExtraFilesFdNum = 3 const FirstExtraFilesFdNum = 3
// returns output(stdout+stderr), extraFdOutput, error // returns output(stdout+stderr), extraFdOutput, error
func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, error) { func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte) ([]byte, error) {
defer close(outputCh)
ecmd.Env = os.Environ() ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType)) shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
cmdPty, cmdTty, err := pty.Open() cmdPty, cmdTty, err := pty.Open()
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("opening new pty: %w", err) return nil, fmt.Errorf("opening new pty: %w", err)
} }
defer cmdTty.Close() defer cmdTty.Close()
defer cmdPty.Close() defer cmdPty.Close()
@ -171,42 +173,44 @@ func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, erro
ecmd.SysProcAttr.Setctty = true ecmd.SysProcAttr.Setctty = true
pipeReader, pipeWriter, err := os.Pipe() pipeReader, pipeWriter, err := os.Pipe()
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("could not create pipe: %w", err) return nil, fmt.Errorf("could not create pipe: %w", err)
} }
defer pipeWriter.Close() defer pipeWriter.Close()
defer pipeReader.Close() defer pipeReader.Close()
extraFiles := make([]*os.File, extraFdNum+1) extraFiles := make([]*os.File, extraFdNum+1)
extraFiles[extraFdNum] = pipeWriter extraFiles[extraFdNum] = pipeWriter
ecmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:] ecmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:]
defer pipeReader.Close() err = ecmd.Start()
ecmd.Start()
cmdTty.Close() cmdTty.Close()
pipeWriter.Close() pipeWriter.Close()
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
var outputWg sync.WaitGroup var outputWg sync.WaitGroup
var outputBuf bytes.Buffer
var extraFdOutputBuf bytes.Buffer var extraFdOutputBuf bytes.Buffer
outputWg.Add(2) outputWg.Add(2)
go func() { go func() {
// ignore error (/dev/ptmx has read error when process is done) // ignore error (/dev/ptmx has read error when process is done)
defer outputWg.Done() defer outputWg.Done()
io.Copy(&outputBuf, cmdPty) err := utilfn.CopyToChannel(outputCh, cmdPty)
if err != nil {
errStr := fmt.Sprintf("\r\nerror reading from pty: %v\r\n", err)
outputCh <- []byte(errStr)
}
}() }()
go func() { go func() {
defer outputWg.Done() defer outputWg.Done()
io.Copy(&extraFdOutputBuf, pipeReader) utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes)
}() }()
exitErr := ecmd.Wait() exitErr := ecmd.Wait()
if exitErr != nil { if exitErr != nil {
return nil, nil, exitErr return nil, exitErr
} }
outputWg.Wait() outputWg.Wait()
return outputBuf.Bytes(), extraFdOutputBuf.Bytes(), nil return extraFdOutputBuf.Bytes(), nil
} }
func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) { func RunSimpleCmdInPty(ecmd *exec.Cmd, endBytes []byte) ([]byte, error) {
ecmd.Env = os.Environ() ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType)) shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
cmdPty, cmdTty, err := pty.Open() cmdPty, cmdTty, err := pty.Open()
@ -231,8 +235,8 @@ func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
var outputBuf bytes.Buffer var outputBuf bytes.Buffer
go func() { go func() {
// ignore error (/dev/ptmx has read error when process is done) // ignore error (/dev/ptmx has read error when process is done)
io.Copy(&outputBuf, cmdPty) defer close(ioDone)
close(ioDone) utilfn.CopyWithEndBytes(&outputBuf, cmdPty, endBytes)
}() }()
exitErr := ecmd.Wait() exitErr := ecmd.Wait()
if exitErr != nil { if exitErr != nil {

View File

@ -8,7 +8,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"os/exec" "os/exec"
"strings" "strings"
"sync" "sync"
@ -28,7 +27,6 @@ import (
const BaseZshOpts = `` const BaseZshOpts = ``
const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION` const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION`
const StateOutputFdNum = 20
const ( const (
ZshSection_Version = iota ZshSection_Version = iota
@ -41,6 +39,7 @@ const (
ZshSection_Funcs ZshSection_Funcs
ZshSection_PVars ZshSection_PVars
ZshSection_Prompt ZshSection_Prompt
ZshSection_EndBytes
ZshSection_NumFieldsExpected // must be last ZshSection_NumFieldsExpected // must be last
) )
@ -118,6 +117,11 @@ var ZshIgnoreVars = map[string]bool{
"zcurses_windows": true, "zcurses_windows": true,
// not listed, but we also exclude all ZFTP_* variables // not listed, but we also exclude all ZFTP_* variables
// powerlevel10k
"_GITSTATUS_CLIENT_PID_POWERLEVEL9K": true,
"GITSTATUS_DAEMON_PID_POWERLEVEL9K": true,
"_GITSTATUS_FILE_PREFIX_POWERLEVEL9K": true,
} }
var ZshIgnoreFuncs = map[string]bool{ var ZshIgnoreFuncs = map[string]bool{
@ -211,7 +215,7 @@ func (z zshShellApi) GetShellType() string {
return packet.ShellType_zsh return packet.ShellType_zsh
} }
func (z zshShellApi) MakeExitTrap(fdNum int) string { func (z zshShellApi) MakeExitTrap(fdNum int) (string, []byte) {
return MakeZshExitTrap(fdNum) return MakeZshExitTrap(fdNum)
} }
@ -242,25 +246,34 @@ func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr) return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
} }
func (z zshShellApi) GetShellState() chan ShellStateOutput { func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout) ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn() defer cancelFn()
rtnCh := make(chan ShellStateOutput, 1) defer close(outCh)
defer close(rtnCh) stateCmd, endBytes := GetZshShellStateCmd(StateOutputFdNum)
cmdStr := BaseZshOpts + "; " + GetZshShellStateCmd(StateOutputFdNum) cmdStr := BaseZshOpts + "; " + stateCmd
ecmd := exec.CommandContext(ctx, GetLocalZshPath(), "-l", "-i", "-c", cmdStr) ecmd := exec.CommandContext(ctx, GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
_, outputBytes, err := RunCommandWithExtraFd(ecmd, StateOutputFdNum) outputCh := make(chan []byte, 10)
if err != nil { var outputWg sync.WaitGroup
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()} outputWg.Add(1)
return rtnCh go func() {
defer outputWg.Done()
for outputBytes := range outputCh {
outCh <- ShellStateOutput{Output: outputBytes}
} }
rtn, err := z.ParseShellStateOutput(outputBytes) }()
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
outputWg.Wait()
if err != nil { if err != nil {
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()} outCh <- ShellStateOutput{Error: err.Error()}
return rtnCh return
} }
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, ShellState: rtn} rtn, stats, err := z.ParseShellStateOutput(outputBytes)
return rtnCh if err != nil {
outCh <- ShellStateOutput{Error: err.Error()}
return
}
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
} }
func (z zshShellApi) GetBaseShellOpts() string { func (z zshShellApi) GetBaseShellOpts() string {
@ -437,19 +450,15 @@ func writeZshId(buf *bytes.Buffer, idStr string) {
const numRandomBytes = 4 const numRandomBytes = 4
// returns (cmd-string) // returns (cmd-string, endbytes)
func GetZshShellStateCmd(fdNum int) string { func GetZshShellStateCmd(fdNum int) (string, []byte) {
var sectionSeparator []byte var sectionSeparator []byte
// adding this extra "\n" helps with debuging and readability of output // adding this extra "\n" helps with debuging and readability of output
sectionSeparator = append(sectionSeparator, byte('\n')) sectionSeparator = append(sectionSeparator, byte('\n'))
for len(sectionSeparator) < numRandomBytes { sectionSeparator = utilfn.AppendNonZeroRandomBytes(sectionSeparator, numRandomBytes)
// any character *except* null (0)
rn := rand.Intn(256)
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
sectionSeparator = append(sectionSeparator, byte(rn))
}
}
sectionSeparator = append(sectionSeparator, 0, 0) sectionSeparator = append(sectionSeparator, 0, 0)
endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes)
endBytes = append(endBytes, byte('\n'))
// we have to use these crazy separators because zsh allows basically anything in // we have to use these crazy separators because zsh allows basically anything in
// variable names and values (including nulls). // variable names and values (including nulls).
// note that we don't need crazy separators for "env" or "typeset". // note that we don't need crazy separators for "env" or "typeset".
@ -511,6 +520,8 @@ printf "[%SECTIONSEP%]";
[%K8SNAMESPACE%] [%K8SNAMESPACE%]
printf "[%SECTIONSEP%]"; printf "[%SECTIONSEP%]";
print -P "$PS1" print -P "$PS1"
printf "[%SECTIONSEP%]";
printf "[%ENDBYTES%]"
` `
cmd = strings.TrimSpace(cmd) cmd = strings.TrimSpace(cmd)
cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr) cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr)
@ -520,17 +531,19 @@ print -P "$PS1"
cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1]))) cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1])))
cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator))) cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator)))
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum)) cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
return cmd cmd = strings.ReplaceAll(cmd, "[%OUTPUTFDNUM%]", fmt.Sprintf("%d", fdNum))
cmd = strings.ReplaceAll(cmd, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes)))
return cmd, endBytes
} }
func MakeZshExitTrap(fdNum int) string { func MakeZshExitTrap(fdNum int) (string, []byte) {
stateCmd := GetZshShellStateCmd(fdNum) stateCmd, endBytes := GetZshShellStateCmd(fdNum)
fmtStr := ` fmtStr := `
zshexit () { zshexit () {
%s %s
} }
` `
return fmt.Sprintf(fmtStr, stateCmd) return fmt.Sprintf(fmtStr, stateCmd), endBytes
} }
func execGetLocalZshShellVersion() string { func execGetLocalZshShellVersion() string {
@ -698,14 +711,14 @@ func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string {
return buf.String() return buf.String()
} }
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) { func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
if scbase.IsDevMode() && DebugState { if scbase.IsDevMode() && DebugState {
writeStateToFile(packet.ShellType_zsh, outputBytes) writeStateToFile(packet.ShellType_zsh, outputBytes)
} }
firstZeroIdx := bytes.Index(outputBytes, []byte{0}) firstZeroIdx := bytes.Index(outputBytes, []byte{0})
firstDZeroIdx := bytes.Index(outputBytes, []byte{0, 0}) firstDZeroIdx := bytes.Index(outputBytes, []byte{0, 0})
if firstZeroIdx == -1 || firstDZeroIdx == -1 { if firstZeroIdx == -1 || firstDZeroIdx == -1 {
return nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes") return nil, nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes")
} }
versionStr := string(outputBytes[0:firstZeroIdx]) versionStr := string(outputBytes[0:firstZeroIdx])
sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2] sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2]
@ -714,15 +727,15 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
sections := bytes.Split(outputBytes, sectionSeparator) sections := bytes.Split(outputBytes, sectionSeparator)
if len(sections) != ZshSection_NumFieldsExpected { if len(sections) != ZshSection_NumFieldsExpected {
base.Logf("invalid -- numfields\n") base.Logf("invalid -- numfields\n")
return nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections)) return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
} }
rtn := &packet.ShellState{} rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(versionStr) rtn.Version = strings.TrimSpace(versionStr)
if rtn.GetShellType() != packet.ShellType_zsh { if rtn.GetShellType() != packet.ShellType_zsh {
return nil, fmt.Errorf("invalid zsh shell state output, wrong shell type") return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong shell type")
} }
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil { if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err) return nil, nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
} }
cwdStr := stripNewLineChars(string(sections[ZshSection_Cwd])) cwdStr := stripNewLineChars(string(sections[ZshSection_Cwd]))
rtn.Cwd = cwdStr rtn.Cwd = cwdStr
@ -730,7 +743,7 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
zshDecls, err := parseZshDecls(sections[ZshSection_Vars]) zshDecls, err := parseZshDecls(sections[ZshSection_Vars])
if err != nil { if err != nil {
base.Logf("invalid - parsedecls %v\n", err) base.Logf("invalid - parsedecls %v\n", err)
return nil, err return nil, nil, err
} }
for _, decl := range zshDecls { for _, decl := range zshDecls {
if decl.IsZshScalarBound() { if decl.IsZshScalarBound() {
@ -746,7 +759,17 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
pvarMap := parseExtVarOutput(sections[ZshSection_PVars], string(sections[ZshSection_Prompt]), string(sections[ZshSection_Mods])) pvarMap := parseExtVarOutput(sections[ZshSection_PVars], string(sections[ZshSection_Prompt]), string(sections[ZshSection_Mods]))
utilfn.CombineMaps(zshDecls, pvarMap) utilfn.CombineMaps(zshDecls, pvarMap)
rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls) rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls)
return rtn, nil stats := &packet.ShellStateStats{
Version: rtn.Version,
AliasCount: int(len(aliasMap)),
FuncCount: int(len(zshFuncs)),
VarCount: int(len(zshDecls)),
EnvCount: int(len(zshEnv)),
HashVal: rtn.GetHashVal(false),
OutputSize: int64(len(outputBytes)),
StateSize: rtn.ApproximateSize(),
}
return rtn, stats, nil
} }
func parseZshEnv(output []byte) map[string]string { func parseZshEnv(output []byte) map[string]string {

View File

@ -7,7 +7,6 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -95,6 +94,7 @@ type ReturnStateBuf struct {
Err error Err error
Reader *os.File Reader *os.File
FdNum int FdNum int
EndBytes []byte
DoneCh chan bool DoneCh chan bool
} }
@ -835,7 +835,8 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
cmd.ReturnState.FdNum = RtnStateFdNum cmd.ReturnState.FdNum = RtnStateFdNum
rtnStateWriter = pw rtnStateWriter = pw
defer pw.Close() defer pw.Close()
trapCmdStr := sapi.MakeExitTrap(cmd.ReturnState.FdNum) trapCmdStr, endBytes := sapi.MakeExitTrap(cmd.ReturnState.FdNum)
cmd.ReturnState.EndBytes = endBytes
rcFileStr += trapCmdStr rcFileStr += trapCmdStr
} }
shellVarMap := shellenv.ShellVarMapFromState(state) shellVarMap := shellenv.ShellVarMapFromState(state)
@ -1021,6 +1022,11 @@ func (rs *ReturnStateBuf) Run() {
} }
rs.Lock.Lock() rs.Lock.Lock()
rs.Buf = append(rs.Buf, buf[0:n]...) rs.Buf = append(rs.Buf, buf[0:n]...)
if bytes.HasSuffix(rs.Buf, rs.EndBytes) {
rs.Buf = rs.Buf[:len(rs.Buf)-len(rs.EndBytes)]
rs.Lock.Unlock()
break
}
rs.Lock.Unlock() rs.Lock.Unlock()
} }
} }
@ -1127,7 +1133,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
wlog.Logf("debug returnstate file %q\n", base.GetDebugReturnStateFileName()) wlog.Logf("debug returnstate file %q\n", base.GetDebugReturnStateFileName())
os.WriteFile(base.GetDebugReturnStateFileName(), c.ReturnState.Buf, 0666) os.WriteFile(base.GetDebugReturnStateFileName(), c.ReturnState.Buf, 0666)
} }
state, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error? state, _, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
donePacket.FinalState = state donePacket.FinalState = state
} }
endTs := time.Now() endTs := time.Now()
@ -1156,21 +1162,6 @@ func MakeInitPacket() *packet.InitPacketType {
return initPacket return initPacket
} }
func MakeShellStatePacket(shellType string) (*packet.ShellStatePacketType, error) {
sapi, err := shellapi.MakeShellApi(shellType)
if err != nil {
return nil, err
}
rtnCh := sapi.GetShellState()
ssOutput := <-rtnCh
if ssOutput.Error != "" {
return nil, errors.New(ssOutput.Error)
}
rtn := packet.MakeShellStatePacket()
rtn.State = ssOutput.ShellState
return rtn, nil
}
func MakeServerInitPacket() (*packet.InitPacketType, error) { func MakeServerInitPacket() (*packet.InitPacketType, error) {
var err error var err error
initPacket := MakeInitPacket() initPacket := MakeInitPacket()

View File

@ -10,7 +10,9 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"math" "math"
mathrand "math/rand"
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
@ -552,3 +554,60 @@ func StrArrayToMap(sarr []string) map[string]bool {
} }
return m return m
} }
func AppendNonZeroRandomBytes(b []byte, randLen int) []byte {
if randLen <= 0 {
return b
}
numAdded := 0
for numAdded < randLen {
rn := mathrand.Intn(256)
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
b = append(b, byte(rn))
numAdded++
}
}
return b
}
// returns (isEOF, error)
func CopyWithEndBytes(outputBuf *bytes.Buffer, reader io.Reader, endBytes []byte) (bool, error) {
buf := make([]byte, 4096)
for {
n, err := reader.Read(buf)
if n > 0 {
outputBuf.Write(buf[:n])
obytes := outputBuf.Bytes()
if bytes.HasSuffix(obytes, endBytes) {
outputBuf.Truncate(len(obytes) - len(endBytes))
return (err == io.EOF), nil
}
}
if err == io.EOF {
return true, nil
}
if err != nil {
return false, err
}
}
}
// does *not* close outputCh on EOF or error
func CopyToChannel(outputCh chan<- []byte, reader io.Reader) error {
buf := make([]byte, 4096)
for {
n, err := reader.Read(buf)
if n > 0 {
// copy so client can use []byte without it being overwritten
bufCopy := make([]byte, n)
copy(bufCopy, buf[:n])
outputCh <- bufCopy
}
if err == io.EOF {
return nil
}
if err != nil {
return err
}
}
}

View File

@ -41,6 +41,7 @@ import (
"github.com/wavetermdev/waveterm/wavesrv/pkg/releasechecker" "github.com/wavetermdev/waveterm/wavesrv/pkg/releasechecker"
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote" "github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote/openai" "github.com/wavetermdev/waveterm/wavesrv/pkg/remote/openai"
"github.com/wavetermdev/waveterm/wavesrv/pkg/rtnstate"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase" "github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus" "github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket" "github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
@ -1648,8 +1649,12 @@ func CopyFileCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scb
} }
var outputPos int64 var outputPos int64
outputStr := fmt.Sprintf("Copying [%v]:%v to [%v]:%v\r\n", sourceRemoteId.DisplayName, sourceFullPath, destRemoteId.DisplayName, destFullPath) outputStr := fmt.Sprintf("Copying [%v]:%v to [%v]:%v\r\n", sourceRemoteId.DisplayName, sourceFullPath, destRemoteId.DisplayName, destFullPath)
termopts := sstore.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize} termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM)
cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), termopts) if err != nil {
return nil, fmt.Errorf("cannot make termopts: %w", err)
}
pkTermOpts := convertTermOpts(termOpts)
cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), *pkTermOpts)
writeStringToPty(ctx, cmd, outputStr, &outputPos) writeStringToPty(ctx, cmd, outputStr, &outputPos)
if err != nil { if err != nil {
// TODO tricky error since the command was a success, but we can't show the output // TODO tricky error since the command was a success, but we can't show the output
@ -3655,11 +3660,14 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbu
return update, nil return update, nil
} }
func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) { func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (rtnUpdate scbus.UpdatePacket, rtnErr error) {
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote) ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !ids.Remote.MShell.IsConnected() {
return nil, fmt.Errorf("cannot reinit, remote is not connected")
}
shellType := ids.Remote.ShellType shellType := ids.Remote.ShellType
if pk.Kwargs["shell"] != "" { if pk.Kwargs["shell"] != "" {
shellArg := pk.Kwargs["shell"] shellArg := pk.Kwargs["shell"]
@ -3668,33 +3676,76 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
} }
shellType = shellArg shellType = shellArg
} }
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType) verbose := resolveBool(pk.Kwargs["verbose"], false)
termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("cannot make termopts: %w", err)
} }
if ssPk == nil || ssPk.State == nil { pkTermOpts := convertTermOpts(termOpts)
return nil, fmt.Errorf("invalid initpk received from remote (no remote state)") cmd, err := makeDynCmd(ctx, "reset", ids, pk.GetRawStr(), *pkTermOpts)
}
feState := sstore.FeStateFromShellState(ssPk.State)
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputStr := fmt.Sprintf("reset remote state (shell:%s)", ssPk.State.GetShellType())
cmd, err := makeStaticCmd(ctx, "reset", ids, pk.GetRawStr(), []byte(outputStr))
if err != nil {
// TODO tricky error since the command was a success, but we can't show the output
return nil, err
}
update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil) update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil)
if err != nil { if err != nil {
// TODO tricky error since the command was a success, but we can't show the output
return nil, err return nil, err
} }
update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst), sstore.InteractiveUpdate(pk.Interactive)) go doResetCommand(ids, shellType, cmd, verbose)
return update, nil return update, nil
} }
func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) {
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFn()
startTime := time.Now()
var outputPos int64
var rtnErr error
exitSuccess := true
defer func() {
if rtnErr != nil {
exitSuccess = false
writeStringToPty(ctx, cmd, fmt.Sprintf("\r\nerror: %v", rtnErr), &outputPos)
}
deferWriteCmdStatus(ctx, cmd, startTime, exitSuccess, outputPos)
}()
dataFn := func(data []byte) {
writeStringToPty(ctx, cmd, string(data), &outputPos)
}
origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType, dataFn, verbose)
if err != nil {
rtnErr = err
return
}
if ssPk == nil || ssPk.State == nil {
rtnErr = fmt.Errorf("invalid initpk received from remote (no remote state)")
return
}
feState := sstore.FeStateFromShellState(ssPk.State)
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
if err != nil {
rtnErr = err
return
}
newStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
if verbose && origStatePtr != nil && newStatePtr != nil {
statePtrDiff := fmt.Sprintf("oldstate: %v, newstate: %v\r\n", origStatePtr.BaseHash, newStatePtr.BaseHash)
writeStringToPty(ctx, cmd, statePtrDiff, &outputPos)
origFullState, _ := sstore.GetFullState(ctx, *origStatePtr)
newFullState, _ := sstore.GetFullState(ctx, *newStatePtr)
if origFullState != nil && newFullState != nil {
var diffBuf bytes.Buffer
rtnstate.DisplayStateUpdateDiff(&diffBuf, *origFullState, *newFullState)
diffStr := diffBuf.String()
diffStr = strings.ReplaceAll(diffStr, "\n", "\r\n")
writeStringToPty(ctx, cmd, diffStr, &outputPos)
}
}
update := scbus.MakeUpdatePacket()
update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst))
scbus.MainUpdateBus.DoUpdate(update)
}
func ResetCwdCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) { func ResetCwdCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) {
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote) ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
if err != nil { if err != nil {

View File

@ -196,7 +196,7 @@ func (msh *MShellProc) EnsureShellType(ctx context.Context, shellType string) er
return nil return nil
} }
// try to reinit the shell // try to reinit the shell
_, err := msh.ReInit(ctx, shellType) _, err := msh.ReInit(ctx, shellType, nil, false)
if err != nil { if err != nil {
return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err) return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err)
} }
@ -1401,33 +1401,60 @@ func makeReinitErrorUpdate(shellType string) sstore.ActivityUpdate {
return rtn return rtn
} }
func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.ShellStatePacketType, error) { func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) {
if !msh.IsConnected() { if !msh.IsConnected() {
return nil, fmt.Errorf("cannot reinit, remote is not connected") return nil, fmt.Errorf("cannot reinit, remote is not connected")
} }
if shellType != packet.ShellType_bash && shellType != packet.ShellType_zsh { if shellType != packet.ShellType_bash && shellType != packet.ShellType_zsh {
return nil, fmt.Errorf("invalid shell type %q", shellType) return nil, fmt.Errorf("invalid shell type %q", shellType)
} }
if dataFn == nil {
dataFn = func([]byte) {}
}
defer func() {
if rtnErr != nil {
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
}
}()
startTs := time.Now()
reinitPk := packet.MakeReInitPacket() reinitPk := packet.MakeReInitPacket()
reinitPk.ReqId = uuid.New().String() reinitPk.ReqId = uuid.New().String()
reinitPk.ShellType = shellType reinitPk.ShellType = shellType
resp, err := msh.PacketRpcRaw(ctx, reinitPk) rpcIter, err := msh.PacketRpcIter(ctx, reinitPk)
if err != nil {
return nil, err
}
defer rpcIter.Close()
var ssPk *packet.ShellStatePacketType
for {
resp, err := rpcIter.Next(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp == nil { if resp == nil {
return nil, fmt.Errorf("no response") return nil, fmt.Errorf("channel closed with no response")
} }
ssPk, ok := resp.(*packet.ShellStatePacketType) var ok bool
if !ok { ssPk, ok = resp.(*packet.ShellStatePacketType)
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror") if ok {
if respPk, ok := resp.(*packet.ResponsePacketType); ok && respPk.Error != "" { break
}
respPk, ok := resp.(*packet.ResponsePacketType)
if ok {
if respPk.Error != "" {
return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error) return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error)
} }
return nil, fmt.Errorf("invalid reinit response (not an shellstate packet): %T", resp) return nil, fmt.Errorf("invalid response from waveshell")
} }
if ssPk.State == nil { dataPk, ok := resp.(*packet.FileDataPacketType)
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror") if ok {
dataFn(dataPk.Data)
continue
}
invalidPkStr := fmt.Sprintf("\r\ninvalid packettype from waveshell: %s\r\n", resp.GetType())
dataFn([]byte(invalidPkStr))
}
if ssPk == nil || ssPk.State == nil {
return nil, fmt.Errorf("invalid reinit response shellstate packet does not contain remote state") return nil, fmt.Errorf("invalid reinit response shellstate packet does not contain remote state")
} }
// TODO: maybe we don't need to save statebase here. should be possible to save it on demand // TODO: maybe we don't need to save statebase here. should be possible to save it on demand
@ -1438,10 +1465,29 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.Sh
return nil, fmt.Errorf("error storing remote state: %w", err) return nil, fmt.Errorf("error storing remote state: %w", err)
} }
msh.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State) msh.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
msh.WriteToPtyBuffer("initialized shell:%s state:%s\n", shellType, ssPk.State.GetHashVal(false)) timeDur := time.Since(startTs)
dataFn([]byte(makeShellInitOutputMsg(verbose, ssPk.State, ssPk.Stats, timeDur, false)))
msh.WriteToPtyBuffer("%s", makeShellInitOutputMsg(false, ssPk.State, ssPk.Stats, timeDur, true))
return ssPk, nil return ssPk, nil
} }
func makeShellInitOutputMsg(verbose bool, state *packet.ShellState, stats *packet.ShellStateStats, dur time.Duration, ptyMsg bool) string {
if !verbose || ptyMsg {
if ptyMsg {
return fmt.Sprintf("initialized state shell:%s statehash:%s %dms\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds())
} else {
return fmt.Sprintf("initialized connection state (shell:%s)\r\n", state.GetShellType())
}
}
var buf bytes.Buffer
buf.WriteString("-----\r\n")
buf.WriteString(fmt.Sprintf("initialized connection shell:%s statehash:%s %dms\r\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds()))
if stats != nil {
buf.WriteString(fmt.Sprintf(" outsize:%s size:%s env:%d, vars:%d, aliases:%d, funcs:%d\r\n", scbase.NumFormatDec(stats.OutputSize), scbase.NumFormatDec(stats.StateSize), stats.EnvCount, stats.VarCount, stats.AliasCount, stats.FuncCount))
}
return buf.String()
}
func (msh *MShellProc) WriteFile(ctx context.Context, writePk *packet.WriteFilePacketType) (*packet.RpcResponseIter, error) { func (msh *MShellProc) WriteFile(ctx context.Context, writePk *packet.WriteFilePacketType) (*packet.RpcResponseIter, error) {
return msh.PacketRpcIter(ctx, writePk) return msh.PacketRpcIter(ctx, writePk)
} }
@ -1690,7 +1736,7 @@ func (msh *MShellProc) initActiveShells() {
return return
} }
for _, shellType := range activeShells { for _, shellType := range activeShells {
_, err = msh.ReInit(ctx, shellType) _, err = msh.ReInit(ctx, shellType, nil, false)
if err != nil { if err != nil {
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err) msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
} }