fix #153 implement #156

This commit is contained in:
creeper123123321 2021-06-12 12:14:11 -03:00
parent 2879a8ace5
commit 7b891d513b
7 changed files with 153 additions and 118 deletions

View File

@ -37,7 +37,7 @@ compileKotlin.kotlinOptions.jvmTarget = "11"
val gitVersion: groovy.lang.Closure<String> by extra val gitVersion: groovy.lang.Closure<String> by extra
group = "com.github.creeper123123321.viaaas" group = "com.github.creeper123123321.viaaas"
version = "0.4.1+" + try { version = "0.4.2+" + try {
gitVersion() gitVersion()
} catch (e: Exception) { } catch (e: Exception) {
"unknown" "unknown"

View File

@ -10,10 +10,26 @@ import com.viaversion.aas.util.StacklessException
import com.viaversion.viaversion.api.protocol.version.ProtocolVersion import com.viaversion.viaversion.api.protocol.version.ProtocolVersion
import com.viaversion.viaversion.api.type.Type import com.viaversion.viaversion.api.type.Type
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.server.netty.*
import io.netty.buffer.ByteBuf import io.netty.buffer.ByteBuf
import io.netty.channel.Channel import io.netty.channel.Channel
import io.netty.channel.ChannelFactory
import io.netty.channel.ChannelFutureListener import io.netty.channel.ChannelFutureListener
import io.netty.channel.EventLoopGroup
import io.netty.channel.epoll.*
import io.netty.channel.kqueue.*
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.DatagramChannel
import io.netty.channel.socket.ServerSocketChannel
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioDatagramChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.codec.DecoderException import io.netty.handler.codec.DecoderException
import io.netty.handler.codec.dns.DefaultDnsQuestion
import io.netty.handler.codec.dns.DefaultDnsRawRecord
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder
import io.netty.handler.codec.dns.DnsRecordType
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.math.BigInteger import java.math.BigInteger
import java.net.InetAddress import java.net.InetAddress
@ -28,7 +44,6 @@ import java.util.concurrent.TimeUnit
import javax.crypto.Cipher import javax.crypto.Cipher
import javax.crypto.spec.IvParameterSpec import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec import javax.crypto.spec.SecretKeySpec
import javax.naming.directory.InitialDirContext
val badLength = DecoderException("Invalid length!") val badLength = DecoderException("Invalid length!")
val mcLogger = LoggerFactory.getLogger("VIAaaS MC") val mcLogger = LoggerFactory.getLogger("VIAaaS MC")
@ -37,18 +52,27 @@ val viaaasLogger = LoggerFactory.getLogger("VIAaaS")
val secureRandom = if (VIAaaSConfig.useStrongRandom) SecureRandom.getInstanceStrong() else SecureRandom() val secureRandom = if (VIAaaSConfig.useStrongRandom) SecureRandom.getInstanceStrong() else SecureRandom()
fun resolveSrv(hostAndPort: HostAndPort): HostAndPort { suspend fun resolveSrv(hostAndPort: HostAndPort): HostAndPort {
if (hostAndPort.host.endsWith(".onion", ignoreCase = true)) return hostAndPort if (hostAndPort.host.endsWith(".onion", ignoreCase = true)) return hostAndPort
if (hostAndPort.port == 25565) { if (hostAndPort.port == 25565) {
try { try {
// https://github.com/GeyserMC/Geyser/blob/99e72f35b308542cf0dbfb5b58816503c3d6a129/connector/src/main/java/org/geysermc/connector/GeyserConnector.java // stolen from PacketLib (MIT) https://github.com/Camotoy/PacketLib/blob/312cff5f975be54cf2d92208ae2947dbda8b9f59/src/main/java/com/github/steveice10/packetlib/tcp/TcpClientSession.java
val attr = InitialDirContext() dnsResolver
.getAttributes("dns:///_minecraft._tcp.${hostAndPort.host}", arrayOf("SRV"))["SRV"] .resolveAll(DefaultDnsQuestion("_minecraft._tcp.${hostAndPort.host}", DnsRecordType.SRV))
if (attr != null && attr.size() > 0) { .suspendAwait()
val record = (attr.get(0) as String).split(" ") .forEach { record ->
return HostAndPort.fromParts(record[3], record[2].toInt()) if (record is DefaultDnsRawRecord && record.type() == DnsRecordType.SRV) {
val content = record.content()
content.skipBytes(4)
val port = content.readUnsignedShort()
val address = DefaultDnsRecordDecoder.decodeName(content)
return HostAndPort.fromParts(address, port)
} }
} catch (ignored: Exception) { // DuckDNS workaround }
} catch (e: Exception) {
viaaasLogger.debug("Couldn't resolve SRV", e)
} }
} }
return hostAndPort return hostAndPort
@ -169,8 +193,10 @@ fun ByteBuf.readByteArray(length: Int) = ByteArray(length).also { readBytes(it)
suspend fun hasJoined(username: String, hash: String): JsonObject { suspend fun hasJoined(username: String, hash: String): JsonObject {
return try { return try {
httpClient.get("https://sessionserver.mojang.com/session/minecraft/hasJoined?username=" + httpClient.get(
UrlEscapers.urlFormParameterEscaper().escape(username) + "&serverId=$hash") "https://sessionserver.mojang.com/session/minecraft/hasJoined?username=" +
UrlEscapers.urlFormParameterEscaper().escape(username) + "&serverId=$hash"
)
} catch (e: Exception) { } catch (e: Exception) {
throw StacklessException("Couldn't authenticate with session servers", e) throw StacklessException("Couldn't authenticate with session servers", e)
} }
@ -190,3 +216,48 @@ fun sha512Hex(data: ByteArray): String {
.asUByteArray() .asUByteArray()
.joinToString("") { it.toString(16).padStart(2, '0') } .joinToString("") { it.toString(16).padStart(2, '0') }
} }
fun eventLoopGroup(): EventLoopGroup {
if (VIAaaSConfig.isNativeTransportMc) {
if (Epoll.isAvailable()) return EpollEventLoopGroup()
if (KQueue.isAvailable()) return KQueueEventLoopGroup()
}
return NioEventLoopGroup()
}
fun channelServerSocketFactory(eventLoop: EventLoopGroup): ChannelFactory<ServerSocketChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollServerSocketChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueServerSocketChannel() }
else -> ChannelFactory { NioServerSocketChannel() }
}
}
fun channelSocketFactory(eventLoop: EventLoopGroup): ChannelFactory<SocketChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollSocketChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueSocketChannel() }
else -> ChannelFactory { NioSocketChannel() }
}
}
fun channelDatagramFactory(eventLoop: EventLoopGroup): ChannelFactory<DatagramChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollDatagramChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueDatagramChannel() }
else -> ChannelFactory { NioDatagramChannel() }
}
}
fun reverseLookup(address: InetAddress): String {
val bytes = address.address
return if (bytes.size == 4) {
// IPv4
bytes.reversed()
.joinToString(".") { it.toUByte().toString() } + ".in-addr.arpa"
} else { // IPv6
bytes.flatMap { it.toUByte().toString(16).padStart(2, '0').toCharArray().map { it.toString() } }
.asReversed()
.joinToString(".") + ".ip6.arpa"
}
}

