zsh fixes (zmodload, k8s, new var type, track zsh prompt) (#473)

* checkpoint.  track zmodloads, get the expanded PS1 value, change how we store and deal with 'prompt vars', add commands to track k8s context

* fix some warnings, change IsPVar to IsExtVar, checkpoint

* parse zmodloads into state

* restore zmodloads, be more careful around readonly variable setting to avoid errors
This commit is contained in:
Mike Sawka 2024-03-18 22:51:16 -07:00 committed by GitHub
parent d2f5d87194
commit 901dcccaa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 215 additions and 45 deletions

View File

@ -113,6 +113,9 @@ func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
if varDecl.IsExport() || varDecl.IsReadOnly() { if varDecl.IsExport() || varDecl.IsReadOnly() {
continue continue
} }
if varDecl.IsExtVar {
continue
}
rcBuf.WriteString(BashDeclareStmt(varDecl)) rcBuf.WriteString(BashDeclareStmt(varDecl))
rcBuf.WriteString("\n") rcBuf.WriteString("\n")
} }

View File

@ -205,7 +205,7 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
declMap[decl.Name] = decl declMap[decl.Name] = decl
} }
} }
pvarMap := parsePVarOutput(pvarBytes, false) pvarMap := parseExtVarOutput(pvarBytes, "", "")
utilfn.CombineMaps(declMap, pvarMap) utilfn.CombineMaps(declMap, pvarMap)
state.ShellVars = shellenv.SerializeDeclMap(declMap) // this writes out the decls in a canonical order state.ShellVars = shellenv.SerializeDeclMap(declMap) // this writes out the decls in a canonical order
if firstParseErr != nil { if firstParseErr != nil {

View File

@ -25,10 +25,13 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
) )
const GetStateTimeout = 5 * time.Second const GetStateTimeout = 5 * time.Second
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"` const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
const RunCommandFmt = `%s` const RunCommandFmt = `%s`
const DebugState = false const DebugState = false
@ -239,7 +242,7 @@ func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
return outputBuf.Bytes(), nil return outputBuf.Bytes(), nil
} }
func parsePVarOutput(pvarBytes []byte, isZsh bool) map[string]*DeclareDeclType { func parseExtVarOutput(pvarBytes []byte, promptOutput string, zmodsOutput string) map[string]*DeclareDeclType {
declMap := make(map[string]*DeclareDeclType) declMap := make(map[string]*DeclareDeclType)
pvars := bytes.Split(pvarBytes, []byte{0}) pvars := bytes.Split(pvarBytes, []byte{0})
for _, pvarBA := range pvars { for _, pvarBA := range pvars {
@ -251,11 +254,35 @@ func parsePVarOutput(pvarBytes []byte, isZsh bool) map[string]*DeclareDeclType {
if pvarFields[0] == "" { if pvarFields[0] == "" {
continue continue
} }
decl := &DeclareDeclType{IsZshDecl: isZsh, Args: "x"} if pvarFields[1] == "" {
continue
}
decl := &DeclareDeclType{IsExtVar: true}
decl.Name = "PROMPTVAR_" + pvarFields[0] decl.Name = "PROMPTVAR_" + pvarFields[0]
decl.Value = shellescape.Quote(pvarFields[1]) decl.Value = shellescape.Quote(pvarFields[1])
declMap[decl.Name] = decl declMap[decl.Name] = decl
} }
if promptOutput != "" {
decl := &DeclareDeclType{IsExtVar: true}
decl.Name = "PROMPTVAR_PS1"
decl.Value = promptOutput
declMap[decl.Name] = decl
}
if zmodsOutput != "" {
var zmods []string
lines := strings.Split(zmodsOutput, "\n")
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) != 2 || fields[0] != "zmodload" {
continue
}
zmods = append(zmods, fields[1])
}
decl := &DeclareDeclType{IsExtVar: true}
decl.Name = ZModsVarName
decl.Value = utilfn.QuickJson(zmods)
declMap[decl.Name] = decl
}
return declMap return declMap
} }

View File

