From 87d5a33c760227c18a75ef86ac5002c9136a8817 Mon Sep 17 00:00:00 2001 From: TheMode Date: Sat, 29 Jan 2022 10:18:23 +0100 Subject: [PATCH] Fix corruption when receiving multiple compressed packets (#611) Signed-off-by: TheMode --- .../player/PlayerSocketConnection.java | 30 ++++++------ .../minestom/server/utils/PacketUtils.java | 21 ++++----- .../server/network/SocketReadTest.java | 46 +++++++++++-------- 3 files changed, 48 insertions(+), 49 deletions(-) diff --git a/src/main/java/net/minestom/server/network/player/PlayerSocketConnection.java b/src/main/java/net/minestom/server/network/player/PlayerSocketConnection.java index 9f93b467c..912adde5d 100644 --- a/src/main/java/net/minestom/server/network/player/PlayerSocketConnection.java +++ b/src/main/java/net/minestom/server/network/player/PlayerSocketConnection.java @@ -99,23 +99,19 @@ public class PlayerSocketConnection extends PlayerConnection { } // Read all packets try { - var result = PacketUtils.readPackets(readBuffer, compressed); - this.cacheBuffer = result.remaining(); - for (var packet : result.packets()) { - var id = packet.id(); - var payload = packet.payload(); - try { - packetProcessor.process(this, id, payload); - } catch (Exception e) { - // Error while reading the packet - MinecraftServer.getExceptionManager().handleException(e); - break; - } finally { - if (payload.position() != payload.limit()) { - LOGGER.warn("WARNING: Packet 0x{} not fully read ({})", Integer.toHexString(id), payload); - } - } - } + this.cacheBuffer = PacketUtils.readPackets(readBuffer, compressed, + (id, payload) -> { + try { + packetProcessor.process(this, id, payload); + } catch (Exception e) { + // Error while reading the packet + MinecraftServer.getExceptionManager().handleException(e); + } finally { + if (payload.position() != payload.limit()) { + LOGGER.warn("WARNING: Packet 0x{} not fully read ({})", Integer.toHexString(id), payload); + } + } + }); } catch (DataFormatException e) { MinecraftServer.getExceptionManager().handleException(e); disconnect(); diff --git a/src/main/java/net/minestom/server/utils/PacketUtils.java b/src/main/java/net/minestom/server/utils/PacketUtils.java index 032edacfe..c69493c80 100644 --- a/src/main/java/net/minestom/server/utils/PacketUtils.java +++ b/src/main/java/net/minestom/server/utils/PacketUtils.java @@ -29,11 +29,10 @@ import org.jetbrains.annotations.Nullable; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; -import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.Objects; import java.util.concurrent.ConcurrentMap; +import java.util.function.BiConsumer; import java.util.function.Predicate; import java.util.zip.DataFormatException; import java.util.zip.Deflater; @@ -161,8 +160,8 @@ public final class PacketUtils { } @ApiStatus.Internal - public static ReadResult readPackets(@NotNull BinaryBuffer readBuffer, boolean compressed) throws DataFormatException { - List packets = new ArrayList<>(); + public static @Nullable BinaryBuffer readPackets(@NotNull BinaryBuffer readBuffer, boolean compressed, + BiConsumer payloadConsumer) throws DataFormatException { BinaryBuffer remaining = null; while (readBuffer.readableBytes() > 0) { final var beginMark = readBuffer.mark(); @@ -196,7 +195,11 @@ public final class PacketUtils { // Slice packet ByteBuffer payload = content.asByteBuffer(content.readerOffset(), decompressedSize); final int packetId = Utils.readVarInt(payload); - packets.add(new PacketPayload(packetId, payload)); + try { + payloadConsumer.accept(packetId, payload); + } catch (Exception e) { + // Empty + } // Position buffer to read the next packet readBuffer.readerOffset(readerStart + packetLength); } catch (BufferUnderflowException e) { @@ -205,13 +208,7 @@ public final class PacketUtils { break; } } - return new ReadResult(packets, remaining); - } - - public record ReadResult(List packets, BinaryBuffer remaining) { - } - - public record PacketPayload(int id, ByteBuffer payload) { + return remaining; } public static void writeFramedPacket(@NotNull ByteBuffer buffer, diff --git a/src/test/java/net/minestom/server/network/SocketReadTest.java b/src/test/java/net/minestom/server/network/SocketReadTest.java index 99b64c26f..2f5094440 100644 --- a/src/test/java/net/minestom/server/network/SocketReadTest.java +++ b/src/test/java/net/minestom/server/network/SocketReadTest.java @@ -1,5 +1,6 @@ package net.minestom.server.network; +import it.unimi.dsi.fastutil.Pair; import net.minestom.server.network.packet.client.play.ClientPluginMessagePacket; import net.minestom.server.utils.PacketUtils; import net.minestom.server.utils.Utils; @@ -9,6 +10,9 @@ import net.minestom.server.utils.binary.PooledBuffers; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import java.util.zip.DataFormatException; import static org.junit.jupiter.api.Assertions.*; @@ -26,14 +30,15 @@ public class SocketReadTest { var wrapper = BinaryBuffer.wrap(buffer); wrapper.reset(0, buffer.position()); - var result = PacketUtils.readPackets(wrapper, compressed); - assertNull(result.remaining()); + List> packets = new ArrayList<>(); + var remaining = PacketUtils.readPackets(wrapper, compressed, + (integer, payload) -> packets.add(Pair.of(integer, payload))); + assertNull(remaining); - var packets = result.packets(); assertEquals(1, packets.size()); var rawPacket = packets.get(0); - assertEquals(0x0A, rawPacket.id()); - var readPacket = new ClientPluginMessagePacket(new BinaryReader(rawPacket.payload())); + assertEquals(0x0A, rawPacket.left()); + var readPacket = new ClientPluginMessagePacket(new BinaryReader(rawPacket.right())); assertEquals("channel", readPacket.channel()); assertEquals(2000, readPacket.data().length); } @@ -50,14 +55,15 @@ public class SocketReadTest { var wrapper = BinaryBuffer.wrap(buffer); wrapper.reset(0, buffer.position()); - var result = PacketUtils.readPackets(wrapper, compressed); - assertNull(result.remaining()); + List> packets = new ArrayList<>(); + var remaining = PacketUtils.readPackets(wrapper, compressed, + (integer, payload) -> packets.add(Pair.of(integer, payload))); + assertNull(remaining); - var packets = result.packets(); assertEquals(2, packets.size()); for (var rawPacket : packets) { - assertEquals(0x0A, rawPacket.id()); - var readPacket = new ClientPluginMessagePacket(new BinaryReader(rawPacket.payload())); + assertEquals(0x0A, rawPacket.left()); + var readPacket = new ClientPluginMessagePacket(new BinaryReader(rawPacket.right())); assertEquals("channel", readPacket.channel()); assertEquals(2000, readPacket.data().length); } @@ -77,16 +83,16 @@ public class SocketReadTest { var wrapper = BinaryBuffer.wrap(buffer); wrapper.reset(0, buffer.position()); - var result = PacketUtils.readPackets(wrapper, compressed); - var remaining = result.remaining(); + List> packets = new ArrayList<>(); + var remaining = PacketUtils.readPackets(wrapper, compressed, + (integer, payload) -> packets.add(Pair.of(integer, payload))); assertNotNull(remaining); assertEquals(Utils.getVarIntSize(200), remaining.readableBytes()); - var packets = result.packets(); assertEquals(1, packets.size()); var rawPacket = packets.get(0); - assertEquals(0x0A, rawPacket.id()); - var readPacket = new ClientPluginMessagePacket(new BinaryReader(rawPacket.payload())); + assertEquals(0x0A, rawPacket.left()); + var readPacket = new ClientPluginMessagePacket(new BinaryReader(rawPacket.right())); assertEquals("channel", readPacket.channel()); assertEquals(2000, readPacket.data().length); } @@ -105,16 +111,16 @@ public class SocketReadTest { var wrapper = BinaryBuffer.wrap(buffer); wrapper.reset(0, buffer.position()); - var result = PacketUtils.readPackets(wrapper, compressed); - var remaining = result.remaining(); + List> packets = new ArrayList<>(); + var remaining = PacketUtils.readPackets(wrapper, compressed, + (integer, payload) -> packets.add(Pair.of(integer, payload))); assertNotNull(remaining); assertEquals(1, remaining.readableBytes()); - var packets = result.packets(); assertEquals(1, packets.size()); var rawPacket = packets.get(0); - assertEquals(0x0A, rawPacket.id()); - var readPacket = new ClientPluginMessagePacket(new BinaryReader(rawPacket.payload())); + assertEquals(0x0A, rawPacket.left()); + var readPacket = new ClientPluginMessagePacket(new BinaryReader(rawPacket.right())); assertEquals("channel", readPacket.channel()); assertEquals(2000, readPacket.data().length); }