Fixup state splitting

This commit is contained in:
Nassim Jahnke 2023-09-25 18:59:15 +10:00
parent 854696abff
commit 3997ea70f7
No known key found for this signature in database
GPG Key ID: EF6771C01F6EF02F
5 changed files with 55 additions and 65 deletions

View File

@ -69,10 +69,15 @@ public interface ProtocolPipeline extends SimpleProtocol {
/** /**
* Returns the list of protocols this pipeline contains. * Returns the list of protocols this pipeline contains.
* *
* @return list of protocols in this pipe * @return immutable list of protocols in this pipe
*/ */
List<Protocol> pipes(); List<Protocol> pipes();
/**
* Returns the list of protocols this pipeline contains in reversed order.
*
* @return immutable list of protocols in reversed direction
*/
List<Protocol> reversedPipes(); List<Protocol> reversedPipes();
/** /**

View File

@ -17,7 +17,6 @@
*/ */
package com.viaversion.viaversion.protocol; package com.viaversion.viaversion.protocol;
import com.google.common.collect.Sets;
import com.viaversion.viaversion.api.Via; import com.viaversion.viaversion.api.Via;
import com.viaversion.viaversion.api.connection.UserConnection; import com.viaversion.viaversion.api.connection.UserConnection;
import com.viaversion.viaversion.api.debug.DebugHandler; import com.viaversion.viaversion.api.debug.DebugHandler;
@ -31,9 +30,9 @@ import com.viaversion.viaversion.api.protocol.packet.State;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.logging.Level; import java.util.logging.Level;
import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.checker.nullness.qual.Nullable;
@ -43,9 +42,10 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot
/** /**
* Protocol list ordered from client to server transforation with the base protocols at the end. * Protocol list ordered from client to server transforation with the base protocols at the end.
*/ */
private List<Protocol> protocolList; private final List<Protocol> protocolList = new CopyOnWriteArrayList<>();
private List<Protocol> reversedProtocolList; private final Set<Class<? extends Protocol>> protocolSet = new HashSet<>();
private Set<Class<? extends Protocol>> protocolSet; private List<Protocol> reversedProtocolList = new CopyOnWriteArrayList<>();
private int baseProtocols;
public ProtocolPipelineImpl(UserConnection userConnection) { public ProtocolPipelineImpl(UserConnection userConnection) {
this.userConnection = userConnection; this.userConnection = userConnection;
@ -55,16 +55,12 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot
@Override @Override
protected void registerPackets() { protected void registerPackets() {
protocolList = new CopyOnWriteArrayList<>();
reversedProtocolList = new CopyOnWriteArrayList<>();
// Create a backing set for faster contains calls with larger pipes
protocolSet = Sets.newSetFromMap(new ConcurrentHashMap<>());
// This is a pipeline so we register basic pipes // This is a pipeline so we register basic pipes
final Protocol<?, ?, ?, ?> baseProtocol = Via.getManager().getProtocolManager().getBaseProtocol(); final Protocol<?, ?, ?, ?> baseProtocol = Via.getManager().getProtocolManager().getBaseProtocol();
protocolList.add(baseProtocol); protocolList.add(baseProtocol);
reversedProtocolList.add(baseProtocol); reversedProtocolList.add(baseProtocol);
protocolSet.add(baseProtocol.getClass()); protocolSet.add(baseProtocol.getClass());
baseProtocols++;
} }
@Override @Override
@ -73,55 +69,38 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot
} }
@Override @Override
public void add(Protocol protocol) { public synchronized void add(final Protocol protocol) {
if (protocol.isBaseProtocol()) {
// Add base protocol on top of previous ones
protocolList.add(baseProtocols, protocol);
reversedProtocolList.add(baseProtocols, protocol);
baseProtocols++;
} else {
protocolList.add(protocol); protocolList.add(protocol);
reversedProtocolList.add(0, protocol);
}
protocolSet.add(protocol.getClass()); protocolSet.add(protocol.getClass());
protocol.init(userConnection); protocol.init(userConnection);
if (!protocol.isBaseProtocol()) {
moveBaseProtocolsToTail(protocolList);
}
setReversedProtocolList();
}
private void setReversedProtocolList() {
final List<Protocol> reversedProtocolList = new ArrayList<>(protocolList);
Collections.reverse(this.reversedProtocolList);
moveBaseProtocolsToTail(reversedProtocolList);
this.reversedProtocolList = new CopyOnWriteArrayList<>(reversedProtocolList);
} }
@Override @Override
public void add(Collection<Protocol> protocols) { public synchronized void add(final Collection<Protocol> protocols) {
protocolList.addAll(protocols); protocolList.addAll(protocols);
for (Protocol protocol : protocols) { for (final Protocol protocol : protocols) {
protocol.init(userConnection); protocol.init(userConnection);
this.protocolSet.add(protocol.getClass()); protocolSet.add(protocol.getClass());
} }
moveBaseProtocolsToTail(protocolList); refreshReversedList();
setReversedProtocolList();
} }
private List<Protocol> filterBaseProtocols(final List<Protocol> protocols) { private synchronized void refreshReversedList() {
final List<Protocol> baseProtocols = new ArrayList<>(); final List<Protocol> protocols = new ArrayList<>(protocolList.subList(0, this.baseProtocols));
for (final Protocol protocol : protocolList) { final List<Protocol> additionalProtocols = new ArrayList<>(protocolList.subList(this.baseProtocols, protocolList.size()));
if (protocol.isBaseProtocol()) { Collections.reverse(additionalProtocols);
baseProtocols.add(protocol); protocols.addAll(additionalProtocols);
} reversedProtocolList = new CopyOnWriteArrayList<>(protocols);
}
return baseProtocols;
}
private void moveBaseProtocolsToTail(final List<Protocol> protocols) {
// Move base Protocols to the end, so the login packets can be modified by other protocols
final List<Protocol> baseProtocols = filterBaseProtocols(protocols);
if (!baseProtocols.isEmpty()) {
protocols.removeAll(baseProtocols);
protocols.addAll(baseProtocols);
}
} }
@Override @Override
@ -134,7 +113,7 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot
} }
// Apply protocols // Apply protocols
packetWrapper.apply(direction, state, 0, protocolListFor(direction), true); packetWrapper.apply(direction, state, 0, protocolListFor(direction));
super.transform(direction, state, packetWrapper); super.transform(direction, state, packetWrapper);
if (debugHandler.enabled() && debugHandler.logPostPacketTransform() && debugHandler.shouldLog(packetWrapper, direction)) { if (debugHandler.enabled() && debugHandler.logPostPacketTransform() && debugHandler.shouldLog(packetWrapper, direction)) {
@ -143,7 +122,7 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot
} }
private List<Protocol> protocolListFor(final Direction direction) { private List<Protocol> protocolListFor(final Direction direction) {
return direction == Direction.CLIENTBOUND ? reversedProtocolList : protocolList; return Collections.unmodifiableList(direction == Direction.SERVERBOUND ? protocolList : reversedProtocolList);
} }
private void logPacket(Direction direction, State state, PacketWrapper packetWrapper, int originalID) { private void logPacket(Direction direction, State state, PacketWrapper packetWrapper, int originalID) {
@ -185,12 +164,12 @@ public class ProtocolPipelineImpl extends AbstractSimpleProtocol implements Prot
@Override @Override
public List<Protocol> pipes() { public List<Protocol> pipes() {
return protocolList; return Collections.unmodifiableList(protocolList);
} }
@Override @Override
public List<Protocol> reversedPipes() { public List<Protocol> reversedPipes() {
return reversedProtocolList; return Collections.unmodifiableList(reversedProtocolList);
} }
@Override @Override

View File

@ -309,14 +309,18 @@ public class PacketWrapperImpl implements PacketWrapper {
* @throws Exception if it fails to write * @throws Exception if it fails to write
*/ */
private ByteBuf constructPacket(Class<? extends Protocol> packetProtocol, boolean skipCurrentPipeline, Direction direction) throws Exception { private ByteBuf constructPacket(Class<? extends Protocol> packetProtocol, boolean skipCurrentPipeline, Direction direction) throws Exception {
// Apply current pipeline - for outgoing protocol, the collection will be reversed in the apply method
final ProtocolInfo protocolInfo = user().getProtocolInfo(); final ProtocolInfo protocolInfo = user().getProtocolInfo();
final boolean reverse = direction == Direction.CLIENTBOUND; final List<Protocol> pipes = direction == Direction.SERVERBOUND ? protocolInfo.getPipeline().pipes() : protocolInfo.getPipeline().reversedPipes();
final List<Protocol> pipes = reverse ? protocolInfo.getPipeline().reversedPipes() : protocolInfo.getPipeline().pipes(); final List<Protocol> protocols = new ArrayList<>();
final Protocol[] protocols = pipes.toArray(PROTOCOL_ARRAY);
int index = -1; int index = -1;
for (int i = 0; i < protocols.length; i++) { for (int i = 0; i < pipes.size(); i++) {
if (protocols[i].getClass() == packetProtocol) { // Always add base protocols to the head
final Protocol protocol = pipes.get(i);
if (protocol.isBaseProtocol()) {
protocols.add(protocol);
}
if (protocol.getClass() == packetProtocol) {
index = i; index = i;
break; break;
} }
@ -328,14 +332,17 @@ public class PacketWrapperImpl implements PacketWrapper {
} }
if (skipCurrentPipeline) { if (skipCurrentPipeline) {
index = reverse ? index - 1 : index + 1; index = Math.min(index + 1, pipes.size());
} }
// Add remaining protocols on top
protocols.addAll(pipes.subList(index, pipes.size()));
// Reset reader before we start // Reset reader before we start
resetReader(); resetReader();
// Apply other protocols // Apply other protocols
apply(direction, protocolInfo.getState(direction), index, protocols, true); apply(direction, protocolInfo.getState(direction), 0, protocols);
final ByteBuf output = inputBuffer == null ? user().getChannel().alloc().buffer() : inputBuffer.alloc().buffer(); final ByteBuf output = inputBuffer == null ? user().getChannel().alloc().buffer() : inputBuffer.alloc().buffer();
try { try {
writeToBuffer(output); writeToBuffer(output);

View File

@ -119,9 +119,8 @@ public class BaseProtocol1_7 extends AbstractProtocol {
// Login Success Packet // Login Success Packet
registerClientbound(ClientboundLoginPackets.GAME_PROFILE, wrapper -> { registerClientbound(ClientboundLoginPackets.GAME_PROFILE, wrapper -> {
ProtocolInfo info = wrapper.user().getProtocolInfo(); ProtocolInfo info = wrapper.user().getProtocolInfo();
info.setServerState(State.PLAY); if (info.getProtocolVersion() < ProtocolVersion.v1_20_2.getVersion()) { // On 1.20.2+, wait for the login ack
if (info.getProtocolVersion() < ProtocolVersion.v1_20_2.getVersion()) { // 1.20.2+ clients will send a login ack first info.setState(State.PLAY);
info.setClientState(State.PLAY);
} }
UUID uuid = passthroughLoginUUID(wrapper); UUID uuid = passthroughLoginUUID(wrapper);
@ -166,7 +165,7 @@ public class BaseProtocol1_7 extends AbstractProtocol {
registerServerbound(ServerboundLoginPackets.LOGIN_ACKNOWLEDGED, wrapper -> { registerServerbound(ServerboundLoginPackets.LOGIN_ACKNOWLEDGED, wrapper -> {
final ProtocolInfo info = wrapper.user().getProtocolInfo(); final ProtocolInfo info = wrapper.user().getProtocolInfo();
info.setClientState(State.CONFIGURATION); info.setState(State.CONFIGURATION);
}); });
} }

View File

@ -209,7 +209,7 @@ public final class Protocol1_20_2To1_20 extends AbstractProtocol<ClientboundPack
// Map some of them to their configuration state counterparts, but make sure to let join game through // Map some of them to their configuration state counterparts, but make sure to let join game through
final int unmappedId = packetWrapper.getId(); final int unmappedId = packetWrapper.getId();
if (unmappedId == ClientboundPackets1_19_4.JOIN_GAME.getId()) { if (unmappedId == ClientboundPackets1_19_4.JOIN_GAME.getId()) {
super.transform(direction, state, packetWrapper); super.transform(direction, State.PLAY, packetWrapper);
return; return;
} }