Handle legacy server list ping for handshake

This commit is contained in:
Geoff Bourne 2019-07-13 15:40:34 -05:00
parent 3699931af0
commit 4c99daafa3
4 changed files with 300 additions and 78 deletions

View File

@ -1,20 +1,38 @@
package mcproto package mcproto
import ( import (
"bufio"
"bytes" "bytes"
"errors" "encoding/binary"
"github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
"io" "io"
"net" "net"
"strings" "strings"
"time" "time"
) )
func ReadPacket(reader io.Reader, addr net.Addr) (*Packet, error) { func ReadPacket(reader io.Reader, addr net.Addr, state State) (*Packet, error) {
logrus. logrus.
WithField("client", addr). WithField("client", addr).
Debug("Reading packet") Debug("Reading packet")
if state == StateHandshaking {
bufReader := bufio.NewReader(reader)
data, err := bufReader.Peek(1)
if err != nil {
return nil, err
}
if data[0] == PacketIdLegacyServerListPing {
return ReadLegacyServerListPing(bufReader, addr)
} else {
reader = bufReader
}
}
frame, err := ReadFrame(reader, addr) frame, err := ReadFrame(reader, addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -38,6 +56,97 @@ func ReadPacket(reader io.Reader, addr net.Addr) (*Packet, error) {
return packet, nil return packet, nil
} }
func ReadLegacyServerListPing(reader *bufio.Reader, addr net.Addr) (*Packet, error) {
logrus.
WithField("client", addr).
Debug("Reading legacy server list ping")
packetId, err := reader.ReadByte()
if err != nil {
return nil, err
}
if packetId != PacketIdLegacyServerListPing {
return nil, errors.Errorf("expected legacy server listing ping packet ID, got %x", packetId)
}
payload, err := reader.ReadByte()
if err != nil {
return nil, err
}
if payload != 0x01 {
return nil, errors.Errorf("expected payload=1 from legacy server listing ping, got %x", payload)
}
packetIdForPluginMsg, err := reader.ReadByte()
if err != nil {
return nil, err
}
if packetIdForPluginMsg != 0xFA {
return nil, errors.Errorf("expected packetIdForPluginMsg=0xFA from legacy server listing ping, got %x", packetIdForPluginMsg)
}
messageNameShortLen, err := ReadUnsignedShort(reader)
if err != nil {
return nil, err
}
if messageNameShortLen != 11 {
return nil, errors.Errorf("expected messageNameShortLen=11 from legacy server listing ping, got %d", messageNameShortLen)
}
messageName, err := ReadUTF16BEString(reader, messageNameShortLen)
if messageName != "MC|PingHost" {
return nil, errors.Errorf("expected messageName=MC|PingHost, got %s", messageName)
}
remainingLen, err := ReadUnsignedShort(reader)
remainingReader := io.LimitReader(reader, int64(remainingLen))
protocolVersion, err := ReadByte(remainingReader)
if err != nil {
return nil, err
}
hostnameLen, err := ReadUnsignedShort(remainingReader)
if err != nil {
return nil, err
}
hostname, err := ReadUTF16BEString(remainingReader, hostnameLen)
if err != nil {
return nil, err
}
port, err := ReadUnsignedInt(remainingReader)
if err != nil {
return nil, err
}
return &Packet{
PacketID: PacketIdLegacyServerListPing,
Length: 0,
Data: &LegacyServerListPing{
ProtocolVersion: int(protocolVersion),
ServerAddress: hostname,
ServerPort: uint16(port),
},
}, nil
}
func ReadUTF16BEString(reader io.Reader, symbolLen uint16) (string, error) {
bsUtf16be := make([]byte, symbolLen*2)
_, err := io.ReadFull(reader, bsUtf16be)
if err != nil {
return "", err
}
result, _, err := transform.Bytes(unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewDecoder(), bsUtf16be)
if err != nil {
return "", err
}
return string(result), nil
}
func ReadFrame(reader io.Reader, addr net.Addr) (*Frame, error) { func ReadFrame(reader io.Reader, addr net.Addr) (*Frame, error) {
logrus. logrus.
WithField("client", addr). WithField("client", addr).
@ -136,25 +245,43 @@ func ReadString(reader io.Reader) (string, error) {
return strBuilder.String(), nil return strBuilder.String(), nil
} }
func ReadByte(reader io.Reader) (byte, error) {
buf := make([]byte, 1)
_, err := reader.Read(buf)
if err != nil {
return 0, err
} else {
return buf[0], nil
}
}
func ReadUnsignedShort(reader io.Reader) (uint16, error) { func ReadUnsignedShort(reader io.Reader) (uint16, error) {
upper := make([]byte, 1) var value uint16
_, err := reader.Read(upper) err := binary.Read(reader, binary.BigEndian, &value)
if err != nil { if err != nil {
return 0, err return 0, err
} }
lower := make([]byte, 1) return value, nil
_, err = reader.Read(lower) }
func ReadUnsignedInt(reader io.Reader) (uint32, error) {
var value uint32
err := binary.Read(reader, binary.BigEndian, &value)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return value, nil
return (uint16(upper[0]) << 8) | uint16(lower[0]), nil
} }
func ReadHandshake(data []byte) (*Handshake, error) { func ReadHandshake(data interface{}) (*Handshake, error) {
dataBytes, ok := data.([]byte)
if !ok {
return nil, errors.New("data is not expected byte slice")
}
handshake := &Handshake{} handshake := &Handshake{}
buffer := bytes.NewBuffer(data) buffer := bytes.NewBuffer(dataBytes)
var err error var err error
handshake.ProtocolVersion, err = ReadVarInt(buffer) handshake.ProtocolVersion, err = ReadVarInt(buffer)

36
mcproto/read_test.go Normal file
View File

@ -0,0 +1,36 @@
package mcproto
import (
"bytes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestReadVarInt(t *testing.T) {
tests := []struct {
Name string
Input []byte
Expected int
}{
{
Name: "Single byte",
Input: []byte{0xFA, 0x00},
Expected: 0x7A,
},
{
Name: "Two byte",
Input: []byte{0x81, 0x04},
Expected: 0x0201,
},
}
for _, tt := range tests {
t.Run(tt.Name, func(t *testing.T) {
result, err := ReadVarInt(bytes.NewBuffer(tt.Input))
require.NoError(t, err)
assert.Equal(t, tt.Expected, result)
})
}
}

View File

@ -7,6 +7,12 @@ type Frame struct {
Payload []byte Payload []byte
} }
type State int
const (
StateHandshaking = iota
)
var trimLimit = 64 var trimLimit = 64
func trimBytes(data []byte) ([]byte, string) { func trimBytes(data []byte) ([]byte, string) {
@ -25,15 +31,23 @@ func (f *Frame) String() string {
type Packet struct { type Packet struct {
Length int Length int
PacketID int PacketID int
Data []byte // Data is either a byte slice of raw content or a parsed message
Data interface{}
} }
func (p *Packet) String() string { func (p *Packet) String() string {
trimmed, cont := trimBytes(p.Data) if dataBytes, ok := p.Data.([]byte); ok {
trimmed, cont := trimBytes(dataBytes)
return fmt.Sprintf("Frame:[len=%d, packetId=%d, data=%#X%s]", p.Length, p.PacketID, trimmed, cont) return fmt.Sprintf("Frame:[len=%d, packetId=%d, data=%#X%s]", p.Length, p.PacketID, trimmed, cont)
} else {
return fmt.Sprintf("Frame:[len=%d, packetId=%d, data=%+v]", p.Length, p.PacketID, p.Data)
}
} }
const PacketIdHandshake = 0x00 const (
PacketIdHandshake = 0x00
PacketIdLegacyServerListPing = 0xFE
)
type Handshake struct { type Handshake struct {
ProtocolVersion int ProtocolVersion int
@ -42,6 +56,12 @@ type Handshake struct {
NextState int NextState int
} }
type LegacyServerListPing struct {
ProtocolVersion int
ServerAddress string
ServerPort uint16
}
type ByteReader interface { type ByteReader interface {
ReadByte() (byte, error) ReadByte() (byte, error)
} }

View File

@ -12,7 +12,7 @@ import (
) )
const ( const (
handshakeTimeout = 2 * time.Second handshakeTimeout = 5 * time.Second
) )
var noDeadline time.Time var noDeadline time.Time
@ -21,9 +21,12 @@ type IConnector interface {
StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error
} }
var Connector IConnector = &connectorImpl{} var Connector IConnector = &connectorImpl{
state: mcproto.StateHandshaking,
}
type connectorImpl struct { type connectorImpl struct {
state mcproto.State
} }
func (c *connectorImpl) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error { func (c *connectorImpl) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error {
@ -70,71 +73,65 @@ func (c *connectorImpl) HandleConnection(ctx context.Context, frontendConn net.C
logrus. logrus.
WithField("client", clientAddr). WithField("client", clientAddr).
Info("Got connection") Info("Got connection")
defer logrus.WithField("client", clientAddr).Debug("Closing frontend connection")
inspectionBuffer := new(bytes.Buffer) inspectionBuffer := new(bytes.Buffer)
inspectionReader := io.TeeReader(frontendConn, inspectionBuffer) inspectionReader := io.TeeReader(frontendConn, inspectionBuffer)
if err := frontendConn.SetReadDeadline(time.Now().Add(handshakeTimeout)); err != nil { /* if err := frontendConn.SetReadDeadline(time.Now().Add(handshakeTimeout)); err != nil {
logrus. logrus.
WithError(err). WithError(err).
WithField("client", clientAddr). WithField("client", clientAddr).
Error("Failed to set read deadline") Error("Failed to set read deadline")
return return
} }
packet, err := mcproto.ReadPacket(inspectionReader, clientAddr) */packet, err := mcproto.ReadPacket(inspectionReader, clientAddr, c.state)
if err != nil { if err != nil {
logrus.WithError(err).WithField("clientAddr", clientAddr).Error("Failed to read packet") logrus.WithError(err).WithField("clientAddr", clientAddr).Error("Failed to read packet")
return return
} }
logrus.WithFields(logrus.Fields{"length": packet.Length, "packetID": packet.PacketID}).Info("Got packet") logrus.
WithField("client", clientAddr).
WithField("length", packet.Length).
WithField("packetID", packet.PacketID).
Debug("Got packet")
if packet.PacketID == mcproto.PacketIdHandshake { if packet.PacketID == mcproto.PacketIdHandshake {
handshake, err := mcproto.ReadHandshake(packet.Data) handshake, err := mcproto.ReadHandshake(packet.Data)
if err != nil { if err != nil {
logrus.WithError(err).WithField("clientAddr", clientAddr).Error("Failed to read handshake") logrus.WithError(err).WithField("clientAddr", clientAddr).
Error("Failed to read handshake")
return return
} }
logrus.WithFields(logrus.Fields{
"protocolVersion": handshake.ProtocolVersion,
"server": handshake.ServerAddress,
"serverPort": handshake.ServerPort,
"nextState": handshake.NextState,
}).Info("Got handshake")
backendHostPort := Routes.FindBackendForServerAddress(handshake.ServerAddress)
if backendHostPort == "" {
logrus.WithField("serverAddress", handshake.ServerAddress).Warn("Unable to find registered backend")
return
}
logrus.WithField("backendHostPort", backendHostPort).Info("Connecting to backend")
backendConn, err := net.Dial("tcp", backendHostPort)
if err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"serverAddress": handshake.ServerAddress,
"backend": backendHostPort,
}).Warn("Unable to connect to backend")
return
}
amount, err := io.Copy(backendConn, inspectionBuffer)
if err != nil {
logrus.WithError(err).Error("Failed to write handshake to backend connection")
return
}
logrus.WithField("amount", amount).Debug("Relayed handshake to backend")
if err = frontendConn.SetReadDeadline(noDeadline); err != nil {
logrus. logrus.
WithError(err).
WithField("client", clientAddr). WithField("client", clientAddr).
Error("Failed to clear read deadline") WithField("handshake", handshake).
Debug("Got handshake")
serverAddress := handshake.ServerAddress
c.findAndConnectBackend(ctx, frontendConn, clientAddr, inspectionBuffer, serverAddress)
} else if packet.PacketID == mcproto.PacketIdLegacyServerListPing {
handshake, ok := packet.Data.(*mcproto.LegacyServerListPing)
if !ok {
logrus.
WithField("client", clientAddr).
WithField("packet", packet).
Warn("Unexpected data type for PacketIdLegacyServerListPing")
return return
} }
pumpConnections(ctx, frontendConn, backendConn)
logrus.
WithField("client", clientAddr).
WithField("handshake", handshake).
Debug("Got legacy server list ping")
serverAddress := handshake.ServerAddress
c.findAndConnectBackend(ctx, frontendConn, clientAddr, inspectionBuffer, serverAddress)
} else { } else {
logrus. logrus.
WithField("client", clientAddr). WithField("client", clientAddr).
@ -144,17 +141,58 @@ func (c *connectorImpl) HandleConnection(ctx context.Context, frontendConn net.C
} }
} }
func (c *connectorImpl) findAndConnectBackend(ctx context.Context, frontendConn net.Conn,
clientAddr net.Addr, preReadContent io.Reader, serverAddress string) {
backendHostPort := Routes.FindBackendForServerAddress(serverAddress)
if backendHostPort == "" {
logrus.WithField("serverAddress", serverAddress).Warn("Unable to find registered backend")
return
}
logrus.
WithField("client", clientAddr).
WithField("server", serverAddress).
WithField("backendHostPort", backendHostPort).
Info("Connecting to backend")
backendConn, err := net.Dial("tcp", backendHostPort)
if err != nil {
logrus.
WithError(err).
WithField("client", clientAddr).
WithField("serverAddress", serverAddress).
WithField("backend", backendHostPort).
Warn("Unable to connect to backend")
return
}
amount, err := io.Copy(backendConn, preReadContent)
if err != nil {
logrus.WithError(err).Error("Failed to write handshake to backend connection")
return
}
logrus.WithField("amount", amount).Debug("Relayed handshake to backend")
if err = frontendConn.SetReadDeadline(noDeadline); err != nil {
logrus.
WithError(err).
WithField("client", clientAddr).
Error("Failed to clear read deadline")
return
}
pumpConnections(ctx, frontendConn, backendConn)
return
}
func pumpConnections(ctx context.Context, frontendConn, backendConn net.Conn) { func pumpConnections(ctx context.Context, frontendConn, backendConn net.Conn) {
//noinspection GoUnhandledErrorResult //noinspection GoUnhandledErrorResult
defer backendConn.Close() defer backendConn.Close()
clientAddr := frontendConn.RemoteAddr() clientAddr := frontendConn.RemoteAddr()
defer logrus.WithField("client", clientAddr).Debug("Closing backend connection")
errors := make(chan error, 2) errors := make(chan error, 2)
go pumpFrames(backendConn, frontendConn, errors, "backend", "frontend", clientAddr) go pumpFrames(backendConn, frontendConn, errors, "backend", "frontend", clientAddr)
go pumpFrames(frontendConn, backendConn, errors, "frontend", "backend", clientAddr) go pumpFrames(frontendConn, backendConn, errors, "frontend", "backend", clientAddr)
for {
select { select {
case err := <-errors: case err := <-errors:
if err != io.EOF { if err != io.EOF {
@ -163,21 +201,22 @@ func pumpConnections(ctx context.Context, frontendConn, backendConn net.Conn) {
Error("Error observed on connection relay") Error("Error observed on connection relay")
} }
return
case <-ctx.Done(): case <-ctx.Done():
return logrus.Debug("Observed context cancellation")
}
} }
} }
func pumpFrames(incoming io.Reader, outgoing io.Writer, errors chan<- error, from, to string, clientAddr net.Addr) { func pumpFrames(incoming io.Reader, outgoing io.Writer, errors chan<- error, from, to string, clientAddr net.Addr) {
amount, err := io.Copy(outgoing, incoming) amount, err := io.Copy(outgoing, incoming)
if err != nil {
errors <- err
}
logrus. logrus.
WithField("client", clientAddr). WithField("client", clientAddr).
WithField("amount", amount). WithField("amount", amount).
Infof("Finished relay %s->%s", from, to) Infof("Finished relay %s->%s", from, to)
if err != nil {
errors <- err
} else {
// successful io.Copy return nil error, not EOF...to simulate that to trigger outer handling
errors <- io.EOF
}
} }