diff --git a/src/main/java/net/minestom/server/network/PacketProcessor.java b/src/main/java/net/minestom/server/network/PacketProcessor.java index 9a8a19a39..25968c5af 100644 --- a/src/main/java/net/minestom/server/network/PacketProcessor.java +++ b/src/main/java/net/minestom/server/network/PacketProcessor.java @@ -6,7 +6,7 @@ import net.minestom.server.network.packet.client.ClientPacket; import net.minestom.server.network.packet.client.ClientPacketsHandler; import net.minestom.server.network.packet.client.ClientPreplayPacket; import net.minestom.server.network.packet.client.handshake.HandshakePacket; -import net.minestom.server.network.player.PlayerSocketConnection; +import net.minestom.server.network.player.PlayerConnection; import net.minestom.server.utils.binary.BinaryReader; import org.jetbrains.annotations.NotNull; @@ -27,36 +27,31 @@ public record PacketProcessor(@NotNull ClientPacketsHandler statusHandler, new ClientPacketsHandler.Play()); } - public void process(@NotNull PlayerSocketConnection connection, int packetId, ByteBuffer body) { + public @NotNull ClientPacket create(@NotNull ConnectionState connectionState, int packetId, ByteBuffer body) { + BinaryReader binaryReader = new BinaryReader(body); + return switch (connectionState) { + case PLAY -> playHandler.create(packetId, binaryReader); + case LOGIN -> loginHandler.create(packetId, binaryReader); + case STATUS -> statusHandler.create(packetId, binaryReader); + case UNKNOWN -> { + assert packetId == 0; + yield new HandshakePacket(binaryReader); + } + }; + } + + public void process(@NotNull PlayerConnection connection, int packetId, ByteBuffer body) { if (MinecraftServer.getRateLimit() > 0) { // Increment packet count (checked in PlayerConnection#update) connection.getPacketCounter().incrementAndGet(); } - BinaryReader binaryReader = new BinaryReader(body); - final ConnectionState connectionState = connection.getConnectionState(); - if (connectionState == ConnectionState.UNKNOWN) { - // Should be handshake packet - if (packetId == 0) { - final HandshakePacket handshakePacket = new HandshakePacket(binaryReader); - handshakePacket.process(connection); - } - return; - } - switch (connectionState) { - case PLAY -> { - final Player player = connection.getPlayer(); - ClientPacket playPacket = playHandler.create(packetId, binaryReader); - assert player != null; - player.addPacketToQueue(playPacket); - } - case LOGIN -> { - final ClientPreplayPacket loginPacket = (ClientPreplayPacket) loginHandler.create(packetId, binaryReader); - loginPacket.process(connection); - } - case STATUS -> { - final ClientPreplayPacket statusPacket = (ClientPreplayPacket) statusHandler.create(packetId, binaryReader); - statusPacket.process(connection); - } + var packet = create(connection.getConnectionState(), packetId, body); + if (packet instanceof ClientPreplayPacket prePlayPacket) { + prePlayPacket.process(connection); + } else { + final Player player = connection.getPlayer(); + assert player != null; + player.addPacketToQueue(packet); } } } diff --git a/src/main/java/net/minestom/server/network/player/PlayerSocketConnection.java b/src/main/java/net/minestom/server/network/player/PlayerSocketConnection.java index 673f24182..9d8198720 100644 --- a/src/main/java/net/minestom/server/network/player/PlayerSocketConnection.java +++ b/src/main/java/net/minestom/server/network/player/PlayerSocketConnection.java @@ -12,7 +12,6 @@ import net.minestom.server.network.packet.server.*; import net.minestom.server.network.packet.server.login.SetCompressionPacket; import net.minestom.server.network.socket.Worker; import net.minestom.server.utils.PacketUtils; -import net.minestom.server.utils.Utils; import net.minestom.server.utils.binary.BinaryBuffer; import net.minestom.server.utils.binary.PooledBuffers; import net.minestom.server.utils.validate.Check; @@ -28,7 +27,6 @@ import javax.crypto.SecretKey; import javax.crypto.ShortBufferException; import java.io.IOException; import java.net.SocketAddress; -import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; @@ -36,7 +34,6 @@ import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.zip.DataFormatException; -import java.util.zip.Inflater; /** * Represents a socket connection. @@ -102,61 +99,27 @@ public class PlayerSocketConnection extends PlayerConnection { } } // Read all packets - while (readBuffer.readableBytes() > 0) { - final var beginMark = readBuffer.mark(); - try { - // Ensure that the buffer contains the full packet (or wait for next socket read) - final int packetLength = readBuffer.readVarInt(); - final int readerStart = readBuffer.readerOffset(); - if (!readBuffer.canRead(packetLength)) { - // Integrity fail - throw new BufferUnderflowException(); - } - // Read packet https://wiki.vg/Protocol#Packet_format - BinaryBuffer content = readBuffer; - int decompressedSize = packetLength; - if (compressed) { - final int dataLength = readBuffer.readVarInt(); - final int payloadLength = packetLength - (readBuffer.readerOffset() - readerStart); - if (dataLength == 0) { - // Data is too small to be compressed, payload is following - decompressedSize = payloadLength; - } else { - // Decompress to content buffer - content = workerContext.contentBuffer.clear(); - decompressedSize = dataLength; - Inflater inflater = workerContext.inflater; - inflater.setInput(readBuffer.asByteBuffer(readBuffer.readerOffset(), payloadLength)); - inflater.inflate(content.asByteBuffer(0, dataLength)); - inflater.reset(); - } - } - // Process packet - ByteBuffer payload = content.asByteBuffer(content.readerOffset(), decompressedSize); - final int packetId = Utils.readVarInt(payload); + try { + var result = PacketUtils.readPackets(readBuffer, compressed, workerContext); + this.cacheBuffer = result.remaining(); + for (var packet : result.packets()) { + var id = packet.id(); + var payload = packet.payload(); try { - packetProcessor.process(this, packetId, payload); + packetProcessor.process(this, id, payload); } catch (Exception e) { // Error while reading the packet MinecraftServer.getExceptionManager().handleException(e); break; } finally { if (payload.position() != payload.limit()) { - LOGGER.warn("WARNING: Packet 0x{} not fully read ({}), {}", - Integer.toHexString(packetId), payload, this); + LOGGER.warn("WARNING: Packet 0x{} not fully read ({})", Integer.toHexString(id), payload); } } - // Position buffer to read the next packet - readBuffer.readerOffset(readerStart + packetLength); - } catch (BufferUnderflowException e) { - readBuffer.reset(beginMark); - this.cacheBuffer = BinaryBuffer.copy(readBuffer); - break; - } catch (DataFormatException e) { - MinecraftServer.getExceptionManager().handleException(e); - disconnect(); - return; } + } catch (DataFormatException e) { + MinecraftServer.getExceptionManager().handleException(e); + disconnect(); } } diff --git a/src/main/java/net/minestom/server/utils/PacketUtils.java b/src/main/java/net/minestom/server/utils/PacketUtils.java index 58d906f8e..c5ea7db5d 100644 --- a/src/main/java/net/minestom/server/utils/PacketUtils.java +++ b/src/main/java/net/minestom/server/utils/PacketUtils.java @@ -20,6 +20,7 @@ import net.minestom.server.network.packet.server.ServerPacket; import net.minestom.server.network.player.PlayerConnection; import net.minestom.server.network.player.PlayerSocketConnection; import net.minestom.server.network.socket.Server; +import net.minestom.server.network.socket.Worker; import net.minestom.server.utils.binary.BinaryBuffer; import net.minestom.server.utils.binary.BinaryWriter; import net.minestom.server.utils.binary.PooledBuffers; @@ -28,12 +29,17 @@ import org.jetbrains.annotations.ApiStatus; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Collection; +import java.util.List; import java.util.Objects; import java.util.concurrent.ConcurrentMap; import java.util.function.Predicate; +import java.util.zip.DataFormatException; import java.util.zip.Deflater; +import java.util.zip.Inflater; /** * Utils class for packets. Including writing a {@link ServerPacket} into a {@link ByteBuffer} @@ -167,6 +173,61 @@ public final class PacketUtils { } } + @ApiStatus.Internal + public static ReadResult readPackets(@NotNull BinaryBuffer readBuffer, boolean compressed, + @NotNull Worker.Context context) throws DataFormatException { + List packets = new ArrayList<>(); + BinaryBuffer remaining = null; + while (readBuffer.readableBytes() > 0) { + final var beginMark = readBuffer.mark(); + try { + // Ensure that the buffer contains the full packet (or wait for next socket read) + final int packetLength = readBuffer.readVarInt(); + final int readerStart = readBuffer.readerOffset(); + if (!readBuffer.canRead(packetLength)) { + // Integrity fail + throw new BufferUnderflowException(); + } + // Read packet https://wiki.vg/Protocol#Packet_format + BinaryBuffer content = readBuffer; + int decompressedSize = packetLength; + if (compressed) { + final int dataLength = readBuffer.readVarInt(); + final int payloadLength = packetLength - (readBuffer.readerOffset() - readerStart); + if (dataLength == 0) { + // Data is too small to be compressed, payload is following + decompressedSize = payloadLength; + } else { + // Decompress to content buffer + content = context.contentBuffer.clear(); + decompressedSize = dataLength; + Inflater inflater = context.inflater; + inflater.setInput(readBuffer.asByteBuffer(readBuffer.readerOffset(), payloadLength)); + inflater.inflate(content.asByteBuffer(0, dataLength)); + inflater.reset(); + } + } + // Slice packet + ByteBuffer payload = content.asByteBuffer(content.readerOffset(), decompressedSize); + final int packetId = Utils.readVarInt(payload); + packets.add(new PacketPayload(packetId, payload)); + // Position buffer to read the next packet + readBuffer.readerOffset(readerStart + packetLength); + } catch (BufferUnderflowException e) { + readBuffer.reset(beginMark); + remaining = BinaryBuffer.copy(readBuffer); + break; + } + } + return new ReadResult(packets, remaining); + } + + public record ReadResult(List packets, BinaryBuffer remaining) { + } + + public record PacketPayload(int id, ByteBuffer payload) { + } + public static void writeFramedPacket(@NotNull ByteBuffer buffer, @NotNull ServerPacket packet, boolean compression) {