@ -30,6 +30,21 @@ const BaseZshOpts = ``
const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION` const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION`
const StateOutputFdNum = 20 const StateOutputFdNum = 20
const (
ZshSection_Version = iota
ZshSection_Cwd
ZshSection_Env
ZshSection_Mods
ZshSection_Vars
ZshSection_Aliases
ZshSection_Fpath
ZshSection_Funcs
ZshSection_PVars
ZshSection_Prompt
ZshSection_NumFieldsExpected // must be last
)
// TODO these need updating // TODO these need updating
const RunZshSudoCommandFmt = `sudo -n -C %d zsh /dev/fd/%d` const RunZshSudoCommandFmt = `sudo -n -C %d zsh /dev/fd/%d`
const RunZshSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d zsh -c "echo '[from-mshell]'; exec %d>&-; zsh /dev/fd/%d < /dev/fd/%d"` const RunZshSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d zsh -c "echo '[from-mshell]'; exec %d>&-; zsh /dev/fd/%d < /dev/fd/%d"`
@ -54,6 +69,7 @@ var ZshIgnoreVars = map[string]bool{
"SHLVL": true, "SHLVL": true,
"TTY": true, "TTY": true,
"ZDOTDIR": true, "ZDOTDIR": true,
"PPID": true,
"epochtime": true, "epochtime": true,
"langinfo": true, "langinfo": true,
"keymaps": true, "keymaps": true,
@ -77,6 +93,8 @@ var ZshIgnoreVars = map[string]bool{
"funcsourcetrace": true, "funcsourcetrace": true,
"funcstack": true, "funcstack": true,
"functrace": true, "functrace": true,
"nameddirs": true,
"userdirs": true,
"parameters": true, "parameters": true,
"commands": true, "commands": true,
"functions": true, "functions": true,
@ -86,6 +104,25 @@ var ZshIgnoreVars = map[string]bool{
"_comps": true, "_comps": true,
"_patcomps": true, "_patcomps": true,
"_postpatcomps": true, "_postpatcomps": true,
// zsh/system
"errnos": true,
"sysparams": true,
// zsh/curses
"ZCURSES_COLORS": true,
"ZCURSES_COLOR_PAIRS": true,
"zcurses_attrs": true,
"zcurses_colors": true,
"zcurses_keycodes": true,
"zcurses_windows": true,
// not listed, but we also exclude all ZFTP_* variables
}
var ZshIgnoreFuncs = map[string]bool{
"zftp_chpwd": true,
"zftp_progress": true,
} }
// only options we restore (other than ZshForceOptions) // only options we restore (other than ZshForceOptions)
@ -131,11 +168,13 @@ var ZshUnsetVars = []string{
"ZSH_EXECUTION_STRING", "ZSH_EXECUTION_STRING",
} }
var ZshLoadMods = []string{ var ZshForceLoadMods = map[string]bool{
"zsh/parameter", "zsh/parameter": true,
"zsh/langinfo", "zsh/langinfo": true,
} }
const ZModsVarName = "WAVESTATE_ZMODS"
// do not use these directly, call GetLocalMajorVersion() // do not use these directly, call GetLocalMajorVersion()
var localZshMajorVersionOnce = &sync.Once{} var localZshMajorVersionOnce = &sync.Once{}
var localZshMajorVersion = "" var localZshMajorVersion = ""
@ -273,14 +312,29 @@ func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
rcBuf.WriteString(fmt.Sprintf("unsetopt %s\n", optName)) rcBuf.WriteString(fmt.Sprintf("unsetopt %s\n", optName))
} }
} }
for _, modName := range ZshLoadMods { for modName := range ZshForceLoadMods {
rcBuf.WriteString(fmt.Sprintf("zmodload %s\n", modName)) rcBuf.WriteString(fmt.Sprintf("zmodload %s\n", modName))
} }
modDecl := getDeclByName(varDecls, ZModsVarName)
if modDecl != nil {
modsArr := utilfn.QuickParseJson[[]string](modDecl.Value)
for _, modName := range modsArr {
if !ZshForceLoadMods[modName] {
rcBuf.WriteString(fmt.Sprintf("zmodload %s\n", modName))
}
}
}
var postDecls []*shellenv.DeclareDeclType var postDecls []*shellenv.DeclareDeclType
for _, varDecl := range varDecls { for _, varDecl := range varDecls {
if ZshIgnoreVars[varDecl.Name] { if ZshIgnoreVars[varDecl.Name] {
continue continue
} }
if strings.HasPrefix(varDecl.Name, "ZFTP_") {
continue
}
if varDecl.IsExtVar {
continue
}
if ZshUniqueArrayVars[varDecl.Name] && !varDecl.IsUniqueArray() { if ZshUniqueArrayVars[varDecl.Name] && !varDecl.IsUniqueArray() {
varDecl.AddFlag("U") varDecl.AddFlag("U")
} }
@ -332,6 +386,9 @@ func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
rcBuf.WriteString("# error decoding zsh functions\n") rcBuf.WriteString("# error decoding zsh functions\n")
} else { } else {
for fnKey, fnValue := range fnMap { for fnKey, fnValue := range fnMap {
if ZshIgnoreFuncs[fnKey.ParamName] {
continue
}
if fnValue == ZshFnAutoLoad { if fnValue == ZshFnAutoLoad {
rcBuf.WriteString(fmt.Sprintf("autoload %s\n", shellescape.Quote(fnKey.ParamName))) rcBuf.WriteString(fmt.Sprintf("autoload %s\n", shellescape.Quote(fnKey.ParamName)))
} else { } else {
@ -411,6 +468,8 @@ pwd;
printf "[%SECTIONSEP%]"; printf "[%SECTIONSEP%]";
env -0; env -0;
printf "[%SECTIONSEP%]"; printf "[%SECTIONSEP%]";
zmodload -L
printf "[%SECTIONSEP%]";
typeset -p +H -m '*'; typeset -p +H -m '*';
printf "[%SECTIONSEP%]"; printf "[%SECTIONSEP%]";
for var in "${(@k)aliases}"; do for var in "${(@k)aliases}"; do
@ -448,10 +507,16 @@ for var in "${(@k)dis_functions_source}"; do
done done
printf "[%SECTIONSEP%]"; printf "[%SECTIONSEP%]";
[%GITBRANCH%] [%GITBRANCH%]
[%K8SCONTEXT%]
[%K8SNAMESPACE%]
printf "[%SECTIONSEP%]";
print -P "$PS1"
` `
cmd = strings.TrimSpace(cmd) cmd = strings.TrimSpace(cmd)
cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr) cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr)
cmd = strings.ReplaceAll(cmd, "[%GITBRANCH%]", GetGitBranchCmdStr) cmd = strings.ReplaceAll(cmd, "[%GITBRANCH%]", GetGitBranchCmdStr)
cmd = strings.ReplaceAll(cmd, "[%K8SCONTEXT%]", GetK8sContextCmdStr)
cmd = strings.ReplaceAll(cmd, "[%K8SNAMESPACE%]", GetK8sNamespaceCmdStr)
cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1]))) cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1])))
cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator))) cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator)))
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum)) cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
@ -599,6 +664,9 @@ func ParseZshFunctions(fpathArr []string, fnBytes []byte, partSeparator []byte)
if fnName == "zshexit" { if fnName == "zshexit" {
continue continue
} }
if ZshIgnoreFuncs[fnName] {
continue
}
if fnType == "functions" || fnType == "dis_functions" { if fnType == "functions" || fnType == "dis_functions" {
fnBody[ZshParamKey{ParamType: fnType, ParamName: fnName}] = fnValue fnBody[ZshParamKey{ParamType: fnType, ParamName: fnName}] = fnValue
} }
@ -609,10 +677,13 @@ func ParseZshFunctions(fpathArr []string, fnBytes []byte, partSeparator []byte)
// ok, so the trick here is that we want to only include functions that are *not* autoloaded // ok, so the trick here is that we want to only include functions that are *not* autoloaded
// the ones that are pending autoloading or come from a source file in fpath, can just be set to autoload // the ones that are pending autoloading or come from a source file in fpath, can just be set to autoload
for fnKey := range fnBody { for fnKey := range fnBody {
var inFpath bool
source := fnSource[fnKey.ParamName] source := fnSource[fnKey.ParamName]
if isSourceFileInFpath(fpathArr, source) { if source != "" {
fnBody[fnKey] = ZshFnAutoLoad inFpath = isSourceFileInFpath(fpathArr, source)
} else if strings.TrimSpace(fnBody[fnKey]) == ZshAutoloadFnBody { }
isAutoloadFnBody := strings.TrimSpace(fnBody[fnKey]) == ZshAutoloadFnBody
if inFpath || isAutoloadFnBody {
fnBody[fnKey] = ZshFnAutoLoad fnBody[fnKey] = ZshFnAutoLoad
} }
} }
@ -639,11 +710,11 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
versionStr := string(outputBytes[0:firstZeroIdx]) versionStr := string(outputBytes[0:firstZeroIdx])
sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2] sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2]
partSeparator := sectionSeparator[0 : len(sectionSeparator)-1] partSeparator := sectionSeparator[0 : len(sectionSeparator)-1]
// 8 fields: version [0], cwd [1], env [2], vars [3], aliases [4], fpath [5], functions [6], pvars [7] // sections: see ZshSection_* consts
fields := bytes.Split(outputBytes, sectionSeparator) sections := bytes.Split(outputBytes, sectionSeparator)
if len(fields) != 8 { if len(sections) != ZshSection_NumFieldsExpected {
base.Logf("invalid -- numfields\n") base.Logf("invalid -- numfields\n")
return nil, fmt.Errorf("invalid zsh shell state output, wrong number of fields, fields=%d", len(fields)) return nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
} }
rtn := &packet.ShellState{} rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(versionStr) rtn.Version = strings.TrimSpace(versionStr)
@ -653,10 +724,10 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil { if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err) return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
} }
cwdStr := stripNewLineChars(string(fields[1])) cwdStr := stripNewLineChars(string(sections[ZshSection_Cwd]))
rtn.Cwd = cwdStr rtn.Cwd = cwdStr
zshEnv := parseZshEnv(fields[2]) zshEnv := parseZshEnv(sections[ZshSection_Env])
zshDecls, err := parseZshDecls(fields[3]) zshDecls, err := parseZshDecls(sections[ZshSection_Vars])
if err != nil { if err != nil {
base.Logf("invalid - parsedecls %v\n", err) base.Logf("invalid - parsedecls %v\n", err)
return nil, err return nil, err
@ -666,16 +737,15 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
decl.ZshEnvValue = zshEnv[decl.ZshBoundScalar] decl.ZshEnvValue = zshEnv[decl.ZshBoundScalar]
} }
} }
aliasMap := parseZshAliasStateOutput(fields[4], partSeparator) aliasMap := parseZshAliasStateOutput(sections[ZshSection_Aliases], partSeparator)
rtn.Aliases = string(EncodeZshMap(aliasMap)) rtn.Aliases = string(EncodeZshMap(aliasMap))
fpathStr := stripNewLineChars(string(string(fields[5]))) fpathStr := stripNewLineChars(string(string(sections[ZshSection_Fpath])))
fpathArr := strings.Split(fpathStr, ":") fpathArr := strings.Split(fpathStr, ":")
zshFuncs := ParseZshFunctions(fpathArr, fields[6], partSeparator) zshFuncs := ParseZshFunctions(fpathArr, sections[ZshSection_Funcs], partSeparator)
rtn.Funcs = string(EncodeZshMap(zshFuncs)) rtn.Funcs = string(EncodeZshMap(zshFuncs))
pvarMap := parsePVarOutput(fields[7], true) pvarMap := parseExtVarOutput(sections[ZshSection_PVars], string(sections[ZshSection_Prompt]), string(sections[ZshSection_Mods]))
utilfn.CombineMaps(zshDecls, pvarMap) utilfn.CombineMaps(zshDecls, pvarMap)
rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls) rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls)
base.Logf("parse shellstate done\n")
return rtn, nil return rtn, nil
} }

