Add compression tests to test_connection.
This commit is contained in:
parent
cab8d56746
commit
cf464d2da2
|
@ -13,14 +13,18 @@ import logging
|
|||
import socket
|
||||
import json
|
||||
import sys
|
||||
import zlib
|
||||
|
||||
VERSIONS = sorted(SUPPORTED_MINECRAFT_VERSIONS.items(), key=lambda i: i[1])
|
||||
THREAD_TIMEOUT_S = 5
|
||||
|
||||
|
||||
class _ConnectTest(unittest.TestCase):
|
||||
compression_threshold = None
|
||||
|
||||
def _test_connect(self, client_version=None, server_version=None):
|
||||
server = FakeServer(minecraft_version=server_version)
|
||||
server = FakeServer(minecraft_version=server_version,
|
||||
compression_threshold=self.compression_threshold)
|
||||
addr = "localhost"
|
||||
port = server.listen_socket.getsockname()[1]
|
||||
|
||||
|
@ -114,6 +118,14 @@ class ConnectNewToNewTest(_ConnectTest):
|
|||
self._test_connect(VERSIONS[-1][1], VERSIONS[-1][0])
|
||||
|
||||
|
||||
class ConnectCompressionLowTest(ConnectNewToNewTest):
|
||||
compression_threshold = 0
|
||||
|
||||
|
||||
class ConnectCompressionHighTest(ConnectNewToNewTest):
|
||||
compression_threshold = 256
|
||||
|
||||
|
||||
class PingTest(_ConnectTest):
|
||||
def runTest(self):
|
||||
self._test_connect()
|
||||
|
@ -138,17 +150,23 @@ class PingTest(_ConnectTest):
|
|||
|
||||
class FakeServer(threading.Thread):
|
||||
__slots__ = 'context', 'minecraft_version', 'listen_socket', \
|
||||
'compression_threshold', 'compression_enabled', \
|
||||
'packets_login', 'packets_playing', 'packets_status', \
|
||||
'packets',
|
||||
'packets'
|
||||
|
||||
def __init__(self, minecraft_version=None, continue_after_status=True):
|
||||
def __init__(self, minecraft_version=None, continue_after_status=True,
|
||||
compression_threshold=None):
|
||||
if minecraft_version is None:
|
||||
minecraft_version = VERSIONS[-1][0]
|
||||
|
||||
self.minecraft_version = minecraft_version
|
||||
self.continue_after_status = continue_after_status
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
protocol_version = SUPPORTED_MINECRAFT_VERSIONS[minecraft_version]
|
||||
self.context = connection.ConnectionContext(
|
||||
protocol_version=protocol_version)
|
||||
self.compression_enabled = False
|
||||
|
||||
self.packets_handshake = {
|
||||
p.get_id(self.context): p for p in
|
||||
|
@ -232,9 +250,16 @@ class FakeServer(threading.Thread):
|
|||
packet = self.read_packet_filtered(client_file)
|
||||
assert isinstance(packet, packets.LoginStartPacket)
|
||||
|
||||
if self.compression_threshold is not None:
|
||||
self.write_packet(packets.SetCompressionPacket(
|
||||
self.context, threshold=self.compression_threshold),
|
||||
client_socket)
|
||||
self.compression_enabled = True
|
||||
|
||||
packet = packets.LoginSuccessPacket(
|
||||
self.context, UUID='{fake uuid}', Username=packet.name)
|
||||
self.write_packet(packet, client_socket)
|
||||
|
||||
self.run_playing(client_socket, client_file)
|
||||
|
||||
def run_playing(self, client_socket, client_file):
|
||||
|
@ -315,8 +340,19 @@ class FakeServer(threading.Thread):
|
|||
while len(buffer.get_writable()) < length:
|
||||
buffer.send(client_file.read(length - len(buffer.get_writable())))
|
||||
buffer.reset_cursor()
|
||||
if self.compression_enabled:
|
||||
data_length = types.VarInt.read(buffer)
|
||||
if data_length > 0:
|
||||
data = zlib.decompress(buffer.read())
|
||||
assert len(data) == data_length, \
|
||||
'%s != %s' % (len(data), data_length)
|
||||
buffer.reset()
|
||||
buffer.send(data)
|
||||
buffer.reset_cursor()
|
||||
return buffer
|
||||
|
||||
def write_packet(self, packet, client_socket):
|
||||
packet.write(client_socket)
|
||||
kwds = {'compression_threshold': self.compression_threshold} \
|
||||
if self.compression_enabled else {}
|
||||
logging.debug('[S-> ] %s' % packet)
|
||||
packet.write(client_socket, **kwds)
|
||||
|
|
Loading…
Reference in New Issue