implement some ratelimits

This commit is contained in:
creeper123123321 2021-02-19 20:56:44 -03:00
parent b011a26d63
commit f74b6bb7f8
6 changed files with 67 additions and 13 deletions

View File

@ -37,4 +37,7 @@ object VIAaaSConfig : Config(File("config/viaaas.yml")) {
val forceOnlineMode: Boolean get() = this.getBoolean("force-online-mode", false)
val showVersionPing: Boolean get() = this.getBoolean("show-version-ping", true)
val showBrandInfo: Boolean get() = this.getBoolean("show-brand-info", true)
val rateLimitWs: Double get() = this.getDouble("rate-limit-ws", 1.0)
val rateLimitConnectionMc: Double get() = this.getDouble("rate-limit-connection-mc", 10.0)
val listeningWsLimit: Int get() = this.getInt("listening-ws-limit", 16)
}

View File

@ -6,10 +6,23 @@ import com.github.creeper123123321.viaaas.handler.MinecraftHandler
import com.github.creeper123123321.viaaas.mcLogger
import com.github.creeper123123321.viaaas.packet.Packet
import com.github.creeper123123321.viaaas.packet.handshake.Handshake
import com.google.common.cache.CacheBuilder
import com.google.common.cache.CacheLoader
import com.google.common.util.concurrent.RateLimiter
import io.netty.channel.ChannelHandlerContext
import us.myles.ViaVersion.packets.State
import java.net.InetAddress
import java.net.InetSocketAddress
import java.util.concurrent.TimeUnit
class HandshakeState : MinecraftConnectionState {
object RateLimit {
val rateLimitByIp = CacheBuilder.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build(CacheLoader.from<InetAddress, RateLimiter> {
RateLimiter.create(VIAaaSConfig.rateLimitConnectionMc)
})
}
override val state: State
get() = State.HANDSHAKE
@ -23,6 +36,10 @@ class HandshakeState : MinecraftConnectionState {
else -> throw IllegalStateException("Invalid next state")
}
if (!RateLimit.rateLimitByIp.get((handler.remoteAddress as InetSocketAddress).address).tryAcquire()) {
throw IllegalStateException("Rate-limited")
}
val parsed = VIAaaSAddress().parse(packet.address.substringBefore(0.toChar()), VIAaaSConfig.hostName)
val backProto = parsed.protocol
val hadHostname = parsed.viaSuffix != null

View File

@ -1,11 +1,32 @@
package com.github.creeper123123321.viaaas.web
import com.github.creeper123123321.viaaas.config.VIAaaSConfig
import com.google.common.util.concurrent.RateLimiter
import io.ktor.websocket.*
import java.util.*
import java.util.concurrent.ConcurrentHashMap
data class WebClient(
val server: WebDashboardServer,
val ws: WebSocketServerSession,
val state: WebState,
) {
val listenedIds: MutableSet<UUID> = mutableSetOf()
)
val rateLimiter = RateLimiter.create(VIAaaSConfig.rateLimitWs)
fun listenId(uuid: UUID): Boolean {
if (listenedIds.size >= VIAaaSConfig.listeningWsLimit) return false // This is getting insane
server.listeners.computeIfAbsent(uuid) { Collections.newSetFromMap(ConcurrentHashMap()) }
.add(this)
listenedIds.add(uuid)
return true
}
fun unlistenId(uuid: UUID) {
server.listeners[uuid]?.remove(this)
if (server.listeners[uuid]?.isEmpty() == true) {
server.listeners.remove(uuid)
}
listenedIds.remove(uuid)
}
}

View File

