zsh fixes (zmodload, k8s, new var type, track zsh prompt) (#473)

* checkpoint.  track zmodloads, get the expanded PS1 value, change how we store and deal with 'prompt vars', add commands to track k8s context

* fix some warnings, change IsPVar to IsExtVar, checkpoint

* parse zmodloads into state

* restore zmodloads, be more careful around readonly variable setting to avoid errors
This commit is contained in:
Mike Sawka 2024-03-18 22:51:16 -07:00 committed by GitHub
parent d2f5d87194
commit 901dcccaa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 215 additions and 45 deletions

View File

@ -113,6 +113,9 @@ func (b bashShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
if varDecl.IsExport() || varDecl.IsReadOnly() {
continue
}
if varDecl.IsExtVar {
continue
}
rcBuf.WriteString(BashDeclareStmt(varDecl))
rcBuf.WriteString("\n")
}

View File

@ -205,7 +205,7 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
declMap[decl.Name] = decl
}
}
pvarMap := parsePVarOutput(pvarBytes, false)
pvarMap := parseExtVarOutput(pvarBytes, "", "")
utilfn.CombineMaps(declMap, pvarMap)
state.ShellVars = shellenv.SerializeDeclMap(declMap) // this writes out the decls in a canonical order
if firstParseErr != nil {

View File

@ -25,10 +25,13 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
)
const GetStateTimeout = 5 * time.Second
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
const RunCommandFmt = `%s`
const DebugState = false
@ -239,7 +242,7 @@ func RunSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
return outputBuf.Bytes(), nil
}
func parsePVarOutput(pvarBytes []byte, isZsh bool) map[string]*DeclareDeclType {
func parseExtVarOutput(pvarBytes []byte, promptOutput string, zmodsOutput string) map[string]*DeclareDeclType {
declMap := make(map[string]*DeclareDeclType)
pvars := bytes.Split(pvarBytes, []byte{0})
for _, pvarBA := range pvars {
@ -251,11 +254,35 @@ func parsePVarOutput(pvarBytes []byte, isZsh bool) map[string]*DeclareDeclType {
if pvarFields[0] == "" {
continue
}
decl := &DeclareDeclType{IsZshDecl: isZsh, Args: "x"}
if pvarFields[1] == "" {
continue
}
decl := &DeclareDeclType{IsExtVar: true}
decl.Name = "PROMPTVAR_" + pvarFields[0]
decl.Value = shellescape.Quote(pvarFields[1])
declMap[decl.Name] = decl
}
if promptOutput != "" {
decl := &DeclareDeclType{IsExtVar: true}
decl.Name = "PROMPTVAR_PS1"
decl.Value = promptOutput
declMap[decl.Name] = decl
}
if zmodsOutput != "" {
var zmods []string
lines := strings.Split(zmodsOutput, "\n")
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) != 2 || fields[0] != "zmodload" {
continue
}
zmods = append(zmods, fields[1])
}
decl := &DeclareDeclType{IsExtVar: true}
decl.Name = ZModsVarName
decl.Value = utilfn.QuickJson(zmods)
declMap[decl.Name] = decl
}
return declMap
}

View File

