mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-02-22 02:41:23 +01:00
checkpoint on statediff. bug fixes. working on more robust error handling for packetsender
This commit is contained in:
parent
5a151369cb
commit
4481cddadc
@ -156,7 +156,7 @@ func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType
|
||||
|
||||
func handleSingle(fromServer bool) {
|
||||
packetParser := packet.MakePacketParser(os.Stdin)
|
||||
sender := packet.MakePacketSender(os.Stdout)
|
||||
sender := packet.MakePacketSender(os.Stdout, nil)
|
||||
defer func() {
|
||||
sender.Close()
|
||||
sender.WaitForDone()
|
||||
|
@ -2,9 +2,15 @@ package binpack
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Unpacker struct {
|
||||
R FullByteReader
|
||||
Err error
|
||||
}
|
||||
|
||||
type FullByteReader interface {
|
||||
io.ByteReader
|
||||
io.Reader
|
||||
@ -17,9 +23,11 @@ func PackValue(w io.Writer, barr []byte) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(barr)
|
||||
if err != nil {
|
||||
return err
|
||||
if len(barr) > 0 {
|
||||
_, err = w.Write(barr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -36,6 +44,9 @@ func UnpackValue(r FullByteReader) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lenVal == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
rtnBuf := make([]byte, int(lenVal))
|
||||
_, err = io.ReadFull(r, rtnBuf)
|
||||
if err != nil {
|
||||
@ -51,3 +62,33 @@ func UnpackInt(r io.ByteReader) (int, error) {
|
||||
}
|
||||
return int(ival64), nil
|
||||
}
|
||||
|
||||
func (u *Unpacker) UnpackValue(name string) []byte {
|
||||
if u.Err != nil {
|
||||
return nil
|
||||
}
|
||||
rtn, err := UnpackValue(u.R)
|
||||
if err != nil {
|
||||
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (u *Unpacker) UnpackInt(name string) int {
|
||||
if u.Err != nil {
|
||||
return 0
|
||||
}
|
||||
rtn, err := UnpackInt(u.R)
|
||||
if err != nil {
|
||||
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (u *Unpacker) Error() error {
|
||||
return u.Err
|
||||
}
|
||||
|
||||
func MakeUnpacker(r FullByteReader) *Unpacker {
|
||||
return &Unpacker{R: r}
|
||||
}
|
||||
|
@ -49,9 +49,10 @@ const (
|
||||
CdPacketStr = "cd" // rpc
|
||||
CmdDataPacketStr = "cmddata" // rpc-response
|
||||
RawPacketStr = "raw"
|
||||
SpecialInputPacketStr = "sinput" // command
|
||||
CompGenPacketStr = "compgen" // rpc
|
||||
ReInitPacketStr = "reinit" // rpc
|
||||
SpecialInputPacketStr = "sinput" // command
|
||||
CompGenPacketStr = "compgen" // rpc
|
||||
ReInitPacketStr = "reinit" // rpc
|
||||
CmdFinalPacketStr = "cmdfinal" // command, pushed at the "end" of a command (fail-safe for no cmddone)
|
||||
)
|
||||
|
||||
const PacketSenderQueueSize = 20
|
||||
@ -80,6 +81,7 @@ func init() {
|
||||
TypeStrToFactory[DataEndPacketStr] = reflect.TypeOf(DataEndPacketType{})
|
||||
TypeStrToFactory[CompGenPacketStr] = reflect.TypeOf(CompGenPacketType{})
|
||||
TypeStrToFactory[ReInitPacketStr] = reflect.TypeOf(ReInitPacketType{})
|
||||
TypeStrToFactory[CmdFinalPacketStr] = reflect.TypeOf(CmdFinalPacketType{})
|
||||
|
||||
var _ RpcPacketType = (*RunPacketType)(nil)
|
||||
var _ RpcPacketType = (*GetCmdPacketType)(nil)
|
||||
@ -96,6 +98,7 @@ func init() {
|
||||
var _ CommandPacketType = (*DataAckPacketType)(nil)
|
||||
var _ CommandPacketType = (*CmdDonePacketType)(nil)
|
||||
var _ CommandPacketType = (*SpecialInputPacketType)(nil)
|
||||
var _ CommandPacketType = (*CmdFinalPacketType)(nil)
|
||||
}
|
||||
|
||||
func RegisterPacketType(typeStr string, rtype reflect.Type) {
|
||||
@ -111,19 +114,6 @@ func MakePacket(packetType string) (PacketType, error) {
|
||||
return rtn.Interface().(PacketType), nil
|
||||
}
|
||||
|
||||
type ShellState struct {
|
||||
Version string `json:"version,omitempty"`
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
ShellVars []byte `json:"shellvars,omitempty"`
|
||||
Aliases string `json:"aliases,omitempty"`
|
||||
Funcs string `json:"funcs,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (state ShellState) IsEmpty() bool {
|
||||
return state.Version == "" && state.Cwd == "" && len(state.ShellVars) == 0 && state.Aliases == "" && state.Funcs == "" && state.Error == ""
|
||||
}
|
||||
|
||||
type CmdDataPacketType struct {
|
||||
Type string `json:"type"`
|
||||
RespId string `json:"respid"`
|
||||
@ -509,13 +499,33 @@ func MakeDonePacket() *DonePacketType {
|
||||
return &DonePacketType{Type: DonePacketStr}
|
||||
}
|
||||
|
||||
type CmdFinalPacketType struct {
|
||||
Type string `json:"type"`
|
||||
Ts int64 `json:"ts"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (*CmdFinalPacketType) GetType() string {
|
||||
return CmdFinalPacketStr
|
||||
}
|
||||
|
||||
func (pk *CmdFinalPacketType) GetCK() base.CommandKey {
|
||||
return pk.CK
|
||||
}
|
||||
|
||||
func MakeCmdFinalPacket(ck base.CommandKey) *CmdFinalPacketType {
|
||||
return &CmdFinalPacketType{Type: CmdFinalPacketStr, CK: ck}
|
||||
}
|
||||
|
||||
type CmdDonePacketType struct {
|
||||
Type string `json:"type"`
|
||||
Ts int64 `json:"ts"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
ExitCode int `json:"exitcode"`
|
||||
DurationMs int64 `json:"durationms"`
|
||||
FinalState *ShellState `json:"state,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Ts int64 `json:"ts"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
ExitCode int `json:"exitcode"`
|
||||
DurationMs int64 `json:"durationms"`
|
||||
FinalState *ShellState `json:"finalstate,omitempty"`
|
||||
FinalStateDiff *ShellStateDiff `json:"finalstatediff,omitempty"`
|
||||
}
|
||||
|
||||
func (*CmdDonePacketType) GetType() string {
|
||||
@ -580,7 +590,8 @@ type RunPacketType struct {
|
||||
ReqId string `json:"reqid"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
Command string `json:"command"`
|
||||
State *ShellState `json:"state"`
|
||||
State *ShellState `json:"state,omitempty"`
|
||||
StateDiff *ShellStateDiff `json:"statediff,omitempty"`
|
||||
StateComplete bool `json:"statecomplete,omitempty"` // set to true if state is complete (the default env should not be set)
|
||||
UsePty bool `json:"usepty,omitempty"`
|
||||
TermOpts *TermOpts `json:"termopts,omitempty"`
|
||||
@ -696,13 +707,34 @@ func sanitizeBytes(buf []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
type SendError struct {
|
||||
IsWriteError bool // fatal
|
||||
IsMarshalError bool // not fatal
|
||||
PacketType string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *SendError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func (e *SendError) Error() string {
|
||||
if e.IsMarshalError {
|
||||
return fmt.Sprintf("SendPacket marshal-error '%s' packet: %v", e.PacketType, e.Err)
|
||||
} else if e.IsWriteError {
|
||||
return fmt.Sprintf("SendPacket write-error: %v", e.Err)
|
||||
} else {
|
||||
return e.Err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
func SendPacket(w io.Writer, packet PacketType) error {
|
||||
if packet == nil {
|
||||
return nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal(packet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling '%s' packet: %w", packet.GetType(), err)
|
||||
return &SendError{IsMarshalError: true, PacketType: packet.GetType(), Err: err}
|
||||
}
|
||||
var outBuf bytes.Buffer
|
||||
outBuf.WriteByte('\n')
|
||||
@ -716,7 +748,7 @@ func SendPacket(w io.Writer, packet PacketType) error {
|
||||
sanitizeBytes(outBytes)
|
||||
_, err = w.Write(outBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
return &SendError{IsWriteError: true, PacketType: packet.GetType(), Err: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -726,18 +758,19 @@ func SendCmdError(w io.Writer, ck base.CommandKey, err error) error {
|
||||
}
|
||||
|
||||
type PacketSender struct {
|
||||
Lock *sync.Mutex
|
||||
SendCh chan PacketType
|
||||
Err error
|
||||
Done bool
|
||||
DoneCh chan bool
|
||||
Lock *sync.Mutex
|
||||
SendCh chan PacketType
|
||||
Done bool
|
||||
DoneCh chan bool
|
||||
ErrHandler func(*PacketSender, PacketType, error)
|
||||
}
|
||||
|
||||
func MakePacketSender(output io.Writer) *PacketSender {
|
||||
func MakePacketSender(output io.Writer, errHandler func(*PacketSender, PacketType, error)) *PacketSender {
|
||||
sender := &PacketSender{
|
||||
Lock: &sync.Mutex{},
|
||||
SendCh: make(chan PacketType, PacketSenderQueueSize),
|
||||
DoneCh: make(chan bool),
|
||||
Lock: &sync.Mutex{},
|
||||
SendCh: make(chan PacketType, PacketSenderQueueSize),
|
||||
DoneCh: make(chan bool),
|
||||
ErrHandler: errHandler,
|
||||
}
|
||||
go func() {
|
||||
defer close(sender.DoneCh)
|
||||
@ -745,9 +778,12 @@ func MakePacketSender(output io.Writer) *PacketSender {
|
||||
for pk := range sender.SendCh {
|
||||
err := SendPacket(output, pk)
|
||||
if err != nil {
|
||||
sender.Lock.Lock()
|
||||
sender.Err = err
|
||||
sender.Lock.Unlock()
|
||||
sender.goHandleError(pk, err)
|
||||
if serr, ok := err.(*SendError); ok && serr.IsMarshalError {
|
||||
// marshaler errors are recoverable
|
||||
continue
|
||||
}
|
||||
// write errors are not recoverable
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -755,6 +791,14 @@ func MakePacketSender(output io.Writer) *PacketSender {
|
||||
return sender
|
||||
}
|
||||
|
||||
func (sender *PacketSender) goHandleError(pk PacketType, err error) {
|
||||
sender.Lock.Lock()
|
||||
defer sender.Lock.Unlock()
|
||||
if sender.ErrHandler != nil {
|
||||
go sender.ErrHandler(sender, pk, err)
|
||||
}
|
||||
}
|
||||
|
||||
func MakeChannelPacketSender(packetCh chan PacketType) *PacketSender {
|
||||
sender := &PacketSender{
|
||||
Lock: &sync.Mutex{},
|
||||
@ -791,9 +835,6 @@ func (sender *PacketSender) checkStatus() error {
|
||||
if sender.Done {
|
||||
return fmt.Errorf("cannot send packet, sender write loop is closed")
|
||||
}
|
||||
if sender.Err != nil {
|
||||
return fmt.Errorf("cannot send packet, sender had error: %w", sender.Err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -833,7 +874,7 @@ func (sender *PacketSender) SendResponse(reqId string, data interface{}) error {
|
||||
return sender.SendPacket(pk)
|
||||
}
|
||||
|
||||
func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) error {
|
||||
func (sender *PacketSender) SendMessageFmt(fmtStr string, args ...interface{}) error {
|
||||
return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...)))
|
||||
}
|
||||
|
||||
|
131
pkg/packet/shellstate.go
Normal file
131
pkg/packet/shellstate.go
Normal file
@ -0,0 +1,131 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/binpack"
|
||||
"github.com/scripthaus-dev/mshell/pkg/statediff"
|
||||
)
|
||||
|
||||
const ShellStatePackVersion = 0
|
||||
const ShellStateDiffPackVersion = 0
|
||||
|
||||
type ShellState struct {
|
||||
Version string `json:"version"` // [type] [semver]
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
ShellVars []byte `json:"shellvars,omitempty"`
|
||||
Aliases string `json:"aliases,omitempty"`
|
||||
Funcs string `json:"funcs,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type ShellStateDiff struct {
|
||||
Version string `json:"version"` // [type] [semver]
|
||||
BaseHash string `json:"basehash"`
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
VarsDiff []byte `json:"shellvarsdiff,omitempty"` // vardiff
|
||||
AliasesDiff []byte `json:"aliasesdiff,omitempty"` // linediff
|
||||
FuncsDiff []byte `json:"funcsdiff,omitempty"` // linediff
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (state ShellState) IsEmpty() bool {
|
||||
return state.Version == "" && state.Cwd == "" && len(state.ShellVars) == 0 && state.Aliases == "" && state.Funcs == "" && state.Error == ""
|
||||
}
|
||||
|
||||
// returns (SHA1, encoded-state)
|
||||
func (state ShellState) EncodeAndHash() (string, []byte) {
|
||||
var buf bytes.Buffer
|
||||
binpack.PackInt(&buf, ShellStatePackVersion)
|
||||
binpack.PackValue(&buf, []byte(state.Version))
|
||||
binpack.PackValue(&buf, []byte(state.Cwd))
|
||||
binpack.PackValue(&buf, state.ShellVars)
|
||||
binpack.PackValue(&buf, []byte(state.Aliases))
|
||||
binpack.PackValue(&buf, []byte(state.Funcs))
|
||||
binpack.PackValue(&buf, []byte(state.Error))
|
||||
hvalRaw := sha1.Sum(buf.Bytes())
|
||||
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
|
||||
return hval, buf.Bytes()
|
||||
}
|
||||
|
||||
func (state ShellState) MarshalJSON() ([]byte, error) {
|
||||
_, encodedState := state.EncodeAndHash()
|
||||
return json.Marshal(encodedState)
|
||||
}
|
||||
|
||||
func (state *ShellState) UnmarshalJSON(jsonBytes []byte) error {
|
||||
var barr []byte
|
||||
err := json.Unmarshal(jsonBytes, &barr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf := bytes.NewBuffer(barr)
|
||||
u := binpack.MakeUnpacker(buf)
|
||||
version := u.UnpackInt("ShellState pack version")
|
||||
if version != ShellStatePackVersion {
|
||||
return fmt.Errorf("invalid ShellState pack version: %d", version)
|
||||
}
|
||||
state.Version = string(u.UnpackValue("ShellState.Version"))
|
||||
state.Cwd = string(u.UnpackValue("ShellState.Cwd"))
|
||||
state.ShellVars = u.UnpackValue("ShellState.ShellVars")
|
||||
state.Aliases = string(u.UnpackValue("ShellState.Aliases"))
|
||||
state.Funcs = string(u.UnpackValue("ShellState.Funcs"))
|
||||
state.Error = string(u.UnpackValue("ShellState.Error"))
|
||||
return u.Error()
|
||||
}
|
||||
|
||||
func (sdiff ShellStateDiff) MarshalJSON() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
binpack.PackInt(&buf, ShellStateDiffPackVersion)
|
||||
binpack.PackValue(&buf, []byte(sdiff.Version))
|
||||
binpack.PackValue(&buf, []byte(sdiff.BaseHash))
|
||||
binpack.PackValue(&buf, []byte(sdiff.Cwd))
|
||||
binpack.PackValue(&buf, sdiff.VarsDiff)
|
||||
binpack.PackValue(&buf, sdiff.AliasesDiff)
|
||||
binpack.PackValue(&buf, sdiff.FuncsDiff)
|
||||
binpack.PackValue(&buf, []byte(sdiff.Error))
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (sdiff *ShellStateDiff) UnmarshalJSON(jsonBytes []byte) error {
|
||||
var barr []byte
|
||||
err := json.Unmarshal(jsonBytes, &barr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf := bytes.NewBuffer(barr)
|
||||
u := binpack.MakeUnpacker(buf)
|
||||
version := u.UnpackInt("ShellState pack version")
|
||||
if version != ShellStateDiffPackVersion {
|
||||
return fmt.Errorf("invalid ShellStateDiff pack version: %d", version)
|
||||
}
|
||||
sdiff.Version = string(u.UnpackValue("ShellStateDiff.Version"))
|
||||
sdiff.BaseHash = string(u.UnpackValue("ShellStateDiff.BaseHash"))
|
||||
sdiff.Cwd = string(u.UnpackValue("ShellStateDiff.Cwd"))
|
||||
sdiff.VarsDiff = u.UnpackValue("ShellStateDiff.VarsDiff")
|
||||
sdiff.AliasesDiff = u.UnpackValue("ShellStateDiff.AliasesDiff")
|
||||
sdiff.FuncsDiff = u.UnpackValue("ShellStateDiff.FuncsDiff")
|
||||
sdiff.Error = string(u.UnpackValue("ShellStateDiff.Error"))
|
||||
return u.Error()
|
||||
}
|
||||
|
||||
func (sdiff ShellStateDiff) Dump() {
|
||||
fmt.Printf("ShellStateDiff:\n")
|
||||
fmt.Printf(" version: %s\n", sdiff.Version)
|
||||
fmt.Printf(" base: %s\n", sdiff.BaseHash)
|
||||
var mdiff statediff.MapDiffType
|
||||
err := mdiff.Decode(sdiff.VarsDiff)
|
||||
if err != nil {
|
||||
fmt.Printf(" vars: error[%s]\n", err.Error())
|
||||
} else {
|
||||
mdiff.Dump()
|
||||
}
|
||||
fmt.Printf(" aliases: %d, funcs: %d\n", len(sdiff.AliasesDiff), len(sdiff.FuncsDiff))
|
||||
if sdiff.Error != "" {
|
||||
fmt.Printf(" error: %s\n", sdiff.Error)
|
||||
}
|
||||
}
|
@ -14,6 +14,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
@ -23,11 +24,13 @@ import (
|
||||
|
||||
// TODO create unblockable packet-sender (backed by an array) for clientproc
|
||||
type MServer struct {
|
||||
Lock *sync.Mutex
|
||||
MainInput *packet.PacketParser
|
||||
Sender *packet.PacketSender
|
||||
ClientMap map[base.CommandKey]*shexec.ClientProc
|
||||
Debug bool
|
||||
Lock *sync.Mutex
|
||||
MainInput *packet.PacketParser
|
||||
Sender *packet.PacketSender
|
||||
ClientMap map[base.CommandKey]*shexec.ClientProc
|
||||
Debug bool
|
||||
StateMap map[string]*packet.ShellState // sha1->state
|
||||
CurrentState string // sha1
|
||||
}
|
||||
|
||||
func (m *MServer) Close() {
|
||||
@ -38,7 +41,7 @@ func (m *MServer) Close() {
|
||||
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
|
||||
ck := pk.GetCK()
|
||||
if ck == "" {
|
||||
m.Sender.SendMessage(fmt.Sprintf("received '%s' packet without ck", pk.GetType()))
|
||||
m.Sender.SendMessageFmt("received '%s' packet without ck", pk.GetType())
|
||||
return
|
||||
}
|
||||
m.Lock.Lock()
|
||||
@ -137,12 +140,24 @@ func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) setCurrentState(state *packet.ShellState) {
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
hval, _ := state.EncodeAndHash()
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.StateMap[hval] = state
|
||||
m.CurrentState = hval
|
||||
}
|
||||
|
||||
func (m *MServer) reinit(reqId string) {
|
||||
initPk, err := shexec.MakeServerInitPacket()
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
|
||||
return
|
||||
}
|
||||
m.setCurrentState(initPk.State)
|
||||
initPk.RespId = reqId
|
||||
m.Sender.SendPacket(initPk)
|
||||
}
|
||||
@ -170,6 +185,32 @@ func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) getCurrentState() (string, *packet.ShellState) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
return m.CurrentState, m.StateMap[m.CurrentState]
|
||||
}
|
||||
|
||||
func (m *MServer) clientPacketCallback(pk packet.PacketType) {
|
||||
if pk.GetType() != packet.CmdDonePacketStr {
|
||||
return
|
||||
}
|
||||
donePk := pk.(*packet.CmdDonePacketType)
|
||||
if donePk.FinalState == nil {
|
||||
return
|
||||
}
|
||||
stateHash, curState := m.getCurrentState()
|
||||
if curState == nil {
|
||||
return
|
||||
}
|
||||
diff, err := shexec.MakeShellStateDiff(*curState, stateHash, *donePk.FinalState)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
donePk.FinalState = nil
|
||||
donePk.FinalStateDiff = &diff
|
||||
}
|
||||
|
||||
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
if err := runPacket.CK.Validate("packet"); err != nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
|
||||
@ -190,16 +231,34 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
m.Lock.Unlock()
|
||||
go func() {
|
||||
defer func() {
|
||||
r := recover()
|
||||
finalPk := packet.MakeCmdFinalPacket(runPacket.CK)
|
||||
finalPk.Ts = time.Now().UnixMilli()
|
||||
if r != nil {
|
||||
finalPk.Error = fmt.Sprintf("%s", r)
|
||||
}
|
||||
m.Sender.SendPacket(finalPk)
|
||||
m.Lock.Lock()
|
||||
delete(m.ClientMap, runPacket.CK)
|
||||
m.Lock.Unlock()
|
||||
cproc.Close()
|
||||
}()
|
||||
shexec.SendRunPacketAndRunData(context.Background(), cproc.Input, runPacket)
|
||||
cproc.ProxySingleOutput(runPacket.CK, m.Sender)
|
||||
cproc.ProxySingleOutput(runPacket.CK, m.Sender, m.clientPacketCallback)
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *MServer) packetSenderErrorHandler(sender *packet.PacketSender, pk packet.PacketType, err error) {
|
||||
if serr, ok := err.(*packet.SendError); ok && serr.IsMarshalError {
|
||||
msg := packet.MakeMessagePacket(err.Error())
|
||||
if cpk, ok := pk.(packet.CommandPacketType); ok {
|
||||
msg.CK = cpk.GetCK()
|
||||
}
|
||||
sender.SendPacket(msg)
|
||||
}
|
||||
// otherwise ignore (we can't output anything for a I/O error)
|
||||
}
|
||||
|
||||
func RunServer() (int, error) {
|
||||
debug := false
|
||||
if len(os.Args) >= 3 && os.Args[2] == "--debug" {
|
||||
@ -208,19 +267,21 @@ func RunServer() (int, error) {
|
||||
server := &MServer{
|
||||
Lock: &sync.Mutex{},
|
||||
ClientMap: make(map[base.CommandKey]*shexec.ClientProc),
|
||||
StateMap: make(map[string]*packet.ShellState),
|
||||
Debug: debug,
|
||||
}
|
||||
if debug {
|
||||
packet.GlobalDebug = true
|
||||
}
|
||||
server.MainInput = packet.MakePacketParser(os.Stdin)
|
||||
server.Sender = packet.MakePacketSender(os.Stdout)
|
||||
server.Sender = packet.MakePacketSender(os.Stdout, server.packetSenderErrorHandler)
|
||||
defer server.Close()
|
||||
var err error
|
||||
initPacket, err := shexec.MakeServerInitPacket()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
server.setCurrentState(initPacket.State)
|
||||
server.Sender.SendPacket(initPacket)
|
||||
builder := packet.MakeRunPacketBuilder()
|
||||
for pk := range server.MainInput.MainCh {
|
||||
@ -243,7 +304,7 @@ func RunServer() (int, error) {
|
||||
server.ProcessRpcPacket(rpcPk)
|
||||
continue
|
||||
}
|
||||
server.Sender.SendMessage(fmt.Sprintf("invalid packet '%s' sent to mshell server", packet.AsString(pk)))
|
||||
server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk))
|
||||
continue
|
||||
}
|
||||
return 0, nil
|
||||
|
@ -46,7 +46,7 @@ func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.I
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("running local client: %w", err)
|
||||
}
|
||||
sender := packet.MakePacketSender(inputWriter)
|
||||
sender := packet.MakePacketSender(inputWriter, nil)
|
||||
stdoutPacketParser := packet.MakePacketParser(stdoutReader)
|
||||
stderrPacketParser := packet.MakePacketParser(stderrReader)
|
||||
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser)
|
||||
@ -108,9 +108,12 @@ func (cproc *ClientProc) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.PacketSender) {
|
||||
func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.PacketSender, packetCallback func(packet.PacketType)) {
|
||||
sentDonePk := false
|
||||
for pk := range cproc.Output.MainCh {
|
||||
if packetCallback != nil {
|
||||
packetCallback(pk)
|
||||
}
|
||||
if pk.GetType() == packet.CmdDonePacketStr {
|
||||
sentDonePk = true
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
"github.com/scripthaus-dev/mshell/pkg/statediff"
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
@ -189,6 +190,32 @@ func ParseDeclLine(envLine string) *DeclareDeclType {
|
||||
}
|
||||
}
|
||||
|
||||
func parseDeclLineToKV(envLine string) (string, string) {
|
||||
eqIdx := strings.Index(envLine, "=")
|
||||
if eqIdx == -1 {
|
||||
return "", ""
|
||||
}
|
||||
namePart := envLine[0:eqIdx]
|
||||
valPart := envLine[eqIdx+1:]
|
||||
return namePart, valPart
|
||||
}
|
||||
|
||||
func shellStateVarsToMap(shellVars []byte) map[string]string {
|
||||
if len(shellVars) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
vars := bytes.Split(shellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
name, val := parseDeclLineToKV(string(varLine))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
rtn[name] = val
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func DeclMapFromState(state *packet.ShellState) map[string]*DeclareDeclType {
|
||||
if state == nil {
|
||||
return nil
|
||||
@ -485,3 +512,24 @@ func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool
|
||||
}
|
||||
return d1.Value == d2.Value // this works even for assoc arrays because we normalize them when parsing
|
||||
}
|
||||
|
||||
func MakeShellStateDiff(oldState packet.ShellState, oldStateHash string, newState packet.ShellState) (packet.ShellStateDiff, error) {
|
||||
var rtn packet.ShellStateDiff
|
||||
rtn.BaseHash = oldStateHash
|
||||
if oldState.Version != newState.Version {
|
||||
return rtn, fmt.Errorf("cannot diff, states have different versions")
|
||||
}
|
||||
rtn.Version = newState.Version
|
||||
if oldState.Cwd != newState.Cwd {
|
||||
rtn.Cwd = newState.Cwd
|
||||
}
|
||||
if oldState.Error != newState.Error {
|
||||
rtn.Error = newState.Error
|
||||
}
|
||||
oldVars := shellStateVarsToMap(oldState.ShellVars)
|
||||
newVars := shellStateVarsToMap(newState.ShellVars)
|
||||
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
|
||||
rtn.AliasesDiff = statediff.MakeLineDiff(oldState.Aliases, newState.Aliases)
|
||||
rtn.FuncsDiff = statediff.MakeLineDiff(oldState.Funcs, newState.Funcs)
|
||||
return rtn, nil
|
||||
}
|
||||
|
@ -834,7 +834,7 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdCon
|
||||
stdoutPacketParser := packet.MakePacketParser(stdoutReader)
|
||||
stderrPacketParser := packet.MakePacketParser(stderrReader)
|
||||
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser)
|
||||
sender := packet.MakePacketSender(inputWriter)
|
||||
sender := packet.MakePacketSender(inputWriter, nil)
|
||||
versionOk := false
|
||||
for pk := range packetParser.MainCh {
|
||||
if pk.GetType() == packet.RawPacketStr {
|
||||
@ -986,14 +986,6 @@ func makeRcFileStr(pk *packet.RunPacketType) string {
|
||||
rcBuf.WriteString(pk.State.Aliases)
|
||||
rcBuf.WriteString("\n")
|
||||
}
|
||||
if pk.ReturnState {
|
||||
rcBuf.WriteString(`
|
||||
_scripthaus_exittrap () {
|
||||
` + GetShellStateCmd + `
|
||||
}
|
||||
trap _scripthaus_exittrap EXIT
|
||||
`)
|
||||
}
|
||||
return rcBuf.String()
|
||||
}
|
||||
|
||||
@ -1285,7 +1277,7 @@ func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("cannot open runout file '%s': %w", fileNames.RunnerOutFile, err)
|
||||
}
|
||||
cmd.DetachedOutput = packet.MakePacketSender(cmd.RunnerOutFd)
|
||||
cmd.DetachedOutput = packet.MakePacketSender(cmd.RunnerOutFd, nil)
|
||||
ecmd, err := MakeDetachedExecCmd(pk, cmdTty)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -19,7 +19,7 @@ type LineDiffType struct {
|
||||
NewData []string
|
||||
}
|
||||
|
||||
func (diff LineDiffType) dump() {
|
||||
func (diff LineDiffType) Dump() {
|
||||
fmt.Printf("DIFF:\n")
|
||||
pos := 1
|
||||
for _, entry := range diff.Lines {
|
||||
@ -68,7 +68,7 @@ func putUVarint(buf *bytes.Buffer, viBuf []byte, ival int) {
|
||||
// simple encoding
|
||||
// write varints. first version, then len, then len-number-of-varints, then fill the rest with newdata
|
||||
// [version] [len-varint] [varint]xlen... newdata (bytes)
|
||||
func (diff LineDiffType) encode() []byte {
|
||||
func (diff LineDiffType) Encode() []byte {
|
||||
var buf bytes.Buffer
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
putUVarint(&buf, viBuf, LineDiffVersion)
|
||||
@ -86,7 +86,7 @@ func (diff LineDiffType) encode() []byte {
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (rtn *LineDiffType) decode(diffBytes []byte) error {
|
||||
func (rtn *LineDiffType) Decode(diffBytes []byte) error {
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
@ -164,12 +164,12 @@ func MakeLineDiff(str1 string, str2 string) []byte {
|
||||
str1Arr := strings.Split(str1, "\n")
|
||||
str2Arr := strings.Split(str2, "\n")
|
||||
diff := makeLineDiff(str1Arr, str2Arr)
|
||||
return diff.encode()
|
||||
return diff.Encode()
|
||||
}
|
||||
|
||||
func ApplyLineDiff(str1 string, diffBytes []byte) (string, error) {
|
||||
var diff LineDiffType
|
||||
err := diff.decode(diffBytes)
|
||||
err := diff.Decode(diffBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ type MapDiffType struct {
|
||||
ToRemove []string
|
||||
}
|
||||
|
||||
func (diff MapDiffType) dump() {
|
||||
func (diff MapDiffType) Dump() {
|
||||
fmt.Printf("VAR-DIFF\n")
|
||||
for name, val := range diff.ToAdd {
|
||||
fmt.Printf(" add: %s=%s\n", name, val)
|
||||
@ -58,7 +58,7 @@ func (diff MapDiffType) apply(oldMap map[string]string) map[string]string {
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (diff MapDiffType) encode() []byte {
|
||||
func (diff MapDiffType) Encode() []byte {
|
||||
var buf bytes.Buffer
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
putUVarint(&buf, viBuf, MapDiffVersion)
|
||||
@ -76,7 +76,7 @@ func (diff MapDiffType) encode() []byte {
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (diff *MapDiffType) decode(diffBytes []byte) error {
|
||||
func (diff *MapDiffType) Decode(diffBytes []byte) error {
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
@ -111,12 +111,12 @@ func (diff *MapDiffType) decode(diffBytes []byte) error {
|
||||
|
||||
func MakeMapDiff(m1 map[string]string, m2 map[string]string) []byte {
|
||||
diff := makeMapDiff(m1, m2)
|
||||
return diff.encode()
|
||||
return diff.Encode()
|
||||
}
|
||||
|
||||
func ApplyMapDiff(oldMap map[string]string, diffBytes []byte) (map[string]string, error) {
|
||||
var diff MapDiffType
|
||||
err := diff.decode(diffBytes)
|
||||
err := diff.Decode(diffBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ func testLineDiff(t *testing.T, str1 string, str2 string) {
|
||||
t.Errorf("bad diff output")
|
||||
}
|
||||
var dt LineDiffType
|
||||
err = dt.decode(diffBytes)
|
||||
err = dt.Decode(diffBytes)
|
||||
if err != nil {
|
||||
t.Errorf("error decoding diff: %v\n", err)
|
||||
}
|
||||
@ -86,8 +86,8 @@ func TestMapDiff(t *testing.T) {
|
||||
diffBytes := MakeMapDiff(m1, m2)
|
||||
fmt.Printf("mapdifflen: %d\n", len(diffBytes))
|
||||
var diff MapDiffType
|
||||
diff.decode(diffBytes)
|
||||
diff.dump()
|
||||
diff.Decode(diffBytes)
|
||||
diff.Dump()
|
||||
mcheck, err := ApplyMapDiff(m1, diffBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("error applying map diff: %v", err)
|
||||
|
Loading…
Reference in New Issue
Block a user