diff --git a/build.gradle.kts b/build.gradle.kts index d25e884..9111313 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -37,7 +37,7 @@ compileKotlin.kotlinOptions.jvmTarget = "11" val gitVersion: groovy.lang.Closure by extra group = "com.github.creeper123123321.viaaas" -version = "0.4.1+" + try { +version = "0.4.2+" + try { gitVersion() } catch (e: Exception) { "unknown" diff --git a/src/main/kotlin/com/viaversion/aas/Util.kt b/src/main/kotlin/com/viaversion/aas/Util.kt index 9416bb4..00b4500 100644 --- a/src/main/kotlin/com/viaversion/aas/Util.kt +++ b/src/main/kotlin/com/viaversion/aas/Util.kt @@ -10,10 +10,26 @@ import com.viaversion.aas.util.StacklessException import com.viaversion.viaversion.api.protocol.version.ProtocolVersion import com.viaversion.viaversion.api.type.Type import io.ktor.client.request.* +import io.ktor.server.netty.* import io.netty.buffer.ByteBuf import io.netty.channel.Channel +import io.netty.channel.ChannelFactory import io.netty.channel.ChannelFutureListener +import io.netty.channel.EventLoopGroup +import io.netty.channel.epoll.* +import io.netty.channel.kqueue.* +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.DatagramChannel +import io.netty.channel.socket.ServerSocketChannel +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioDatagramChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.channel.socket.nio.NioSocketChannel import io.netty.handler.codec.DecoderException +import io.netty.handler.codec.dns.DefaultDnsQuestion +import io.netty.handler.codec.dns.DefaultDnsRawRecord +import io.netty.handler.codec.dns.DefaultDnsRecordDecoder +import io.netty.handler.codec.dns.DnsRecordType import org.slf4j.LoggerFactory import java.math.BigInteger import java.net.InetAddress @@ -28,7 +44,6 @@ import java.util.concurrent.TimeUnit import javax.crypto.Cipher import javax.crypto.spec.IvParameterSpec import javax.crypto.spec.SecretKeySpec -import javax.naming.directory.InitialDirContext val badLength = DecoderException("Invalid length!") val mcLogger = LoggerFactory.getLogger("VIAaaS MC") @@ -37,18 +52,27 @@ val viaaasLogger = LoggerFactory.getLogger("VIAaaS") val secureRandom = if (VIAaaSConfig.useStrongRandom) SecureRandom.getInstanceStrong() else SecureRandom() -fun resolveSrv(hostAndPort: HostAndPort): HostAndPort { +suspend fun resolveSrv(hostAndPort: HostAndPort): HostAndPort { if (hostAndPort.host.endsWith(".onion", ignoreCase = true)) return hostAndPort if (hostAndPort.port == 25565) { try { - // https://github.com/GeyserMC/Geyser/blob/99e72f35b308542cf0dbfb5b58816503c3d6a129/connector/src/main/java/org/geysermc/connector/GeyserConnector.java - val attr = InitialDirContext() - .getAttributes("dns:///_minecraft._tcp.${hostAndPort.host}", arrayOf("SRV"))["SRV"] - if (attr != null && attr.size() > 0) { - val record = (attr.get(0) as String).split(" ") - return HostAndPort.fromParts(record[3], record[2].toInt()) - } - } catch (ignored: Exception) { // DuckDNS workaround + // stolen from PacketLib (MIT) https://github.com/Camotoy/PacketLib/blob/312cff5f975be54cf2d92208ae2947dbda8b9f59/src/main/java/com/github/steveice10/packetlib/tcp/TcpClientSession.java + dnsResolver + .resolveAll(DefaultDnsQuestion("_minecraft._tcp.${hostAndPort.host}", DnsRecordType.SRV)) + .suspendAwait() + .forEach { record -> + if (record is DefaultDnsRawRecord && record.type() == DnsRecordType.SRV) { + val content = record.content() + + content.skipBytes(4) + val port = content.readUnsignedShort() + val address = DefaultDnsRecordDecoder.decodeName(content) + + return HostAndPort.fromParts(address, port) + } + } + } catch (e: Exception) { + viaaasLogger.debug("Couldn't resolve SRV", e) } } return hostAndPort @@ -169,8 +193,10 @@ fun ByteBuf.readByteArray(length: Int) = ByteArray(length).also { readBytes(it) suspend fun hasJoined(username: String, hash: String): JsonObject { return try { - httpClient.get("https://sessionserver.mojang.com/session/minecraft/hasJoined?username=" + - UrlEscapers.urlFormParameterEscaper().escape(username) + "&serverId=$hash") + httpClient.get( + "https://sessionserver.mojang.com/session/minecraft/hasJoined?username=" + + UrlEscapers.urlFormParameterEscaper().escape(username) + "&serverId=$hash" + ) } catch (e: Exception) { throw StacklessException("Couldn't authenticate with session servers", e) } @@ -189,4 +215,49 @@ fun sha512Hex(data: ByteArray): String { return MessageDigest.getInstance("SHA-512").digest(data) .asUByteArray() .joinToString("") { it.toString(16).padStart(2, '0') } +} + +fun eventLoopGroup(): EventLoopGroup { + if (VIAaaSConfig.isNativeTransportMc) { + if (Epoll.isAvailable()) return EpollEventLoopGroup() + if (KQueue.isAvailable()) return KQueueEventLoopGroup() + } + return NioEventLoopGroup() +} + +fun channelServerSocketFactory(eventLoop: EventLoopGroup): ChannelFactory { + return when (eventLoop) { + is EpollEventLoopGroup -> ChannelFactory { EpollServerSocketChannel() } + is KQueueEventLoopGroup -> ChannelFactory { KQueueServerSocketChannel() } + else -> ChannelFactory { NioServerSocketChannel() } + } +} + +fun channelSocketFactory(eventLoop: EventLoopGroup): ChannelFactory { + return when (eventLoop) { + is EpollEventLoopGroup -> ChannelFactory { EpollSocketChannel() } + is KQueueEventLoopGroup -> ChannelFactory { KQueueSocketChannel() } + else -> ChannelFactory { NioSocketChannel() } + } +} + +fun channelDatagramFactory(eventLoop: EventLoopGroup): ChannelFactory { + return when (eventLoop) { + is EpollEventLoopGroup -> ChannelFactory { EpollDatagramChannel() } + is KQueueEventLoopGroup -> ChannelFactory { KQueueDatagramChannel() } + else -> ChannelFactory { NioDatagramChannel() } + } +} + +fun reverseLookup(address: InetAddress): String { + val bytes = address.address + return if (bytes.size == 4) { + // IPv4 + bytes.reversed() + .joinToString(".") { it.toUByte().toString() } + ".in-addr.arpa" + } else { // IPv6 + bytes.flatMap { it.toUByte().toString(16).padStart(2, '0').toCharArray().map { it.toString() } } + .asReversed() + .joinToString(".") + ".ip6.arpa" + } } \ No newline at end of file diff --git a/src/main/kotlin/com/viaversion/aas/VIAaaS.kt b/src/main/kotlin/com/viaversion/aas/VIAaaS.kt index 0b9f75b..7249aff 100644 --- a/src/main/kotlin/com/viaversion/aas/VIAaaS.kt +++ b/src/main/kotlin/com/viaversion/aas/VIAaaS.kt @@ -23,20 +23,10 @@ import io.ktor.network.tls.certificates.* import io.ktor.server.engine.* import io.ktor.server.netty.* import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.* -import io.netty.channel.epoll.Epoll -import io.netty.channel.epoll.EpollEventLoopGroup -import io.netty.channel.epoll.EpollServerSocketChannel -import io.netty.channel.epoll.EpollSocketChannel -import io.netty.channel.kqueue.KQueue -import io.netty.channel.kqueue.KQueueEventLoopGroup -import io.netty.channel.kqueue.KQueueServerSocketChannel -import io.netty.channel.kqueue.KQueueSocketChannel -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.ServerSocketChannel -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.socket.nio.NioSocketChannel +import io.netty.channel.ChannelFuture +import io.netty.channel.ChannelOption +import io.netty.channel.WriteBufferWaterMark +import io.netty.resolver.dns.DnsNameResolverBuilder import io.netty.util.concurrent.Future import org.apache.logging.log4j.Level import org.apache.logging.log4j.io.IoBuilder @@ -68,34 +58,13 @@ val mcCryptoKey = KeyPairGenerator.getInstance("RSA").let { it.genKeyPair() } -fun eventLoopGroup(): EventLoopGroup { - if (VIAaaSConfig.isNativeTransportMc) { - if (Epoll.isAvailable()) return EpollEventLoopGroup() - if (KQueue.isAvailable()) return KQueueEventLoopGroup() - } - return NioEventLoopGroup() -} - -fun channelServerSocketFactory(eventLoop: EventLoopGroup): ChannelFactory { - return when (eventLoop) { - is EpollEventLoopGroup -> ChannelFactory { EpollServerSocketChannel() } - is KQueueEventLoopGroup -> ChannelFactory { KQueueServerSocketChannel() } - else -> ChannelFactory { NioServerSocketChannel() } - } -} - -fun channelSocketFactory(eventLoop: EventLoopGroup): ChannelFactory { - return when (eventLoop) { - is EpollEventLoopGroup -> ChannelFactory { EpollSocketChannel() } - is KQueueEventLoopGroup -> ChannelFactory { KQueueSocketChannel() } - else -> ChannelFactory { NioSocketChannel() } - } -} - val parentLoop = eventLoopGroup() val childLoop = eventLoopGroup() var chFuture: ChannelFuture? = null var ktorServer: NettyApplicationEngine? = null +val dnsResolver = DnsNameResolverBuilder(childLoop.next()) + .channelFactory(channelDatagramFactory(childLoop)) + .build() fun main(args: Array) { try { @@ -107,9 +76,9 @@ fun main(args: Array) { initFuture.complete(Unit) addShutdownHook() - + Thread { VIAaaSConsole.start() }.start() - + serverFinishing.join() } catch (e: Exception) { e.printStackTrace() @@ -162,9 +131,10 @@ private fun initVia() { ) MappingDataLoader.enableMappingsCache() (Via.getManager() as ViaManagerImpl).init() - ProtocolVersion.register(-2, "AUTO") AspirinRewind.init(ViaRewindConfigImpl(File("config/viarewind.yml"))) AspirinBackwards.init(File("config/viabackwards")) + + ProtocolVersion.register(-2, "AUTO") registerAspirinProtocols() } diff --git a/src/main/kotlin/com/viaversion/aas/handler/state/LoginState.kt b/src/main/kotlin/com/viaversion/aas/handler/state/LoginState.kt index df90af6..8a7a034 100644 --- a/src/main/kotlin/com/viaversion/aas/handler/state/LoginState.kt +++ b/src/main/kotlin/com/viaversion/aas/handler/state/LoginState.kt @@ -113,7 +113,6 @@ class LoginState : MinecraftConnectionState { frontHandler.endRemoteAddress, handler.data.backHandler!!.endRemoteAddress ).await() - if (!handler.data.frontChannel.isActive) return@launch val cryptoResponse = CryptoResponse() cryptoResponse.encryptedKey = encryptRsa(backPublicKey, backKey) @@ -179,7 +178,7 @@ class LoginState : MinecraftConnectionState { loginStart.username = backName!! send(handler.data.backChannel!!, loginStart, true) } catch (e: Exception) { - handler.data.frontChannel.pipeline().fireExceptionCaught(StacklessException("Login error: $e", e)) + handler.data.frontChannel.pipeline().fireExceptionCaught(e) } } } diff --git a/src/main/kotlin/com/viaversion/aas/handler/state/Util.kt b/src/main/kotlin/com/viaversion/aas/handler/state/Util.kt index 9c19338..cbdf05f 100644 --- a/src/main/kotlin/com/viaversion/aas/handler/state/Util.kt +++ b/src/main/kotlin/com/viaversion/aas/handler/state/Util.kt @@ -10,29 +10,25 @@ import com.viaversion.aas.handler.autoprotocol.ProtocolDetector import com.viaversion.aas.handler.forward import com.viaversion.aas.util.StacklessException import com.viaversion.viaversion.api.protocol.packet.State +import io.ktor.server.netty.* import io.netty.bootstrap.Bootstrap import io.netty.channel.Channel -import io.netty.channel.ChannelFutureListener import io.netty.channel.ChannelOption import io.netty.channel.socket.SocketChannel import io.netty.resolver.NoopAddressResolverGroup import kotlinx.coroutines.future.await -import kotlinx.coroutines.launch import kotlinx.coroutines.withTimeoutOrNull import java.net.Inet4Address -import java.net.InetAddress import java.net.InetSocketAddress -import java.util.concurrent.CompletableFuture -private fun createBackChannel( +private suspend fun createBackChannel( handler: MinecraftHandler, socketAddr: InetSocketAddress, state: State, extraData: String? -): CompletableFuture { - val future = CompletableFuture() +): Channel { val loop = handler.data.frontChannel.eventLoop() - Bootstrap() + val channel = Bootstrap() .handler(BackEndInit(handler.data)) .channelFactory(channelSocketFactory(loop.parent())) .group(loop) @@ -42,47 +38,40 @@ private fun createBackChannel( .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10_000) // We need to show the error before the client timeout .resolver(NoopAddressResolverGroup.INSTANCE) .connect(socketAddr) - .addListener(ChannelFutureListener { - try { - if (!it.isSuccess) throw it.cause() + .also { it.suspendAwait() } + .channel() - mcLogger.info("+ ${handler.endRemoteAddress} -> $socketAddr") - handler.data.backChannel = it.channel() as SocketChannel + mcLogger.info("+ ${handler.endRemoteAddress} -> $socketAddr") + handler.data.backChannel = channel as SocketChannel - handler.coroutineScope.launch { - if (handler.data.viaBackServerVer == null) { - try { - val detectedProtocol = withTimeoutOrNull(10_000) { - ProtocolDetector.detectVersion(socketAddr).await() - } - - if (detectedProtocol != null && detectedProtocol.version != -1) { - handler.data.viaBackServerVer = detectedProtocol.version - } else { - handler.data.viaBackServerVer = -1 // fallback - } - } catch (e: Exception) { - mcLogger.warn("Failed to auto-detect version for $socketAddr: $e") - mcLogger.debug("Stacktrace: ", e) - } - } - - val packet = Handshake() - packet.nextState = state - packet.protocolId = handler.data.frontVer!! - packet.address = socketAddr.hostString + if (extraData != null) 0.toChar() + extraData else "" - packet.port = socketAddr.port - - forward(handler, packet, true) - - handler.data.frontChannel.setAutoRead(true) - future.complete(it.channel()) - } - } catch (e: Exception) { - future.completeExceptionally(it.cause()) + if (handler.data.viaBackServerVer == null) { + try { + val detectedProtocol = withTimeoutOrNull(10_000) { + ProtocolDetector.detectVersion(socketAddr).await() } - }) - return future + + if (detectedProtocol != null && detectedProtocol.version != -1) { + handler.data.viaBackServerVer = detectedProtocol.version + } else { + handler.data.viaBackServerVer = -1 // fallback + } + } catch (e: Exception) { + mcLogger.warn("Failed to auto-detect version for $socketAddr: $e") + mcLogger.debug("Stacktrace: ", e) + } + } + + val packet = Handshake() + packet.nextState = state + packet.protocolId = handler.data.frontVer!! + packet.address = socketAddr.hostString + if (extraData != null) 0.toChar() + extraData else "" + packet.port = socketAddr.port + + forward(handler, packet, true) + + handler.data.frontChannel.setAutoRead(true) + + return channel } private suspend fun tryBackAddresses( @@ -102,7 +91,7 @@ private suspend fun tryBackAddresses( throw StacklessException("Not allowed") } - createBackChannel(handler, socketAddr, state, extraData).await() + createBackChannel(handler, socketAddr, state, extraData) return // Finally it worked! } catch (e: Exception) { latestException = e @@ -112,7 +101,7 @@ private suspend fun tryBackAddresses( throw latestException ?: StacklessException("No address found") } -private fun resolveBackendAddresses(hostAndPort: HostAndPort): List { +private suspend fun resolveBackendAddresses(hostAndPort: HostAndPort): List { val srvResolved = resolveSrv(hostAndPort) val removedEndDot = srvResolved.host.replace(Regex("\\.$"), "") @@ -120,7 +109,9 @@ private fun resolveBackendAddresses(hostAndPort: HostAndPort): List listOf(InetSocketAddress.createUnresolved(removedEndDot, srvResolved.port)) - else -> InetAddress.getAllByName(srvResolved.host) + else -> dnsResolver + .resolveAll(srvResolved.host) + .suspendAwait() .groupBy { it is Inet4Address } .toSortedMap() // I'm sorry, IPv4, but my true love is IPv6... We can still be friends though... .map { InetSocketAddress(it.value.random(), srvResolved.port) } @@ -137,10 +128,10 @@ suspend fun connectBack( try { val addresses = resolveBackendAddresses(HostAndPort.fromParts(address, port)) - if (addresses.isEmpty()) throw StacklessException("Hostname has no IP address") + if (addresses.isEmpty()) throw StacklessException("Hostname has no IP addresses") tryBackAddresses(handler, addresses, state, extraData) } catch (e: Exception) { - throw StacklessException("Couldn't connect: " + e, e) + throw StacklessException("Couldn't connect: $e", e) } } diff --git a/src/main/kotlin/com/viaversion/aas/web/WebDashboardServer.kt b/src/main/kotlin/com/viaversion/aas/web/WebDashboardServer.kt index d02b6d0..d96c4b6 100644 --- a/src/main/kotlin/com/viaversion/aas/web/WebDashboardServer.kt +++ b/src/main/kotlin/com/viaversion/aas/web/WebDashboardServer.kt @@ -8,16 +8,18 @@ import com.google.common.cache.CacheLoader import com.google.common.collect.MultimapBuilder import com.google.common.collect.Multimaps import com.google.gson.JsonObject +import com.viaversion.aas.* import com.viaversion.aas.config.VIAaaSConfig -import com.viaversion.aas.httpClient -import com.viaversion.aas.parseUndashedId import com.viaversion.aas.util.StacklessException -import com.viaversion.aas.webLogger import io.ipinfo.api.IPInfo import io.ipinfo.api.model.IPResponse import io.ktor.client.request.* import io.ktor.http.cio.websocket.* +import io.ktor.server.netty.* import io.ktor.websocket.* +import io.netty.handler.codec.dns.DefaultDnsQuestion +import io.netty.handler.codec.dns.DnsPtrRecord +import io.netty.handler.codec.dns.DnsRecordType import kotlinx.coroutines.* import kotlinx.coroutines.future.asCompletableFuture import java.net.InetSocketAddress @@ -28,6 +30,7 @@ import java.util.* import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.TimeUnit +import kotlin.coroutines.coroutineContext class WebDashboardServer { // I don't think i'll need more than 1k/day @@ -84,23 +87,24 @@ class WebDashboardServer { if (!listeners.containsKey(id)) { future.completeExceptionally(StacklessException("No browser listening")) } else { - coroutineScope { + CoroutineScope(coroutineContext).apply { launch(Dispatchers.IO) { var info: IPResponse? = null + var ptr: String? = null (address as? InetSocketAddress)?.let { try { val ipLookup = async(Dispatchers.IO) { - ipInfo.lookupIP(it.address?.hostAddress?.substringBefore("%")) - } - val reverseLookup = async(Dispatchers.IO) { - it.address?.hostName + ipInfo.lookupIP(it.address!!.hostAddress!!.substringBefore("%")) } + val dnsQuery = dnsResolver.resolveAll( + DefaultDnsQuestion(reverseLookup(it.address), DnsRecordType.PTR) + ) info = ipLookup.await() - reverseLookup.await() + ptr = dnsQuery.suspendAwait().first { it is DnsPtrRecord }?.name() } catch (ignored: Exception) { } } - val msg = "Requester: $id $address (${info?.org}, ${info?.city}, ${info?.region}, " + + val msg = "Requester: $id $address ($ptr) (${info?.org}, ${info?.city}, ${info?.region}, " + "${info?.countryCode})\nBackend: $backAddress" listeners[id]?.forEach { it.ws.send(JsonObject().also { diff --git a/src/main/kotlin/com/viaversion/aas/web/WebLogin.kt b/src/main/kotlin/com/viaversion/aas/web/WebLogin.kt index f4b31b6..6b8e125 100644 --- a/src/main/kotlin/com/viaversion/aas/web/WebLogin.kt +++ b/src/main/kotlin/com/viaversion/aas/web/WebLogin.kt @@ -106,7 +106,7 @@ class WebLogin : WebState { } "session_hash_response" -> { val hash = obj.get("session_hash").asString - webClient.server.sessionHashCallbacks.getIfPresent(hash)?.complete(null) + webClient.server.sessionHashCallbacks.getIfPresent(hash)?.complete(Unit) } else -> throw StacklessException("invalid action!") }