updates/bugfixes for statediff

This commit is contained in:
sawka 2022-11-28 18:05:54 -08:00
parent 605d0899cf
commit bdd8381b01
6 changed files with 111 additions and 21 deletions

View File

@ -536,12 +536,15 @@ func main() {
os.Exit(1) os.Exit(1)
} }
} else if firstArg == "--single" { } else if firstArg == "--single" {
base.InitDebugLog("single")
handleSingle(false) handleSingle(false)
return return
} else if firstArg == "--single-from-server" { } else if firstArg == "--single-from-server" {
base.InitDebugLog("single")
handleSingle(true) handleSingle(true)
return return
} else if firstArg == "--server" { } else if firstArg == "--server" {
base.InitDebugLog("server")
rtnCode, err := server.RunServer() rtnCode, err := server.RunServer()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err) fmt.Fprintf(os.Stderr, "[error] %v\n", err)

View File

@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"log"
"os" "os"
"os/exec" "os/exec"
"path" "path"
@ -33,9 +34,13 @@ const SessionsDirBaseName = "sessions"
const MShellVersion = "v0.2.0" const MShellVersion = "v0.2.0"
const RemoteIdFile = "remoteid" const RemoteIdFile = "remoteid"
const DefaultMShellInstallBinDir = "/opt/mshell/bin" const DefaultMShellInstallBinDir = "/opt/mshell/bin"
const LogFileName = "mshell.log"
const ForceDebugLog = false
var sessionDirCache = make(map[string]string) var sessionDirCache = make(map[string]string)
var baseLock = &sync.Mutex{} var baseLock = &sync.Mutex{}
var DebugLogEnabled = false
var DebugLogger *log.Logger
type CommandFileNames struct { type CommandFileNames struct {
PtyOutFile string PtyOutFile string
@ -56,6 +61,31 @@ func (ckey CommandKey) IsEmpty() bool {
return string(ckey) == "" return string(ckey) == ""
} }
func Logf(fmtStr string, args ...interface{}) {
if (!DebugLogEnabled && !ForceDebugLog) || DebugLogger == nil {
return
}
DebugLogger.Printf(fmtStr, args...)
}
func InitDebugLog(prefix string) {
homeDir := GetMShellHomeDir()
err := os.MkdirAll(homeDir, 0777)
if err != nil {
return
}
logFile := path.Join(homeDir, LogFileName)
fd, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
DebugLogger = log.New(fd, prefix+" ", log.LstdFlags)
}
func SetEnableDebugLog(enable bool) {
DebugLogEnabled = enable
}
func (ckey CommandKey) GetSessionId() string { func (ckey CommandKey) GetSessionId() string {
slashIdx := strings.Index(string(ckey), "/") slashIdx := strings.Index(string(ckey), "/")
if slashIdx == -1 { if slashIdx == -1 {

View File

@ -154,19 +154,39 @@ func (sdiff *ShellStateDiff) GetHashVal(force bool) string {
return sdiff.HashVal return sdiff.HashVal
} }
func (sdiff ShellStateDiff) Dump() { func (sdiff ShellStateDiff) Dump(vars bool, aliases bool, funcs bool) {
fmt.Printf("ShellStateDiff:\n") fmt.Printf("ShellStateDiff:\n")
fmt.Printf(" version: %s\n", sdiff.Version) fmt.Printf(" version: %s\n", sdiff.Version)
fmt.Printf(" base: %s\n", sdiff.BaseHash) fmt.Printf(" base: %s\n", sdiff.BaseHash)
var mdiff statediff.MapDiffType fmt.Printf(" vars: %d, aliases: %d, funcs: %d\n", len(sdiff.VarsDiff), len(sdiff.AliasesDiff), len(sdiff.FuncsDiff))
err := mdiff.Decode(sdiff.VarsDiff)
if err != nil {
fmt.Printf(" vars: error[%s]\n", err.Error())
} else {
mdiff.Dump()
}
fmt.Printf(" aliases: %d, funcs: %d\n", len(sdiff.AliasesDiff), len(sdiff.FuncsDiff))
if sdiff.Error != "" { if sdiff.Error != "" {
fmt.Printf(" error: %s\n", sdiff.Error) fmt.Printf(" error: %s\n", sdiff.Error)
} }
if vars {
var mdiff statediff.MapDiffType
err := mdiff.Decode(sdiff.VarsDiff)
if err != nil {
fmt.Printf(" vars: error[%s]\n", err.Error())
} else {
mdiff.Dump()
}
}
if aliases && len(sdiff.AliasesDiff) > 0 {
var ldiff statediff.LineDiffType
err := ldiff.Decode(sdiff.AliasesDiff)
if err != nil {
fmt.Printf(" aliases: error[%s]\n", err.Error())
} else {
ldiff.Dump()
}
}
if funcs && len(sdiff.FuncsDiff) > 0 {
var ldiff statediff.LineDiffType
err := ldiff.Decode(sdiff.FuncsDiff)
if err != nil {
fmt.Printf(" funcs: error[%s]\n", err.Error())
} else {
ldiff.Dump()
}
}
} }

View File

@ -9,7 +9,9 @@ import (
"strings" "strings"
"github.com/alessio/shellescape" "github.com/alessio/shellescape"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/simpleexpand"
"github.com/scripthaus-dev/mshell/pkg/statediff" "github.com/scripthaus-dev/mshell/pkg/statediff"
"mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/syntax" "mvdan.cc/sh/v3/syntax"
@ -157,10 +159,6 @@ func (d *DeclareDeclType) Serialize() string {
return fmt.Sprintf("%s|%s=%s\x00", d.Args, d.Name, d.Value) return fmt.Sprintf("%s|%s=%s\x00", d.Args, d.Name, d.Value)
} }
func (d *DeclareDeclType) EnvString() string {
return d.Name + "=" + d.Value
}
func (d *DeclareDeclType) DeclareStmt() string { func (d *DeclareDeclType) DeclareStmt() string {
var argsStr string var argsStr string
if d.Args == "" { if d.Args == "" {
@ -274,11 +272,12 @@ func EnvMapFromState(state *packet.ShellState) map[string]string {
return nil return nil
} }
rtn := make(map[string]string) rtn := make(map[string]string)
ectx := simpleexpand.SimpleExpandContext{}
vars := bytes.Split(state.ShellVars, []byte{0}) vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars { for _, varLine := range vars {
decl := ParseDeclLine(string(varLine)) decl := ParseDeclLine(string(varLine))
if decl != nil && decl.IsExport() { if decl != nil && decl.IsExport() {
rtn[decl.Name] = decl.Value rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
} }
} }
return rtn return rtn
@ -289,11 +288,12 @@ func ShellVarMapFromState(state *packet.ShellState) map[string]string {
return nil return nil
} }
rtn := make(map[string]string) rtn := make(map[string]string)
ectx := simpleexpand.SimpleExpandContext{}
vars := bytes.Split(state.ShellVars, []byte{0}) vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars { for _, varLine := range vars {
decl := ParseDeclLine(string(varLine)) decl := ParseDeclLine(string(varLine))
if decl != nil { if decl != nil {
rtn[decl.Name] = decl.Value rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
} }
} }
return rtn return rtn
@ -438,6 +438,7 @@ func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
if strings.Index(rtn.Version, "bash") == -1 { if strings.Index(rtn.Version, "bash") == -1 {
return nil, fmt.Errorf("invalid shell state output, only bash is supported") return nil, fmt.Errorf("invalid shell state output, only bash is supported")
} }
rtn.Version = rtn.Version
cwdStr := string(fields[1]) cwdStr := string(fields[1])
if strings.HasSuffix(cwdStr, "\r\n") { if strings.HasSuffix(cwdStr, "\r\n") {
cwdStr = cwdStr[0 : len(cwdStr)-2] cwdStr = cwdStr[0 : len(cwdStr)-2]
@ -451,9 +452,35 @@ func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
} }
rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n") rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n") rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
rtn.Funcs = removeFunc(rtn.Funcs, "_scripthaus_exittrap")
lines := strings.Split(rtn.Funcs, "\n")
for _, line := range lines {
base.Logf("func-line: [%s]\n", line)
}
return rtn, nil return rtn, nil
} }
func removeFunc(funcs string, toRemove string) string {
lines := strings.Split(funcs, "\n")
var newLines []string
removeLine := fmt.Sprintf("%s ()", toRemove)
doingRemove := false
for _, line := range lines {
if line == removeLine {
doingRemove = true
continue
}
if doingRemove {
if line == "}" {
doingRemove = false
}
continue
}
newLines = append(newLines, line)
}
return strings.Join(newLines, "\n")
}
func (d *DeclareDeclType) normalize() error { func (d *DeclareDeclType) normalize() error {
if d.DataType() == DeclTypeAssocArray { if d.DataType() == DeclTypeAssocArray {
return d.normalizeAssocArrayDecl() return d.normalizeAssocArrayDecl()
@ -565,9 +592,7 @@ func MakeShellStateDiff(oldState packet.ShellState, oldStateHash string, newStat
if oldState.Cwd != newState.Cwd { if oldState.Cwd != newState.Cwd {
rtn.Cwd = newState.Cwd rtn.Cwd = newState.Cwd
} }
if oldState.Error != newState.Error { rtn.Error = newState.Error
rtn.Error = newState.Error
}
oldVars := shellStateVarsToMap(oldState.ShellVars) oldVars := shellStateVarsToMap(oldState.ShellVars)
newVars := shellStateVarsToMap(newState.ShellVars) newVars := shellStateVarsToMap(newState.ShellVars)
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars) rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
@ -580,7 +605,10 @@ func ApplyShellStateDiff(oldState packet.ShellState, diff packet.ShellStateDiff)
var rtnState packet.ShellState var rtnState packet.ShellState
var err error var err error
rtnState.Version = oldState.Version rtnState.Version = oldState.Version
rtnState.Cwd = diff.Cwd rtnState.Cwd = oldState.Cwd
if diff.Cwd != "" {
rtnState.Cwd = diff.Cwd
}
rtnState.Error = diff.Error rtnState.Error = diff.Error
oldVars := shellStateVarsToMap(oldState.ShellVars) oldVars := shellStateVarsToMap(oldState.ShellVars)
newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff) newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff)

View File

@ -93,12 +93,18 @@ func expandLiteralPlus(buf *bytes.Buffer, info *SimpleExpandInfo, litVal string,
func expandSQANSILiteral(buf *bytes.Buffer, litVal string) { func expandSQANSILiteral(buf *bytes.Buffer, litVal string) {
// no info specials // no info specials
if strings.HasSuffix(litVal, "'") {
litVal = litVal[0 : len(litVal)-1]
}
str, _, _ := expand.Format(nil, litVal, nil) str, _, _ := expand.Format(nil, litVal, nil)
buf.WriteString(str) buf.WriteString(str)
} }
func expandSQLiteral(buf *bytes.Buffer, litVal string) { func expandSQLiteral(buf *bytes.Buffer, litVal string) {
// no info specials // no info specials
if strings.HasSuffix(litVal, "'") {
litVal = litVal[0 : len(litVal)-1]
}
buf.WriteString(litVal) buf.WriteString(litVal)
} }
@ -125,6 +131,9 @@ func expandDQLiteral(buf *bytes.Buffer, info *SimpleExpandInfo, litVal string) {
lastDollar = false lastDollar = false
continue continue
} }
if ch == '"' {
break
}
// similar to expandLiteral, but no globbing // similar to expandLiteral, but no globbing
if ch == '`' { if ch == '`' {

View File

@ -18,10 +18,10 @@ type MapDiffType struct {
func (diff MapDiffType) Dump() { func (diff MapDiffType) Dump() {
fmt.Printf("VAR-DIFF\n") fmt.Printf("VAR-DIFF\n")
for name, val := range diff.ToAdd { for name, val := range diff.ToAdd {
fmt.Printf(" add: %s=%s\n", name, val) fmt.Printf(" add[%s] %s\n", name, val)
} }
for _, name := range diff.ToRemove { for _, name := range diff.ToRemove {
fmt.Printf(" rem: %s\n", name) fmt.Printf(" rem[%s]\n", name)
} }
} }