Packetparser Ignore Spurious Invalid Input at Beginning of Stream (#140)

* take a stab at fixing #99. ignore invalid output before we see a real packet.  the complication here was ensuring we always output a real packet in every flow so we don't actually lose valid errors.

* add ping packets to prime the parser (when in ignoreUntilValid mode)
This commit is contained in:
Mike Sawka 2023-12-18 12:42:40 -08:00 committed by GitHub
parent ce252d479b
commit a639d72e30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 35 additions and 14 deletions

View File

@ -156,7 +156,7 @@ func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType
} }
func handleSingle(fromServer bool) { func handleSingle(fromServer bool) {
packetParser := packet.MakePacketParser(os.Stdin, false) packetParser := packet.MakePacketParser(os.Stdin, nil)
sender := packet.MakePacketSender(os.Stdout, nil) sender := packet.MakePacketSender(os.Stdout, nil)
defer func() { defer func() {
sender.Close() sender.Close()

View File

@ -177,13 +177,22 @@ func (p *PacketParser) SetErr(err error) {
} }
} }
func MakePacketParser(input io.Reader, rpcHandler bool) *PacketParser { type PacketParserOpts struct {
RpcHandler bool
IgnoreUntilValid bool
}
func MakePacketParser(input io.Reader, opts *PacketParserOpts) *PacketParser {
if opts == nil {
opts = &PacketParserOpts{}
}
parser := &PacketParser{ parser := &PacketParser{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
MainCh: make(chan PacketType), MainCh: make(chan PacketType),
RpcMap: make(map[string]*RpcEntry), RpcMap: make(map[string]*RpcEntry),
RpcHandler: rpcHandler, RpcHandler: opts.RpcHandler,
} }
ignoreUntilValid := opts.IgnoreUntilValid
bufReader := bufio.NewReader(input) bufReader := bufio.NewReader(input)
go func() { go func() {
defer func() { defer func() {
@ -204,11 +213,15 @@ func MakePacketParser(input io.Reader, rpcHandler bool) *PacketParser {
// ##[len][json]\n // ##[len][json]\n
// ##14{"hello":true}\n // ##14{"hello":true}\n
// ##N{...} // ##N{...}
hasPrefix := strings.HasPrefix(line, "##")
bracePos := strings.Index(line, "{") bracePos := strings.Index(line, "{")
if !strings.HasPrefix(line, "##") || bracePos == -1 { if !hasPrefix || bracePos == -1 {
parser.MainCh <- MakeRawPacket(line[:len(line)-1]) if !ignoreUntilValid {
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
}
continue continue
} }
ignoreUntilValid = false
packetLen := -1 packetLen := -1
if line[2:bracePos] != "N" { if line[2:bracePos] != "N" {
packetLen, err = strconv.Atoi(line[2:bracePos]) packetLen, err = strconv.Atoi(line[2:bracePos])

View File

@ -717,7 +717,7 @@ func RunServer() (int, error) {
if debug { if debug {
packet.GlobalDebug = true packet.GlobalDebug = true
} }
server.MainInput = packet.MakePacketParser(os.Stdin, false) server.MainInput = packet.MakePacketParser(os.Stdin, nil)
server.Sender = packet.MakePacketSender(os.Stdout, server.packetSenderErrorHandler) server.Sender = packet.MakePacketSender(os.Stdout, server.packetSenderErrorHandler)
defer server.Close() defer server.Close()
var err error var err error

View File

@ -50,8 +50,8 @@ func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.I
return nil, nil, fmt.Errorf("running local client: %w", err) return nil, nil, fmt.Errorf("running local client: %w", err)
} }
sender := packet.MakePacketSender(inputWriter, nil) sender := packet.MakePacketSender(inputWriter, nil)
stdoutPacketParser := packet.MakePacketParser(stdoutReader, false) stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
stderrPacketParser := packet.MakePacketParser(stderrReader, false) stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true) packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
cproc := &ClientProc{ cproc := &ClientProc{
Cmd: ecmd, Cmd: ecmd,

View File

@ -754,7 +754,7 @@ func RunInstallFromCmd(ctx context.Context, ecmd *exec.Cmd, tryDetect bool, mshe
if mshellStream != nil { if mshellStream != nil {
sendMShellBinary(inputWriter, mshellStream) sendMShellBinary(inputWriter, mshellStream)
} }
packetParser := packet.MakePacketParser(stdoutReader, false) packetParser := packet.MakePacketParser(stdoutReader, nil)
err = ecmd.Start() err = ecmd.Start()
if err != nil { if err != nil {
return fmt.Errorf("running ssh command: %w", err) return fmt.Errorf("running ssh command: %w", err)
@ -887,8 +887,8 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdCon
return nil, fmt.Errorf("running ssh command: %w", err) return nil, fmt.Errorf("running ssh command: %w", err)
} }
defer cmd.Close() defer cmd.Close()
stdoutPacketParser := packet.MakePacketParser(stdoutReader, false) stdoutPacketParser := packet.MakePacketParser(stdoutReader, nil)
stderrPacketParser := packet.MakePacketParser(stderrReader, false) stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, false) packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, false)
sender := packet.MakePacketSender(inputWriter, nil) sender := packet.MakePacketSender(inputWriter, nil)
versionOk := false versionOk := false

View File

@ -43,6 +43,11 @@ const RemoteTermCols = 80
const PtyReadBufSize = 100 const PtyReadBufSize = 100
const RemoteConnectTimeout = 15 * time.Second const RemoteConnectTimeout = 15 * time.Second
// we add this ping packet to the MShellServer Commands in order to deal with spurious SSH output
// basically we guarantee the parser will see a valid packet (either an init error or a ping)
// so we can pass ignoreUntilValid to PacketParser
const PrintPingPacket = `printf "\n##N{\"type\": \"ping\"}\n"`
const MShellServerCommandFmt = ` const MShellServerCommandFmt = `
PATH=$PATH:~/.mshell; PATH=$PATH:~/.mshell;
which mshell-[%VERSION%] > /dev/null; which mshell-[%VERSION%] > /dev/null;
@ -50,6 +55,7 @@ if [[ "$?" -ne 0 ]]
then then
printf "\n##N{\"type\": \"init\", \"notfound\": true, \"uname\": \"%s | %s\"}\n" "$(uname -s)" "$(uname -m)" printf "\n##N{\"type\": \"init\", \"notfound\": true, \"uname\": \"%s | %s\"}\n" "$(uname -s)" "$(uname -m)"
else else
[%PINGPACKET%]
mshell-[%VERSION%] --server mshell-[%VERSION%] --server
fi fi
` `
@ -60,14 +66,16 @@ func MakeLocalMShellCommandStr(isSudo bool) (string, error) {
return "", err return "", err
} }
if isSudo { if isSudo {
return fmt.Sprintf("sudo %s --server", mshellPath), nil return fmt.Sprintf(`%s; sudo %s --server`, PrintPingPacket, mshellPath), nil
} else { } else {
return fmt.Sprintf("%s --server", mshellPath), nil return fmt.Sprintf(`%s; %s --server`, PrintPingPacket, mshellPath), nil
} }
} }
func MakeServerCommandStr() string { func MakeServerCommandStr() string {
return strings.ReplaceAll(MShellServerCommandFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion)) rtn := strings.ReplaceAll(MShellServerCommandFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion))
rtn = strings.ReplaceAll(rtn, "[%PINGPACKET%]", PrintPingPacket)
return rtn
} }
const ( const (