View File

@ -23,20 +23,10 @@ import io.ktor.network.tls.certificates.*
import io.ktor.server.engine.* import io.ktor.server.engine.*
import io.ktor.server.netty.* import io.ktor.server.netty.*
import io.netty.bootstrap.ServerBootstrap import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.* import io.netty.channel.ChannelFuture
import io.netty.channel.epoll.Epoll import io.netty.channel.ChannelOption
import io.netty.channel.epoll.EpollEventLoopGroup import io.netty.channel.WriteBufferWaterMark
import io.netty.channel.epoll.EpollServerSocketChannel import io.netty.resolver.dns.DnsNameResolverBuilder
import io.netty.channel.epoll.EpollSocketChannel
import io.netty.channel.kqueue.KQueue
import io.netty.channel.kqueue.KQueueEventLoopGroup
import io.netty.channel.kqueue.KQueueServerSocketChannel
import io.netty.channel.kqueue.KQueueSocketChannel
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.ServerSocketChannel
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.util.concurrent.Future import io.netty.util.concurrent.Future
import org.apache.logging.log4j.Level import org.apache.logging.log4j.Level
import org.apache.logging.log4j.io.IoBuilder import org.apache.logging.log4j.io.IoBuilder
@ -68,34 +58,13 @@ val mcCryptoKey = KeyPairGenerator.getInstance("RSA").let {
it.genKeyPair() it.genKeyPair()
} }
fun eventLoopGroup(): EventLoopGroup {
if (VIAaaSConfig.isNativeTransportMc) {
if (Epoll.isAvailable()) return EpollEventLoopGroup()
if (KQueue.isAvailable()) return KQueueEventLoopGroup()
}
return NioEventLoopGroup()
}
fun channelServerSocketFactory(eventLoop: EventLoopGroup): ChannelFactory<ServerSocketChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollServerSocketChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueServerSocketChannel() }
else -> ChannelFactory { NioServerSocketChannel() }
}
}
fun channelSocketFactory(eventLoop: EventLoopGroup): ChannelFactory<SocketChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollSocketChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueSocketChannel() }
else -> ChannelFactory { NioSocketChannel() }
}
}
val parentLoop = eventLoopGroup() val parentLoop = eventLoopGroup()
val childLoop = eventLoopGroup() val childLoop = eventLoopGroup()
var chFuture: ChannelFuture? = null var chFuture: ChannelFuture? = null
var ktorServer: NettyApplicationEngine? = null var ktorServer: NettyApplicationEngine? = null
val dnsResolver = DnsNameResolverBuilder(childLoop.next())
.channelFactory(channelDatagramFactory(childLoop))
.build()
fun main(args: Array<String>) { fun main(args: Array<String>) {
try { try {
@ -162,9 +131,10 @@ private fun initVia() {
) )
MappingDataLoader.enableMappingsCache() MappingDataLoader.enableMappingsCache()
(Via.getManager() as ViaManagerImpl).init() (Via.getManager() as ViaManagerImpl).init()
ProtocolVersion.register(-2, "AUTO")
AspirinRewind.init(ViaRewindConfigImpl(File("config/viarewind.yml"))) AspirinRewind.init(ViaRewindConfigImpl(File("config/viarewind.yml")))
AspirinBackwards.init(File("config/viabackwards")) AspirinBackwards.init(File("config/viabackwards"))
ProtocolVersion.register(-2, "AUTO")
registerAspirinProtocols() registerAspirinProtocols()
} }

View File

@ -113,7 +113,6 @@ class LoginState : MinecraftConnectionState {
frontHandler.endRemoteAddress, frontHandler.endRemoteAddress,
handler.data.backHandler!!.endRemoteAddress handler.data.backHandler!!.endRemoteAddress
).await() ).await()
if (!handler.data.frontChannel.isActive) return@launch
val cryptoResponse = CryptoResponse() val cryptoResponse = CryptoResponse()
cryptoResponse.encryptedKey = encryptRsa(backPublicKey, backKey) cryptoResponse.encryptedKey = encryptRsa(backPublicKey, backKey)
@ -179,7 +178,7 @@ class LoginState : MinecraftConnectionState {
loginStart.username = backName!! loginStart.username = backName!!
send(handler.data.backChannel!!, loginStart, true) send(handler.data.backChannel!!, loginStart, true)
} catch (e: Exception) { } catch (e: Exception) {
handler.data.frontChannel.pipeline().fireExceptionCaught(StacklessException("Login error: $e", e)) handler.data.frontChannel.pipeline().fireExceptionCaught(e)
} }
} }
} }

