fix #153 implement #156

This commit is contained in:
creeper123123321 2021-06-12 12:14:11 -03:00
parent 2879a8ace5
commit 7b891d513b
7 changed files with 153 additions and 118 deletions

View File

@ -37,7 +37,7 @@ compileKotlin.kotlinOptions.jvmTarget = "11"
val gitVersion: groovy.lang.Closure<String> by extra
group = "com.github.creeper123123321.viaaas"
version = "0.4.1+" + try {
version = "0.4.2+" + try {
gitVersion()
} catch (e: Exception) {
"unknown"

View File

@ -10,10 +10,26 @@ import com.viaversion.aas.util.StacklessException
import com.viaversion.viaversion.api.protocol.version.ProtocolVersion
import com.viaversion.viaversion.api.type.Type
import io.ktor.client.request.*
import io.ktor.server.netty.*
import io.netty.buffer.ByteBuf
import io.netty.channel.Channel
import io.netty.channel.ChannelFactory
import io.netty.channel.ChannelFutureListener
import io.netty.channel.EventLoopGroup
import io.netty.channel.epoll.*
import io.netty.channel.kqueue.*
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.DatagramChannel
import io.netty.channel.socket.ServerSocketChannel
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioDatagramChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.codec.DecoderException
import io.netty.handler.codec.dns.DefaultDnsQuestion
import io.netty.handler.codec.dns.DefaultDnsRawRecord
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder
import io.netty.handler.codec.dns.DnsRecordType
import org.slf4j.LoggerFactory
import java.math.BigInteger
import java.net.InetAddress
@ -28,7 +44,6 @@ import java.util.concurrent.TimeUnit
import javax.crypto.Cipher
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import javax.naming.directory.InitialDirContext
val badLength = DecoderException("Invalid length!")
val mcLogger = LoggerFactory.getLogger("VIAaaS MC")
@ -37,18 +52,27 @@ val viaaasLogger = LoggerFactory.getLogger("VIAaaS")
val secureRandom = if (VIAaaSConfig.useStrongRandom) SecureRandom.getInstanceStrong() else SecureRandom()
fun resolveSrv(hostAndPort: HostAndPort): HostAndPort {
suspend fun resolveSrv(hostAndPort: HostAndPort): HostAndPort {
if (hostAndPort.host.endsWith(".onion", ignoreCase = true)) return hostAndPort
if (hostAndPort.port == 25565) {
try {
// https://github.com/GeyserMC/Geyser/blob/99e72f35b308542cf0dbfb5b58816503c3d6a129/connector/src/main/java/org/geysermc/connector/GeyserConnector.java
val attr = InitialDirContext()
.getAttributes("dns:///_minecraft._tcp.${hostAndPort.host}", arrayOf("SRV"))["SRV"]
if (attr != null && attr.size() > 0) {
val record = (attr.get(0) as String).split(" ")
return HostAndPort.fromParts(record[3], record[2].toInt())
}
} catch (ignored: Exception) { // DuckDNS workaround
// stolen from PacketLib (MIT) https://github.com/Camotoy/PacketLib/blob/312cff5f975be54cf2d92208ae2947dbda8b9f59/src/main/java/com/github/steveice10/packetlib/tcp/TcpClientSession.java
dnsResolver
.resolveAll(DefaultDnsQuestion("_minecraft._tcp.${hostAndPort.host}", DnsRecordType.SRV))
.suspendAwait()
.forEach { record ->
if (record is DefaultDnsRawRecord && record.type() == DnsRecordType.SRV) {
val content = record.content()
content.skipBytes(4)
val port = content.readUnsignedShort()
val address = DefaultDnsRecordDecoder.decodeName(content)
return HostAndPort.fromParts(address, port)
}
}
} catch (e: Exception) {
viaaasLogger.debug("Couldn't resolve SRV", e)
}
}
return hostAndPort
@ -169,8 +193,10 @@ fun ByteBuf.readByteArray(length: Int) = ByteArray(length).also { readBytes(it)
suspend fun hasJoined(username: String, hash: String): JsonObject {
return try {
httpClient.get("https://sessionserver.mojang.com/session/minecraft/hasJoined?username=" +
UrlEscapers.urlFormParameterEscaper().escape(username) + "&serverId=$hash")
httpClient.get(
"https://sessionserver.mojang.com/session/minecraft/hasJoined?username=" +
UrlEscapers.urlFormParameterEscaper().escape(username) + "&serverId=$hash"
)
} catch (e: Exception) {
throw StacklessException("Couldn't authenticate with session servers", e)
}
@ -189,4 +215,49 @@ fun sha512Hex(data: ByteArray): String {
return MessageDigest.getInstance("SHA-512").digest(data)
.asUByteArray()
.joinToString("") { it.toString(16).padStart(2, '0') }
}
fun eventLoopGroup(): EventLoopGroup {
if (VIAaaSConfig.isNativeTransportMc) {
if (Epoll.isAvailable()) return EpollEventLoopGroup()
if (KQueue.isAvailable()) return KQueueEventLoopGroup()
}
return NioEventLoopGroup()
}
fun channelServerSocketFactory(eventLoop: EventLoopGroup): ChannelFactory<ServerSocketChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollServerSocketChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueServerSocketChannel() }
else -> ChannelFactory { NioServerSocketChannel() }
}
}
fun channelSocketFactory(eventLoop: EventLoopGroup): ChannelFactory<SocketChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollSocketChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueSocketChannel() }
else -> ChannelFactory { NioSocketChannel() }
}
}
fun channelDatagramFactory(eventLoop: EventLoopGroup): ChannelFactory<DatagramChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollDatagramChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueDatagramChannel() }
else -> ChannelFactory { NioDatagramChannel() }
}
}
fun reverseLookup(address: InetAddress): String {
val bytes = address.address
return if (bytes.size == 4) {
// IPv4
bytes.reversed()
.joinToString(".") { it.toUByte().toString() } + ".in-addr.arpa"
} else { // IPv6
bytes.flatMap { it.toUByte().toString(16).padStart(2, '0').toCharArray().map { it.toString() } }
.asReversed()
.joinToString(".") + ".ip6.arpa"
}
}

