mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-22 16:48:23 +01:00
zsh reinit fixes (#477)
* reset command now initiates and completes async so there is feedback that something is happening when it takes a long time * switch from standard rpc to rpciter * checkpoint on reinit -- stream output, stats packet, logging to cmd pty, new endBytes for EOF * make generic versions of endbytes scanner and channel output funcs * update bash to use more modern state parsing (tricks learned from zsh) * verbose mode, fix stats output message * add a diff when verbose mode is on
This commit is contained in:
parent
accb74ae0f
commit
5616c9abbb
@ -602,6 +602,7 @@ type ShellStatePacketType struct {
|
|||||||
ShellType string `json:"shelltype"`
|
ShellType string `json:"shelltype"`
|
||||||
RespId string `json:"respid,omitempty"`
|
RespId string `json:"respid,omitempty"`
|
||||||
State *ShellState `json:"state"`
|
State *ShellState `json:"state"`
|
||||||
|
Stats *ShellStateStats `json:"stats"`
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,6 +19,17 @@ import (
|
|||||||
const ShellStatePackVersion = 0
|
const ShellStatePackVersion = 0
|
||||||
const ShellStateDiffPackVersion = 0
|
const ShellStateDiffPackVersion = 0
|
||||||
|
|
||||||
|
type ShellStateStats struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
AliasCount int `json:"aliascount"`
|
||||||
|
EnvCount int `json:"envcount"`
|
||||||
|
VarCount int `json:"varcount"`
|
||||||
|
FuncCount int `json:"funccount"`
|
||||||
|
HashVal string `json:"hashval"`
|
||||||
|
OutputSize int64 `json:"outputsize"`
|
||||||
|
StateSize int64 `json:"statesize"`
|
||||||
|
}
|
||||||
|
|
||||||
type ShellState struct {
|
type ShellState struct {
|
||||||
Version string `json:"version"` // [type] [semver]
|
Version string `json:"version"` // [type] [semver]
|
||||||
Cwd string `json:"cwd,omitempty"`
|
Cwd string `json:"cwd,omitempty"`
|
||||||
@ -29,6 +40,10 @@ type ShellState struct {
|
|||||||
HashVal string `json:"-"`
|
HashVal string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (state ShellState) ApproximateSize() int64 {
|
||||||
|
return int64(len(state.Version) + len(state.Cwd) + len(state.ShellVars) + len(state.Aliases) + len(state.Funcs) + len(state.Error))
|
||||||
|
}
|
||||||
|
|
||||||
type ShellStateDiff struct {
|
type ShellStateDiff struct {
|
||||||
Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base)
|
Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base)
|
||||||
BaseHash string `json:"basehash"`
|
BaseHash string `json:"basehash"`
|
||||||
|
@ -244,11 +244,10 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
|
|||||||
appendSlashes(comps)
|
appendSlashes(comps)
|
||||||
}
|
}
|
||||||
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
|
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MServer) reinit(reqId string, shellType string) {
|
func (m *MServer) reinit(reqId string, shellType string) {
|
||||||
ssPk, err := shexec.MakeShellStatePacket(shellType)
|
ssPk, err := m.MakeShellStatePacket(reqId, shellType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
|
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
|
||||||
return
|
return
|
||||||
@ -262,6 +261,32 @@ func (m *MServer) reinit(reqId string, shellType string) {
|
|||||||
m.Sender.SendPacket(ssPk)
|
m.Sender.SendPacket(ssPk)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) {
|
||||||
|
sapi, err := shellapi.MakeShellApi(shellType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rtnCh := make(chan shellapi.ShellStateOutput, 1)
|
||||||
|
go sapi.GetShellState(rtnCh)
|
||||||
|
for ssOutput := range rtnCh {
|
||||||
|
if ssOutput.Error != "" {
|
||||||
|
return nil, errors.New(ssOutput.Error)
|
||||||
|
}
|
||||||
|
if ssOutput.ShellState != nil {
|
||||||
|
rtn := packet.MakeShellStatePacket()
|
||||||
|
rtn.State = ssOutput.ShellState
|
||||||
|
rtn.Stats = ssOutput.Stats
|
||||||
|
return rtn, nil
|
||||||
|
}
|
||||||
|
if ssOutput.Output != nil {
|
||||||
|
dataPk := packet.MakeFileDataPacket(reqId)
|
||||||
|
dataPk.Data = ssOutput.Output
|
||||||
|
m.Sender.SendPacket(dataPk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
|
func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
|
||||||
dirName := filepath.Dir(path)
|
dirName := filepath.Dir(path)
|
||||||
baseName := filepath.Base(path)
|
baseName := filepath.Base(path)
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||||
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
|
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
|
||||||
|
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||||
)
|
)
|
||||||
|
|
||||||
const BaseBashOpts = `set +m; set +H; shopt -s extglob`
|
const BaseBashOpts = `set +m; set +H; shopt -s extglob`
|
||||||
@ -48,7 +49,7 @@ func (b bashShellApi) GetShellType() string {
|
|||||||
return packet.ShellType_bash
|
return packet.ShellType_bash
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bashShellApi) MakeExitTrap(fdNum int) string {
|
func (b bashShellApi) MakeExitTrap(fdNum int) (string, []byte) {
|
||||||
return MakeBashExitTrap(fdNum)
|
return MakeBashExitTrap(fdNum)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,29 +80,15 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
|
|||||||
return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
|
return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bashShellApi) GetShellState() chan ShellStateOutput {
|
func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) {
|
||||||
ch := make(chan ShellStateOutput, 1)
|
GetBashShellState(outCh)
|
||||||
defer close(ch)
|
|
||||||
ssPk, err := GetBashShellState()
|
|
||||||
if err != nil {
|
|
||||||
ch <- ShellStateOutput{
|
|
||||||
Status: ShellStateOutputStatus_Done,
|
|
||||||
Error: err.Error(),
|
|
||||||
}
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
ch <- ShellStateOutput{
|
|
||||||
Status: ShellStateOutputStatus_Done,
|
|
||||||
ShellState: ssPk,
|
|
||||||
}
|
|
||||||
return ch
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bashShellApi) GetBaseShellOpts() string {
|
func (b bashShellApi) GetBaseShellOpts() string {
|
||||||
return BaseBashOpts
|
return BaseBashOpts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, error) {
|
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
|
||||||
return parseBashShellStateOutput(output)
|
return parseBashShellStateOutput(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,8 +117,32 @@ func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
|
|||||||
return rcBuf.String()
|
return rcBuf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetBashShellStateCmd() string {
|
func GetBashShellStateCmd(fdNum int) (string, []byte) {
|
||||||
return strings.Join(GetBashShellStateCmds, ` printf "\x00\x00";`)
|
endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes)
|
||||||
|
endBytes = append(endBytes, '\n')
|
||||||
|
cmdStr := strings.TrimSpace(`
|
||||||
|
exec 2> /dev/null;
|
||||||
|
exec > [%OUTPUTFD%];
|
||||||
|
printf "\x00\x00";
|
||||||
|
[%BASHVERSIONCMD%];
|
||||||
|
printf "\x00\x00";
|
||||||
|
pwd;
|
||||||
|
printf "\x00\x00";
|
||||||
|
declare -p $(compgen -A variable);
|
||||||
|
printf "\x00\x00";
|
||||||
|
alias -p;
|
||||||
|
printf "\x00\x00";
|
||||||
|
declare -f;
|
||||||
|
printf "\x00\x00";
|
||||||
|
[%GITBRANCHCMD%];
|
||||||
|
printf "\x00\x00";
|
||||||
|
printf "[%ENDBYTES%]";
|
||||||
|
`)
|
||||||
|
cmdStr = strings.ReplaceAll(cmdStr, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
|
||||||
|
cmdStr = strings.ReplaceAll(cmdStr, "[%BASHVERSIONCMD%]", BashShellVersionCmdStr)
|
||||||
|
cmdStr = strings.ReplaceAll(cmdStr, "[%GITBRANCHCMD%]", GetGitBranchCmdStr)
|
||||||
|
cmdStr = strings.ReplaceAll(cmdStr, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes)))
|
||||||
|
return cmdStr, endBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func execGetLocalBashShellVersion() string {
|
func execGetLocalBashShellVersion() string {
|
||||||
@ -158,16 +169,34 @@ func GetLocalBashMajorVersion() string {
|
|||||||
return localBashMajorVersion
|
return localBashMajorVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetBashShellState() (*packet.ShellState, error) {
|
func GetBashShellState(outCh chan ShellStateOutput) {
|
||||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||||
defer cancelFn()
|
defer cancelFn()
|
||||||
cmdStr := BaseBashOpts + "; " + GetBashShellStateCmd()
|
defer close(outCh)
|
||||||
|
stateCmd, endBytes := GetBashShellStateCmd(StateOutputFdNum)
|
||||||
|
cmdStr := BaseBashOpts + "; " + stateCmd
|
||||||
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
|
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
|
||||||
outputBytes, err := RunSimpleCmdInPty(ecmd)
|
outputCh := make(chan []byte, 10)
|
||||||
if err != nil {
|
var outputWg sync.WaitGroup
|
||||||
return nil, err
|
outputWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer outputWg.Done()
|
||||||
|
for outputBytes := range outputCh {
|
||||||
|
outCh <- ShellStateOutput{Output: outputBytes}
|
||||||
}
|
}
|
||||||
return parseBashShellStateOutput(outputBytes)
|
}()
|
||||||
|
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
|
||||||
|
outputWg.Wait()
|
||||||
|
if err != nil {
|
||||||
|
outCh <- ShellStateOutput{Error: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rtn, stats, err := parseBashShellStateOutput(outputBytes)
|
||||||
|
if err != nil {
|
||||||
|
outCh <- ShellStateOutput{Error: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLocalBashPath() string {
|
func GetLocalBashPath() string {
|
||||||
@ -190,19 +219,20 @@ func GetLocalZshPath() string {
|
|||||||
return "zsh"
|
return "zsh"
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetBashShellStateRedirectCommandStr(outputFdNum int) string {
|
func GetBashShellStateRedirectCommandStr(outputFdNum int) (string, []byte) {
|
||||||
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetBashShellStateCmd(), outputFdNum)
|
cmdStr, endBytes := GetBashShellStateCmd(outputFdNum)
|
||||||
|
return cmdStr, endBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeBashExitTrap(fdNum int) string {
|
func MakeBashExitTrap(fdNum int) (string, []byte) {
|
||||||
stateCmd := GetBashShellStateRedirectCommandStr(fdNum)
|
stateCmd, endBytes := GetBashShellStateRedirectCommandStr(fdNum)
|
||||||
fmtStr := `
|
fmtStr := `
|
||||||
_waveshell_exittrap () {
|
_waveshell_exittrap () {
|
||||||
%s
|
%s
|
||||||
}
|
}
|
||||||
trap _waveshell_exittrap EXIT
|
trap _waveshell_exittrap EXIT
|
||||||
`
|
`
|
||||||
return fmt.Sprintf(fmtStr, stateCmd)
|
return fmt.Sprintf(fmtStr, stateCmd), endBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
|
func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
|
||||||
|
@ -20,6 +20,19 @@ import (
|
|||||||
"mvdan.cc/sh/v3/syntax"
|
"mvdan.cc/sh/v3/syntax"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
BashSection_Ignored = iota
|
||||||
|
BashSection_Version
|
||||||
|
BashSection_Cwd
|
||||||
|
BashSection_Vars
|
||||||
|
BashSection_Aliases
|
||||||
|
BashSection_Funcs
|
||||||
|
BashSection_PVars
|
||||||
|
BashSection_EndBytes
|
||||||
|
|
||||||
|
BashSection_Count // must be last
|
||||||
|
)
|
||||||
|
|
||||||
type DeclareDeclType = shellenv.DeclareDeclType
|
type DeclareDeclType = shellenv.DeclareDeclType
|
||||||
|
|
||||||
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
|
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
|
||||||
@ -214,38 +227,37 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
|
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
|
||||||
if scbase.IsDevMode() && DebugState {
|
if scbase.IsDevMode() && DebugState {
|
||||||
writeStateToFile(packet.ShellType_bash, outputBytes)
|
writeStateToFile(packet.ShellType_bash, outputBytes)
|
||||||
}
|
}
|
||||||
// 7 fields: ignored [0], version [1], cwd [2], env/vars [3], aliases [4], funcs [5], pvars [6]
|
sections := bytes.Split(outputBytes, []byte{0, 0})
|
||||||
fields := bytes.Split(outputBytes, []byte{0, 0})
|
if len(sections) != BashSection_Count {
|
||||||
if len(fields) != 7 {
|
return nil, nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(sections))
|
||||||
return nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(fields))
|
|
||||||
}
|
}
|
||||||
rtn := &packet.ShellState{}
|
rtn := &packet.ShellState{}
|
||||||
rtn.Version = strings.TrimSpace(string(fields[1]))
|
rtn.Version = strings.TrimSpace(string(sections[BashSection_Version]))
|
||||||
if rtn.GetShellType() != packet.ShellType_bash {
|
if rtn.GetShellType() != packet.ShellType_bash {
|
||||||
return nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
|
return nil, nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
|
||||||
}
|
}
|
||||||
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
|
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
|
||||||
return nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
|
return nil, nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
|
||||||
}
|
}
|
||||||
cwdStr := string(fields[2])
|
cwdStr := string(sections[BashSection_Cwd])
|
||||||
if strings.HasSuffix(cwdStr, "\r\n") {
|
if strings.HasSuffix(cwdStr, "\r\n") {
|
||||||
cwdStr = cwdStr[0 : len(cwdStr)-2]
|
cwdStr = cwdStr[0 : len(cwdStr)-2]
|
||||||
} else if strings.HasSuffix(cwdStr, "\n") {
|
} else {
|
||||||
cwdStr = cwdStr[0 : len(cwdStr)-1]
|
cwdStr = strings.TrimSuffix(cwdStr, "\n")
|
||||||
}
|
}
|
||||||
rtn.Cwd = string(cwdStr)
|
rtn.Cwd = string(cwdStr)
|
||||||
err := bashParseDeclareOutput(rtn, fields[3], fields[6])
|
err := bashParseDeclareOutput(rtn, sections[BashSection_Vars], sections[BashSection_PVars])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
rtn.Aliases = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
|
rtn.Aliases = strings.ReplaceAll(string(sections[BashSection_Aliases]), "\r\n", "\n")
|
||||||
rtn.Funcs = strings.ReplaceAll(string(fields[5]), "\r\n", "\n")
|
rtn.Funcs = strings.ReplaceAll(string(sections[BashSection_Funcs]), "\r\n", "\n")
|
||||||
rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap")
|
rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap")
|
||||||
return rtn, nil
|
return rtn, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func bashNormalize(d *DeclareDeclType) error {
|
func bashNormalize(d *DeclareDeclType) error {
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
@ -28,12 +27,14 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||||
)
|
)
|
||||||
|
|
||||||
const GetStateTimeout = 5 * time.Second
|
const GetStateTimeout = 15 * time.Second
|
||||||
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
|
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
|
||||||
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
|
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
|
||||||
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
|
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
|
||||||
const RunCommandFmt = `%s`
|
const RunCommandFmt = `%s`
|
||||||
const DebugState = false
|
const DebugState = false
|
||||||
|
const StateOutputFdNum = 20
|
||||||
|
const NumRandomEndBytes = 8
|
||||||
|
|
||||||
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
|
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
|
||||||
|
|
||||||
@ -56,23 +57,23 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ShellStateOutput struct {
|
type ShellStateOutput struct {
|
||||||
Status string
|
Output []byte
|
||||||
StderrOutput []byte
|
|
||||||
ShellState *packet.ShellState
|
ShellState *packet.ShellState
|
||||||
|
Stats *packet.ShellStateStats
|
||||||
Error string
|
Error string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ShellApi interface {
|
type ShellApi interface {
|
||||||
GetShellType() string
|
GetShellType() string
|
||||||
MakeExitTrap(fdNum int) string
|
MakeExitTrap(fdNum int) (string, []byte)
|
||||||
GetLocalMajorVersion() string
|
GetLocalMajorVersion() string
|
||||||
GetLocalShellPath() string
|
GetLocalShellPath() string
|
||||||
GetRemoteShellPath() string
|
GetRemoteShellPath() string
|
||||||
MakeRunCommand(cmdStr string, opts RunCommandOpts) string
|
MakeRunCommand(cmdStr string, opts RunCommandOpts) string
|
||||||
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
|
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
|
||||||
GetShellState() chan ShellStateOutput
|
GetShellState(chan ShellStateOutput)
|
||||||
GetBaseShellOpts() string
|
GetBaseShellOpts() string
|
||||||
ParseShellStateOutput(output []byte) (*packet.ShellState, error)
|
ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error)
|
||||||
MakeRcFileStr(pk *packet.RunPacketType) string
|
MakeRcFileStr(pk *packet.RunPacketType) string
|
||||||
MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error)
|
MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error)
|
||||||
ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error)
|
ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error)
|
||||||
@ -153,12 +154,13 @@ func internalMacUserShell() string {
|
|||||||
const FirstExtraFilesFdNum = 3
|
const FirstExtraFilesFdNum = 3
|
||||||
|
|
||||||
// returns output(stdout+stderr), extraFdOutput, error
|
// returns output(stdout+stderr), extraFdOutput, error
|
||||||
func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, error) {
|
func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte) ([]byte, error) {
|
||||||
|
defer close(outputCh)
|
||||||
ecmd.Env = os.Environ()
|
ecmd.Env = os.Environ()
|
||||||
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
|
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
|
||||||
cmdPty, cmdTty, err := pty.Open()
|
cmdPty, cmdTty, err := pty.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("opening new pty: %w", err)
|
return nil, fmt.Errorf("opening new pty: %w", err)
|
||||||
}
|
}
|
||||||
defer cmdTty.Close()
|
defer cmdTty.Close()
|
||||||
defer cmdPty.Close()
|
defer cmdPty.Close()
|
||||||
@ -171,42 +173,44 @@ func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, erro
|
|||||||
ecmd.SysProcAttr.Setctty = true
|
ecmd.SysProcAttr.Setctty = true
|
||||||
pipeReader, pipeWriter, err := os.Pipe()
|
pipeReader, pipeWriter, err := os.Pipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("could not create pipe: %w", err)
|
return nil, fmt.Errorf("could not create pipe: %w", err)
|
||||||
}
|
}
|
||||||
defer pipeWriter.Close()
|
defer pipeWriter.Close()
|
||||||
defer pipeReader.Close()
|
defer pipeReader.Close()
|
||||||
extraFiles := make([]*os.File, extraFdNum+1)
|
extraFiles := make([]*os.File, extraFdNum+1)
|
||||||
extraFiles[extraFdNum] = pipeWriter
|
extraFiles[extraFdNum] = pipeWriter
|
||||||
ecmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:]
|
ecmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:]
|
||||||
defer pipeReader.Close()
|
err = ecmd.Start()
|
||||||
ecmd.Start()
|
|
||||||
cmdTty.Close()
|
cmdTty.Close()
|
||||||
pipeWriter.Close()
|
pipeWriter.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var outputWg sync.WaitGroup
|
var outputWg sync.WaitGroup
|
||||||
var outputBuf bytes.Buffer
|
|
||||||
var extraFdOutputBuf bytes.Buffer
|
var extraFdOutputBuf bytes.Buffer
|
||||||
outputWg.Add(2)
|
outputWg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
// ignore error (/dev/ptmx has read error when process is done)
|
// ignore error (/dev/ptmx has read error when process is done)
|
||||||
defer outputWg.Done()
|
defer outputWg.Done()
|
||||||
io.Copy(&outputBuf, cmdPty)
|
err := utilfn.CopyToChannel(outputCh, cmdPty)
|
||||||
|
if err != nil {
|
||||||
|
errStr := fmt.Sprintf("\r\nerror reading from pty: %v\r\n", err)
|
||||||
|
outputCh <- []byte(errStr)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer outputWg.Done()
|
defer outputWg.Done()
|
||||||
io.Copy(&extraFdOutputBuf, pipeReader)
|
utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes)
|
||||||
}()
|
}()
|
||||||
exitErr := ecmd.Wait()
|
exitErr := ecmd.Wait()
|
||||||
if exitErr != nil {
|
if exitErr != nil {
|
||||||
return nil, nil, exitErr
|
return nil, exitErr
|
||||||
}
|
}
|
||||||
outputWg.Wait()
|
outputWg.Wait()
|
||||||
return outputBuf.Bytes(), extraFdOutputBuf.Bytes(), nil
|
return extraFdOutputBuf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
|
func RunSimpleCmdInPty(ecmd *exec.Cmd, endBytes []byte) ([]byte, error) {
|
||||||
ecmd.Env = os.Environ()
|
ecmd.Env = os.Environ()
|
||||||
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
|
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
|
||||||
cmdPty, cmdTty, err := pty.Open()
|
cmdPty, cmdTty, err := pty.Open()
|
||||||
@ -231,8 +235,8 @@ func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
|
|||||||
var outputBuf bytes.Buffer
|
var outputBuf bytes.Buffer
|
||||||
go func() {
|
go func() {
|
||||||
// ignore error (/dev/ptmx has read error when process is done)
|
// ignore error (/dev/ptmx has read error when process is done)
|
||||||
io.Copy(&outputBuf, cmdPty)
|
defer close(ioDone)
|
||||||
close(ioDone)
|
utilfn.CopyWithEndBytes(&outputBuf, cmdPty, endBytes)
|
||||||
}()
|
}()
|
||||||
exitErr := ecmd.Wait()
|
exitErr := ecmd.Wait()
|
||||||
if exitErr != nil {
|
if exitErr != nil {
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -28,7 +27,6 @@ import (
|
|||||||
const BaseZshOpts = ``
|
const BaseZshOpts = ``
|
||||||
|
|
||||||
const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION`
|
const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION`
|
||||||
const StateOutputFdNum = 20
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ZshSection_Version = iota
|
ZshSection_Version = iota
|
||||||
@ -41,6 +39,7 @@ const (
|
|||||||
ZshSection_Funcs
|
ZshSection_Funcs
|
||||||
ZshSection_PVars
|
ZshSection_PVars
|
||||||
ZshSection_Prompt
|
ZshSection_Prompt
|
||||||
|
ZshSection_EndBytes
|
||||||
|
|
||||||
ZshSection_NumFieldsExpected // must be last
|
ZshSection_NumFieldsExpected // must be last
|
||||||
)
|
)
|
||||||
@ -118,6 +117,11 @@ var ZshIgnoreVars = map[string]bool{
|
|||||||
"zcurses_windows": true,
|
"zcurses_windows": true,
|
||||||
|
|
||||||
// not listed, but we also exclude all ZFTP_* variables
|
// not listed, but we also exclude all ZFTP_* variables
|
||||||
|
|
||||||
|
// powerlevel10k
|
||||||
|
"_GITSTATUS_CLIENT_PID_POWERLEVEL9K": true,
|
||||||
|
"GITSTATUS_DAEMON_PID_POWERLEVEL9K": true,
|
||||||
|
"_GITSTATUS_FILE_PREFIX_POWERLEVEL9K": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
var ZshIgnoreFuncs = map[string]bool{
|
var ZshIgnoreFuncs = map[string]bool{
|
||||||
@ -211,7 +215,7 @@ func (z zshShellApi) GetShellType() string {
|
|||||||
return packet.ShellType_zsh
|
return packet.ShellType_zsh
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z zshShellApi) MakeExitTrap(fdNum int) string {
|
func (z zshShellApi) MakeExitTrap(fdNum int) (string, []byte) {
|
||||||
return MakeZshExitTrap(fdNum)
|
return MakeZshExitTrap(fdNum)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -242,25 +246,34 @@ func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
|
|||||||
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z zshShellApi) GetShellState() chan ShellStateOutput {
|
func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) {
|
||||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||||
defer cancelFn()
|
defer cancelFn()
|
||||||
rtnCh := make(chan ShellStateOutput, 1)
|
defer close(outCh)
|
||||||
defer close(rtnCh)
|
stateCmd, endBytes := GetZshShellStateCmd(StateOutputFdNum)
|
||||||
cmdStr := BaseZshOpts + "; " + GetZshShellStateCmd(StateOutputFdNum)
|
cmdStr := BaseZshOpts + "; " + stateCmd
|
||||||
ecmd := exec.CommandContext(ctx, GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
ecmd := exec.CommandContext(ctx, GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
||||||
_, outputBytes, err := RunCommandWithExtraFd(ecmd, StateOutputFdNum)
|
outputCh := make(chan []byte, 10)
|
||||||
if err != nil {
|
var outputWg sync.WaitGroup
|
||||||
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()}
|
outputWg.Add(1)
|
||||||
return rtnCh
|
go func() {
|
||||||
|
defer outputWg.Done()
|
||||||
|
for outputBytes := range outputCh {
|
||||||
|
outCh <- ShellStateOutput{Output: outputBytes}
|
||||||
}
|
}
|
||||||
rtn, err := z.ParseShellStateOutput(outputBytes)
|
}()
|
||||||
|
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
|
||||||
|
outputWg.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()}
|
outCh <- ShellStateOutput{Error: err.Error()}
|
||||||
return rtnCh
|
return
|
||||||
}
|
}
|
||||||
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, ShellState: rtn}
|
rtn, stats, err := z.ParseShellStateOutput(outputBytes)
|
||||||
return rtnCh
|
if err != nil {
|
||||||
|
outCh <- ShellStateOutput{Error: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z zshShellApi) GetBaseShellOpts() string {
|
func (z zshShellApi) GetBaseShellOpts() string {
|
||||||
@ -437,19 +450,15 @@ func writeZshId(buf *bytes.Buffer, idStr string) {
|
|||||||
|
|
||||||
const numRandomBytes = 4
|
const numRandomBytes = 4
|
||||||
|
|
||||||
// returns (cmd-string)
|
// returns (cmd-string, endbytes)
|
||||||
func GetZshShellStateCmd(fdNum int) string {
|
func GetZshShellStateCmd(fdNum int) (string, []byte) {
|
||||||
var sectionSeparator []byte
|
var sectionSeparator []byte
|
||||||
// adding this extra "\n" helps with debuging and readability of output
|
// adding this extra "\n" helps with debuging and readability of output
|
||||||
sectionSeparator = append(sectionSeparator, byte('\n'))
|
sectionSeparator = append(sectionSeparator, byte('\n'))
|
||||||
for len(sectionSeparator) < numRandomBytes {
|
sectionSeparator = utilfn.AppendNonZeroRandomBytes(sectionSeparator, numRandomBytes)
|
||||||
// any character *except* null (0)
|
|
||||||
rn := rand.Intn(256)
|
|
||||||
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
|
|
||||||
sectionSeparator = append(sectionSeparator, byte(rn))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sectionSeparator = append(sectionSeparator, 0, 0)
|
sectionSeparator = append(sectionSeparator, 0, 0)
|
||||||
|
endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes)
|
||||||
|
endBytes = append(endBytes, byte('\n'))
|
||||||
// we have to use these crazy separators because zsh allows basically anything in
|
// we have to use these crazy separators because zsh allows basically anything in
|
||||||
// variable names and values (including nulls).
|
// variable names and values (including nulls).
|
||||||
// note that we don't need crazy separators for "env" or "typeset".
|
// note that we don't need crazy separators for "env" or "typeset".
|
||||||
@ -511,6 +520,8 @@ printf "[%SECTIONSEP%]";
|
|||||||
[%K8SNAMESPACE%]
|
[%K8SNAMESPACE%]
|
||||||
printf "[%SECTIONSEP%]";
|
printf "[%SECTIONSEP%]";
|
||||||
print -P "$PS1"
|
print -P "$PS1"
|
||||||
|
printf "[%SECTIONSEP%]";
|
||||||
|
printf "[%ENDBYTES%]"
|
||||||
`
|
`
|
||||||
cmd = strings.TrimSpace(cmd)
|
cmd = strings.TrimSpace(cmd)
|
||||||
cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr)
|
cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr)
|
||||||
@ -520,17 +531,19 @@ print -P "$PS1"
|
|||||||
cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1])))
|
cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1])))
|
||||||
cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator)))
|
cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator)))
|
||||||
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
|
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
|
||||||
return cmd
|
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFDNUM%]", fmt.Sprintf("%d", fdNum))
|
||||||
|
cmd = strings.ReplaceAll(cmd, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes)))
|
||||||
|
return cmd, endBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeZshExitTrap(fdNum int) string {
|
func MakeZshExitTrap(fdNum int) (string, []byte) {
|
||||||
stateCmd := GetZshShellStateCmd(fdNum)
|
stateCmd, endBytes := GetZshShellStateCmd(fdNum)
|
||||||
fmtStr := `
|
fmtStr := `
|
||||||
zshexit () {
|
zshexit () {
|
||||||
%s
|
%s
|
||||||
}
|
}
|
||||||
`
|
`
|
||||||
return fmt.Sprintf(fmtStr, stateCmd)
|
return fmt.Sprintf(fmtStr, stateCmd), endBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func execGetLocalZshShellVersion() string {
|
func execGetLocalZshShellVersion() string {
|
||||||
@ -698,14 +711,14 @@ func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string {
|
|||||||
return buf.String()
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
|
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
|
||||||
if scbase.IsDevMode() && DebugState {
|
if scbase.IsDevMode() && DebugState {
|
||||||
writeStateToFile(packet.ShellType_zsh, outputBytes)
|
writeStateToFile(packet.ShellType_zsh, outputBytes)
|
||||||
}
|
}
|
||||||
firstZeroIdx := bytes.Index(outputBytes, []byte{0})
|
firstZeroIdx := bytes.Index(outputBytes, []byte{0})
|
||||||
firstDZeroIdx := bytes.Index(outputBytes, []byte{0, 0})
|
firstDZeroIdx := bytes.Index(outputBytes, []byte{0, 0})
|
||||||
if firstZeroIdx == -1 || firstDZeroIdx == -1 {
|
if firstZeroIdx == -1 || firstDZeroIdx == -1 {
|
||||||
return nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes")
|
return nil, nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes")
|
||||||
}
|
}
|
||||||
versionStr := string(outputBytes[0:firstZeroIdx])
|
versionStr := string(outputBytes[0:firstZeroIdx])
|
||||||
sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2]
|
sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2]
|
||||||
@ -714,15 +727,15 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
|
|||||||
sections := bytes.Split(outputBytes, sectionSeparator)
|
sections := bytes.Split(outputBytes, sectionSeparator)
|
||||||
if len(sections) != ZshSection_NumFieldsExpected {
|
if len(sections) != ZshSection_NumFieldsExpected {
|
||||||
base.Logf("invalid -- numfields\n")
|
base.Logf("invalid -- numfields\n")
|
||||||
return nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
|
return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
|
||||||
}
|
}
|
||||||
rtn := &packet.ShellState{}
|
rtn := &packet.ShellState{}
|
||||||
rtn.Version = strings.TrimSpace(versionStr)
|
rtn.Version = strings.TrimSpace(versionStr)
|
||||||
if rtn.GetShellType() != packet.ShellType_zsh {
|
if rtn.GetShellType() != packet.ShellType_zsh {
|
||||||
return nil, fmt.Errorf("invalid zsh shell state output, wrong shell type")
|
return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong shell type")
|
||||||
}
|
}
|
||||||
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
|
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
|
||||||
return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
|
return nil, nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
|
||||||
}
|
}
|
||||||
cwdStr := stripNewLineChars(string(sections[ZshSection_Cwd]))
|
cwdStr := stripNewLineChars(string(sections[ZshSection_Cwd]))
|
||||||
rtn.Cwd = cwdStr
|
rtn.Cwd = cwdStr
|
||||||
@ -730,7 +743,7 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
|
|||||||
zshDecls, err := parseZshDecls(sections[ZshSection_Vars])
|
zshDecls, err := parseZshDecls(sections[ZshSection_Vars])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
base.Logf("invalid - parsedecls %v\n", err)
|
base.Logf("invalid - parsedecls %v\n", err)
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
for _, decl := range zshDecls {
|
for _, decl := range zshDecls {
|
||||||
if decl.IsZshScalarBound() {
|
if decl.IsZshScalarBound() {
|
||||||
@ -746,7 +759,17 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
|
|||||||
pvarMap := parseExtVarOutput(sections[ZshSection_PVars], string(sections[ZshSection_Prompt]), string(sections[ZshSection_Mods]))
|
pvarMap := parseExtVarOutput(sections[ZshSection_PVars], string(sections[ZshSection_Prompt]), string(sections[ZshSection_Mods]))
|
||||||
utilfn.CombineMaps(zshDecls, pvarMap)
|
utilfn.CombineMaps(zshDecls, pvarMap)
|
||||||
rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls)
|
rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls)
|
||||||
return rtn, nil
|
stats := &packet.ShellStateStats{
|
||||||
|
Version: rtn.Version,
|
||||||
|
AliasCount: int(len(aliasMap)),
|
||||||
|
FuncCount: int(len(zshFuncs)),
|
||||||
|
VarCount: int(len(zshDecls)),
|
||||||
|
EnvCount: int(len(zshEnv)),
|
||||||
|
HashVal: rtn.GetHashVal(false),
|
||||||
|
OutputSize: int64(len(outputBytes)),
|
||||||
|
StateSize: rtn.ApproximateSize(),
|
||||||
|
}
|
||||||
|
return rtn, stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseZshEnv(output []byte) map[string]string {
|
func parseZshEnv(output []byte) map[string]string {
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
@ -95,6 +94,7 @@ type ReturnStateBuf struct {
|
|||||||
Err error
|
Err error
|
||||||
Reader *os.File
|
Reader *os.File
|
||||||
FdNum int
|
FdNum int
|
||||||
|
EndBytes []byte
|
||||||
DoneCh chan bool
|
DoneCh chan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -835,7 +835,8 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
|||||||
cmd.ReturnState.FdNum = RtnStateFdNum
|
cmd.ReturnState.FdNum = RtnStateFdNum
|
||||||
rtnStateWriter = pw
|
rtnStateWriter = pw
|
||||||
defer pw.Close()
|
defer pw.Close()
|
||||||
trapCmdStr := sapi.MakeExitTrap(cmd.ReturnState.FdNum)
|
trapCmdStr, endBytes := sapi.MakeExitTrap(cmd.ReturnState.FdNum)
|
||||||
|
cmd.ReturnState.EndBytes = endBytes
|
||||||
rcFileStr += trapCmdStr
|
rcFileStr += trapCmdStr
|
||||||
}
|
}
|
||||||
shellVarMap := shellenv.ShellVarMapFromState(state)
|
shellVarMap := shellenv.ShellVarMapFromState(state)
|
||||||
@ -1021,6 +1022,11 @@ func (rs *ReturnStateBuf) Run() {
|
|||||||
}
|
}
|
||||||
rs.Lock.Lock()
|
rs.Lock.Lock()
|
||||||
rs.Buf = append(rs.Buf, buf[0:n]...)
|
rs.Buf = append(rs.Buf, buf[0:n]...)
|
||||||
|
if bytes.HasSuffix(rs.Buf, rs.EndBytes) {
|
||||||
|
rs.Buf = rs.Buf[:len(rs.Buf)-len(rs.EndBytes)]
|
||||||
|
rs.Lock.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
rs.Lock.Unlock()
|
rs.Lock.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1127,7 +1133,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
|
|||||||
wlog.Logf("debug returnstate file %q\n", base.GetDebugReturnStateFileName())
|
wlog.Logf("debug returnstate file %q\n", base.GetDebugReturnStateFileName())
|
||||||
os.WriteFile(base.GetDebugReturnStateFileName(), c.ReturnState.Buf, 0666)
|
os.WriteFile(base.GetDebugReturnStateFileName(), c.ReturnState.Buf, 0666)
|
||||||
}
|
}
|
||||||
state, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
|
state, _, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
|
||||||
donePacket.FinalState = state
|
donePacket.FinalState = state
|
||||||
}
|
}
|
||||||
endTs := time.Now()
|
endTs := time.Now()
|
||||||
@ -1156,21 +1162,6 @@ func MakeInitPacket() *packet.InitPacketType {
|
|||||||
return initPacket
|
return initPacket
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeShellStatePacket(shellType string) (*packet.ShellStatePacketType, error) {
|
|
||||||
sapi, err := shellapi.MakeShellApi(shellType)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
rtnCh := sapi.GetShellState()
|
|
||||||
ssOutput := <-rtnCh
|
|
||||||
if ssOutput.Error != "" {
|
|
||||||
return nil, errors.New(ssOutput.Error)
|
|
||||||
}
|
|
||||||
rtn := packet.MakeShellStatePacket()
|
|
||||||
rtn.State = ssOutput.ShellState
|
|
||||||
return rtn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func MakeServerInitPacket() (*packet.InitPacketType, error) {
|
func MakeServerInitPacket() (*packet.InitPacketType, error) {
|
||||||
var err error
|
var err error
|
||||||
initPacket := MakeInitPacket()
|
initPacket := MakeInitPacket()
|
||||||
|
@ -10,7 +10,9 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"math"
|
"math"
|
||||||
|
mathrand "math/rand"
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@ -552,3 +554,60 @@ func StrArrayToMap(sarr []string) map[string]bool {
|
|||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AppendNonZeroRandomBytes(b []byte, randLen int) []byte {
|
||||||
|
if randLen <= 0 {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
numAdded := 0
|
||||||
|
for numAdded < randLen {
|
||||||
|
rn := mathrand.Intn(256)
|
||||||
|
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
|
||||||
|
b = append(b, byte(rn))
|
||||||
|
numAdded++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns (isEOF, error)
|
||||||
|
func CopyWithEndBytes(outputBuf *bytes.Buffer, reader io.Reader, endBytes []byte) (bool, error) {
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
for {
|
||||||
|
n, err := reader.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
outputBuf.Write(buf[:n])
|
||||||
|
obytes := outputBuf.Bytes()
|
||||||
|
if bytes.HasSuffix(obytes, endBytes) {
|
||||||
|
outputBuf.Truncate(len(obytes) - len(endBytes))
|
||||||
|
return (err == io.EOF), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// does *not* close outputCh on EOF or error
|
||||||
|
func CopyToChannel(outputCh chan<- []byte, reader io.Reader) error {
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
for {
|
||||||
|
n, err := reader.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
// copy so client can use []byte without it being overwritten
|
||||||
|
bufCopy := make([]byte, n)
|
||||||
|
copy(bufCopy, buf[:n])
|
||||||
|
outputCh <- bufCopy
|
||||||
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -41,6 +41,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/releasechecker"
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/releasechecker"
|
||||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
|
||||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote/openai"
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote/openai"
|
||||||
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/rtnstate"
|
||||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
|
||||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
|
||||||
@ -1648,8 +1649,12 @@ func CopyFileCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scb
|
|||||||
}
|
}
|
||||||
var outputPos int64
|
var outputPos int64
|
||||||
outputStr := fmt.Sprintf("Copying [%v]:%v to [%v]:%v\r\n", sourceRemoteId.DisplayName, sourceFullPath, destRemoteId.DisplayName, destFullPath)
|
outputStr := fmt.Sprintf("Copying [%v]:%v to [%v]:%v\r\n", sourceRemoteId.DisplayName, sourceFullPath, destRemoteId.DisplayName, destFullPath)
|
||||||
termopts := sstore.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize}
|
termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM)
|
||||||
cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), termopts)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot make termopts: %w", err)
|
||||||
|
}
|
||||||
|
pkTermOpts := convertTermOpts(termOpts)
|
||||||
|
cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), *pkTermOpts)
|
||||||
writeStringToPty(ctx, cmd, outputStr, &outputPos)
|
writeStringToPty(ctx, cmd, outputStr, &outputPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO tricky error since the command was a success, but we can't show the output
|
// TODO tricky error since the command was a success, but we can't show the output
|
||||||
@ -3655,11 +3660,14 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbu
|
|||||||
return update, nil
|
return update, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) {
|
func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (rtnUpdate scbus.UpdatePacket, rtnErr error) {
|
||||||
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
|
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if !ids.Remote.MShell.IsConnected() {
|
||||||
|
return nil, fmt.Errorf("cannot reinit, remote is not connected")
|
||||||
|
}
|
||||||
shellType := ids.Remote.ShellType
|
shellType := ids.Remote.ShellType
|
||||||
if pk.Kwargs["shell"] != "" {
|
if pk.Kwargs["shell"] != "" {
|
||||||
shellArg := pk.Kwargs["shell"]
|
shellArg := pk.Kwargs["shell"]
|
||||||
@ -3668,33 +3676,76 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
|
|||||||
}
|
}
|
||||||
shellType = shellArg
|
shellType = shellArg
|
||||||
}
|
}
|
||||||
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType)
|
verbose := resolveBool(pk.Kwargs["verbose"], false)
|
||||||
|
termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("cannot make termopts: %w", err)
|
||||||
}
|
}
|
||||||
if ssPk == nil || ssPk.State == nil {
|
pkTermOpts := convertTermOpts(termOpts)
|
||||||
return nil, fmt.Errorf("invalid initpk received from remote (no remote state)")
|
cmd, err := makeDynCmd(ctx, "reset", ids, pk.GetRawStr(), *pkTermOpts)
|
||||||
}
|
|
||||||
feState := sstore.FeStateFromShellState(ssPk.State)
|
|
||||||
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
outputStr := fmt.Sprintf("reset remote state (shell:%s)", ssPk.State.GetShellType())
|
|
||||||
cmd, err := makeStaticCmd(ctx, "reset", ids, pk.GetRawStr(), []byte(outputStr))
|
|
||||||
if err != nil {
|
|
||||||
// TODO tricky error since the command was a success, but we can't show the output
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil)
|
update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO tricky error since the command was a success, but we can't show the output
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst), sstore.InteractiveUpdate(pk.Interactive))
|
go doResetCommand(ids, shellType, cmd, verbose)
|
||||||
return update, nil
|
return update, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) {
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
startTime := time.Now()
|
||||||
|
var outputPos int64
|
||||||
|
var rtnErr error
|
||||||
|
exitSuccess := true
|
||||||
|
defer func() {
|
||||||
|
if rtnErr != nil {
|
||||||
|
exitSuccess = false
|
||||||
|
writeStringToPty(ctx, cmd, fmt.Sprintf("\r\nerror: %v", rtnErr), &outputPos)
|
||||||
|
}
|
||||||
|
deferWriteCmdStatus(ctx, cmd, startTime, exitSuccess, outputPos)
|
||||||
|
}()
|
||||||
|
dataFn := func(data []byte) {
|
||||||
|
writeStringToPty(ctx, cmd, string(data), &outputPos)
|
||||||
|
}
|
||||||
|
origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
|
||||||
|
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType, dataFn, verbose)
|
||||||
|
if err != nil {
|
||||||
|
rtnErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ssPk == nil || ssPk.State == nil {
|
||||||
|
rtnErr = fmt.Errorf("invalid initpk received from remote (no remote state)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
feState := sstore.FeStateFromShellState(ssPk.State)
|
||||||
|
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
|
||||||
|
if err != nil {
|
||||||
|
rtnErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
|
||||||
|
if verbose && origStatePtr != nil && newStatePtr != nil {
|
||||||
|
statePtrDiff := fmt.Sprintf("oldstate: %v, newstate: %v\r\n", origStatePtr.BaseHash, newStatePtr.BaseHash)
|
||||||
|
writeStringToPty(ctx, cmd, statePtrDiff, &outputPos)
|
||||||
|
origFullState, _ := sstore.GetFullState(ctx, *origStatePtr)
|
||||||
|
newFullState, _ := sstore.GetFullState(ctx, *newStatePtr)
|
||||||
|
if origFullState != nil && newFullState != nil {
|
||||||
|
var diffBuf bytes.Buffer
|
||||||
|
rtnstate.DisplayStateUpdateDiff(&diffBuf, *origFullState, *newFullState)
|
||||||
|
diffStr := diffBuf.String()
|
||||||
|
diffStr = strings.ReplaceAll(diffStr, "\n", "\r\n")
|
||||||
|
writeStringToPty(ctx, cmd, diffStr, &outputPos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
update := scbus.MakeUpdatePacket()
|
||||||
|
update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst))
|
||||||
|
scbus.MainUpdateBus.DoUpdate(update)
|
||||||
|
}
|
||||||
|
|
||||||
func ResetCwdCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) {
|
func ResetCwdCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) {
|
||||||
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
|
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -196,7 +196,7 @@ func (msh *MShellProc) EnsureShellType(ctx context.Context, shellType string) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// try to reinit the shell
|
// try to reinit the shell
|
||||||
_, err := msh.ReInit(ctx, shellType)
|
_, err := msh.ReInit(ctx, shellType, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err)
|
return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err)
|
||||||
}
|
}
|
||||||
@ -1401,33 +1401,60 @@ func makeReinitErrorUpdate(shellType string) sstore.ActivityUpdate {
|
|||||||
return rtn
|
return rtn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.ShellStatePacketType, error) {
|
func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) {
|
||||||
if !msh.IsConnected() {
|
if !msh.IsConnected() {
|
||||||
return nil, fmt.Errorf("cannot reinit, remote is not connected")
|
return nil, fmt.Errorf("cannot reinit, remote is not connected")
|
||||||
}
|
}
|
||||||
if shellType != packet.ShellType_bash && shellType != packet.ShellType_zsh {
|
if shellType != packet.ShellType_bash && shellType != packet.ShellType_zsh {
|
||||||
return nil, fmt.Errorf("invalid shell type %q", shellType)
|
return nil, fmt.Errorf("invalid shell type %q", shellType)
|
||||||
}
|
}
|
||||||
|
if dataFn == nil {
|
||||||
|
dataFn = func([]byte) {}
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if rtnErr != nil {
|
||||||
|
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
startTs := time.Now()
|
||||||
reinitPk := packet.MakeReInitPacket()
|
reinitPk := packet.MakeReInitPacket()
|
||||||
reinitPk.ReqId = uuid.New().String()
|
reinitPk.ReqId = uuid.New().String()
|
||||||
reinitPk.ShellType = shellType
|
reinitPk.ShellType = shellType
|
||||||
resp, err := msh.PacketRpcRaw(ctx, reinitPk)
|
rpcIter, err := msh.PacketRpcIter(ctx, reinitPk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rpcIter.Close()
|
||||||
|
var ssPk *packet.ShellStatePacketType
|
||||||
|
for {
|
||||||
|
resp, err := rpcIter.Next(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return nil, fmt.Errorf("no response")
|
return nil, fmt.Errorf("channel closed with no response")
|
||||||
}
|
}
|
||||||
ssPk, ok := resp.(*packet.ShellStatePacketType)
|
var ok bool
|
||||||
if !ok {
|
ssPk, ok = resp.(*packet.ShellStatePacketType)
|
||||||
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
|
if ok {
|
||||||
if respPk, ok := resp.(*packet.ResponsePacketType); ok && respPk.Error != "" {
|
break
|
||||||
|
}
|
||||||
|
respPk, ok := resp.(*packet.ResponsePacketType)
|
||||||
|
if ok {
|
||||||
|
if respPk.Error != "" {
|
||||||
return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error)
|
return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("invalid reinit response (not an shellstate packet): %T", resp)
|
return nil, fmt.Errorf("invalid response from waveshell")
|
||||||
}
|
}
|
||||||
if ssPk.State == nil {
|
dataPk, ok := resp.(*packet.FileDataPacketType)
|
||||||
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
|
if ok {
|
||||||
|
dataFn(dataPk.Data)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
invalidPkStr := fmt.Sprintf("\r\ninvalid packettype from waveshell: %s\r\n", resp.GetType())
|
||||||
|
dataFn([]byte(invalidPkStr))
|
||||||
|
}
|
||||||
|
if ssPk == nil || ssPk.State == nil {
|
||||||
return nil, fmt.Errorf("invalid reinit response shellstate packet does not contain remote state")
|
return nil, fmt.Errorf("invalid reinit response shellstate packet does not contain remote state")
|
||||||
}
|
}
|
||||||
// TODO: maybe we don't need to save statebase here. should be possible to save it on demand
|
// TODO: maybe we don't need to save statebase here. should be possible to save it on demand
|
||||||
@ -1438,10 +1465,29 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.Sh
|
|||||||
return nil, fmt.Errorf("error storing remote state: %w", err)
|
return nil, fmt.Errorf("error storing remote state: %w", err)
|
||||||
}
|
}
|
||||||
msh.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
|
msh.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
|
||||||
msh.WriteToPtyBuffer("initialized shell:%s state:%s\n", shellType, ssPk.State.GetHashVal(false))
|
timeDur := time.Since(startTs)
|
||||||
|
dataFn([]byte(makeShellInitOutputMsg(verbose, ssPk.State, ssPk.Stats, timeDur, false)))
|
||||||
|
msh.WriteToPtyBuffer("%s", makeShellInitOutputMsg(false, ssPk.State, ssPk.Stats, timeDur, true))
|
||||||
return ssPk, nil
|
return ssPk, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeShellInitOutputMsg(verbose bool, state *packet.ShellState, stats *packet.ShellStateStats, dur time.Duration, ptyMsg bool) string {
|
||||||
|
if !verbose || ptyMsg {
|
||||||
|
if ptyMsg {
|
||||||
|
return fmt.Sprintf("initialized state shell:%s statehash:%s %dms\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds())
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("initialized connection state (shell:%s)\r\n", state.GetShellType())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString("-----\r\n")
|
||||||
|
buf.WriteString(fmt.Sprintf("initialized connection shell:%s statehash:%s %dms\r\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds()))
|
||||||
|
if stats != nil {
|
||||||
|
buf.WriteString(fmt.Sprintf(" outsize:%s size:%s env:%d, vars:%d, aliases:%d, funcs:%d\r\n", scbase.NumFormatDec(stats.OutputSize), scbase.NumFormatDec(stats.StateSize), stats.EnvCount, stats.VarCount, stats.AliasCount, stats.FuncCount))
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
func (msh *MShellProc) WriteFile(ctx context.Context, writePk *packet.WriteFilePacketType) (*packet.RpcResponseIter, error) {
|
func (msh *MShellProc) WriteFile(ctx context.Context, writePk *packet.WriteFilePacketType) (*packet.RpcResponseIter, error) {
|
||||||
return msh.PacketRpcIter(ctx, writePk)
|
return msh.PacketRpcIter(ctx, writePk)
|
||||||
}
|
}
|
||||||
@ -1690,7 +1736,7 @@ func (msh *MShellProc) initActiveShells() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, shellType := range activeShells {
|
for _, shellType := range activeShells {
|
||||||
_, err = msh.ReInit(ctx, shellType)
|
_, err = msh.ReInit(ctx, shellType, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
|
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user