From f74b6bb7f88b7547150e0ed84ee3277ec316ea58 Mon Sep 17 00:00:00 2001 From: creeper123123321 <7974274+creeper123123321@users.noreply.github.com> Date: Fri, 19 Feb 2021 20:56:44 -0300 Subject: [PATCH] implement some ratelimits --- .../viaaas/config/VIAaaSConfig.kt | 3 +++ .../viaaas/handler/state/HandshakeState.kt | 17 +++++++++++++ .../creeper123123321/viaaas/web/WebClient.kt | 23 +++++++++++++++++- .../viaaas/web/WebDashboardServer.kt | 5 ++++ .../creeper123123321/viaaas/web/WebLogin.kt | 24 ++++++++++--------- src/main/resources/viaaas.yml | 8 ++++++- 6 files changed, 67 insertions(+), 13 deletions(-) diff --git a/src/main/kotlin/com/github/creeper123123321/viaaas/config/VIAaaSConfig.kt b/src/main/kotlin/com/github/creeper123123321/viaaas/config/VIAaaSConfig.kt index 7ed4e83..fcb145b 100644 --- a/src/main/kotlin/com/github/creeper123123321/viaaas/config/VIAaaSConfig.kt +++ b/src/main/kotlin/com/github/creeper123123321/viaaas/config/VIAaaSConfig.kt @@ -37,4 +37,7 @@ object VIAaaSConfig : Config(File("config/viaaas.yml")) { val forceOnlineMode: Boolean get() = this.getBoolean("force-online-mode", false) val showVersionPing: Boolean get() = this.getBoolean("show-version-ping", true) val showBrandInfo: Boolean get() = this.getBoolean("show-brand-info", true) + val rateLimitWs: Double get() = this.getDouble("rate-limit-ws", 1.0) + val rateLimitConnectionMc: Double get() = this.getDouble("rate-limit-connection-mc", 10.0) + val listeningWsLimit: Int get() = this.getInt("listening-ws-limit", 16) } \ No newline at end of file diff --git a/src/main/kotlin/com/github/creeper123123321/viaaas/handler/state/HandshakeState.kt b/src/main/kotlin/com/github/creeper123123321/viaaas/handler/state/HandshakeState.kt index 567dea2..8b75332 100644 --- a/src/main/kotlin/com/github/creeper123123321/viaaas/handler/state/HandshakeState.kt +++ b/src/main/kotlin/com/github/creeper123123321/viaaas/handler/state/HandshakeState.kt @@ -6,10 +6,23 @@ import com.github.creeper123123321.viaaas.handler.MinecraftHandler import com.github.creeper123123321.viaaas.mcLogger import com.github.creeper123123321.viaaas.packet.Packet import com.github.creeper123123321.viaaas.packet.handshake.Handshake +import com.google.common.cache.CacheBuilder +import com.google.common.cache.CacheLoader +import com.google.common.util.concurrent.RateLimiter import io.netty.channel.ChannelHandlerContext import us.myles.ViaVersion.packets.State +import java.net.InetAddress +import java.net.InetSocketAddress +import java.util.concurrent.TimeUnit class HandshakeState : MinecraftConnectionState { + object RateLimit { + val rateLimitByIp = CacheBuilder.newBuilder() + .expireAfterAccess(1, TimeUnit.MINUTES) + .build(CacheLoader.from { + RateLimiter.create(VIAaaSConfig.rateLimitConnectionMc) + }) + } override val state: State get() = State.HANDSHAKE @@ -23,6 +36,10 @@ class HandshakeState : MinecraftConnectionState { else -> throw IllegalStateException("Invalid next state") } + if (!RateLimit.rateLimitByIp.get((handler.remoteAddress as InetSocketAddress).address).tryAcquire()) { + throw IllegalStateException("Rate-limited") + } + val parsed = VIAaaSAddress().parse(packet.address.substringBefore(0.toChar()), VIAaaSConfig.hostName) val backProto = parsed.protocol val hadHostname = parsed.viaSuffix != null diff --git a/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebClient.kt b/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebClient.kt index 6e157d6..ef3bcbf 100644 --- a/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebClient.kt +++ b/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebClient.kt @@ -1,11 +1,32 @@ package com.github.creeper123123321.viaaas.web +import com.github.creeper123123321.viaaas.config.VIAaaSConfig +import com.google.common.util.concurrent.RateLimiter import io.ktor.websocket.* import java.util.* +import java.util.concurrent.ConcurrentHashMap data class WebClient( val server: WebDashboardServer, val ws: WebSocketServerSession, val state: WebState, +) { val listenedIds: MutableSet = mutableSetOf() -) \ No newline at end of file + val rateLimiter = RateLimiter.create(VIAaaSConfig.rateLimitWs) + + fun listenId(uuid: UUID): Boolean { + if (listenedIds.size >= VIAaaSConfig.listeningWsLimit) return false // This is getting insane + server.listeners.computeIfAbsent(uuid) { Collections.newSetFromMap(ConcurrentHashMap()) } + .add(this) + listenedIds.add(uuid) + return true + } + + fun unlistenId(uuid: UUID) { + server.listeners[uuid]?.remove(this) + if (server.listeners[uuid]?.isEmpty() == true) { + server.listeners.remove(uuid) + } + listenedIds.remove(uuid) + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebDashboardServer.kt b/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebDashboardServer.kt index 71ed6bf..c5b5f97 100644 --- a/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebDashboardServer.kt +++ b/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebDashboardServer.kt @@ -27,6 +27,10 @@ class WebDashboardServer { .expireAfterAccess(10, TimeUnit.DAYS) .build() + fun generateToken(account: UUID): UUID { + return UUID.randomUUID().also { loginTokens.put(it, account) } + } + // Minecraft account -> WebClient val listeners = ConcurrentHashMap>() val usernameIdCache = CacheBuilder.newBuilder() @@ -81,6 +85,7 @@ class WebDashboardServer { suspend fun onMessage(ws: WebSocketSession, msg: String) { val client = clients[ws]!! + client.rateLimiter.acquire() client.state.onMessage(client, msg) } diff --git a/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebLogin.kt b/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebLogin.kt index 8a0d881..d2cea69 100644 --- a/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebLogin.kt +++ b/src/main/kotlin/com/github/creeper123123321/viaaas/web/WebLogin.kt @@ -11,7 +11,6 @@ import io.ktor.http.* import io.ktor.http.cio.websocket.* import java.net.URLEncoder import java.util.* -import java.util.concurrent.ConcurrentHashMap class WebLogin : WebState { override suspend fun start(webClient: WebClient) { @@ -26,10 +25,9 @@ class WebLogin : WebState { "offline_login" -> { // todo add some spam check val username = obj.get("username").asString - val token = UUID.randomUUID() val uuid = generateOfflinePlayerUuid(username) - webClient.server.loginTokens.put(token, uuid) + val token = webClient.server.generateToken(uuid) webClient.ws.send( """{"action": "login_result", "success": true, | "username": "$username", "uuid": "$uuid", "token": "$token"}""".trimMargin() @@ -47,11 +45,10 @@ class WebLogin : WebState { ) if (check.getAsJsonPrimitive("valid").asBoolean) { - val token = UUID.randomUUID() val mcIdUser = check.get("username").asString val uuid = webClient.server.usernameIdCache.get(mcIdUser) - webClient.server.loginTokens.put(token, uuid) + val token = webClient.server.generateToken(uuid) webClient.ws.send( """{"action": "login_result", "success": true, | "username": "$mcIdUser", "uuid": "$uuid", "token": "$token"}""".trimMargin() @@ -66,18 +63,23 @@ class WebLogin : WebState { "listen_login_requests" -> { val token = UUID.fromString(obj.getAsJsonPrimitive("token").asString) val user = webClient.server.loginTokens.getIfPresent(token) - if (user != null) { + if (user != null && webClient.listenId(user)) { webClient.ws.send("""{"action": "listen_login_requests_result", "token": "$token", "success": true, "user": "$user"}""") - webClient.listenedIds.add(user) - webClient.server.listeners.computeIfAbsent(user) { Collections.newSetFromMap(ConcurrentHashMap()) } - .add(webClient) - webLogger.info("${webClient.ws.call.request.local.remoteHost} (O: ${webClient.ws.call.request.origin.remoteHost}) listening for logins for $user") } else { + webClient.server.loginTokens.invalidate(token) webClient.ws.send("""{"action": "listen_login_requests_result", "token": "$token", "success": false}""") webLogger.info("${webClient.ws.call.request.local.remoteHost} (O: ${webClient.ws.call.request.origin.remoteHost}) failed token") } } + "unlisten_login_requests" -> { + val uuid = UUID.fromString(obj.getAsJsonPrimitive("uuid").asString) + webClient.unlistenId(uuid) + } + "invalidate_token" -> { + val token = UUID.fromString(obj.getAsJsonPrimitive("token").asString) + webClient.server.loginTokens.invalidate(token) + } "session_hash_response" -> { val hash = obj.get("session_hash").asString webClient.server.pendingSessionHashes.getIfPresent(hash)?.complete(null) @@ -89,7 +91,7 @@ class WebLogin : WebState { } override suspend fun disconnected(webClient: WebClient) { - webClient.listenedIds.forEach { webClient.server.listeners[it]?.remove(webClient) } + webClient.listenedIds.forEach { webClient.unlistenId(it) } } override suspend fun onException(webClient: WebClient, exception: java.lang.Exception) { diff --git a/src/main/resources/viaaas.yml b/src/main/resources/viaaas.yml index 3eaca87..db62336 100644 --- a/src/main/resources/viaaas.yml +++ b/src/main/resources/viaaas.yml @@ -31,4 +31,10 @@ force-online-mode: false # Shows player and server version in player list show-version-ping: true # Shows info in server brand (F3) -show-brand-info: true \ No newline at end of file +show-brand-info: true +# Rates limits websocket messages per second. Messages will be waiting for process +rate-limit-ws: 1.5 +# Rate limits new front-end connections per second per ip. Will disconnect when hit +rate-limit-connection-mc: 10.0 +# Limits how many usernames a websocket connection can listen to. +listening-ws-limit: 10 \ No newline at end of file