@ -27,6 +27,10 @@ class WebDashboardServer {
.expireAfterAccess(10, TimeUnit.DAYS)
.build<UUID, UUID>()
fun generateToken(account: UUID): UUID {
return UUID.randomUUID().also { loginTokens.put(it, account) }
}
// Minecraft account -> WebClient
val listeners = ConcurrentHashMap<UUID, MutableSet<WebClient>>()
val usernameIdCache = CacheBuilder.newBuilder()
@ -81,6 +85,7 @@ class WebDashboardServer {
suspend fun onMessage(ws: WebSocketSession, msg: String) {
val client = clients[ws]!!
client.rateLimiter.acquire()
client.state.onMessage(client, msg)
}

View File

@ -11,7 +11,6 @@ import io.ktor.http.*
import io.ktor.http.cio.websocket.*
import java.net.URLEncoder
import java.util.*
import java.util.concurrent.ConcurrentHashMap
class WebLogin : WebState {
override suspend fun start(webClient: WebClient) {
@ -26,10 +25,9 @@ class WebLogin : WebState {
"offline_login" -> {
// todo add some spam check
val username = obj.get("username").asString
val token = UUID.randomUUID()
val uuid = generateOfflinePlayerUuid(username)
webClient.server.loginTokens.put(token, uuid)
val token = webClient.server.generateToken(uuid)
webClient.ws.send(
"""{"action": "login_result", "success": true,
| "username": "$username", "uuid": "$uuid", "token": "$token"}""".trimMargin()
@ -47,11 +45,10 @@ class WebLogin : WebState {
)
if (check.getAsJsonPrimitive("valid").asBoolean) {
val token = UUID.randomUUID()
val mcIdUser = check.get("username").asString
val uuid = webClient.server.usernameIdCache.get(mcIdUser)
webClient.server.loginTokens.put(token, uuid)
val token = webClient.server.generateToken(uuid)
webClient.ws.send(
"""{"action": "login_result", "success": true,
| "username": "$mcIdUser", "uuid": "$uuid", "token": "$token"}""".trimMargin()
@ -66,18 +63,23 @@ class WebLogin : WebState {
"listen_login_requests" -> {
val token = UUID.fromString(obj.getAsJsonPrimitive("token").asString)
val user = webClient.server.loginTokens.getIfPresent(token)
if (user != null) {
if (user != null && webClient.listenId(user)) {
webClient.ws.send("""{"action": "listen_login_requests_result", "token": "$token", "success": true, "user": "$user"}""")
webClient.listenedIds.add(user)
webClient.server.listeners.computeIfAbsent(user) { Collections.newSetFromMap(ConcurrentHashMap()) }
.add(webClient)
webLogger.info("${webClient.ws.call.request.local.remoteHost} (O: ${webClient.ws.call.request.origin.remoteHost}) listening for logins for $user")
} else {
webClient.server.loginTokens.invalidate(token)
webClient.ws.send("""{"action": "listen_login_requests_result", "token": "$token", "success": false}""")
webLogger.info("${webClient.ws.call.request.local.remoteHost} (O: ${webClient.ws.call.request.origin.remoteHost}) failed token")
}
}
"unlisten_login_requests" -> {
val uuid = UUID.fromString(obj.getAsJsonPrimitive("uuid").asString)
webClient.unlistenId(uuid)
}
"invalidate_token" -> {
val token = UUID.fromString(obj.getAsJsonPrimitive("token").asString)
webClient.server.loginTokens.invalidate(token)
}
"session_hash_response" -> {
val hash = obj.get("session_hash").asString
webClient.server.pendingSessionHashes.getIfPresent(hash)?.complete(null)
@ -89,7 +91,7 @@ class WebLogin : WebState {
}
override suspend fun disconnected(webClient: WebClient) {
webClient.listenedIds.forEach { webClient.server.listeners[it]?.remove(webClient) }
webClient.listenedIds.forEach { webClient.unlistenId(it) }
}
override suspend fun onException(webClient: WebClient, exception: java.lang.Exception) {

View File

@ -31,4 +31,10 @@ force-online-mode: false
# Shows player and server version in player list
show-version-ping: true
# Shows info in server brand (F3)
show-brand-info: true
show-brand-info: true
# Rates limits websocket messages per second. Messages will be waiting for process
rate-limit-ws: 1.5
# Rate limits new front-end connections per second per ip. Will disconnect when hit
rate-limit-connection-mc: 10.0
# Limits how many usernames a websocket connection can listen to.
listening-ws-limit: 10