fix invalid packet types due to state mismatch when calling packet events (#2568)

This commit is contained in:
Pasqual Koschmieder 2023-10-25 02:56:38 +02:00 committed by GitHub
parent 03d7be13d0
commit af33a2ab41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 218 additions and 38 deletions

View File

@ -46,6 +46,8 @@ public interface ListenerInvoker {
*
* @param packet - the packet.
* @return The packet type.
* @deprecated use {@link com.comphenix.protocol.injector.packet.PacketRegistry#getPacketType(PacketType.Protocol, Class)} instead.
*/
@Deprecated
PacketType getPacketType(Object packet);
}

View File

@ -118,7 +118,9 @@ public class StructureCache {
*
* @param packetType - packet type.
* @return A structure modifier.
* @deprecated use {@link #getStructure(PacketType)} instead.
*/
@Deprecated
public static StructureModifier<Object> getStructure(Class<?> packetType) {
// Get the ID from the class
PacketType type = PacketRegistry.getPacketType(packetType);

View File

@ -1,5 +1,6 @@
package com.comphenix.protocol.injector.netty;
import com.comphenix.protocol.PacketType;
import com.comphenix.protocol.PacketType.Protocol;
import com.comphenix.protocol.events.NetworkMarker;
import org.bukkit.entity.Player;
@ -49,8 +50,21 @@ public interface Injector {
* Retrieve the current protocol state.
*
* @return The current protocol.
* @deprecated use {@link #getCurrentProtocol(PacketType.Sender)} instead.
*/
Protocol getCurrentProtocol();
@Deprecated
default Protocol getCurrentProtocol() {
return this.getCurrentProtocol(PacketType.Sender.SERVER);
}
/**
* Retrieve the current protocol state. Note that since 1.20.2 the client and server direction can be in different
* protocol states.
*
* @param sender the side for which the state should be resolved.
* @return The current protocol.
*/
Protocol getCurrentProtocol(PacketType.Sender sender);
/**
* Retrieve the network marker associated with a given packet.

View File

@ -0,0 +1,133 @@
package com.comphenix.protocol.injector.netty.channel;
import com.comphenix.protocol.PacketType;
import com.comphenix.protocol.reflect.FuzzyReflection;
import com.comphenix.protocol.reflect.accessors.Accessors;
import com.comphenix.protocol.reflect.accessors.FieldAccessor;
import com.comphenix.protocol.reflect.fuzzy.FuzzyFieldContract;
import com.comphenix.protocol.utility.MinecraftReflection;
import io.netty.channel.Channel;
import io.netty.util.AttributeKey;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.function.BiFunction;
@SuppressWarnings("unchecked")
final class ChannelProtocolUtil {
public static final BiFunction<Channel, PacketType.Sender, PacketType.Protocol> PROTOCOL_RESOLVER;
static {
Class<?> networkManagerClass = MinecraftReflection.getNetworkManagerClass();
List<Field> attributeKeys = FuzzyReflection.fromClass(networkManagerClass, true).getFieldList(FuzzyFieldContract.newBuilder()
.typeExact(AttributeKey.class)
.requireModifier(Modifier.STATIC)
.declaringClassExactType(networkManagerClass)
.build());
BiFunction<Channel, PacketType.Sender, Object> baseResolver = null;
if (attributeKeys.size() == 1) {
// if there is only one attribute key we can assume it's the correct one (1.8 - 1.20.1)
Object protocolKey = Accessors.getFieldAccessor(attributeKeys.get(0)).get(null);
baseResolver = new Pre1_20_2DirectResolver((AttributeKey<Object>) protocolKey);
} else if (attributeKeys.size() > 1) {
// most likely 1.20.2+: 1 protocol key per protocol direction
AttributeKey<Object> serverBoundKey = null;
AttributeKey<Object> clientBoundKey = null;
for (Field keyField : attributeKeys) {
AttributeKey<Object> key = (AttributeKey<Object>) Accessors.getFieldAccessor(keyField).get(null);
if (key.name().equals("protocol")) {
// legacy (pre 1.20.2 name) - fall back to the old behaviour
baseResolver = new Pre1_20_2DirectResolver(key);
break;
}
if (key.name().contains("protocol")) {
// one of the two protocol keys for 1.20.2
if (key.name().contains("server")) {
serverBoundKey = key;
} else {
clientBoundKey = key;
}
}
}
if (baseResolver == null) {
if ((serverBoundKey == null || clientBoundKey == null)) {
// neither pre 1.20.2 key nor 1.20.2+ keys are available
throw new ExceptionInInitializerError("Unable to resolve protocol state attribute keys");
} else {
baseResolver = new Post1_20_2WrappedResolver(serverBoundKey, clientBoundKey);
}
}
} else {
throw new ExceptionInInitializerError("Unable to resolve protocol state attribute key(s)");
}
// decorate the base resolver by wrapping its return value into our packet type value
PROTOCOL_RESOLVER = baseResolver.andThen(nmsProtocol -> PacketType.Protocol.fromVanilla((Enum<?>) nmsProtocol));
}
private static final class Pre1_20_2DirectResolver implements BiFunction<Channel, PacketType.Sender, Object> {
private final AttributeKey<Object> attributeKey;
public Pre1_20_2DirectResolver(AttributeKey<Object> attributeKey) {
this.attributeKey = attributeKey;
}
@Override
public Object apply(Channel channel, PacketType.Sender sender) {
return channel.attr(this.attributeKey).get();
}
}
private static final class Post1_20_2WrappedResolver implements BiFunction<Channel, PacketType.Sender, Object> {
private final AttributeKey<Object> serverBoundKey;
private final AttributeKey<Object> clientBoundKey;
// lazy initialized when needed
private FieldAccessor protocolAccessor;
public Post1_20_2WrappedResolver(AttributeKey<Object> serverBoundKey, AttributeKey<Object> clientBoundKey) {
this.serverBoundKey = serverBoundKey;
this.clientBoundKey = clientBoundKey;
}
@Override
public Object apply(Channel channel, PacketType.Sender sender) {
AttributeKey<Object> key = this.getKeyForSender(sender);
Object codecData = channel.attr(key).get();
if (codecData == null) {
return null;
}
FieldAccessor protocolAccessor = this.getProtocolAccessor(codecData.getClass());
return protocolAccessor.get(codecData);
}
private AttributeKey<Object> getKeyForSender(PacketType.Sender sender) {
switch (sender) {
case SERVER:
return this.clientBoundKey;
case CLIENT:
return this.serverBoundKey;
default:
throw new IllegalArgumentException("Illegal packet sender " + sender.name());
}
}
private FieldAccessor getProtocolAccessor(Class<?> codecClass) {
if (this.protocolAccessor == null) {
Class<?> enumProtocolClass = MinecraftReflection.getEnumProtocolClass();
this.protocolAccessor = Accessors.getFieldAccessor(codecClass, enumProtocolClass, true);
}
return this.protocolAccessor;
}
}
}

View File

@ -1,5 +1,6 @@
package com.comphenix.protocol.injector.netty.channel;
import com.comphenix.protocol.PacketType;
import com.comphenix.protocol.PacketType.Protocol;
import com.comphenix.protocol.events.NetworkMarker;
import com.comphenix.protocol.injector.netty.Injector;
@ -42,7 +43,7 @@ final class EmptyInjector implements Injector {
}
@Override
public Protocol getCurrentProtocol() {
public Protocol getCurrentProtocol(PacketType.Sender sender) {
return Protocol.HANDSHAKING;
}

View File

@ -2,8 +2,6 @@ package com.comphenix.protocol.injector.netty.channel;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
@ -110,7 +108,6 @@ public class NettyChannelInjector implements Injector {
// lazy initialized fields, if we don't need them we don't bother about them
private Object playerConnection;
private FieldAccessor protocolAccessor;
public NettyChannelInjector(
Player player,
@ -322,17 +319,8 @@ public class NettyChannelInjector implements Injector {
}
@Override
public Protocol getCurrentProtocol() {
// ensure that the accessor to the protocol field is available
if (this.protocolAccessor == null) {
this.protocolAccessor = Accessors.getFieldAccessor(
this.networkManager.getClass(),
MinecraftReflection.getEnumProtocolClass(),
true);
}
Object nmsProtocol = this.protocolAccessor.get(this.networkManager);
return Protocol.fromVanilla((Enum<?>) nmsProtocol);
public Protocol getCurrentProtocol(PacketType.Sender sender) {
return ChannelProtocolUtil.PROTOCOL_RESOLVER.apply(this.wrappedChannel, sender);
}
@Override

View File

@ -28,7 +28,6 @@ import com.comphenix.protocol.reflect.accessors.FieldAccessor;
import com.comphenix.protocol.reflect.fuzzy.FuzzyFieldContract;
import com.comphenix.protocol.reflect.fuzzy.FuzzyMethodContract;
import com.comphenix.protocol.utility.MinecraftReflection;
import com.comphenix.protocol.utility.Util;
import com.comphenix.protocol.wrappers.Pair;
import io.netty.channel.ChannelFuture;
import org.bukkit.Server;
@ -93,7 +92,8 @@ public class NetworkManagerInjector implements ChannelListener {
Class<?> packetClass = packet.getClass();
if (marker != null || MinecraftReflection.isBundlePacket(packetClass) || outboundListeners.contains(packetClass)) {
// wrap packet and construct the event
PacketContainer container = new PacketContainer(PacketRegistry.getPacketType(packetClass), packet);
PacketType.Protocol currentProtocol = injector.getCurrentProtocol(PacketType.Sender.SERVER);
PacketContainer container = new PacketContainer(PacketRegistry.getPacketType(currentProtocol, packetClass), packet);
PacketEvent packetEvent = PacketEvent.fromServer(this, container, marker, injector.getPlayer());
// post to all listeners, then return the packet event we constructed
@ -111,7 +111,8 @@ public class NetworkManagerInjector implements ChannelListener {
Class<?> packetClass = packet.getClass();
if (marker != null || inboundListeners.contains(packetClass)) {
// wrap the packet and construct the event
PacketContainer container = new PacketContainer(PacketRegistry.getPacketType(packetClass), packet);
PacketType.Protocol currentProtocol = injector.getCurrentProtocol(PacketType.Sender.CLIENT);
PacketContainer container = new PacketContainer(PacketRegistry.getPacketType(currentProtocol, packetClass), packet);
PacketEvent packetEvent = PacketEvent.fromClient(this, container, marker, injector.getPlayer());
// post to all listeners, then return the packet event we constructed
@ -238,7 +239,6 @@ public class NetworkManagerInjector implements ChannelListener {
// just reset to the list we wrapped originally
ListeningList ourList = (ListeningList) currentFieldValue;
List<Object> original = ourList.getOriginal();
//noinspection SynchronizationOnLocalVariableOrMethodParameter
synchronized (original) {
// revert the injection from all values of the list
ourList.unProcessAll();

View File

@ -48,7 +48,9 @@ public class PacketRegistry {
protected static class Register {
// The main lookup table
final Map<PacketType, Optional<Class<?>>> typeToClass = new ConcurrentHashMap<>();
final Map<Class<?>, PacketType> classToType = new ConcurrentHashMap<>();
final Map<PacketType.Protocol, Map<Class<?>, PacketType>> protocolClassToType = new ConcurrentHashMap<>();
volatile Set<PacketType> serverPackets = new HashSet<>();
volatile Set<PacketType> clientPackets = new HashSet<>();
@ -58,7 +60,10 @@ public class PacketRegistry {
public void registerPacket(PacketType type, Class<?> clazz, Sender sender) {
typeToClass.put(type, Optional.of(clazz));
classToType.put(clazz, type);
protocolClassToType.computeIfAbsent(type.getProtocol(), __ -> new ConcurrentHashMap<>()).put(clazz, type);
if (sender == Sender.CLIENT) {
clientPackets.add(type);
} else {
@ -430,7 +435,9 @@ public class PacketRegistry {
* Retrieve the packet type of a given packet.
* @param packet - the class of the packet.
* @return The packet type, or NULL if not found.
* @deprecated major issues due to packets with shared classes being registered in multiple states.
*/
@Deprecated
public static PacketType getPacketType(Class<?> packet) {
initialize();
@ -440,7 +447,24 @@ public class PacketRegistry {
return REGISTER.classToType.get(packet);
}
/**
* Retrieve the associated packet type for a packet class in the given protocol state.
*
* @param protocol the protocol state to retrieve the packet from.
* @param packet the class identifying the packet type.
* @return the packet type associated with the given class in the given protocol state, or null if not found.
*/
public static PacketType getPacketType(PacketType.Protocol protocol, Class<?> packet) {
initialize();
if (MinecraftReflection.isBundlePacket(packet)) {
return PacketType.Play.Server.BUNDLE;
}
Map<Class<?>, PacketType> classToTypesForProtocol = REGISTER.protocolClassToType.get(protocol);
return classToTypesForProtocol == null ? null : classToTypesForProtocol.get(packet);
}
/**
* Retrieve the packet type of a given packet.
* @param packet - the class of the packet.

View File

@ -17,11 +17,6 @@
package com.comphenix.protocol.reflect;
import com.comphenix.protocol.PacketType;
import com.comphenix.protocol.injector.StructureCache;
import com.comphenix.protocol.injector.packet.PacketRegistry;
import com.comphenix.protocol.utility.MinecraftReflection;
import com.comphenix.protocol.utility.StreamSerializer;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.HashMap;
@ -46,18 +41,6 @@ public class ObjectWriter {
* @return A structure modifier for the given type.
*/
private StructureModifier<Object> getModifier(Class<?> type) {
Class<?> packetClass = MinecraftReflection.getPacketClass();
// Handle subclasses of the packet class with our custom structure cache, if possible
if (!type.equals(packetClass) && packetClass.isAssignableFrom(type)) {
// might be a packet, but some packets are not registered (for example PacketPlayInFlying, only the subtypes are present)
PacketType packetType = PacketRegistry.getPacketType(type);
if (packetType != null) {
// packet is present, delegate to the cache
return StructureCache.getStructure(packetType);
}
}
// Create the structure modifier if we haven't already
StructureModifier<Object> modifier = CACHE.get(type);
if (modifier == null) {

View File

@ -0,0 +1,33 @@
package com.comphenix.protocol.injector.netty.channel;
import com.comphenix.protocol.BukkitInitialization;
import com.comphenix.protocol.PacketType;
import io.netty.channel.Channel;
import io.netty.channel.local.LocalServerChannel;
import net.minecraft.network.EnumProtocol;
import net.minecraft.network.NetworkManager;
import net.minecraft.network.protocol.EnumProtocolDirection;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
public class ChannelProtocolUtilTest {
@BeforeAll
public static void beforeClass() {
BukkitInitialization.initializeAll();
}
@Test
public void testProtocolResolving() {
Channel channel = new LocalServerChannel();
channel.attr(NetworkManager.e).set(EnumProtocol.e.b(EnumProtocolDirection.a)); // ATTRIBUTE_SERVERBOUND_PROTOCOL -> Protocol.CONFIG.codec(SERVERBOUND)
channel.attr(NetworkManager.f).set(EnumProtocol.b.b(EnumProtocolDirection.b)); // ATTRIBUTE_CLIENTBOUND_PROTOCOL -> Protocol.PLAY.codec(CLIENTBOUND)
PacketType.Protocol serverBoundProtocol = ChannelProtocolUtil.PROTOCOL_RESOLVER.apply(channel, PacketType.Sender.CLIENT);
Assertions.assertEquals(PacketType.Protocol.CONFIGURATION, serverBoundProtocol);
PacketType.Protocol clientBoundProtocol = ChannelProtocolUtil.PROTOCOL_RESOLVER.apply(channel, PacketType.Sender.SERVER);
Assertions.assertEquals(PacketType.Protocol.PLAY, clientBoundProtocol);
}
}