mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
zsh reinit fixes (#477)
* reset command now initiates and completes async so there is feedback that something is happening when it takes a long time * switch from standard rpc to rpciter * checkpoint on reinit -- stream output, stats packet, logging to cmd pty, new endBytes for EOF * make generic versions of endbytes scanner and channel output funcs * update bash to use more modern state parsing (tricks learned from zsh) * verbose mode, fix stats output message * add a diff when verbose mode is on
This commit is contained in:
parent
accb74ae0f
commit
5616c9abbb
@ -598,11 +598,12 @@ func MakeLogPacket(entry wlog.LogEntry) *LogPacketType {
|
||||
}
|
||||
|
||||
type ShellStatePacketType struct {
|
||||
Type string `json:"type"`
|
||||
ShellType string `json:"shelltype"`
|
||||
RespId string `json:"respid,omitempty"`
|
||||
State *ShellState `json:"state"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Type string `json:"type"`
|
||||
ShellType string `json:"shelltype"`
|
||||
RespId string `json:"respid,omitempty"`
|
||||
State *ShellState `json:"state"`
|
||||
Stats *ShellStateStats `json:"stats"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (*ShellStatePacketType) GetType() string {
|
||||
|
@ -19,6 +19,17 @@ import (
|
||||
const ShellStatePackVersion = 0
|
||||
const ShellStateDiffPackVersion = 0
|
||||
|
||||
type ShellStateStats struct {
|
||||
Version string `json:"version"`
|
||||
AliasCount int `json:"aliascount"`
|
||||
EnvCount int `json:"envcount"`
|
||||
VarCount int `json:"varcount"`
|
||||
FuncCount int `json:"funccount"`
|
||||
HashVal string `json:"hashval"`
|
||||
OutputSize int64 `json:"outputsize"`
|
||||
StateSize int64 `json:"statesize"`
|
||||
}
|
||||
|
||||
type ShellState struct {
|
||||
Version string `json:"version"` // [type] [semver]
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
@ -29,6 +40,10 @@ type ShellState struct {
|
||||
HashVal string `json:"-"`
|
||||
}
|
||||
|
||||
func (state ShellState) ApproximateSize() int64 {
|
||||
return int64(len(state.Version) + len(state.Cwd) + len(state.ShellVars) + len(state.Aliases) + len(state.Funcs) + len(state.Error))
|
||||
}
|
||||
|
||||
type ShellStateDiff struct {
|
||||
Version string `json:"version"` // [type] [semver] (note this should *always* be set even if the same as base)
|
||||
BaseHash string `json:"basehash"`
|
||||
|
@ -244,11 +244,10 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
|
||||
appendSlashes(comps)
|
||||
}
|
||||
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) reinit(reqId string, shellType string) {
|
||||
ssPk, err := shexec.MakeShellStatePacket(shellType)
|
||||
ssPk, err := m.MakeShellStatePacket(reqId, shellType)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
|
||||
return
|
||||
@ -262,6 +261,32 @@ func (m *MServer) reinit(reqId string, shellType string) {
|
||||
m.Sender.SendPacket(ssPk)
|
||||
}
|
||||
|
||||
func (m *MServer) MakeShellStatePacket(reqId string, shellType string) (*packet.ShellStatePacketType, error) {
|
||||
sapi, err := shellapi.MakeShellApi(shellType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtnCh := make(chan shellapi.ShellStateOutput, 1)
|
||||
go sapi.GetShellState(rtnCh)
|
||||
for ssOutput := range rtnCh {
|
||||
if ssOutput.Error != "" {
|
||||
return nil, errors.New(ssOutput.Error)
|
||||
}
|
||||
if ssOutput.ShellState != nil {
|
||||
rtn := packet.MakeShellStatePacket()
|
||||
rtn.State = ssOutput.ShellState
|
||||
rtn.Stats = ssOutput.Stats
|
||||
return rtn, nil
|
||||
}
|
||||
if ssOutput.Output != nil {
|
||||
dataPk := packet.MakeFileDataPacket(reqId)
|
||||
dataPk.Data = ssOutput.Output
|
||||
m.Sender.SendPacket(dataPk)
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
|
||||
dirName := filepath.Dir(path)
|
||||
baseName := filepath.Base(path)
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
const BaseBashOpts = `set +m; set +H; shopt -s extglob`
|
||||
@ -48,7 +49,7 @@ func (b bashShellApi) GetShellType() string {
|
||||
return packet.ShellType_bash
|
||||
}
|
||||
|
||||
func (b bashShellApi) MakeExitTrap(fdNum int) string {
|
||||
func (b bashShellApi) MakeExitTrap(fdNum int) (string, []byte) {
|
||||
return MakeBashExitTrap(fdNum)
|
||||
}
|
||||
|
||||
@ -79,29 +80,15 @@ func (b bashShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
|
||||
return MakeBashShExecCommand(cmdStr, rcFileName, usePty)
|
||||
}
|
||||
|
||||
func (b bashShellApi) GetShellState() chan ShellStateOutput {
|
||||
ch := make(chan ShellStateOutput, 1)
|
||||
defer close(ch)
|
||||
ssPk, err := GetBashShellState()
|
||||
if err != nil {
|
||||
ch <- ShellStateOutput{
|
||||
Status: ShellStateOutputStatus_Done,
|
||||
Error: err.Error(),
|
||||
}
|
||||
return ch
|
||||
}
|
||||
ch <- ShellStateOutput{
|
||||
Status: ShellStateOutputStatus_Done,
|
||||
ShellState: ssPk,
|
||||
}
|
||||
return ch
|
||||
func (b bashShellApi) GetShellState(outCh chan ShellStateOutput) {
|
||||
GetBashShellState(outCh)
|
||||
}
|
||||
|
||||
func (b bashShellApi) GetBaseShellOpts() string {
|
||||
return BaseBashOpts
|
||||
}
|
||||
|
||||
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, error) {
|
||||
func (b bashShellApi) ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
|
||||
return parseBashShellStateOutput(output)
|
||||
}
|
||||
|
||||
@ -130,8 +117,32 @@ func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
|
||||
return rcBuf.String()
|
||||
}
|
||||
|
||||
func GetBashShellStateCmd() string {
|
||||
return strings.Join(GetBashShellStateCmds, ` printf "\x00\x00";`)
|
||||
func GetBashShellStateCmd(fdNum int) (string, []byte) {
|
||||
endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes)
|
||||
endBytes = append(endBytes, '\n')
|
||||
cmdStr := strings.TrimSpace(`
|
||||
exec 2> /dev/null;
|
||||
exec > [%OUTPUTFD%];
|
||||
printf "\x00\x00";
|
||||
[%BASHVERSIONCMD%];
|
||||
printf "\x00\x00";
|
||||
pwd;
|
||||
printf "\x00\x00";
|
||||
declare -p $(compgen -A variable);
|
||||
printf "\x00\x00";
|
||||
alias -p;
|
||||
printf "\x00\x00";
|
||||
declare -f;
|
||||
printf "\x00\x00";
|
||||
[%GITBRANCHCMD%];
|
||||
printf "\x00\x00";
|
||||
printf "[%ENDBYTES%]";
|
||||
`)
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "[%BASHVERSIONCMD%]", BashShellVersionCmdStr)
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "[%GITBRANCHCMD%]", GetGitBranchCmdStr)
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes)))
|
||||
return cmdStr, endBytes
|
||||
}
|
||||
|
||||
func execGetLocalBashShellVersion() string {
|
||||
@ -158,16 +169,34 @@ func GetLocalBashMajorVersion() string {
|
||||
return localBashMajorVersion
|
||||
}
|
||||
|
||||
func GetBashShellState() (*packet.ShellState, error) {
|
||||
func GetBashShellState(outCh chan ShellStateOutput) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
cmdStr := BaseBashOpts + "; " + GetBashShellStateCmd()
|
||||
defer close(outCh)
|
||||
stateCmd, endBytes := GetBashShellStateCmd(StateOutputFdNum)
|
||||
cmdStr := BaseBashOpts + "; " + stateCmd
|
||||
ecmd := exec.CommandContext(ctx, GetLocalBashPath(), "-l", "-i", "-c", cmdStr)
|
||||
outputBytes, err := RunSimpleCmdInPty(ecmd)
|
||||
outputCh := make(chan []byte, 10)
|
||||
var outputWg sync.WaitGroup
|
||||
outputWg.Add(1)
|
||||
go func() {
|
||||
defer outputWg.Done()
|
||||
for outputBytes := range outputCh {
|
||||
outCh <- ShellStateOutput{Output: outputBytes}
|
||||
}
|
||||
}()
|
||||
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
|
||||
outputWg.Wait()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
outCh <- ShellStateOutput{Error: err.Error()}
|
||||
return
|
||||
}
|
||||
return parseBashShellStateOutput(outputBytes)
|
||||
rtn, stats, err := parseBashShellStateOutput(outputBytes)
|
||||
if err != nil {
|
||||
outCh <- ShellStateOutput{Error: err.Error()}
|
||||
return
|
||||
}
|
||||
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
|
||||
}
|
||||
|
||||
func GetLocalBashPath() string {
|
||||
@ -190,19 +219,20 @@ func GetLocalZshPath() string {
|
||||
return "zsh"
|
||||
}
|
||||
|
||||
func GetBashShellStateRedirectCommandStr(outputFdNum int) string {
|
||||
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetBashShellStateCmd(), outputFdNum)
|
||||
func GetBashShellStateRedirectCommandStr(outputFdNum int) (string, []byte) {
|
||||
cmdStr, endBytes := GetBashShellStateCmd(outputFdNum)
|
||||
return cmdStr, endBytes
|
||||
}
|
||||
|
||||
func MakeBashExitTrap(fdNum int) string {
|
||||
stateCmd := GetBashShellStateRedirectCommandStr(fdNum)
|
||||
func MakeBashExitTrap(fdNum int) (string, []byte) {
|
||||
stateCmd, endBytes := GetBashShellStateRedirectCommandStr(fdNum)
|
||||
fmtStr := `
|
||||
_waveshell_exittrap () {
|
||||
%s
|
||||
}
|
||||
trap _waveshell_exittrap EXIT
|
||||
`
|
||||
return fmt.Sprintf(fmtStr, stateCmd)
|
||||
return fmt.Sprintf(fmtStr, stateCmd), endBytes
|
||||
}
|
||||
|
||||
func MakeBashShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
|
||||
|
@ -20,6 +20,19 @@ import (
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
const (
|
||||
BashSection_Ignored = iota
|
||||
BashSection_Version
|
||||
BashSection_Cwd
|
||||
BashSection_Vars
|
||||
BashSection_Aliases
|
||||
BashSection_Funcs
|
||||
BashSection_PVars
|
||||
BashSection_EndBytes
|
||||
|
||||
BashSection_Count // must be last
|
||||
)
|
||||
|
||||
type DeclareDeclType = shellenv.DeclareDeclType
|
||||
|
||||
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
|
||||
@ -214,38 +227,37 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
|
||||
func parseBashShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
|
||||
if scbase.IsDevMode() && DebugState {
|
||||
writeStateToFile(packet.ShellType_bash, outputBytes)
|
||||
}
|
||||
// 7 fields: ignored [0], version [1], cwd [2], env/vars [3], aliases [4], funcs [5], pvars [6]
|
||||
fields := bytes.Split(outputBytes, []byte{0, 0})
|
||||
if len(fields) != 7 {
|
||||
return nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(fields))
|
||||
sections := bytes.Split(outputBytes, []byte{0, 0})
|
||||
if len(sections) != BashSection_Count {
|
||||
return nil, nil, fmt.Errorf("invalid bash shell state output, wrong number of fields, fields=%d", len(sections))
|
||||
}
|
||||
rtn := &packet.ShellState{}
|
||||
rtn.Version = strings.TrimSpace(string(fields[1]))
|
||||
rtn.Version = strings.TrimSpace(string(sections[BashSection_Version]))
|
||||
if rtn.GetShellType() != packet.ShellType_bash {
|
||||
return nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
|
||||
return nil, nil, fmt.Errorf("invalid bash shell state output, wrong shell type: %q", rtn.Version)
|
||||
}
|
||||
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
|
||||
return nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
|
||||
return nil, nil, fmt.Errorf("invalid bash shell state output, invalid version: %v", err)
|
||||
}
|
||||
cwdStr := string(fields[2])
|
||||
cwdStr := string(sections[BashSection_Cwd])
|
||||
if strings.HasSuffix(cwdStr, "\r\n") {
|
||||
cwdStr = cwdStr[0 : len(cwdStr)-2]
|
||||
} else if strings.HasSuffix(cwdStr, "\n") {
|
||||
cwdStr = cwdStr[0 : len(cwdStr)-1]
|
||||
} else {
|
||||
cwdStr = strings.TrimSuffix(cwdStr, "\n")
|
||||
}
|
||||
rtn.Cwd = string(cwdStr)
|
||||
err := bashParseDeclareOutput(rtn, fields[3], fields[6])
|
||||
err := bashParseDeclareOutput(rtn, sections[BashSection_Vars], sections[BashSection_PVars])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
rtn.Aliases = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
|
||||
rtn.Funcs = strings.ReplaceAll(string(fields[5]), "\r\n", "\n")
|
||||
rtn.Aliases = strings.ReplaceAll(string(sections[BashSection_Aliases]), "\r\n", "\n")
|
||||
rtn.Funcs = strings.ReplaceAll(string(sections[BashSection_Funcs]), "\r\n", "\n")
|
||||
rtn.Funcs = shellenv.RemoveFunc(rtn.Funcs, "_waveshell_exittrap")
|
||||
return rtn, nil
|
||||
return rtn, nil, nil
|
||||
}
|
||||
|
||||
func bashNormalize(d *DeclareDeclType) error {
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
@ -28,12 +27,14 @@ import (
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
)
|
||||
|
||||
const GetStateTimeout = 5 * time.Second
|
||||
const GetStateTimeout = 15 * time.Second
|
||||
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
|
||||
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
|
||||
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
|
||||
const RunCommandFmt = `%s`
|
||||
const DebugState = false
|
||||
const StateOutputFdNum = 20
|
||||
const NumRandomEndBytes = 8
|
||||
|
||||
var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`)
|
||||
|
||||
@ -56,23 +57,23 @@ const (
|
||||
)
|
||||
|
||||
type ShellStateOutput struct {
|
||||
Status string
|
||||
StderrOutput []byte
|
||||
ShellState *packet.ShellState
|
||||
Error string
|
||||
Output []byte
|
||||
ShellState *packet.ShellState
|
||||
Stats *packet.ShellStateStats
|
||||
Error string
|
||||
}
|
||||
|
||||
type ShellApi interface {
|
||||
GetShellType() string
|
||||
MakeExitTrap(fdNum int) string
|
||||
MakeExitTrap(fdNum int) (string, []byte)
|
||||
GetLocalMajorVersion() string
|
||||
GetLocalShellPath() string
|
||||
GetRemoteShellPath() string
|
||||
MakeRunCommand(cmdStr string, opts RunCommandOpts) string
|
||||
MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd
|
||||
GetShellState() chan ShellStateOutput
|
||||
GetShellState(chan ShellStateOutput)
|
||||
GetBaseShellOpts() string
|
||||
ParseShellStateOutput(output []byte) (*packet.ShellState, error)
|
||||
ParseShellStateOutput(output []byte) (*packet.ShellState, *packet.ShellStateStats, error)
|
||||
MakeRcFileStr(pk *packet.RunPacketType) string
|
||||
MakeShellStateDiff(oldState *packet.ShellState, oldStateHash string, newState *packet.ShellState) (*packet.ShellStateDiff, error)
|
||||
ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error)
|
||||
@ -153,12 +154,13 @@ func internalMacUserShell() string {
|
||||
const FirstExtraFilesFdNum = 3
|
||||
|
||||
// returns output(stdout+stderr), extraFdOutput, error
|
||||
func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, error) {
|
||||
func StreamCommandWithExtraFd(ecmd *exec.Cmd, outputCh chan []byte, extraFdNum int, endBytes []byte) ([]byte, error) {
|
||||
defer close(outputCh)
|
||||
ecmd.Env = os.Environ()
|
||||
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
|
||||
cmdPty, cmdTty, err := pty.Open()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("opening new pty: %w", err)
|
||||
return nil, fmt.Errorf("opening new pty: %w", err)
|
||||
}
|
||||
defer cmdTty.Close()
|
||||
defer cmdPty.Close()
|
||||
@ -171,42 +173,44 @@ func RunCommandWithExtraFd(ecmd *exec.Cmd, extraFdNum int) ([]byte, []byte, erro
|
||||
ecmd.SysProcAttr.Setctty = true
|
||||
pipeReader, pipeWriter, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not create pipe: %w", err)
|
||||
return nil, fmt.Errorf("could not create pipe: %w", err)
|
||||
}
|
||||
defer pipeWriter.Close()
|
||||
defer pipeReader.Close()
|
||||
extraFiles := make([]*os.File, extraFdNum+1)
|
||||
extraFiles[extraFdNum] = pipeWriter
|
||||
ecmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:]
|
||||
defer pipeReader.Close()
|
||||
ecmd.Start()
|
||||
err = ecmd.Start()
|
||||
cmdTty.Close()
|
||||
pipeWriter.Close()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
var outputWg sync.WaitGroup
|
||||
var outputBuf bytes.Buffer
|
||||
var extraFdOutputBuf bytes.Buffer
|
||||
outputWg.Add(2)
|
||||
go func() {
|
||||
// ignore error (/dev/ptmx has read error when process is done)
|
||||
defer outputWg.Done()
|
||||
io.Copy(&outputBuf, cmdPty)
|
||||
err := utilfn.CopyToChannel(outputCh, cmdPty)
|
||||
if err != nil {
|
||||
errStr := fmt.Sprintf("\r\nerror reading from pty: %v\r\n", err)
|
||||
outputCh <- []byte(errStr)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer outputWg.Done()
|
||||
io.Copy(&extraFdOutputBuf, pipeReader)
|
||||
utilfn.CopyWithEndBytes(&extraFdOutputBuf, pipeReader, endBytes)
|
||||
}()
|
||||
exitErr := ecmd.Wait()
|
||||
if exitErr != nil {
|
||||
return nil, nil, exitErr
|
||||
return nil, exitErr
|
||||
}
|
||||
outputWg.Wait()
|
||||
return outputBuf.Bytes(), extraFdOutputBuf.Bytes(), nil
|
||||
return extraFdOutputBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
|
||||
func RunSimpleCmdInPty(ecmd *exec.Cmd, endBytes []byte) ([]byte, error) {
|
||||
ecmd.Env = os.Environ()
|
||||
shellutil.UpdateCmdEnv(ecmd, shellutil.MShellEnvVars(shellutil.DefaultTermType))
|
||||
cmdPty, cmdTty, err := pty.Open()
|
||||
@ -231,8 +235,8 @@ func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
|
||||
var outputBuf bytes.Buffer
|
||||
go func() {
|
||||
// ignore error (/dev/ptmx has read error when process is done)
|
||||
io.Copy(&outputBuf, cmdPty)
|
||||
close(ioDone)
|
||||
defer close(ioDone)
|
||||
utilfn.CopyWithEndBytes(&outputBuf, cmdPty, endBytes)
|
||||
}()
|
||||
exitErr := ecmd.Wait()
|
||||
if exitErr != nil {
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -28,7 +27,6 @@ import (
|
||||
const BaseZshOpts = ``
|
||||
|
||||
const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION`
|
||||
const StateOutputFdNum = 20
|
||||
|
||||
const (
|
||||
ZshSection_Version = iota
|
||||
@ -41,6 +39,7 @@ const (
|
||||
ZshSection_Funcs
|
||||
ZshSection_PVars
|
||||
ZshSection_Prompt
|
||||
ZshSection_EndBytes
|
||||
|
||||
ZshSection_NumFieldsExpected // must be last
|
||||
)
|
||||
@ -118,6 +117,11 @@ var ZshIgnoreVars = map[string]bool{
|
||||
"zcurses_windows": true,
|
||||
|
||||
// not listed, but we also exclude all ZFTP_* variables
|
||||
|
||||
// powerlevel10k
|
||||
"_GITSTATUS_CLIENT_PID_POWERLEVEL9K": true,
|
||||
"GITSTATUS_DAEMON_PID_POWERLEVEL9K": true,
|
||||
"_GITSTATUS_FILE_PREFIX_POWERLEVEL9K": true,
|
||||
}
|
||||
|
||||
var ZshIgnoreFuncs = map[string]bool{
|
||||
@ -211,7 +215,7 @@ func (z zshShellApi) GetShellType() string {
|
||||
return packet.ShellType_zsh
|
||||
}
|
||||
|
||||
func (z zshShellApi) MakeExitTrap(fdNum int) string {
|
||||
func (z zshShellApi) MakeExitTrap(fdNum int) (string, []byte) {
|
||||
return MakeZshExitTrap(fdNum)
|
||||
}
|
||||
|
||||
@ -242,25 +246,34 @@ func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty
|
||||
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetShellState() chan ShellStateOutput {
|
||||
func (z zshShellApi) GetShellState(outCh chan ShellStateOutput) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
|
||||
defer cancelFn()
|
||||
rtnCh := make(chan ShellStateOutput, 1)
|
||||
defer close(rtnCh)
|
||||
cmdStr := BaseZshOpts + "; " + GetZshShellStateCmd(StateOutputFdNum)
|
||||
defer close(outCh)
|
||||
stateCmd, endBytes := GetZshShellStateCmd(StateOutputFdNum)
|
||||
cmdStr := BaseZshOpts + "; " + stateCmd
|
||||
ecmd := exec.CommandContext(ctx, GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
||||
_, outputBytes, err := RunCommandWithExtraFd(ecmd, StateOutputFdNum)
|
||||
outputCh := make(chan []byte, 10)
|
||||
var outputWg sync.WaitGroup
|
||||
outputWg.Add(1)
|
||||
go func() {
|
||||
defer outputWg.Done()
|
||||
for outputBytes := range outputCh {
|
||||
outCh <- ShellStateOutput{Output: outputBytes}
|
||||
}
|
||||
}()
|
||||
outputBytes, err := StreamCommandWithExtraFd(ecmd, outputCh, StateOutputFdNum, endBytes)
|
||||
outputWg.Wait()
|
||||
if err != nil {
|
||||
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()}
|
||||
return rtnCh
|
||||
outCh <- ShellStateOutput{Error: err.Error()}
|
||||
return
|
||||
}
|
||||
rtn, err := z.ParseShellStateOutput(outputBytes)
|
||||
rtn, stats, err := z.ParseShellStateOutput(outputBytes)
|
||||
if err != nil {
|
||||
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, Error: err.Error()}
|
||||
return rtnCh
|
||||
outCh <- ShellStateOutput{Error: err.Error()}
|
||||
return
|
||||
}
|
||||
rtnCh <- ShellStateOutput{Status: ShellStateOutputStatus_Done, ShellState: rtn}
|
||||
return rtnCh
|
||||
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetBaseShellOpts() string {
|
||||
@ -437,19 +450,15 @@ func writeZshId(buf *bytes.Buffer, idStr string) {
|
||||
|
||||
const numRandomBytes = 4
|
||||
|
||||
// returns (cmd-string)
|
||||
func GetZshShellStateCmd(fdNum int) string {
|
||||
// returns (cmd-string, endbytes)
|
||||
func GetZshShellStateCmd(fdNum int) (string, []byte) {
|
||||
var sectionSeparator []byte
|
||||
// adding this extra "\n" helps with debuging and readability of output
|
||||
sectionSeparator = append(sectionSeparator, byte('\n'))
|
||||
for len(sectionSeparator) < numRandomBytes {
|
||||
// any character *except* null (0)
|
||||
rn := rand.Intn(256)
|
||||
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
|
||||
sectionSeparator = append(sectionSeparator, byte(rn))
|
||||
}
|
||||
}
|
||||
sectionSeparator = utilfn.AppendNonZeroRandomBytes(sectionSeparator, numRandomBytes)
|
||||
sectionSeparator = append(sectionSeparator, 0, 0)
|
||||
endBytes := utilfn.AppendNonZeroRandomBytes(nil, NumRandomEndBytes)
|
||||
endBytes = append(endBytes, byte('\n'))
|
||||
// we have to use these crazy separators because zsh allows basically anything in
|
||||
// variable names and values (including nulls).
|
||||
// note that we don't need crazy separators for "env" or "typeset".
|
||||
@ -511,6 +520,8 @@ printf "[%SECTIONSEP%]";
|
||||
[%K8SNAMESPACE%]
|
||||
printf "[%SECTIONSEP%]";
|
||||
print -P "$PS1"
|
||||
printf "[%SECTIONSEP%]";
|
||||
printf "[%ENDBYTES%]"
|
||||
`
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr)
|
||||
@ -520,17 +531,19 @@ print -P "$PS1"
|
||||
cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1])))
|
||||
cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator)))
|
||||
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
|
||||
return cmd
|
||||
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFDNUM%]", fmt.Sprintf("%d", fdNum))
|
||||
cmd = strings.ReplaceAll(cmd, "[%ENDBYTES%]", utilfn.ShellHexEscape(string(endBytes)))
|
||||
return cmd, endBytes
|
||||
}
|
||||
|
||||
func MakeZshExitTrap(fdNum int) string {
|
||||
stateCmd := GetZshShellStateCmd(fdNum)
|
||||
func MakeZshExitTrap(fdNum int) (string, []byte) {
|
||||
stateCmd, endBytes := GetZshShellStateCmd(fdNum)
|
||||
fmtStr := `
|
||||
zshexit () {
|
||||
%s
|
||||
}
|
||||
`
|
||||
return fmt.Sprintf(fmtStr, stateCmd)
|
||||
return fmt.Sprintf(fmtStr, stateCmd), endBytes
|
||||
}
|
||||
|
||||
func execGetLocalZshShellVersion() string {
|
||||
@ -698,14 +711,14 @@ func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string {
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
|
||||
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
|
||||
if scbase.IsDevMode() && DebugState {
|
||||
writeStateToFile(packet.ShellType_zsh, outputBytes)
|
||||
}
|
||||
firstZeroIdx := bytes.Index(outputBytes, []byte{0})
|
||||
firstDZeroIdx := bytes.Index(outputBytes, []byte{0, 0})
|
||||
if firstZeroIdx == -1 || firstDZeroIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes")
|
||||
return nil, nil, fmt.Errorf("invalid zsh shell state output, could not parse separator bytes")
|
||||
}
|
||||
versionStr := string(outputBytes[0:firstZeroIdx])
|
||||
sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2]
|
||||
@ -714,15 +727,15 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
|
||||
sections := bytes.Split(outputBytes, sectionSeparator)
|
||||
if len(sections) != ZshSection_NumFieldsExpected {
|
||||
base.Logf("invalid -- numfields\n")
|
||||
return nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
|
||||
return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
|
||||
}
|
||||
rtn := &packet.ShellState{}
|
||||
rtn.Version = strings.TrimSpace(versionStr)
|
||||
if rtn.GetShellType() != packet.ShellType_zsh {
|
||||
return nil, fmt.Errorf("invalid zsh shell state output, wrong shell type")
|
||||
return nil, nil, fmt.Errorf("invalid zsh shell state output, wrong shell type")
|
||||
}
|
||||
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
|
||||
return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
|
||||
return nil, nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
|
||||
}
|
||||
cwdStr := stripNewLineChars(string(sections[ZshSection_Cwd]))
|
||||
rtn.Cwd = cwdStr
|
||||
@ -730,7 +743,7 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
|
||||
zshDecls, err := parseZshDecls(sections[ZshSection_Vars])
|
||||
if err != nil {
|
||||
base.Logf("invalid - parsedecls %v\n", err)
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, decl := range zshDecls {
|
||||
if decl.IsZshScalarBound() {
|
||||
@ -746,7 +759,17 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
|
||||
pvarMap := parseExtVarOutput(sections[ZshSection_PVars], string(sections[ZshSection_Prompt]), string(sections[ZshSection_Mods]))
|
||||
utilfn.CombineMaps(zshDecls, pvarMap)
|
||||
rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls)
|
||||
return rtn, nil
|
||||
stats := &packet.ShellStateStats{
|
||||
Version: rtn.Version,
|
||||
AliasCount: int(len(aliasMap)),
|
||||
FuncCount: int(len(zshFuncs)),
|
||||
VarCount: int(len(zshDecls)),
|
||||
EnvCount: int(len(zshEnv)),
|
||||
HashVal: rtn.GetHashVal(false),
|
||||
OutputSize: int64(len(outputBytes)),
|
||||
StateSize: rtn.ApproximateSize(),
|
||||
}
|
||||
return rtn, stats, nil
|
||||
}
|
||||
|
||||
func parseZshEnv(output []byte) map[string]string {
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@ -89,13 +88,14 @@ func MakeInstallCommandStr() string {
|
||||
type MShellBinaryReaderFn func(version string, goos string, goarch string) (io.ReadCloser, error)
|
||||
|
||||
type ReturnStateBuf struct {
|
||||
Lock *sync.Mutex
|
||||
Buf []byte
|
||||
Done bool
|
||||
Err error
|
||||
Reader *os.File
|
||||
FdNum int
|
||||
DoneCh chan bool
|
||||
Lock *sync.Mutex
|
||||
Buf []byte
|
||||
Done bool
|
||||
Err error
|
||||
Reader *os.File
|
||||
FdNum int
|
||||
EndBytes []byte
|
||||
DoneCh chan bool
|
||||
}
|
||||
|
||||
func MakeReturnStateBuf() *ReturnStateBuf {
|
||||
@ -835,7 +835,8 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
||||
cmd.ReturnState.FdNum = RtnStateFdNum
|
||||
rtnStateWriter = pw
|
||||
defer pw.Close()
|
||||
trapCmdStr := sapi.MakeExitTrap(cmd.ReturnState.FdNum)
|
||||
trapCmdStr, endBytes := sapi.MakeExitTrap(cmd.ReturnState.FdNum)
|
||||
cmd.ReturnState.EndBytes = endBytes
|
||||
rcFileStr += trapCmdStr
|
||||
}
|
||||
shellVarMap := shellenv.ShellVarMapFromState(state)
|
||||
@ -1021,6 +1022,11 @@ func (rs *ReturnStateBuf) Run() {
|
||||
}
|
||||
rs.Lock.Lock()
|
||||
rs.Buf = append(rs.Buf, buf[0:n]...)
|
||||
if bytes.HasSuffix(rs.Buf, rs.EndBytes) {
|
||||
rs.Buf = rs.Buf[:len(rs.Buf)-len(rs.EndBytes)]
|
||||
rs.Lock.Unlock()
|
||||
break
|
||||
}
|
||||
rs.Lock.Unlock()
|
||||
}
|
||||
}
|
||||
@ -1127,7 +1133,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
|
||||
wlog.Logf("debug returnstate file %q\n", base.GetDebugReturnStateFileName())
|
||||
os.WriteFile(base.GetDebugReturnStateFileName(), c.ReturnState.Buf, 0666)
|
||||
}
|
||||
state, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
|
||||
state, _, _ := c.SAPI.ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
|
||||
donePacket.FinalState = state
|
||||
}
|
||||
endTs := time.Now()
|
||||
@ -1156,21 +1162,6 @@ func MakeInitPacket() *packet.InitPacketType {
|
||||
return initPacket
|
||||
}
|
||||
|
||||
func MakeShellStatePacket(shellType string) (*packet.ShellStatePacketType, error) {
|
||||
sapi, err := shellapi.MakeShellApi(shellType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtnCh := sapi.GetShellState()
|
||||
ssOutput := <-rtnCh
|
||||
if ssOutput.Error != "" {
|
||||
return nil, errors.New(ssOutput.Error)
|
||||
}
|
||||
rtn := packet.MakeShellStatePacket()
|
||||
rtn.State = ssOutput.ShellState
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func MakeServerInitPacket() (*packet.InitPacketType, error) {
|
||||
var err error
|
||||
initPacket := MakeInitPacket()
|
||||
|
@ -10,7 +10,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
mathrand "math/rand"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -552,3 +554,60 @@ func StrArrayToMap(sarr []string) map[string]bool {
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func AppendNonZeroRandomBytes(b []byte, randLen int) []byte {
|
||||
if randLen <= 0 {
|
||||
return b
|
||||
}
|
||||
numAdded := 0
|
||||
for numAdded < randLen {
|
||||
rn := mathrand.Intn(256)
|
||||
if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here
|
||||
b = append(b, byte(rn))
|
||||
numAdded++
|
||||
}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// returns (isEOF, error)
|
||||
func CopyWithEndBytes(outputBuf *bytes.Buffer, reader io.Reader, endBytes []byte) (bool, error) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
outputBuf.Write(buf[:n])
|
||||
obytes := outputBuf.Bytes()
|
||||
if bytes.HasSuffix(obytes, endBytes) {
|
||||
outputBuf.Truncate(len(obytes) - len(endBytes))
|
||||
return (err == io.EOF), nil
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
return true, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// does *not* close outputCh on EOF or error
|
||||
func CopyToChannel(outputCh chan<- []byte, reader io.Reader) error {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
// copy so client can use []byte without it being overwritten
|
||||
bufCopy := make([]byte, n)
|
||||
copy(bufCopy, buf[:n])
|
||||
outputCh <- bufCopy
|
||||
}
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -41,6 +41,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/releasechecker"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote/openai"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/rtnstate"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
|
||||
@ -1648,8 +1649,12 @@ func CopyFileCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scb
|
||||
}
|
||||
var outputPos int64
|
||||
outputStr := fmt.Sprintf("Copying [%v]:%v to [%v]:%v\r\n", sourceRemoteId.DisplayName, sourceFullPath, destRemoteId.DisplayName, destFullPath)
|
||||
termopts := sstore.TermOpts{Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, FlexRows: true, MaxPtySize: remote.DefaultMaxPtySize}
|
||||
cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), termopts)
|
||||
termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot make termopts: %w", err)
|
||||
}
|
||||
pkTermOpts := convertTermOpts(termOpts)
|
||||
cmd, err := makeDynCmd(ctx, "copy file", ids, pk.GetRawStr(), *pkTermOpts)
|
||||
writeStringToPty(ctx, cmd, outputStr, &outputPos)
|
||||
if err != nil {
|
||||
// TODO tricky error since the command was a success, but we can't show the output
|
||||
@ -3655,11 +3660,14 @@ func SessionCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbu
|
||||
return update, nil
|
||||
}
|
||||
|
||||
func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) {
|
||||
func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (rtnUpdate scbus.UpdatePacket, rtnErr error) {
|
||||
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ids.Remote.MShell.IsConnected() {
|
||||
return nil, fmt.Errorf("cannot reinit, remote is not connected")
|
||||
}
|
||||
shellType := ids.Remote.ShellType
|
||||
if pk.Kwargs["shell"] != "" {
|
||||
shellArg := pk.Kwargs["shell"]
|
||||
@ -3668,33 +3676,76 @@ func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (
|
||||
}
|
||||
shellType = shellArg
|
||||
}
|
||||
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType)
|
||||
verbose := resolveBool(pk.Kwargs["verbose"], false)
|
||||
termOpts, err := GetUITermOpts(pk.UIContext.WinSize, DefaultPTERM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("cannot make termopts: %w", err)
|
||||
}
|
||||
if ssPk == nil || ssPk.State == nil {
|
||||
return nil, fmt.Errorf("invalid initpk received from remote (no remote state)")
|
||||
}
|
||||
feState := sstore.FeStateFromShellState(ssPk.State)
|
||||
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
|
||||
pkTermOpts := convertTermOpts(termOpts)
|
||||
cmd, err := makeDynCmd(ctx, "reset", ids, pk.GetRawStr(), *pkTermOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputStr := fmt.Sprintf("reset remote state (shell:%s)", ssPk.State.GetShellType())
|
||||
cmd, err := makeStaticCmd(ctx, "reset", ids, pk.GetRawStr(), []byte(outputStr))
|
||||
if err != nil {
|
||||
// TODO tricky error since the command was a success, but we can't show the output
|
||||
return nil, err
|
||||
}
|
||||
update, err := addLineForCmd(ctx, "/reset", false, ids, cmd, "", nil)
|
||||
if err != nil {
|
||||
// TODO tricky error since the command was a success, but we can't show the output
|
||||
return nil, err
|
||||
}
|
||||
update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst), sstore.InteractiveUpdate(pk.Interactive))
|
||||
go doResetCommand(ids, shellType, cmd, verbose)
|
||||
return update, nil
|
||||
}
|
||||
|
||||
func doResetCommand(ids resolvedIds, shellType string, cmd *sstore.CmdType, verbose bool) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancelFn()
|
||||
startTime := time.Now()
|
||||
var outputPos int64
|
||||
var rtnErr error
|
||||
exitSuccess := true
|
||||
defer func() {
|
||||
if rtnErr != nil {
|
||||
exitSuccess = false
|
||||
writeStringToPty(ctx, cmd, fmt.Sprintf("\r\nerror: %v", rtnErr), &outputPos)
|
||||
}
|
||||
deferWriteCmdStatus(ctx, cmd, startTime, exitSuccess, outputPos)
|
||||
}()
|
||||
dataFn := func(data []byte) {
|
||||
writeStringToPty(ctx, cmd, string(data), &outputPos)
|
||||
}
|
||||
origStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
|
||||
ssPk, err := ids.Remote.MShell.ReInit(ctx, shellType, dataFn, verbose)
|
||||
if err != nil {
|
||||
rtnErr = err
|
||||
return
|
||||
}
|
||||
if ssPk == nil || ssPk.State == nil {
|
||||
rtnErr = fmt.Errorf("invalid initpk received from remote (no remote state)")
|
||||
return
|
||||
}
|
||||
feState := sstore.FeStateFromShellState(ssPk.State)
|
||||
remoteInst, err := sstore.UpdateRemoteState(ctx, ids.SessionId, ids.ScreenId, ids.Remote.RemotePtr, feState, ssPk.State, nil)
|
||||
if err != nil {
|
||||
rtnErr = err
|
||||
return
|
||||
}
|
||||
newStatePtr := ids.Remote.MShell.GetDefaultStatePtr(shellType)
|
||||
if verbose && origStatePtr != nil && newStatePtr != nil {
|
||||
statePtrDiff := fmt.Sprintf("oldstate: %v, newstate: %v\r\n", origStatePtr.BaseHash, newStatePtr.BaseHash)
|
||||
writeStringToPty(ctx, cmd, statePtrDiff, &outputPos)
|
||||
origFullState, _ := sstore.GetFullState(ctx, *origStatePtr)
|
||||
newFullState, _ := sstore.GetFullState(ctx, *newStatePtr)
|
||||
if origFullState != nil && newFullState != nil {
|
||||
var diffBuf bytes.Buffer
|
||||
rtnstate.DisplayStateUpdateDiff(&diffBuf, *origFullState, *newFullState)
|
||||
diffStr := diffBuf.String()
|
||||
diffStr = strings.ReplaceAll(diffStr, "\n", "\r\n")
|
||||
writeStringToPty(ctx, cmd, diffStr, &outputPos)
|
||||
}
|
||||
}
|
||||
update := scbus.MakeUpdatePacket()
|
||||
update.AddUpdate(sstore.MakeSessionUpdateForRemote(ids.SessionId, remoteInst))
|
||||
scbus.MainUpdateBus.DoUpdate(update)
|
||||
}
|
||||
|
||||
func ResetCwdCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) {
|
||||
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote)
|
||||
if err != nil {
|
||||
|
@ -196,7 +196,7 @@ func (msh *MShellProc) EnsureShellType(ctx context.Context, shellType string) er
|
||||
return nil
|
||||
}
|
||||
// try to reinit the shell
|
||||
_, err := msh.ReInit(ctx, shellType)
|
||||
_, err := msh.ReInit(ctx, shellType, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error trying to initialize shell %q: %v", shellType, err)
|
||||
}
|
||||
@ -1401,33 +1401,60 @@ func makeReinitErrorUpdate(shellType string) sstore.ActivityUpdate {
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.ShellStatePacketType, error) {
|
||||
func (msh *MShellProc) ReInit(ctx context.Context, shellType string, dataFn func([]byte), verbose bool) (rtnPk *packet.ShellStatePacketType, rtnErr error) {
|
||||
if !msh.IsConnected() {
|
||||
return nil, fmt.Errorf("cannot reinit, remote is not connected")
|
||||
}
|
||||
if shellType != packet.ShellType_bash && shellType != packet.ShellType_zsh {
|
||||
return nil, fmt.Errorf("invalid shell type %q", shellType)
|
||||
}
|
||||
if dataFn == nil {
|
||||
dataFn = func([]byte) {}
|
||||
}
|
||||
defer func() {
|
||||
if rtnErr != nil {
|
||||
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
|
||||
}
|
||||
}()
|
||||
startTs := time.Now()
|
||||
reinitPk := packet.MakeReInitPacket()
|
||||
reinitPk.ReqId = uuid.New().String()
|
||||
reinitPk.ShellType = shellType
|
||||
resp, err := msh.PacketRpcRaw(ctx, reinitPk)
|
||||
rpcIter, err := msh.PacketRpcIter(ctx, reinitPk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("no response")
|
||||
}
|
||||
ssPk, ok := resp.(*packet.ShellStatePacketType)
|
||||
if !ok {
|
||||
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
|
||||
if respPk, ok := resp.(*packet.ResponsePacketType); ok && respPk.Error != "" {
|
||||
return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error)
|
||||
defer rpcIter.Close()
|
||||
var ssPk *packet.ShellStatePacketType
|
||||
for {
|
||||
resp, err := rpcIter.Next(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("invalid reinit response (not an shellstate packet): %T", resp)
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("channel closed with no response")
|
||||
}
|
||||
var ok bool
|
||||
ssPk, ok = resp.(*packet.ShellStatePacketType)
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
respPk, ok := resp.(*packet.ResponsePacketType)
|
||||
if ok {
|
||||
if respPk.Error != "" {
|
||||
return nil, fmt.Errorf("error reinitializing remote: %s", respPk.Error)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid response from waveshell")
|
||||
}
|
||||
dataPk, ok := resp.(*packet.FileDataPacketType)
|
||||
if ok {
|
||||
dataFn(dataPk.Data)
|
||||
continue
|
||||
}
|
||||
invalidPkStr := fmt.Sprintf("\r\ninvalid packettype from waveshell: %s\r\n", resp.GetType())
|
||||
dataFn([]byte(invalidPkStr))
|
||||
}
|
||||
if ssPk.State == nil {
|
||||
sstore.UpdateActivityWrap(ctx, makeReinitErrorUpdate(shellType), "reiniterror")
|
||||
if ssPk == nil || ssPk.State == nil {
|
||||
return nil, fmt.Errorf("invalid reinit response shellstate packet does not contain remote state")
|
||||
}
|
||||
// TODO: maybe we don't need to save statebase here. should be possible to save it on demand
|
||||
@ -1438,10 +1465,29 @@ func (msh *MShellProc) ReInit(ctx context.Context, shellType string) (*packet.Sh
|
||||
return nil, fmt.Errorf("error storing remote state: %w", err)
|
||||
}
|
||||
msh.StateMap.SetCurrentState(ssPk.State.GetShellType(), ssPk.State)
|
||||
msh.WriteToPtyBuffer("initialized shell:%s state:%s\n", shellType, ssPk.State.GetHashVal(false))
|
||||
timeDur := time.Since(startTs)
|
||||
dataFn([]byte(makeShellInitOutputMsg(verbose, ssPk.State, ssPk.Stats, timeDur, false)))
|
||||
msh.WriteToPtyBuffer("%s", makeShellInitOutputMsg(false, ssPk.State, ssPk.Stats, timeDur, true))
|
||||
return ssPk, nil
|
||||
}
|
||||
|
||||
func makeShellInitOutputMsg(verbose bool, state *packet.ShellState, stats *packet.ShellStateStats, dur time.Duration, ptyMsg bool) string {
|
||||
if !verbose || ptyMsg {
|
||||
if ptyMsg {
|
||||
return fmt.Sprintf("initialized state shell:%s statehash:%s %dms\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds())
|
||||
} else {
|
||||
return fmt.Sprintf("initialized connection state (shell:%s)\r\n", state.GetShellType())
|
||||
}
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("-----\r\n")
|
||||
buf.WriteString(fmt.Sprintf("initialized connection shell:%s statehash:%s %dms\r\n", state.GetShellType(), state.GetHashVal(false), dur.Milliseconds()))
|
||||
if stats != nil {
|
||||
buf.WriteString(fmt.Sprintf(" outsize:%s size:%s env:%d, vars:%d, aliases:%d, funcs:%d\r\n", scbase.NumFormatDec(stats.OutputSize), scbase.NumFormatDec(stats.StateSize), stats.EnvCount, stats.VarCount, stats.AliasCount, stats.FuncCount))
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (msh *MShellProc) WriteFile(ctx context.Context, writePk *packet.WriteFilePacketType) (*packet.RpcResponseIter, error) {
|
||||
return msh.PacketRpcIter(ctx, writePk)
|
||||
}
|
||||
@ -1690,7 +1736,7 @@ func (msh *MShellProc) initActiveShells() {
|
||||
return
|
||||
}
|
||||
for _, shellType := range activeShells {
|
||||
_, err = msh.ReInit(ctx, shellType)
|
||||
_, err = msh.ReInit(ctx, shellType, nil, false)
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error reiniting shell %q: %v\n", shellType, err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user