View File

@ -22,6 +22,7 @@ const (
type DeclareDeclType struct { type DeclareDeclType struct {
IsZshDecl bool IsZshDecl bool
IsExtVar bool // set for "special" wave internal variables
Args string Args string
Name string Name string
@ -36,31 +37,31 @@ type DeclareDeclType struct {
} }
func (d *DeclareDeclType) IsExport() bool { func (d *DeclareDeclType) IsExport() bool {
return strings.Index(d.Args, "x") >= 0 return strings.Contains(d.Args, "x")
} }
func (d *DeclareDeclType) IsReadOnly() bool { func (d *DeclareDeclType) IsReadOnly() bool {
return strings.Index(d.Args, "r") >= 0 return strings.Contains(d.Args, "r")
} }
func (d *DeclareDeclType) IsZshScalarBound() bool { func (d *DeclareDeclType) IsZshScalarBound() bool {
return strings.Index(d.Args, "T") >= 0 return strings.Contains(d.Args, "T")
} }
func (d *DeclareDeclType) IsArray() bool { func (d *DeclareDeclType) IsArray() bool {
return strings.Index(d.Args, "a") >= 0 return strings.Contains(d.Args, "a")
} }
func (d *DeclareDeclType) IsAssocArray() bool { func (d *DeclareDeclType) IsAssocArray() bool {
return strings.Index(d.Args, "A") >= 0 return strings.Contains(d.Args, "A")
} }
func (d *DeclareDeclType) IsUniqueArray() bool { func (d *DeclareDeclType) IsUniqueArray() bool {
return d.IsArray() && strings.Index(d.Args, "U") >= 0 return d.IsArray() && strings.Contains(d.Args, "U")
} }
func (d *DeclareDeclType) AddFlag(flag string) { func (d *DeclareDeclType) AddFlag(flag string) {
if strings.Index(d.Args, flag) >= 0 { if strings.Contains(d.Args, flag) {
return return
} }
d.Args += flag d.Args += flag
@ -101,13 +102,13 @@ func (d *DeclareDeclType) SortZshFlags() {
} }
func (d *DeclareDeclType) DataType() string { func (d *DeclareDeclType) DataType() string {
if strings.Index(d.Args, "a") >= 0 { if strings.Contains(d.Args, "a") {
return DeclTypeArray return DeclTypeArray
} }
if strings.Index(d.Args, "A") >= 0 { if strings.Contains(d.Args, "A") {
return DeclTypeAssocArray return DeclTypeAssocArray
} }
if strings.Index(d.Args, "i") >= 0 { if strings.Contains(d.Args, "i") {
return DeclTypeInt return DeclTypeInt
} }
return DeclTypeNormal return DeclTypeNormal
@ -124,7 +125,15 @@ func FindVarDecl(decls []*DeclareDeclType, name string) *DeclareDeclType {
// NOTE Serialize no longer writes the final null byte // NOTE Serialize no longer writes the final null byte
func (d *DeclareDeclType) Serialize() []byte { func (d *DeclareDeclType) Serialize() []byte {
if d.IsZshDecl { if d.IsExtVar {
parts := []string{
"e1",
d.Args,
d.Name,
d.Value,
}
return utilfn.EncodeStringArray(parts)
} else if d.IsZshDecl {
d.SortZshFlags() d.SortZshFlags()
parts := []string{ parts := []string{
"z1", "z1",
@ -149,6 +158,15 @@ func (d *DeclareDeclType) Serialize() []byte {
// return []byte(rtn) // return []byte(rtn)
} }
func (d *DeclareDeclType) UnescapedValue() string {
if d.IsExtVar {
return d.Value
}
ectx := simpleexpand.SimpleExpandContext{}
rtn, _ := simpleexpand.SimpleExpandPartialWord(ectx, d.Value, false)
return rtn
}
func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool { func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool {
if d1.IsExport() != d2.IsExport() { if d1.IsExport() != d2.IsExport() {
return false return false
@ -164,7 +182,8 @@ func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool
// envline should be valid // envline should be valid
func parseDeclLine(envLineBytes []byte) *DeclareDeclType { func parseDeclLine(envLineBytes []byte) *DeclareDeclType {
if utilfn.EncodedStringArrayHasFirstKey(envLineBytes, "z1") { esFirstVal := utilfn.EncodedStringArrayGetFirstVal(envLineBytes)
if esFirstVal == "z1" {
parts, err := utilfn.DecodeStringArray(envLineBytes) parts, err := utilfn.DecodeStringArray(envLineBytes)
if err != nil { if err != nil {
return nil return nil
@ -180,7 +199,7 @@ func parseDeclLine(envLineBytes []byte) *DeclareDeclType {
ZshBoundScalar: parts[4], ZshBoundScalar: parts[4],
ZshEnvValue: parts[5], ZshEnvValue: parts[5],
} }
} else if utilfn.EncodedStringArrayHasFirstKey(envLineBytes, "b1") { } else if esFirstVal == "b1" {
parts, err := utilfn.DecodeStringArray(envLineBytes) parts, err := utilfn.DecodeStringArray(envLineBytes)
if err != nil { if err != nil {
return nil return nil
@ -193,8 +212,25 @@ func parseDeclLine(envLineBytes []byte) *DeclareDeclType {
Name: parts[2], Name: parts[2],
Value: parts[3], Value: parts[3],
} }
} else if esFirstVal == "e1" {
parts, err := utilfn.DecodeStringArray(envLineBytes)
if err != nil {
return nil
}
if len(parts) != 4 {
return nil
}
return &DeclareDeclType{
IsExtVar: true,
Args: parts[1],
Name: parts[2],
Value: parts[3],
}
} else if esFirstVal == "p1" {
// deprecated
return nil
} }
// legacy decoding (v0) // legacy decoding (v0) (not an encoded string array)
envLine := string(envLineBytes) envLine := string(envLineBytes)
eqIdx := strings.Index(envLine, "=") eqIdx := strings.Index(envLine, "=")
if eqIdx == -1 { if eqIdx == -1 {

View File

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"math" "math"
@ -400,7 +401,7 @@ func DecodeStringArray(barr []byte) ([]string, error) {
return rtn, nil return rtn, nil
} }
func EncodedStringArrayHasFirstKey(encoded []byte, firstKey string) bool { func EncodedStringArrayHasFirstVal(encoded []byte, firstKey string) bool {
firstKeyBytes := NullEncodeStr(firstKey) firstKeyBytes := NullEncodeStr(firstKey)
if !bytes.HasPrefix(encoded, firstKeyBytes) { if !bytes.HasPrefix(encoded, firstKeyBytes) {
return false return false
@ -411,6 +412,18 @@ func EncodedStringArrayHasFirstKey(encoded []byte, firstKey string) bool {
return false return false
} }
// on encoding error returns ""
// this is used to perform logic on first value without decoding the entire array
func EncodedStringArrayGetFirstVal(encoded []byte) string {
sepIdx := bytes.IndexByte(encoded, nullEncodeSepByte)
if sepIdx == -1 {
str, _ := NullDecodeStr(encoded)
return str
}
str, _ := NullDecodeStr(encoded[0:sepIdx])
return str
}
// encodes a string, removing null/zero bytes (and separators '|') // encodes a string, removing null/zero bytes (and separators '|')
// a zero byte is encoded as "\0", a '\' is encoded as "\\", sep is encoded as "\s" // a zero byte is encoded as "\0", a '\' is encoded as "\\", sep is encoded as "\s"
// allows for easy double splitting (first on \x00, and next on "|") // allows for easy double splitting (first on \x00, and next on "|")
@ -520,3 +533,22 @@ func CombineStrArrays(sarr1 []string, sarr2 []string) []string {
} }
return rtn return rtn
} }
func QuickJson(v interface{}) string {
barr, _ := json.Marshal(v)
return string(barr)
}
func QuickParseJson[T any](s string) T {
var v T
_ = json.Unmarshal([]byte(s), &v)
return v
}
func StrArrayToMap(sarr []string) map[string]bool {
m := make(map[string]bool)
for _, s := range sarr {
m[s] = true
}
return m
}

View File

@ -214,6 +214,7 @@ var literalRtnStateCommands = []string{
"enable", "enable",
"disable", "disable",
"function", "function",
"zmodload",
} }
func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string { func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string {

View File

@ -751,16 +751,17 @@ func FeStateFromShellState(state *packet.ShellState) map[string]string {
} }
rtn := make(map[string]string) rtn := make(map[string]string)
rtn["cwd"] = state.Cwd rtn["cwd"] = state.Cwd
envMap := shellenv.EnvMapFromState(state) declMap := shellenv.DeclMapFromState(state)
if envMap["VIRTUAL_ENV"] != "" { if decl, ok := declMap["VIRTUAL_ENV"]; ok {
rtn["VIRTUAL_ENV"] = envMap["VIRTUAL_ENV"] rtn["VIRTUAL_ENV"] = decl.UnescapedValue()
} }
if envMap["CONDA_DEFAULT_ENV"] != "" { if decl, ok := declMap["CONDA_DEFAULT_ENV"]; ok {
rtn["CONDA_DEFAULT_ENV"] = envMap["CONDA_DEFAULT_ENV"] rtn["CONDA_DEFAULT_ENV"] = decl.UnescapedValue()
} }
for key, val := range envMap { for _, decl := range declMap {
if strings.HasPrefix(key, "PROMPTVAR_") && envMap[key] != "" { // works for both legacy and new IsExtVar decls
rtn[key] = val if strings.HasPrefix(decl.Name, "PROMPTVAR_") {
rtn[decl.Name] = decl.UnescapedValue()
} }
} }
_, _, err := packet.ParseShellStateVersion(state.Version) _, _, err := packet.ParseShellStateVersion(state.Version)