View File

@ -10,29 +10,25 @@ import com.viaversion.aas.handler.autoprotocol.ProtocolDetector
import com.viaversion.aas.handler.forward import com.viaversion.aas.handler.forward
import com.viaversion.aas.util.StacklessException import com.viaversion.aas.util.StacklessException
import com.viaversion.viaversion.api.protocol.packet.State import com.viaversion.viaversion.api.protocol.packet.State
import io.ktor.server.netty.*
import io.netty.bootstrap.Bootstrap import io.netty.bootstrap.Bootstrap
import io.netty.channel.Channel import io.netty.channel.Channel
import io.netty.channel.ChannelFutureListener
import io.netty.channel.ChannelOption import io.netty.channel.ChannelOption
import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.SocketChannel
import io.netty.resolver.NoopAddressResolverGroup import io.netty.resolver.NoopAddressResolverGroup
import kotlinx.coroutines.future.await import kotlinx.coroutines.future.await
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull import kotlinx.coroutines.withTimeoutOrNull
import java.net.Inet4Address import java.net.Inet4Address
import java.net.InetAddress
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.util.concurrent.CompletableFuture
private fun createBackChannel( private suspend fun createBackChannel(
handler: MinecraftHandler, handler: MinecraftHandler,
socketAddr: InetSocketAddress, socketAddr: InetSocketAddress,
state: State, state: State,
extraData: String? extraData: String?
): CompletableFuture<Channel> { ): Channel {
val future = CompletableFuture<Channel>()
val loop = handler.data.frontChannel.eventLoop() val loop = handler.data.frontChannel.eventLoop()
Bootstrap() val channel = Bootstrap()
.handler(BackEndInit(handler.data)) .handler(BackEndInit(handler.data))
.channelFactory(channelSocketFactory(loop.parent())) .channelFactory(channelSocketFactory(loop.parent()))
.group(loop) .group(loop)
@ -42,14 +38,12 @@ private fun createBackChannel(
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10_000) // We need to show the error before the client timeout .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10_000) // We need to show the error before the client timeout
.resolver(NoopAddressResolverGroup.INSTANCE) .resolver(NoopAddressResolverGroup.INSTANCE)
.connect(socketAddr) .connect(socketAddr)
.addListener(ChannelFutureListener { .also { it.suspendAwait() }
try { .channel()
if (!it.isSuccess) throw it.cause()
mcLogger.info("+ ${handler.endRemoteAddress} -> $socketAddr") mcLogger.info("+ ${handler.endRemoteAddress} -> $socketAddr")
handler.data.backChannel = it.channel() as SocketChannel handler.data.backChannel = channel as SocketChannel
handler.coroutineScope.launch {
if (handler.data.viaBackServerVer == null) { if (handler.data.viaBackServerVer == null) {
try { try {
val detectedProtocol = withTimeoutOrNull(10_000) { val detectedProtocol = withTimeoutOrNull(10_000) {
@ -76,13 +70,8 @@ private fun createBackChannel(
forward(handler, packet, true) forward(handler, packet, true)
handler.data.frontChannel.setAutoRead(true) handler.data.frontChannel.setAutoRead(true)
future.complete(it.channel())
} return channel
} catch (e: Exception) {
future.completeExceptionally(it.cause())
}
})
return future
} }
private suspend fun tryBackAddresses( private suspend fun tryBackAddresses(
@ -102,7 +91,7 @@ private suspend fun tryBackAddresses(
throw StacklessException("Not allowed") throw StacklessException("Not allowed")
} }
createBackChannel(handler, socketAddr, state, extraData).await() createBackChannel(handler, socketAddr, state, extraData)
return // Finally it worked! return // Finally it worked!
} catch (e: Exception) { } catch (e: Exception) {
latestException = e latestException = e
@ -112,7 +101,7 @@ private suspend fun tryBackAddresses(
throw latestException ?: StacklessException("No address found") throw latestException ?: StacklessException("No address found")
} }
private fun resolveBackendAddresses(hostAndPort: HostAndPort): List<InetSocketAddress> { private suspend fun resolveBackendAddresses(hostAndPort: HostAndPort): List<InetSocketAddress> {
val srvResolved = resolveSrv(hostAndPort) val srvResolved = resolveSrv(hostAndPort)
val removedEndDot = srvResolved.host.replace(Regex("\\.$"), "") val removedEndDot = srvResolved.host.replace(Regex("\\.$"), "")
@ -120,7 +109,9 @@ private fun resolveBackendAddresses(hostAndPort: HostAndPort): List<InetSocketAd
return when { return when {
removedEndDot.endsWith(".onion", ignoreCase = true) -> removedEndDot.endsWith(".onion", ignoreCase = true) ->
listOf(InetSocketAddress.createUnresolved(removedEndDot, srvResolved.port)) listOf(InetSocketAddress.createUnresolved(removedEndDot, srvResolved.port))
else -> InetAddress.getAllByName(srvResolved.host) else -> dnsResolver
.resolveAll(srvResolved.host)
.suspendAwait()
.groupBy { it is Inet4Address } .groupBy { it is Inet4Address }
.toSortedMap() // I'm sorry, IPv4, but my true love is IPv6... We can still be friends though... .toSortedMap() // I'm sorry, IPv4, but my true love is IPv6... We can still be friends though...
.map { InetSocketAddress(it.value.random(), srvResolved.port) } .map { InetSocketAddress(it.value.random(), srvResolved.port) }
@ -137,10 +128,10 @@ suspend fun connectBack(
try { try {
val addresses = resolveBackendAddresses(HostAndPort.fromParts(address, port)) val addresses = resolveBackendAddresses(HostAndPort.fromParts(address, port))
if (addresses.isEmpty()) throw StacklessException("Hostname has no IP address") if (addresses.isEmpty()) throw StacklessException("Hostname has no IP addresses")
tryBackAddresses(handler, addresses, state, extraData) tryBackAddresses(handler, addresses, state, extraData)
} catch (e: Exception) { } catch (e: Exception) {
throw StacklessException("Couldn't connect: " + e, e) throw StacklessException("Couldn't connect: $e", e)
} }
} }

View File

@ -8,16 +8,18 @@ import com.google.common.cache.CacheLoader
import com.google.common.collect.MultimapBuilder import com.google.common.collect.MultimapBuilder
import com.google.common.collect.Multimaps import com.google.common.collect.Multimaps
import com.google.gson.JsonObject import com.google.gson.JsonObject
import com.viaversion.aas.*
import com.viaversion.aas.config.VIAaaSConfig import com.viaversion.aas.config.VIAaaSConfig
import com.viaversion.aas.httpClient
import com.viaversion.aas.parseUndashedId
import com.viaversion.aas.util.StacklessException import com.viaversion.aas.util.StacklessException
import com.viaversion.aas.webLogger
import io.ipinfo.api.IPInfo import io.ipinfo.api.IPInfo
import io.ipinfo.api.model.IPResponse import io.ipinfo.api.model.IPResponse
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.http.cio.websocket.* import io.ktor.http.cio.websocket.*
import io.ktor.server.netty.*
import io.ktor.websocket.* import io.ktor.websocket.*
import io.netty.handler.codec.dns.DefaultDnsQuestion
import io.netty.handler.codec.dns.DnsPtrRecord
import io.netty.handler.codec.dns.DnsRecordType
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.future.asCompletableFuture import kotlinx.coroutines.future.asCompletableFuture
import java.net.InetSocketAddress import java.net.InetSocketAddress
@ -28,6 +30,7 @@ import java.util.*
import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.coroutines.coroutineContext
class WebDashboardServer { class WebDashboardServer {
// I don't think i'll need more than 1k/day // I don't think i'll need more than 1k/day
@ -84,23 +87,24 @@ class WebDashboardServer {
if (!listeners.containsKey(id)) { if (!listeners.containsKey(id)) {
future.completeExceptionally(StacklessException("No browser listening")) future.completeExceptionally(StacklessException("No browser listening"))
} else { } else {
coroutineScope { CoroutineScope(coroutineContext).apply {
launch(Dispatchers.IO) { launch(Dispatchers.IO) {
var info: IPResponse? = null var info: IPResponse? = null
var ptr: String? = null
(address as? InetSocketAddress)?.let { (address as? InetSocketAddress)?.let {
try { try {
val ipLookup = async(Dispatchers.IO) { val ipLookup = async(Dispatchers.IO) {
ipInfo.lookupIP(it.address?.hostAddress?.substringBefore("%")) ipInfo.lookupIP(it.address!!.hostAddress!!.substringBefore("%"))
}
val reverseLookup = async(Dispatchers.IO) {
it.address?.hostName
} }
val dnsQuery = dnsResolver.resolveAll(
DefaultDnsQuestion(reverseLookup(it.address), DnsRecordType.PTR)
)
info = ipLookup.await() info = ipLookup.await()
reverseLookup.await() ptr = dnsQuery.suspendAwait().first { it is DnsPtrRecord }?.name()
} catch (ignored: Exception) { } catch (ignored: Exception) {
} }
} }
val msg = "Requester: $id $address (${info?.org}, ${info?.city}, ${info?.region}, " + val msg = "Requester: $id $address ($ptr) (${info?.org}, ${info?.city}, ${info?.region}, " +
"${info?.countryCode})\nBackend: $backAddress" "${info?.countryCode})\nBackend: $backAddress"
listeners[id]?.forEach { listeners[id]?.forEach {
it.ws.send(JsonObject().also { it.ws.send(JsonObject().also {

View File

@ -106,7 +106,7 @@ class WebLogin : WebState {
} }
"session_hash_response" -> { "session_hash_response" -> {
val hash = obj.get("session_hash").asString val hash = obj.get("session_hash").asString
webClient.server.sessionHashCallbacks.getIfPresent(hash)?.complete(null) webClient.server.sessionHashCallbacks.getIfPresent(hash)?.complete(Unit)
} }
else -> throw StacklessException("invalid action!") else -> throw StacklessException("invalid action!")
} }