Fix issue #109 and add regression test.

This commit is contained in:
joo 2018-10-12 16:47:55 +01:00
parent eb302094aa
commit 48e1003f42
5 changed files with 70 additions and 12 deletions

View File

@ -352,6 +352,8 @@ class Connection(object):
self.socket = socket.socket(ai_faml, ai_type, ai_prot) self.socket = socket.socket(ai_faml, ai_type, ai_prot)
self.socket.connect(ai_addr) self.socket.connect(ai_addr)
self.file_object = self.socket.makefile("rb", 0) self.file_object = self.socket.makefile("rb", 0)
self.options.compression_enabled = False
self.options.compression_threshold = -1
self.connected = True self.connected = True
def disconnect(self, immediate=False): def disconnect(self, immediate=False):
@ -500,7 +502,7 @@ class NetworkingThread(threading.Thread):
# Ignore the earlier exception if a disconnect packet is # Ignore the earlier exception if a disconnect packet is
# received, as it may have been caused by trying to write to # received, as it may have been caused by trying to write to
# thw closed socket, which does not represent a program error. # the closed socket, which does not represent a program error.
if exc_info is not None and packet.packet_name == "disconnect": if exc_info is not None and packet.packet_name == "disconnect":
exc_info = None exc_info = None

View File

@ -119,6 +119,8 @@ class Packet(object):
@property @property
def fields(self): def fields(self):
""" An iterable of the names of the packet's fields, or None. """ """ An iterable of the names of the packet's fields, or None. """
if self.definition is None:
return None
return (field for defn in self.definition for field in defn) return (field for defn in self.definition for field in defn)
def field_string(self, field): def field_string(self, field):

View File

@ -326,14 +326,14 @@ class FakeServer(object):
""" """
__slots__ = 'listen_socket', 'compression_threshold', 'context', \ __slots__ = 'listen_socket', 'compression_threshold', 'context', \
'minecraft_version', 'client_handler_type', \ 'minecraft_version', 'client_handler_type', 'server_type', \
'packets_handshake', 'packets_login', 'packets_playing', \ 'packets_handshake', 'packets_login', 'packets_playing', \
'packets_status', 'lock', 'stopping', 'private_key', \ 'packets_status', 'lock', 'stopping', 'private_key', \
'public_key_bytes', 'public_key_bytes', 'test_case',
def __init__(self, minecraft_version=None, compression_threshold=None, def __init__(self, minecraft_version=None, compression_threshold=None,
client_handler_type=FakeClientHandler, private_key=None, client_handler_type=FakeClientHandler, private_key=None,
public_key_bytes=None): public_key_bytes=None, test_case=None):
if minecraft_version is None: if minecraft_version is None:
minecraft_version = VERSIONS[-1][0] minecraft_version = VERSIONS[-1][0]
@ -352,6 +352,7 @@ class FakeServer(object):
self.client_handler_type = client_handler_type self.client_handler_type = client_handler_type
self.private_key = private_key self.private_key = private_key
self.public_key_bytes = public_key_bytes self.public_key_bytes = public_key_bytes
self.test_case = test_case
self.packets_handshake = { self.packets_handshake = {
p.get_id(self.context): p for p in p.get_id(self.context): p for p in
@ -427,6 +428,9 @@ class _FakeServerTest(unittest.TestCase):
# The set of Minecraft version names or protocol version numbers that the # The set of Minecraft version names or protocol version numbers that the
# client will support. If None, the client supports all possible versions. # client will support. If None, the client supports all possible versions.
server_type = FakeServer
# A subclass of FakeServer to be used in tests.
client_handler_type = FakeClientHandler client_handler_type = FakeClientHandler
# A subclass of FakeClientHandler to be used in tests. # A subclass of FakeClientHandler to be used in tests.
@ -464,13 +468,16 @@ class _FakeServerTest(unittest.TestCase):
client.connect() client.connect()
def _test_connect(self, client_versions=None, server_version=None, def _test_connect(self, client_versions=None, server_version=None,
client_handler_type=None, connection_type=None, server_type=None, client_handler_type=None,
compression_threshold=None, private_key=None, connection_type=None, compression_threshold=None,
public_key_bytes=None, ignore_extra_exceptions=None): private_key=None, public_key_bytes=None,
ignore_extra_exceptions=None):
if client_versions is None: if client_versions is None:
client_versions = self.client_versions client_versions = self.client_versions
if server_version is None: if server_version is None:
server_version = self.server_version server_version = self.server_version
if server_type is None:
server_type = self.server_type
if client_handler_type is None: if client_handler_type is None:
client_handler_type = self.client_handler_type client_handler_type = self.client_handler_type
if connection_type is None: if connection_type is None:
@ -484,11 +491,12 @@ class _FakeServerTest(unittest.TestCase):
if ignore_extra_exceptions is None: if ignore_extra_exceptions is None:
ignore_extra_exceptions = self.ignore_extra_exceptions ignore_extra_exceptions = self.ignore_extra_exceptions
server = FakeServer(minecraft_version=server_version, server = server_type(minecraft_version=server_version,
compression_threshold=compression_threshold, compression_threshold=compression_threshold,
client_handler_type=client_handler_type, client_handler_type=client_handler_type,
private_key=private_key, private_key=private_key,
public_key_bytes=public_key_bytes) public_key_bytes=public_key_bytes,
test_case=self)
addr = "localhost" addr = "localhost"
port = server.listen_socket.getsockname()[1] port = server.listen_socket.getsockname()[1]

View File

@ -32,6 +32,46 @@ class ConnectTest(fake_server._FakeServerTest):
raise fake_server.FakeServerDisconnect raise fake_server.FakeServerDisconnect
class ReconnectTest(ConnectTest):
phase = 0
def _start_client(self, client):
def handle_login_disconnect(packet):
if 'Please reconnect' in packet.json_data:
# Override the default behaviour of raising a fatal exception.
client.disconnect()
client.connect()
raise IgnorePacket
client.register_packet_listener(
handle_login_disconnect, clientbound.login.DisconnectPacket,
early=True)
def handle_play_disconnect(packet):
if 'Please reconnect' in packet.json_data:
client.connect()
elif 'Test successful' in packet.json_data:
raise fake_server.FakeServerTestSuccess
client.register_packet_listener(
handle_play_disconnect, clientbound.play.DisconnectPacket)
client.connect()
class client_handler_type(fake_server.FakeClientHandler):
def handle_login(self, packet):
if self.server.test_case.phase == 0:
self.server.test_case.phase = 1
raise fake_server.FakeServerDisconnect('Please reconnect (0).')
super(ReconnectTest.client_handler_type, self).handle_login(packet)
def handle_play_start(self):
if self.server.test_case.phase == 1:
self.server.test_case.phase = 2
raise fake_server.FakeServerDisconnect('Please reconnect (1).')
else:
assert self.server.test_case.phase == 2
raise fake_server.FakeServerDisconnect('Test successful (2).')
class PingTest(ConnectTest): class PingTest(ConnectTest):
def _start_client(self, client): def _start_client(self, client):
def handle_ping(latency_ms): def handle_ping(latency_ms):

View File

@ -131,6 +131,12 @@ class EncryptedCompressedConnection(EncryptedConnection,
pass pass
# Regression test for <https://github.com/ammaraskar/pyCraft/issues/109>.
class EncryptedCompressedReconnect(test_connection.ReconnectTest,
EncryptedCompressedConnection):
pass
class MockSocket(object): class MockSocket(object):
def __init__(self, encryptor, decryptor): def __init__(self, encryptor, decryptor):