diff --git a/src/main/kotlin/com/viaversion/aas/handler/state/LoginState.kt b/src/main/kotlin/com/viaversion/aas/handler/state/LoginState.kt index f82af9f..bf256c0 100644 --- a/src/main/kotlin/com/viaversion/aas/handler/state/LoginState.kt +++ b/src/main/kotlin/com/viaversion/aas/handler/state/LoginState.kt @@ -12,6 +12,9 @@ import com.viaversion.aas.handler.MinecraftHandler import com.viaversion.aas.handler.forward import com.viaversion.aas.util.StacklessException import com.viaversion.viaversion.api.protocol.packet.State +import com.viaversion.viaversion.api.protocol.version.ProtocolVersion +import com.viaversion.viaversion.api.type.Type +import io.netty.buffer.ByteBufAllocator import io.netty.channel.Channel import io.netty.channel.ChannelHandlerContext import kotlinx.coroutines.Dispatchers @@ -19,6 +22,7 @@ import kotlinx.coroutines.future.await import kotlinx.coroutines.launch import java.util.* import java.util.concurrent.CompletableFuture +import java.util.concurrent.ThreadLocalRandom class LoginState : ConnectionState { val callbackPlayerId = CompletableFuture() @@ -34,12 +38,14 @@ class LoginState : ConnectionState { get() = State.LOGIN override val logDc: Boolean get() = true + var callbackPluginReauth = CompletableFuture() + var pendingReauth: Int? = null override fun handlePacket(handler: MinecraftHandler, ctx: ChannelHandlerContext, packet: Packet) { when (packet) { is LoginStart -> handleLoginStart(handler, packet) is CryptoResponse -> handleCryptoResponse(handler, packet) - is PluginResponse -> forward(handler, packet) + is PluginResponse -> handlePluginResponse(handler, packet) is LoginDisconnect -> forward(handler, packet) is CryptoRequest -> handleCryptoRequest(handler, packet) is LoginSuccess -> handleLoginSuccess(handler, packet) @@ -49,6 +55,16 @@ class LoginState : ConnectionState { } } + private fun handlePluginResponse(handler: MinecraftHandler, packet: PluginResponse) { + if (packet.id == pendingReauth) { + callbackPluginReauth.complete(packet.success) + pendingReauth = null + + return + } + forward(handler, packet) + } + private fun handleLoginSuccess(handler: MinecraftHandler, loginSuccess: LoginSuccess) { handler.data.state = PlayState forward(handler, loginSuccess) @@ -89,6 +105,28 @@ class LoginState : ConnectionState { send(frontChannel, cryptoRequest, true) } + fun reauthMessage(handler: MinecraftHandler, backName: String, backHash: String): CompletableFuture { + if (handler.data.frontVer!! < ProtocolVersion.v1_13.version) { + callbackPluginReauth.complete(false) + } else { + val buf = ByteBufAllocator.DEFAULT.buffer() + try { + Type.STRING.write(buf, backName) + Type.STRING.write(buf, backHash) + + val packet = PluginRequest() + packet.id = ThreadLocalRandom.current().nextInt() + packet.channel = "viaaas:reauth" + packet.data = readRemainingBytes(buf) + send(handler.data.frontChannel, packet, true) + pendingReauth = packet.id + } finally { + buf.release() + } + } + return callbackPluginReauth + } + fun handleCryptoRequest(handler: MinecraftHandler, cryptoRequest: CryptoRequest) { val backServerId = cryptoRequest.serverId val backPublicKey = cryptoRequest.publicKey @@ -108,13 +146,16 @@ class LoginState : ConnectionState { val backHash = generateServerHash(backServerId, backKey, backPublicKey) mcLogger.info("Session req: ${handler.data.frontHandler.endRemoteAddress} ($playerId $frontName) $backName") - AspirinServer.viaWebServer.requestSessionJoin( - playerId, - backName!!, - backHash, - frontHandler.endRemoteAddress, - handler.data.backHandler!!.endRemoteAddress - ).await() + val pluginReauthed = reauthMessage(handler, backName!!, backHash).await() + if (!pluginReauthed) { + AspirinServer.viaWebServer.requestSessionJoin( + playerId, + backName!!, + backHash, + frontHandler.endRemoteAddress, + handler.data.backHandler!!.endRemoteAddress + ).await() + } val cryptoResponse = CryptoResponse() cryptoResponse.encryptedKey = encryptRsa(backPublicKey, backKey)