Packetparser Ignore Spurious Invalid Input at Beginning of Stream (#140)

* take a stab at fixing #99. ignore invalid output before we see a real packet.  the complication here was ensuring we always output a real packet in every flow so we don't actually lose valid errors.

* add ping packets to prime the parser (when in ignoreUntilValid mode)
This commit is contained in:
Mike Sawka 2023-12-18 12:42:40 -08:00 committed by GitHub
parent ce252d479b
commit a639d72e30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 35 additions and 14 deletions

View File

@ -156,7 +156,7 @@ func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType
}
func handleSingle(fromServer bool) {
packetParser := packet.MakePacketParser(os.Stdin, false)
packetParser := packet.MakePacketParser(os.Stdin, nil)
sender := packet.MakePacketSender(os.Stdout, nil)
defer func() {
sender.Close()

View File

@ -177,13 +177,22 @@ func (p *PacketParser) SetErr(err error) {
}
}
func MakePacketParser(input io.Reader, rpcHandler bool) *PacketParser {
type PacketParserOpts struct {
RpcHandler bool
IgnoreUntilValid bool
}
func MakePacketParser(input io.Reader, opts *PacketParserOpts) *PacketParser {
if opts == nil {
opts = &PacketParserOpts{}
}
parser := &PacketParser{
Lock: &sync.Mutex{},
MainCh: make(chan PacketType),
RpcMap: make(map[string]*RpcEntry),
RpcHandler: rpcHandler,
RpcHandler: opts.RpcHandler,
}
ignoreUntilValid := opts.IgnoreUntilValid
bufReader := bufio.NewReader(input)
go func() {
defer func() {
@ -204,11 +213,15 @@ func MakePacketParser(input io.Reader, rpcHandler bool) *PacketParser {
// ##[len][json]\n
// ##14{"hello":true}\n
// ##N{...}
hasPrefix := strings.HasPrefix(line, "##")
bracePos := strings.Index(line, "{")
if !strings.HasPrefix(line, "##") || bracePos == -1 {
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
if !hasPrefix || bracePos == -1 {
if !ignoreUntilValid {
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
}
continue
}
ignoreUntilValid = false
packetLen := -1
if line[2:bracePos] != "N" {
packetLen, err = strconv.Atoi(line[2:bracePos])

View File

@ -717,7 +717,7 @@ func RunServer() (int, error) {
if debug {
packet.GlobalDebug = true
}
server.MainInput = packet.MakePacketParser(os.Stdin, false)
server.MainInput = packet.MakePacketParser(os.Stdin, nil)
server.Sender = packet.MakePacketSender(os.Stdout, server.packetSenderErrorHandler)
defer server.Close()
var err error

View File

@ -50,8 +50,8 @@ func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.I
return nil, nil, fmt.Errorf("running local client: %w", err)
}
sender := packet.MakePacketSender(inputWriter, nil)
stdoutPacketParser := packet.MakePacketParser(stdoutReader, false)
stderrPacketParser := packet.MakePacketParser(stderrReader, false)
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
cproc := &ClientProc{
Cmd: ecmd,

View File

@ -754,7 +754,7 @@ func RunInstallFromCmd(ctx context.Context, ecmd *exec.Cmd, tryDetect bool, mshe
if mshellStream != nil {
sendMShellBinary(inputWriter, mshellStream)
}
packetParser := packet.MakePacketParser(stdoutReader, false)
packetParser := packet.MakePacketParser(stdoutReader, nil)
err = ecmd.Start()
if err != nil {
return fmt.Errorf("running ssh command: %w", err)
@ -887,8 +887,8 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdCon
return nil, fmt.Errorf("running ssh command: %w", err)
}
defer cmd.Close()
stdoutPacketParser := packet.MakePacketParser(stdoutReader, false)
stderrPacketParser := packet.MakePacketParser(stderrReader, false)
stdoutPacketParser := packet.MakePacketParser(stdoutReader, nil)
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, false)
sender := packet.MakePacketSender(inputWriter, nil)
versionOk := false

View File

@ -43,6 +43,11 @@ const RemoteTermCols = 80
const PtyReadBufSize = 100
const RemoteConnectTimeout = 15 * time.Second
// we add this ping packet to the MShellServer Commands in order to deal with spurious SSH output
// basically we guarantee the parser will see a valid packet (either an init error or a ping)
// so we can pass ignoreUntilValid to PacketParser
const PrintPingPacket = `printf "\n##N{\"type\": \"ping\"}\n"`
const MShellServerCommandFmt = `
PATH=$PATH:~/.mshell;
which mshell-[%VERSION%] > /dev/null;
@ -50,6 +55,7 @@ if [[ "$?" -ne 0 ]]
then
printf "\n##N{\"type\": \"init\", \"notfound\": true, \"uname\": \"%s | %s\"}\n" "$(uname -s)" "$(uname -m)"
else
[%PINGPACKET%]
mshell-[%VERSION%] --server
fi
`
@ -60,14 +66,16 @@ func MakeLocalMShellCommandStr(isSudo bool) (string, error) {
return "", err
}
if isSudo {
return fmt.Sprintf("sudo %s --server", mshellPath), nil
return fmt.Sprintf(`%s; sudo %s --server`, PrintPingPacket, mshellPath), nil
} else {
return fmt.Sprintf("%s --server", mshellPath), nil
return fmt.Sprintf(`%s; %s --server`, PrintPingPacket, mshellPath), nil
}
}
func MakeServerCommandStr() string {
return strings.ReplaceAll(MShellServerCommandFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion))
rtn := strings.ReplaceAll(MShellServerCommandFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion))
rtn = strings.ReplaceAll(rtn, "[%PINGPACKET%]", PrintPingPacket)
return rtn
}
const (