@ -30,6 +30,21 @@ const BaseZshOpts = ``
const ZshShellVersionCmdStr = `echo zsh v$ZSH_VERSION`
const StateOutputFdNum = 20
const (
ZshSection_Version = iota
ZshSection_Cwd
ZshSection_Env
ZshSection_Mods
ZshSection_Vars
ZshSection_Aliases
ZshSection_Fpath
ZshSection_Funcs
ZshSection_PVars
ZshSection_Prompt
ZshSection_NumFieldsExpected // must be last
)
// TODO these need updating
const RunZshSudoCommandFmt = `sudo -n -C %d zsh /dev/fd/%d`
const RunZshSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d zsh -c "echo '[from-mshell]'; exec %d>&-; zsh /dev/fd/%d < /dev/fd/%d"`
@ -54,6 +69,7 @@ var ZshIgnoreVars = map[string]bool{
"SHLVL": true,
"TTY": true,
"ZDOTDIR": true,
"PPID": true,
"epochtime": true,
"langinfo": true,
"keymaps": true,
@ -77,6 +93,8 @@ var ZshIgnoreVars = map[string]bool{
"funcsourcetrace": true,
"funcstack": true,
"functrace": true,
"nameddirs": true,
"userdirs": true,
"parameters": true,
"commands": true,
"functions": true,
@ -86,6 +104,25 @@ var ZshIgnoreVars = map[string]bool{
"_comps": true,
"_patcomps": true,
"_postpatcomps": true,
// zsh/system
"errnos": true,
"sysparams": true,
// zsh/curses
"ZCURSES_COLORS": true,
"ZCURSES_COLOR_PAIRS": true,
"zcurses_attrs": true,
"zcurses_colors": true,
"zcurses_keycodes": true,
"zcurses_windows": true,
// not listed, but we also exclude all ZFTP_* variables
}
var ZshIgnoreFuncs = map[string]bool{
"zftp_chpwd": true,
"zftp_progress": true,
}
// only options we restore (other than ZshForceOptions)
@ -131,11 +168,13 @@ var ZshUnsetVars = []string{
"ZSH_EXECUTION_STRING",
}
var ZshLoadMods = []string{
"zsh/parameter",
"zsh/langinfo",
var ZshForceLoadMods = map[string]bool{
"zsh/parameter": true,
"zsh/langinfo": true,
}
const ZModsVarName = "WAVESTATE_ZMODS"
// do not use these directly, call GetLocalMajorVersion()
var localZshMajorVersionOnce = &sync.Once{}
var localZshMajorVersion = ""
@ -273,14 +312,29 @@ func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
rcBuf.WriteString(fmt.Sprintf("unsetopt %s\n", optName))
}
}
for _, modName := range ZshLoadMods {
for modName := range ZshForceLoadMods {
rcBuf.WriteString(fmt.Sprintf("zmodload %s\n", modName))
}
modDecl := getDeclByName(varDecls, ZModsVarName)
if modDecl != nil {
modsArr := utilfn.QuickParseJson[[]string](modDecl.Value)
for _, modName := range modsArr {
if !ZshForceLoadMods[modName] {
rcBuf.WriteString(fmt.Sprintf("zmodload %s\n", modName))
}
}
}
var postDecls []*shellenv.DeclareDeclType
for _, varDecl := range varDecls {
if ZshIgnoreVars[varDecl.Name] {
continue
}
if strings.HasPrefix(varDecl.Name, "ZFTP_") {
continue
}
if varDecl.IsExtVar {
continue
}
if ZshUniqueArrayVars[varDecl.Name] && !varDecl.IsUniqueArray() {
varDecl.AddFlag("U")
}
@ -332,6 +386,9 @@ func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
rcBuf.WriteString("# error decoding zsh functions\n")
} else {
for fnKey, fnValue := range fnMap {
if ZshIgnoreFuncs[fnKey.ParamName] {
continue
}
if fnValue == ZshFnAutoLoad {
rcBuf.WriteString(fmt.Sprintf("autoload %s\n", shellescape.Quote(fnKey.ParamName)))
} else {
@ -411,6 +468,8 @@ pwd;
printf "[%SECTIONSEP%]";
env -0;
printf "[%SECTIONSEP%]";
zmodload -L
printf "[%SECTIONSEP%]";
typeset -p +H -m '*';
printf "[%SECTIONSEP%]";
for var in "${(@k)aliases}"; do
@ -448,10 +507,16 @@ for var in "${(@k)dis_functions_source}"; do
done
printf "[%SECTIONSEP%]";
[%GITBRANCH%]
[%K8SCONTEXT%]
[%K8SNAMESPACE%]
printf "[%SECTIONSEP%]";
print -P "$PS1"
`
cmd = strings.TrimSpace(cmd)
cmd = strings.ReplaceAll(cmd, "[%ZSHVERSION%]", ZshShellVersionCmdStr)
cmd = strings.ReplaceAll(cmd, "[%GITBRANCH%]", GetGitBranchCmdStr)
cmd = strings.ReplaceAll(cmd, "[%K8SCONTEXT%]", GetK8sContextCmdStr)
cmd = strings.ReplaceAll(cmd, "[%K8SNAMESPACE%]", GetK8sNamespaceCmdStr)
cmd = strings.ReplaceAll(cmd, "[%PARTSEP%]", utilfn.ShellHexEscape(string(sectionSeparator[0:len(sectionSeparator)-1])))
cmd = strings.ReplaceAll(cmd, "[%SECTIONSEP%]", utilfn.ShellHexEscape(string(sectionSeparator)))
cmd = strings.ReplaceAll(cmd, "[%OUTPUTFD%]", fmt.Sprintf("/dev/fd/%d", fdNum))
@ -599,6 +664,9 @@ func ParseZshFunctions(fpathArr []string, fnBytes []byte, partSeparator []byte)
if fnName == "zshexit" {
continue
}
if ZshIgnoreFuncs[fnName] {
continue
}
if fnType == "functions" || fnType == "dis_functions" {
fnBody[ZshParamKey{ParamType: fnType, ParamName: fnName}] = fnValue
}
@ -609,10 +677,13 @@ func ParseZshFunctions(fpathArr []string, fnBytes []byte, partSeparator []byte)
// ok, so the trick here is that we want to only include functions that are *not* autoloaded
// the ones that are pending autoloading or come from a source file in fpath, can just be set to autoload
for fnKey := range fnBody {
var inFpath bool
source := fnSource[fnKey.ParamName]
if isSourceFileInFpath(fpathArr, source) {
fnBody[fnKey] = ZshFnAutoLoad
} else if strings.TrimSpace(fnBody[fnKey]) == ZshAutoloadFnBody {
if source != "" {
inFpath = isSourceFileInFpath(fpathArr, source)
}
isAutoloadFnBody := strings.TrimSpace(fnBody[fnKey]) == ZshAutoloadFnBody
if inFpath || isAutoloadFnBody {
fnBody[fnKey] = ZshFnAutoLoad
}
}
@ -639,11 +710,11 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
versionStr := string(outputBytes[0:firstZeroIdx])
sectionSeparator := outputBytes[firstZeroIdx+1 : firstDZeroIdx+2]
partSeparator := sectionSeparator[0 : len(sectionSeparator)-1]
// 8 fields: version [0], cwd [1], env [2], vars [3], aliases [4], fpath [5], functions [6], pvars [7]
fields := bytes.Split(outputBytes, sectionSeparator)
if len(fields) != 8 {
// sections: see ZshSection_* consts
sections := bytes.Split(outputBytes, sectionSeparator)
if len(sections) != ZshSection_NumFieldsExpected {
base.Logf("invalid -- numfields\n")
return nil, fmt.Errorf("invalid zsh shell state output, wrong number of fields, fields=%d", len(fields))
return nil, fmt.Errorf("invalid zsh shell state output, wrong number of sections, section=%d", len(sections))
}
rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(versionStr)
@ -653,10 +724,10 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
if _, _, err := packet.ParseShellStateVersion(rtn.Version); err != nil {
return nil, fmt.Errorf("invalid zsh shell state output, invalid version: %v", err)
}
cwdStr := stripNewLineChars(string(fields[1]))
cwdStr := stripNewLineChars(string(sections[ZshSection_Cwd]))
rtn.Cwd = cwdStr
zshEnv := parseZshEnv(fields[2])
zshDecls, err := parseZshDecls(fields[3])
zshEnv := parseZshEnv(sections[ZshSection_Env])
zshDecls, err := parseZshDecls(sections[ZshSection_Vars])
if err != nil {
base.Logf("invalid - parsedecls %v\n", err)
return nil, err
@ -666,16 +737,15 @@ func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellSta
decl.ZshEnvValue = zshEnv[decl.ZshBoundScalar]
}
}
aliasMap := parseZshAliasStateOutput(fields[4], partSeparator)
aliasMap := parseZshAliasStateOutput(sections[ZshSection_Aliases], partSeparator)
rtn.Aliases = string(EncodeZshMap(aliasMap))
fpathStr := stripNewLineChars(string(string(fields[5])))
fpathStr := stripNewLineChars(string(string(sections[ZshSection_Fpath])))
fpathArr := strings.Split(fpathStr, ":")
zshFuncs := ParseZshFunctions(fpathArr, fields[6], partSeparator)
zshFuncs := ParseZshFunctions(fpathArr, sections[ZshSection_Funcs], partSeparator)
rtn.Funcs = string(EncodeZshMap(zshFuncs))
pvarMap := parsePVarOutput(fields[7], true)
pvarMap := parseExtVarOutput(sections[ZshSection_PVars], string(sections[ZshSection_Prompt]), string(sections[ZshSection_Mods]))
utilfn.CombineMaps(zshDecls, pvarMap)
rtn.ShellVars = shellenv.SerializeDeclMap(zshDecls)
base.Logf("parse shellstate done\n")
return rtn, nil
}