View File

@ -23,20 +23,10 @@ import io.ktor.network.tls.certificates.*
import io.ktor.server.engine.*
import io.ktor.server.netty.*
import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.*
import io.netty.channel.epoll.Epoll
import io.netty.channel.epoll.EpollEventLoopGroup
import io.netty.channel.epoll.EpollServerSocketChannel
import io.netty.channel.epoll.EpollSocketChannel
import io.netty.channel.kqueue.KQueue
import io.netty.channel.kqueue.KQueueEventLoopGroup
import io.netty.channel.kqueue.KQueueServerSocketChannel
import io.netty.channel.kqueue.KQueueSocketChannel
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.ServerSocketChannel
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.channel.ChannelFuture
import io.netty.channel.ChannelOption
import io.netty.channel.WriteBufferWaterMark
import io.netty.resolver.dns.DnsNameResolverBuilder
import io.netty.util.concurrent.Future
import org.apache.logging.log4j.Level
import org.apache.logging.log4j.io.IoBuilder
@ -68,34 +58,13 @@ val mcCryptoKey = KeyPairGenerator.getInstance("RSA").let {
it.genKeyPair()
}
fun eventLoopGroup(): EventLoopGroup {
if (VIAaaSConfig.isNativeTransportMc) {
if (Epoll.isAvailable()) return EpollEventLoopGroup()
if (KQueue.isAvailable()) return KQueueEventLoopGroup()
}
return NioEventLoopGroup()
}
fun channelServerSocketFactory(eventLoop: EventLoopGroup): ChannelFactory<ServerSocketChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollServerSocketChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueServerSocketChannel() }
else -> ChannelFactory { NioServerSocketChannel() }
}
}
fun channelSocketFactory(eventLoop: EventLoopGroup): ChannelFactory<SocketChannel> {
return when (eventLoop) {
is EpollEventLoopGroup -> ChannelFactory { EpollSocketChannel() }
is KQueueEventLoopGroup -> ChannelFactory { KQueueSocketChannel() }
else -> ChannelFactory { NioSocketChannel() }
}
}
val parentLoop = eventLoopGroup()
val childLoop = eventLoopGroup()
var chFuture: ChannelFuture? = null
var ktorServer: NettyApplicationEngine? = null
val dnsResolver = DnsNameResolverBuilder(childLoop.next())
.channelFactory(channelDatagramFactory(childLoop))
.build()
fun main(args: Array<String>) {
try {
@ -107,9 +76,9 @@ fun main(args: Array<String>) {
initFuture.complete(Unit)
addShutdownHook()
Thread { VIAaaSConsole.start() }.start()
serverFinishing.join()
} catch (e: Exception) {
e.printStackTrace()
@ -162,9 +131,10 @@ private fun initVia() {
)
MappingDataLoader.enableMappingsCache()
(Via.getManager() as ViaManagerImpl).init()
ProtocolVersion.register(-2, "AUTO")
AspirinRewind.init(ViaRewindConfigImpl(File("config/viarewind.yml")))
AspirinBackwards.init(File("config/viabackwards"))
ProtocolVersion.register(-2, "AUTO")
registerAspirinProtocols()
}

View File

@ -113,7 +113,6 @@ class LoginState : MinecraftConnectionState {
frontHandler.endRemoteAddress,
handler.data.backHandler!!.endRemoteAddress
).await()
if (!handler.data.frontChannel.isActive) return@launch
val cryptoResponse = CryptoResponse()
cryptoResponse.encryptedKey = encryptRsa(backPublicKey, backKey)
@ -179,7 +178,7 @@ class LoginState : MinecraftConnectionState {
loginStart.username = backName!!
send(handler.data.backChannel!!, loginStart, true)
} catch (e: Exception) {
handler.data.frontChannel.pipeline().fireExceptionCaught(StacklessException("Login error: $e", e))
handler.data.frontChannel.pipeline().fireExceptionCaught(e)
}
}
}

