checkpoint on statediff. bug fixes. working on more robust error handling for packetsender

This commit is contained in:
sawka 2022-11-27 13:47:18 -08:00
parent 5a151369cb
commit 4481cddadc
11 changed files with 396 additions and 79 deletions

View File

@ -156,7 +156,7 @@ func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType
func handleSingle(fromServer bool) {
packetParser := packet.MakePacketParser(os.Stdin)
sender := packet.MakePacketSender(os.Stdout)
sender := packet.MakePacketSender(os.Stdout, nil)
defer func() {
sender.Close()
sender.WaitForDone()

View File

@ -2,9 +2,15 @@ package binpack
import (
"encoding/binary"
"fmt"
"io"
)
type Unpacker struct {
R FullByteReader
Err error
}
type FullByteReader interface {
io.ByteReader
io.Reader
@ -17,9 +23,11 @@ func PackValue(w io.Writer, barr []byte) error {
if err != nil {
return err
}
_, err = w.Write(barr)
if err != nil {
return err
if len(barr) > 0 {
_, err = w.Write(barr)
if err != nil {
return err
}
}
return nil
}
@ -36,6 +44,9 @@ func UnpackValue(r FullByteReader) ([]byte, error) {
if err != nil {
return nil, err
}
if lenVal == 0 {
return nil, nil
}
rtnBuf := make([]byte, int(lenVal))
_, err = io.ReadFull(r, rtnBuf)
if err != nil {
@ -51,3 +62,33 @@ func UnpackInt(r io.ByteReader) (int, error) {
}
return int(ival64), nil
}
func (u *Unpacker) UnpackValue(name string) []byte {
if u.Err != nil {
return nil
}
rtn, err := UnpackValue(u.R)
if err != nil {
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
}
return rtn
}
func (u *Unpacker) UnpackInt(name string) int {
if u.Err != nil {
return 0
}
rtn, err := UnpackInt(u.R)
if err != nil {
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
}
return rtn
}
func (u *Unpacker) Error() error {
return u.Err
}
func MakeUnpacker(r FullByteReader) *Unpacker {
return &Unpacker{R: r}
}

View File

@ -49,9 +49,10 @@ const (
CdPacketStr = "cd" // rpc
CmdDataPacketStr = "cmddata" // rpc-response
RawPacketStr = "raw"
SpecialInputPacketStr = "sinput" // command
CompGenPacketStr = "compgen" // rpc
ReInitPacketStr = "reinit" // rpc
SpecialInputPacketStr = "sinput" // command
CompGenPacketStr = "compgen" // rpc
ReInitPacketStr = "reinit" // rpc
CmdFinalPacketStr = "cmdfinal" // command, pushed at the "end" of a command (fail-safe for no cmddone)
)
const PacketSenderQueueSize = 20
@ -80,6 +81,7 @@ func init() {
TypeStrToFactory[DataEndPacketStr] = reflect.TypeOf(DataEndPacketType{})
TypeStrToFactory[CompGenPacketStr] = reflect.TypeOf(CompGenPacketType{})
TypeStrToFactory[ReInitPacketStr] = reflect.TypeOf(ReInitPacketType{})
TypeStrToFactory[CmdFinalPacketStr] = reflect.TypeOf(CmdFinalPacketType{})
var _ RpcPacketType = (*RunPacketType)(nil)
var _ RpcPacketType = (*GetCmdPacketType)(nil)
@ -96,6 +98,7 @@ func init() {
var _ CommandPacketType = (*DataAckPacketType)(nil)
var _ CommandPacketType = (*CmdDonePacketType)(nil)
var _ CommandPacketType = (*SpecialInputPacketType)(nil)
var _ CommandPacketType = (*CmdFinalPacketType)(nil)
}
func RegisterPacketType(typeStr string, rtype reflect.Type) {
@ -111,19 +114,6 @@ func MakePacket(packetType string) (PacketType, error) {
return rtn.Interface().(PacketType), nil
}
type ShellState struct {
Version string `json:"version,omitempty"`
Cwd string `json:"cwd,omitempty"`
ShellVars []byte `json:"shellvars,omitempty"`
Aliases string `json:"aliases,omitempty"`
Funcs string `json:"funcs,omitempty"`
Error string `json:"error,omitempty"`
}
func (state ShellState) IsEmpty() bool {
return state.Version == "" && state.Cwd == "" && len(state.ShellVars) == 0 && state.Aliases == "" && state.Funcs == "" && state.Error == ""
}
type CmdDataPacketType struct {
Type string `json:"type"`
RespId string `json:"respid"`
@ -509,13 +499,33 @@ func MakeDonePacket() *DonePacketType {
return &DonePacketType{Type: DonePacketStr}
}
type CmdFinalPacketType struct {
Type string `json:"type"`
Ts int64 `json:"ts"`
CK base.CommandKey `json:"ck"`
Error string `json:"error"`
}
func (*CmdFinalPacketType) GetType() string {
return CmdFinalPacketStr
}
func (pk *CmdFinalPacketType) GetCK() base.CommandKey {
return pk.CK
}
func MakeCmdFinalPacket(ck base.CommandKey) *CmdFinalPacketType {
return &CmdFinalPacketType{Type: CmdFinalPacketStr, CK: ck}
}
type CmdDonePacketType struct {
Type string `json:"type"`
Ts int64 `json:"ts"`
CK base.CommandKey `json:"ck"`
ExitCode int `json:"exitcode"`
DurationMs int64 `json:"durationms"`
FinalState *ShellState `json:"state,omitempty"`
Type string `json:"type"`
Ts int64 `json:"ts"`
CK base.CommandKey `json:"ck"`
ExitCode int `json:"exitcode"`
DurationMs int64 `json:"durationms"`
FinalState *ShellState `json:"finalstate,omitempty"`
FinalStateDiff *ShellStateDiff `json:"finalstatediff,omitempty"`
}
func (*CmdDonePacketType) GetType() string {
@ -580,7 +590,8 @@ type RunPacketType struct {
ReqId string `json:"reqid"`
CK base.CommandKey `json:"ck"`
Command string `json:"command"`
State *ShellState `json:"state"`
State *ShellState `json:"state,omitempty"`
StateDiff *ShellStateDiff `json:"statediff,omitempty"`
StateComplete bool `json:"statecomplete,omitempty"` // set to true if state is complete (the default env should not be set)
UsePty bool `json:"usepty,omitempty"`
TermOpts *TermOpts `json:"termopts,omitempty"`
@ -696,13 +707,34 @@ func sanitizeBytes(buf []byte) {
}
}
type SendError struct {
IsWriteError bool // fatal
IsMarshalError bool // not fatal
PacketType string
Err error
}
func (e *SendError) Unwrap() error {
return e.Err
}
func (e *SendError) Error() string {
if e.IsMarshalError {
return fmt.Sprintf("SendPacket marshal-error '%s' packet: %v", e.PacketType, e.Err)
} else if e.IsWriteError {
return fmt.Sprintf("SendPacket write-error: %v", e.Err)
} else {
return e.Err.Error()
}
}
func SendPacket(w io.Writer, packet PacketType) error {
if packet == nil {
return nil
}
jsonBytes, err := json.Marshal(packet)
if err != nil {
return fmt.Errorf("marshaling '%s' packet: %w", packet.GetType(), err)
return &SendError{IsMarshalError: true, PacketType: packet.GetType(), Err: err}
}
var outBuf bytes.Buffer
outBuf.WriteByte('\n')
@ -716,7 +748,7 @@ func SendPacket(w io.Writer, packet PacketType) error {
sanitizeBytes(outBytes)
_, err = w.Write(outBytes)
if err != nil {
return err
return &SendError{IsWriteError: true, PacketType: packet.GetType(), Err: err}
}
return nil
}
@ -726,18 +758,19 @@ func SendCmdError(w io.Writer, ck base.CommandKey, err error) error {
}
type PacketSender struct {
Lock *sync.Mutex
SendCh chan PacketType
Err error
Done bool
DoneCh chan bool
Lock *sync.Mutex
SendCh chan PacketType
Done bool
DoneCh chan bool
ErrHandler func(*PacketSender, PacketType, error)
}
func MakePacketSender(output io.Writer) *PacketSender {
func MakePacketSender(output io.Writer, errHandler func(*PacketSender, PacketType, error)) *PacketSender {
sender := &PacketSender{
Lock: &sync.Mutex{},
SendCh: make(chan PacketType, PacketSenderQueueSize),
DoneCh: make(chan bool),
Lock: &sync.Mutex{},
SendCh: make(chan PacketType, PacketSenderQueueSize),
DoneCh: make(chan bool),
ErrHandler: errHandler,
}
go func() {
defer close(sender.DoneCh)
@ -745,9 +778,12 @@ func MakePacketSender(output io.Writer) *PacketSender {
for pk := range sender.SendCh {
err := SendPacket(output, pk)
if err != nil {
sender.Lock.Lock()
sender.Err = err
sender.Lock.Unlock()
sender.goHandleError(pk, err)
if serr, ok := err.(*SendError); ok && serr.IsMarshalError {
// marshaler errors are recoverable
continue
}
// write errors are not recoverable
return
}
}
@ -755,6 +791,14 @@ func MakePacketSender(output io.Writer) *PacketSender {
return sender
}
func (sender *PacketSender) goHandleError(pk PacketType, err error) {
sender.Lock.Lock()
defer sender.Lock.Unlock()
if sender.ErrHandler != nil {
go sender.ErrHandler(sender, pk, err)
}
}
func MakeChannelPacketSender(packetCh chan PacketType) *PacketSender {
sender := &PacketSender{
Lock: &sync.Mutex{},
@ -791,9 +835,6 @@ func (sender *PacketSender) checkStatus() error {
if sender.Done {
return fmt.Errorf("cannot send packet, sender write loop is closed")
}
if sender.Err != nil {
return fmt.Errorf("cannot send packet, sender had error: %w", sender.Err)
}
return nil
}
@ -833,7 +874,7 @@ func (sender *PacketSender) SendResponse(reqId string, data interface{}) error {
return sender.SendPacket(pk)
}
func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) error {
func (sender *PacketSender) SendMessageFmt(fmtStr string, args ...interface{}) error {
return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...)))
}

131
pkg/packet/shellstate.go Normal file
View File

@ -0,0 +1,131 @@
package packet
import (
"bytes"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/scripthaus-dev/mshell/pkg/binpack"
"github.com/scripthaus-dev/mshell/pkg/statediff"
)
const ShellStatePackVersion = 0
const ShellStateDiffPackVersion = 0
type ShellState struct {
Version string `json:"version"` // [type] [semver]
Cwd string `json:"cwd,omitempty"`
ShellVars []byte `json:"shellvars,omitempty"`
Aliases string `json:"aliases,omitempty"`
Funcs string `json:"funcs,omitempty"`
Error string `json:"error,omitempty"`
}
type ShellStateDiff struct {
Version string `json:"version"` // [type] [semver]
BaseHash string `json:"basehash"`
Cwd string `json:"cwd,omitempty"`
VarsDiff []byte `json:"shellvarsdiff,omitempty"` // vardiff
AliasesDiff []byte `json:"aliasesdiff,omitempty"` // linediff
FuncsDiff []byte `json:"funcsdiff,omitempty"` // linediff
Error string `json:"error,omitempty"`
}
func (state ShellState) IsEmpty() bool {
return state.Version == "" && state.Cwd == "" && len(state.ShellVars) == 0 && state.Aliases == "" && state.Funcs == "" && state.Error == ""
}
// returns (SHA1, encoded-state)
func (state ShellState) EncodeAndHash() (string, []byte) {
var buf bytes.Buffer
binpack.PackInt(&buf, ShellStatePackVersion)
binpack.PackValue(&buf, []byte(state.Version))
binpack.PackValue(&buf, []byte(state.Cwd))
binpack.PackValue(&buf, state.ShellVars)
binpack.PackValue(&buf, []byte(state.Aliases))
binpack.PackValue(&buf, []byte(state.Funcs))
binpack.PackValue(&buf, []byte(state.Error))
hvalRaw := sha1.Sum(buf.Bytes())
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
return hval, buf.Bytes()
}
func (state ShellState) MarshalJSON() ([]byte, error) {
_, encodedState := state.EncodeAndHash()
return json.Marshal(encodedState)
}
func (state *ShellState) UnmarshalJSON(jsonBytes []byte) error {
var barr []byte
err := json.Unmarshal(jsonBytes, &barr)
if err != nil {
return err
}
buf := bytes.NewBuffer(barr)
u := binpack.MakeUnpacker(buf)
version := u.UnpackInt("ShellState pack version")
if version != ShellStatePackVersion {
return fmt.Errorf("invalid ShellState pack version: %d", version)
}
state.Version = string(u.UnpackValue("ShellState.Version"))
state.Cwd = string(u.UnpackValue("ShellState.Cwd"))
state.ShellVars = u.UnpackValue("ShellState.ShellVars")
state.Aliases = string(u.UnpackValue("ShellState.Aliases"))
state.Funcs = string(u.UnpackValue("ShellState.Funcs"))
state.Error = string(u.UnpackValue("ShellState.Error"))
return u.Error()
}
func (sdiff ShellStateDiff) MarshalJSON() ([]byte, error) {
var buf bytes.Buffer
binpack.PackInt(&buf, ShellStateDiffPackVersion)
binpack.PackValue(&buf, []byte(sdiff.Version))
binpack.PackValue(&buf, []byte(sdiff.BaseHash))
binpack.PackValue(&buf, []byte(sdiff.Cwd))
binpack.PackValue(&buf, sdiff.VarsDiff)
binpack.PackValue(&buf, sdiff.AliasesDiff)
binpack.PackValue(&buf, sdiff.FuncsDiff)
binpack.PackValue(&buf, []byte(sdiff.Error))
return buf.Bytes(), nil
}
func (sdiff *ShellStateDiff) UnmarshalJSON(jsonBytes []byte) error {
var barr []byte
err := json.Unmarshal(jsonBytes, &barr)
if err != nil {
return err
}
buf := bytes.NewBuffer(barr)
u := binpack.MakeUnpacker(buf)
version := u.UnpackInt("ShellState pack version")
if version != ShellStateDiffPackVersion {
return fmt.Errorf("invalid ShellStateDiff pack version: %d", version)
}
sdiff.Version = string(u.UnpackValue("ShellStateDiff.Version"))
sdiff.BaseHash = string(u.UnpackValue("ShellStateDiff.BaseHash"))
sdiff.Cwd = string(u.UnpackValue("ShellStateDiff.Cwd"))
sdiff.VarsDiff = u.UnpackValue("ShellStateDiff.VarsDiff")
sdiff.AliasesDiff = u.UnpackValue("ShellStateDiff.AliasesDiff")
sdiff.FuncsDiff = u.UnpackValue("ShellStateDiff.FuncsDiff")
sdiff.Error = string(u.UnpackValue("ShellStateDiff.Error"))
return u.Error()
}
func (sdiff ShellStateDiff) Dump() {
fmt.Printf("ShellStateDiff:\n")
fmt.Printf(" version: %s\n", sdiff.Version)
fmt.Printf(" base: %s\n", sdiff.BaseHash)
var mdiff statediff.MapDiffType
err := mdiff.Decode(sdiff.VarsDiff)
if err != nil {
fmt.Printf(" vars: error[%s]\n", err.Error())
} else {
mdiff.Dump()
}
fmt.Printf(" aliases: %d, funcs: %d\n", len(sdiff.AliasesDiff), len(sdiff.FuncsDiff))
if sdiff.Error != "" {
fmt.Printf(" error: %s\n", sdiff.Error)
}
}

View File

@ -14,6 +14,7 @@ import (
"sort"
"strings"
"sync"
"time"
"github.com/alessio/shellescape"
"github.com/scripthaus-dev/mshell/pkg/base"
@ -23,11 +24,13 @@ import (
// TODO create unblockable packet-sender (backed by an array) for clientproc
type MServer struct {
Lock *sync.Mutex
MainInput *packet.PacketParser
Sender *packet.PacketSender
ClientMap map[base.CommandKey]*shexec.ClientProc
Debug bool
Lock *sync.Mutex
MainInput *packet.PacketParser
Sender *packet.PacketSender
ClientMap map[base.CommandKey]*shexec.ClientProc
Debug bool
StateMap map[string]*packet.ShellState // sha1->state
CurrentState string // sha1
}
func (m *MServer) Close() {
@ -38,7 +41,7 @@ func (m *MServer) Close() {
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
ck := pk.GetCK()
if ck == "" {
m.Sender.SendMessage(fmt.Sprintf("received '%s' packet without ck", pk.GetType()))
m.Sender.SendMessageFmt("received '%s' packet without ck", pk.GetType())
return
}
m.Lock.Lock()
@ -137,12 +140,24 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
return
}
func (m *MServer) setCurrentState(state *packet.ShellState) {
if state == nil {
return
}
hval, _ := state.EncodeAndHash()
m.Lock.Lock()
defer m.Lock.Unlock()
m.StateMap[hval] = state
m.CurrentState = hval
}
func (m *MServer) reinit(reqId string) {
initPk, err := shexec.MakeServerInitPacket()
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
return
}
m.setCurrentState(initPk.State)
initPk.RespId = reqId
m.Sender.SendPacket(initPk)
}
@ -170,6 +185,32 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
return
}
func (m *MServer) getCurrentState() (string, *packet.ShellState) {
m.Lock.Lock()
defer m.Lock.Unlock()
return m.CurrentState, m.StateMap[m.CurrentState]
}
func (m *MServer) clientPacketCallback(pk packet.PacketType) {
if pk.GetType() != packet.CmdDonePacketStr {
return
}
donePk := pk.(*packet.CmdDonePacketType)
if donePk.FinalState == nil {
return
}
stateHash, curState := m.getCurrentState()
if curState == nil {
return
}
diff, err := shexec.MakeShellStateDiff(*curState, stateHash, *donePk.FinalState)
if err != nil {
return
}
donePk.FinalState = nil
donePk.FinalStateDiff = &diff
}
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
if err := runPacket.CK.Validate("packet"); err != nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
@ -190,16 +231,34 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
m.Lock.Unlock()
go func() {
defer func() {
r := recover()
finalPk := packet.MakeCmdFinalPacket(runPacket.CK)
finalPk.Ts = time.Now().UnixMilli()
if r != nil {
finalPk.Error = fmt.Sprintf("%s", r)
}
m.Sender.SendPacket(finalPk)
m.Lock.Lock()
delete(m.ClientMap, runPacket.CK)
m.Lock.Unlock()
cproc.Close()
}()
shexec.SendRunPacketAndRunData(context.Background(), cproc.Input, runPacket)
cproc.ProxySingleOutput(runPacket.CK, m.Sender)
cproc.ProxySingleOutput(runPacket.CK, m.Sender, m.clientPacketCallback)
}()
}
func (m *MServer) packetSenderErrorHandler(sender *packet.PacketSender, pk packet.PacketType, err error) {
if serr, ok := err.(*packet.SendError); ok && serr.IsMarshalError {
msg := packet.MakeMessagePacket(err.Error())
if cpk, ok := pk.(packet.CommandPacketType); ok {
msg.CK = cpk.GetCK()
}
sender.SendPacket(msg)
}
// otherwise ignore (we can't output anything for a I/O error)
}
func RunServer() (int, error) {
debug := false
if len(os.Args) >= 3 && os.Args[2] == "--debug" {
@ -208,19 +267,21 @@ func RunServer() (int, error) {
server := &MServer{
Lock: &sync.Mutex{},
ClientMap: make(map[base.CommandKey]*shexec.ClientProc),
StateMap: make(map[string]*packet.ShellState),
Debug: debug,
}
if debug {
packet.GlobalDebug = true
}
server.MainInput = packet.MakePacketParser(os.Stdin)
server.Sender = packet.MakePacketSender(os.Stdout)
server.Sender = packet.MakePacketSender(os.Stdout, server.packetSenderErrorHandler)
defer server.Close()
var err error
initPacket, err := shexec.MakeServerInitPacket()
if err != nil {
return 1, err
}
server.setCurrentState(initPacket.State)
server.Sender.SendPacket(initPacket)
builder := packet.MakeRunPacketBuilder()
for pk := range server.MainInput.MainCh {
@ -243,7 +304,7 @@ func RunServer() (int, error) {
server.ProcessRpcPacket(rpcPk)
continue
}
server.Sender.SendMessage(fmt.Sprintf("invalid packet '%s' sent to mshell server", packet.AsString(pk)))
server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk))
continue
}
return 0, nil

View File

@ -46,7 +46,7 @@ func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.I
if err != nil {
return nil, nil, fmt.Errorf("running local client: %w", err)
}
sender := packet.MakePacketSender(inputWriter)
sender := packet.MakePacketSender(inputWriter, nil)
stdoutPacketParser := packet.MakePacketParser(stdoutReader)
stderrPacketParser := packet.MakePacketParser(stderrReader)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser)
@ -108,9 +108,12 @@ func (cproc *ClientProc) Close() {
}
}
func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.PacketSender) {
func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.PacketSender, packetCallback func(packet.PacketType)) {
sentDonePk := false
for pk := range cproc.Output.MainCh {
if packetCallback != nil {
packetCallback(pk)
}
if pk.GetType() == packet.CmdDonePacketStr {
sentDonePk = true
}

View File

@ -10,6 +10,7 @@ import (
"github.com/alessio/shellescape"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/statediff"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/syntax"
)
@ -189,6 +190,32 @@ func ParseDeclLine(envLine string) *DeclareDeclType {
}
}
func parseDeclLineToKV(envLine string) (string, string) {
eqIdx := strings.Index(envLine, "=")
if eqIdx == -1 {
return "", ""
}
namePart := envLine[0:eqIdx]
valPart := envLine[eqIdx+1:]
return namePart, valPart
}
func shellStateVarsToMap(shellVars []byte) map[string]string {
if len(shellVars) == 0 {
return nil
}
rtn := make(map[string]string)
vars := bytes.Split(shellVars, []byte{0})
for _, varLine := range vars {
name, val := parseDeclLineToKV(string(varLine))
if name == "" {
continue
}
rtn[name] = val
}
return rtn
}
func DeclMapFromState(state *packet.ShellState) map[string]*DeclareDeclType {
if state == nil {
return nil
@ -485,3 +512,24 @@ func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool
}
return d1.Value == d2.Value // this works even for assoc arrays because we normalize them when parsing
}
func MakeShellStateDiff(oldState packet.ShellState, oldStateHash string, newState packet.ShellState) (packet.ShellStateDiff, error) {
var rtn packet.ShellStateDiff
rtn.BaseHash = oldStateHash
if oldState.Version != newState.Version {
return rtn, fmt.Errorf("cannot diff, states have different versions")
}
rtn.Version = newState.Version
if oldState.Cwd != newState.Cwd {
rtn.Cwd = newState.Cwd
}
if oldState.Error != newState.Error {
rtn.Error = newState.Error
}
oldVars := shellStateVarsToMap(oldState.ShellVars)
newVars := shellStateVarsToMap(newState.ShellVars)
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
rtn.AliasesDiff = statediff.MakeLineDiff(oldState.Aliases, newState.Aliases)
rtn.FuncsDiff = statediff.MakeLineDiff(oldState.Funcs, newState.Funcs)
return rtn, nil
}

View File

@ -834,7 +834,7 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdCon
stdoutPacketParser := packet.MakePacketParser(stdoutReader)
stderrPacketParser := packet.MakePacketParser(stderrReader)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser)
sender := packet.MakePacketSender(inputWriter)
sender := packet.MakePacketSender(inputWriter, nil)
versionOk := false
for pk := range packetParser.MainCh {
if pk.GetType() == packet.RawPacketStr {
@ -986,14 +986,6 @@ func makeRcFileStr(pk *packet.RunPacketType) string {
rcBuf.WriteString(pk.State.Aliases)
rcBuf.WriteString("\n")
}
if pk.ReturnState {
rcBuf.WriteString(`
_scripthaus_exittrap () {
` + GetShellStateCmd + `
}
trap _scripthaus_exittrap EXIT
`)
}
return rcBuf.String()
}
@ -1285,7 +1277,7 @@ func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
if err != nil {
return nil, nil, fmt.Errorf("cannot open runout file '%s': %w", fileNames.RunnerOutFile, err)
}
cmd.DetachedOutput = packet.MakePacketSender(cmd.RunnerOutFd)
cmd.DetachedOutput = packet.MakePacketSender(cmd.RunnerOutFd, nil)
ecmd, err := MakeDetachedExecCmd(pk, cmdTty)
if err != nil {
return nil, nil, err

View File

@ -19,7 +19,7 @@ type LineDiffType struct {
NewData []string
}
func (diff LineDiffType) dump() {
func (diff LineDiffType) Dump() {
fmt.Printf("DIFF:\n")
pos := 1
for _, entry := range diff.Lines {
@ -68,7 +68,7 @@ func putUVarint(buf *bytes.Buffer, viBuf []byte, ival int) {
// simple encoding
// write varints. first version, then len, then len-number-of-varints, then fill the rest with newdata
// [version] [len-varint] [varint]xlen... newdata (bytes)
func (diff LineDiffType) encode() []byte {
func (diff LineDiffType) Encode() []byte {
var buf bytes.Buffer
viBuf := make([]byte, binary.MaxVarintLen64)
putUVarint(&buf, viBuf, LineDiffVersion)
@ -86,7 +86,7 @@ func (diff LineDiffType) encode() []byte {
return buf.Bytes()
}
func (rtn *LineDiffType) decode(diffBytes []byte) error {
func (rtn *LineDiffType) Decode(diffBytes []byte) error {
r := bytes.NewBuffer(diffBytes)
version, err := binary.ReadUvarint(r)
if err != nil {
@ -164,12 +164,12 @@ func MakeLineDiff(str1 string, str2 string) []byte {
str1Arr := strings.Split(str1, "\n")
str2Arr := strings.Split(str2, "\n")
diff := makeLineDiff(str1Arr, str2Arr)
return diff.encode()
return diff.Encode()
}
func ApplyLineDiff(str1 string, diffBytes []byte) (string, error) {
var diff LineDiffType
err := diff.decode(diffBytes)
err := diff.Decode(diffBytes)
if err != nil {
return "", err
}

View File

@ -15,7 +15,7 @@ type MapDiffType struct {
ToRemove []string
}
func (diff MapDiffType) dump() {
func (diff MapDiffType) Dump() {
fmt.Printf("VAR-DIFF\n")
for name, val := range diff.ToAdd {
fmt.Printf(" add: %s=%s\n", name, val)
@ -58,7 +58,7 @@ func (diff MapDiffType) apply(oldMap map[string]string) map[string]string {
return rtn
}
func (diff MapDiffType) encode() []byte {
func (diff MapDiffType) Encode() []byte {
var buf bytes.Buffer
viBuf := make([]byte, binary.MaxVarintLen64)
putUVarint(&buf, viBuf, MapDiffVersion)
@ -76,7 +76,7 @@ func (diff MapDiffType) encode() []byte {
return buf.Bytes()
}
func (diff *MapDiffType) decode(diffBytes []byte) error {
func (diff *MapDiffType) Decode(diffBytes []byte) error {
r := bytes.NewBuffer(diffBytes)
version, err := binary.ReadUvarint(r)
if err != nil {
@ -111,12 +111,12 @@ func (diff *MapDiffType) decode(diffBytes []byte) error {
func MakeMapDiff(m1 map[string]string, m2 map[string]string) []byte {
diff := makeMapDiff(m1, m2)
return diff.encode()
return diff.Encode()
}
func ApplyMapDiff(oldMap map[string]string, diffBytes []byte) (map[string]string, error) {
var diff MapDiffType
err := diff.decode(diffBytes)
err := diff.Decode(diffBytes)
if err != nil {
return nil, err
}

View File

@ -47,7 +47,7 @@ func testLineDiff(t *testing.T, str1 string, str2 string) {
t.Errorf("bad diff output")
}
var dt LineDiffType
err = dt.decode(diffBytes)
err = dt.Decode(diffBytes)
if err != nil {
t.Errorf("error decoding diff: %v\n", err)
}
@ -86,8 +86,8 @@ func TestMapDiff(t *testing.T) {
diffBytes := MakeMapDiff(m1, m2)
fmt.Printf("mapdifflen: %d\n", len(diffBytes))
var diff MapDiffType
diff.decode(diffBytes)
diff.dump()
diff.Decode(diffBytes)
diff.Dump()
mcheck, err := ApplyMapDiff(m1, diffBytes)
if err != nil {
t.Fatalf("error applying map diff: %v", err)