View File

@ -22,6 +22,7 @@ const (
type DeclareDeclType struct {
IsZshDecl bool
IsExtVar bool // set for "special" wave internal variables
Args string
Name string
@ -36,31 +37,31 @@ type DeclareDeclType struct {
}
func (d *DeclareDeclType) IsExport() bool {
return strings.Index(d.Args, "x") >= 0
return strings.Contains(d.Args, "x")
}
func (d *DeclareDeclType) IsReadOnly() bool {
return strings.Index(d.Args, "r") >= 0
return strings.Contains(d.Args, "r")
}
func (d *DeclareDeclType) IsZshScalarBound() bool {
return strings.Index(d.Args, "T") >= 0
return strings.Contains(d.Args, "T")
}
func (d *DeclareDeclType) IsArray() bool {
return strings.Index(d.Args, "a") >= 0
return strings.Contains(d.Args, "a")
}
func (d *DeclareDeclType) IsAssocArray() bool {
return strings.Index(d.Args, "A") >= 0
return strings.Contains(d.Args, "A")
}
func (d *DeclareDeclType) IsUniqueArray() bool {
return d.IsArray() && strings.Index(d.Args, "U") >= 0
return d.IsArray() && strings.Contains(d.Args, "U")
}
func (d *DeclareDeclType) AddFlag(flag string) {
if strings.Index(d.Args, flag) >= 0 {
if strings.Contains(d.Args, flag) {
return
}
d.Args += flag
@ -101,13 +102,13 @@ func (d *DeclareDeclType) SortZshFlags() {
}
func (d *DeclareDeclType) DataType() string {
if strings.Index(d.Args, "a") >= 0 {
if strings.Contains(d.Args, "a") {
return DeclTypeArray
}
if strings.Index(d.Args, "A") >= 0 {
if strings.Contains(d.Args, "A") {
return DeclTypeAssocArray
}
if strings.Index(d.Args, "i") >= 0 {
if strings.Contains(d.Args, "i") {
return DeclTypeInt
}
return DeclTypeNormal
@ -124,7 +125,15 @@ func FindVarDecl(decls []*DeclareDeclType, name string) *DeclareDeclType {
// NOTE Serialize no longer writes the final null byte
func (d *DeclareDeclType) Serialize() []byte {
if d.IsZshDecl {
if d.IsExtVar {
parts := []string{
"e1",
d.Args,
d.Name,
d.Value,
}
return utilfn.EncodeStringArray(parts)
} else if d.IsZshDecl {
d.SortZshFlags()
parts := []string{
"z1",
@ -149,6 +158,15 @@ func (d *DeclareDeclType) Serialize() []byte {
// return []byte(rtn)
}
func (d *DeclareDeclType) UnescapedValue() string {
if d.IsExtVar {
return d.Value
}
ectx := simpleexpand.SimpleExpandContext{}
rtn, _ := simpleexpand.SimpleExpandPartialWord(ectx, d.Value, false)
return rtn
}
func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool {
if d1.IsExport() != d2.IsExport() {
return false
@ -164,7 +182,8 @@ func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool
// envline should be valid
func parseDeclLine(envLineBytes []byte) *DeclareDeclType {
if utilfn.EncodedStringArrayHasFirstKey(envLineBytes, "z1") {
esFirstVal := utilfn.EncodedStringArrayGetFirstVal(envLineBytes)
if esFirstVal == "z1" {
parts, err := utilfn.DecodeStringArray(envLineBytes)
if err != nil {
return nil
@ -180,7 +199,7 @@ func parseDeclLine(envLineBytes []byte) *DeclareDeclType {
ZshBoundScalar: parts[4],
ZshEnvValue: parts[5],
}
} else if utilfn.EncodedStringArrayHasFirstKey(envLineBytes, "b1") {
} else if esFirstVal == "b1" {
parts, err := utilfn.DecodeStringArray(envLineBytes)
if err != nil {
return nil
@ -193,8 +212,25 @@ func parseDeclLine(envLineBytes []byte) *DeclareDeclType {
Name: parts[2],
Value: parts[3],
}
} else if esFirstVal == "e1" {
parts, err := utilfn.DecodeStringArray(envLineBytes)
if err != nil {
return nil
}
// legacy decoding (v0)
if len(parts) != 4 {
return nil
}
return &DeclareDeclType{
IsExtVar: true,
Args: parts[1],
Name: parts[2],
Value: parts[3],
}
} else if esFirstVal == "p1" {
// deprecated
return nil
}
// legacy decoding (v0) (not an encoded string array)
envLine := string(envLineBytes)
eqIdx := strings.Index(envLine, "=")
if eqIdx == -1 {

View File

@ -7,6 +7,7 @@ import (
"bytes"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math"
@ -400,7 +401,7 @@ func DecodeStringArray(barr []byte) ([]string, error) {
return rtn, nil
}
func EncodedStringArrayHasFirstKey(encoded []byte, firstKey string) bool {
func EncodedStringArrayHasFirstVal(encoded []byte, firstKey string) bool {
firstKeyBytes := NullEncodeStr(firstKey)
if !bytes.HasPrefix(encoded, firstKeyBytes) {
return false
@ -411,6 +412,18 @@ func EncodedStringArrayHasFirstKey(encoded []byte, firstKey string) bool {
return false
}
// on encoding error returns ""
// this is used to perform logic on first value without decoding the entire array
func EncodedStringArrayGetFirstVal(encoded []byte) string {
sepIdx := bytes.IndexByte(encoded, nullEncodeSepByte)
if sepIdx == -1 {
str, _ := NullDecodeStr(encoded)
return str
}
str, _ := NullDecodeStr(encoded[0:sepIdx])
return str
}
// encodes a string, removing null/zero bytes (and separators '|')
// a zero byte is encoded as "\0", a '\' is encoded as "\\", sep is encoded as "\s"
// allows for easy double splitting (first on \x00, and next on "|")
@ -520,3 +533,22 @@ func CombineStrArrays(sarr1 []string, sarr2 []string) []string {
}
return rtn
}
func QuickJson(v interface{}) string {
barr, _ := json.Marshal(v)
return string(barr)
}
func QuickParseJson[T any](s string) T {
var v T
_ = json.Unmarshal([]byte(s), &v)
return v
}
func StrArrayToMap(sarr []string) map[string]bool {
m := make(map[string]bool)
for _, s := range sarr {
m[s] = true
}
return m
}

View File

@ -214,6 +214,7 @@ var literalRtnStateCommands = []string{
"enable",
"disable",
"function",
"zmodload",
}
func getCallExprLitArg(callExpr *syntax.CallExpr, argNum int) string {

View File

@ -751,16 +751,17 @@ func FeStateFromShellState(state *packet.ShellState) map[string]string {
}
rtn := make(map[string]string)
rtn["cwd"] = state.Cwd
envMap := shellenv.EnvMapFromState(state)
if envMap["VIRTUAL_ENV"] != "" {
rtn["VIRTUAL_ENV"] = envMap["VIRTUAL_ENV"]
declMap := shellenv.DeclMapFromState(state)
if decl, ok := declMap["VIRTUAL_ENV"]; ok {
rtn["VIRTUAL_ENV"] = decl.UnescapedValue()
}
if envMap["CONDA_DEFAULT_ENV"] != "" {
rtn["CONDA_DEFAULT_ENV"] = envMap["CONDA_DEFAULT_ENV"]
if decl, ok := declMap["CONDA_DEFAULT_ENV"]; ok {
rtn["CONDA_DEFAULT_ENV"] = decl.UnescapedValue()
}
for key, val := range envMap {
if strings.HasPrefix(key, "PROMPTVAR_") && envMap[key] != "" {
rtn[key] = val
for _, decl := range declMap {
// works for both legacy and new IsExtVar decls
if strings.HasPrefix(decl.Name, "PROMPTVAR_") {
rtn[decl.Name] = decl.UnescapedValue()
}
}
_, _, err := packet.ParseShellStateVersion(state.Version)