View File

@ -10,29 +10,25 @@ import com.viaversion.aas.handler.autoprotocol.ProtocolDetector
import com.viaversion.aas.handler.forward
import com.viaversion.aas.util.StacklessException
import com.viaversion.viaversion.api.protocol.packet.State
import io.ktor.server.netty.*
import io.netty.bootstrap.Bootstrap
import io.netty.channel.Channel
import io.netty.channel.ChannelFutureListener
import io.netty.channel.ChannelOption
import io.netty.channel.socket.SocketChannel
import io.netty.resolver.NoopAddressResolverGroup
import kotlinx.coroutines.future.await
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull
import java.net.Inet4Address
import java.net.InetAddress
import java.net.InetSocketAddress
import java.util.concurrent.CompletableFuture
private fun createBackChannel(
private suspend fun createBackChannel(
handler: MinecraftHandler,
socketAddr: InetSocketAddress,
state: State,
extraData: String?
): CompletableFuture<Channel> {
val future = CompletableFuture<Channel>()
): Channel {
val loop = handler.data.frontChannel.eventLoop()
Bootstrap()
val channel = Bootstrap()
.handler(BackEndInit(handler.data))
.channelFactory(channelSocketFactory(loop.parent()))
.group(loop)
@ -42,47 +38,40 @@ private fun createBackChannel(
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10_000) // We need to show the error before the client timeout
.resolver(NoopAddressResolverGroup.INSTANCE)
.connect(socketAddr)
.addListener(ChannelFutureListener {
try {
if (!it.isSuccess) throw it.cause()
.also { it.suspendAwait() }
.channel()
mcLogger.info("+ ${handler.endRemoteAddress} -> $socketAddr")
handler.data.backChannel = it.channel() as SocketChannel
mcLogger.info("+ ${handler.endRemoteAddress} -> $socketAddr")
handler.data.backChannel = channel as SocketChannel
handler.coroutineScope.launch {
if (handler.data.viaBackServerVer == null) {
try {
val detectedProtocol = withTimeoutOrNull(10_000) {
ProtocolDetector.detectVersion(socketAddr).await()
}
if (detectedProtocol != null && detectedProtocol.version != -1) {
handler.data.viaBackServerVer = detectedProtocol.version
} else {
handler.data.viaBackServerVer = -1 // fallback
}
} catch (e: Exception) {
mcLogger.warn("Failed to auto-detect version for $socketAddr: $e")
mcLogger.debug("Stacktrace: ", e)
}
}
val packet = Handshake()
packet.nextState = state
packet.protocolId = handler.data.frontVer!!
packet.address = socketAddr.hostString + if (extraData != null) 0.toChar() + extraData else ""
packet.port = socketAddr.port
forward(handler, packet, true)
handler.data.frontChannel.setAutoRead(true)
future.complete(it.channel())
}
} catch (e: Exception) {
future.completeExceptionally(it.cause())
if (handler.data.viaBackServerVer == null) {
try {
val detectedProtocol = withTimeoutOrNull(10_000) {
ProtocolDetector.detectVersion(socketAddr).await()
}
})
return future
if (detectedProtocol != null && detectedProtocol.version != -1) {
handler.data.viaBackServerVer = detectedProtocol.version
} else {
handler.data.viaBackServerVer = -1 // fallback
}
} catch (e: Exception) {
mcLogger.warn("Failed to auto-detect version for $socketAddr: $e")
mcLogger.debug("Stacktrace: ", e)
}
}
val packet = Handshake()
packet.nextState = state
packet.protocolId = handler.data.frontVer!!
packet.address = socketAddr.hostString + if (extraData != null) 0.toChar() + extraData else ""
packet.port = socketAddr.port
forward(handler, packet, true)
handler.data.frontChannel.setAutoRead(true)
return channel
}
private suspend fun tryBackAddresses(
@ -102,7 +91,7 @@ private suspend fun tryBackAddresses(
throw StacklessException("Not allowed")
}
createBackChannel(handler, socketAddr, state, extraData).await()
createBackChannel(handler, socketAddr, state, extraData)
return // Finally it worked!
} catch (e: Exception) {
latestException = e
@ -112,7 +101,7 @@ private suspend fun tryBackAddresses(
throw latestException ?: StacklessException("No address found")
}
private fun resolveBackendAddresses(hostAndPort: HostAndPort): List<InetSocketAddress> {
private suspend fun resolveBackendAddresses(hostAndPort: HostAndPort): List<InetSocketAddress> {
val srvResolved = resolveSrv(hostAndPort)
val removedEndDot = srvResolved.host.replace(Regex("\\.$"), "")
@ -120,7 +109,9 @@ private fun resolveBackendAddresses(hostAndPort: HostAndPort): List<InetSocketAd
return when {
removedEndDot.endsWith(".onion", ignoreCase = true) ->
listOf(InetSocketAddress.createUnresolved(removedEndDot, srvResolved.port))
else -> InetAddress.getAllByName(srvResolved.host)
else -> dnsResolver
.resolveAll(srvResolved.host)
.suspendAwait()
.groupBy { it is Inet4Address }
.toSortedMap() // I'm sorry, IPv4, but my true love is IPv6... We can still be friends though...
.map { InetSocketAddress(it.value.random(), srvResolved.port) }
@ -137,10 +128,10 @@ suspend fun connectBack(
try {
val addresses = resolveBackendAddresses(HostAndPort.fromParts(address, port))
if (addresses.isEmpty()) throw StacklessException("Hostname has no IP address")
if (addresses.isEmpty()) throw StacklessException("Hostname has no IP addresses")
tryBackAddresses(handler, addresses, state, extraData)
} catch (e: Exception) {
throw StacklessException("Couldn't connect: " + e, e)
throw StacklessException("Couldn't connect: $e", e)
}
}

View File

@ -8,16 +8,18 @@ import com.google.common.cache.CacheLoader
import com.google.common.collect.MultimapBuilder
import com.google.common.collect.Multimaps
import com.google.gson.JsonObject
import com.viaversion.aas.*
import com.viaversion.aas.config.VIAaaSConfig
import com.viaversion.aas.httpClient
import com.viaversion.aas.parseUndashedId
import com.viaversion.aas.util.StacklessException
import com.viaversion.aas.webLogger
import io.ipinfo.api.IPInfo
import io.ipinfo.api.model.IPResponse
import io.ktor.client.request.*
import io.ktor.http.cio.websocket.*
import io.ktor.server.netty.*
import io.ktor.websocket.*
import io.netty.handler.codec.dns.DefaultDnsQuestion
import io.netty.handler.codec.dns.DnsPtrRecord
import io.netty.handler.codec.dns.DnsRecordType
import kotlinx.coroutines.*
import kotlinx.coroutines.future.asCompletableFuture
import java.net.InetSocketAddress
@ -28,6 +30,7 @@ import java.util.*
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit
import kotlin.coroutines.coroutineContext
class WebDashboardServer {
// I don't think i'll need more than 1k/day
@ -84,23 +87,24 @@ class WebDashboardServer {
if (!listeners.containsKey(id)) {
future.completeExceptionally(StacklessException("No browser listening"))
} else {
coroutineScope {
CoroutineScope(coroutineContext).apply {
launch(Dispatchers.IO) {
var info: IPResponse? = null
var ptr: String? = null
(address as? InetSocketAddress)?.let {
try {
val ipLookup = async(Dispatchers.IO) {
ipInfo.lookupIP(it.address?.hostAddress?.substringBefore("%"))
}
val reverseLookup = async(Dispatchers.IO) {
it.address?.hostName
ipInfo.lookupIP(it.address!!.hostAddress!!.substringBefore("%"))
}
val dnsQuery = dnsResolver.resolveAll(
DefaultDnsQuestion(reverseLookup(it.address), DnsRecordType.PTR)
)
info = ipLookup.await()
reverseLookup.await()
ptr = dnsQuery.suspendAwait().first { it is DnsPtrRecord }?.name()
} catch (ignored: Exception) {
}
}
val msg = "Requester: $id $address (${info?.org}, ${info?.city}, ${info?.region}, " +
val msg = "Requester: $id $address ($ptr) (${info?.org}, ${info?.city}, ${info?.region}, " +
"${info?.countryCode})\nBackend: $backAddress"
listeners[id]?.forEach {
it.ws.send(JsonObject().also {

View File

@ -106,7 +106,7 @@ class WebLogin : WebState {
}
"session_hash_response" -> {
val hash = obj.get("session_hash").asString
webClient.server.sessionHashCallbacks.getIfPresent(hash)?.complete(null)
webClient.server.sessionHashCallbacks.getIfPresent(hash)?.complete(Unit)
}
else -> throw StacklessException("invalid action!")
}