Implement full Server List Ping capability with test.

This commit is contained in:
joo 2016-11-20 05:03:23 +00:00
parent 56b06ca80f
commit 115693f8c3
3 changed files with 173 additions and 32 deletions

View File

@ -6,6 +6,7 @@ from zlib import decompress
import threading import threading
import socket import socket
import time import time
import timeit
import select import select
import sys import sys
import json import json
@ -47,7 +48,7 @@ class Connection(object):
def __init__( def __init__(
self, self,
address, address,
port, port=25565,
auth_token=None, auth_token=None,
username=None, username=None,
initial_version=None, initial_version=None,
@ -114,9 +115,12 @@ class Connection(object):
def _start_network_thread(self): def _start_network_thread(self):
"""May safely be called multiple times.""" """May safely be called multiple times."""
if self.networking_thread is None: if self.networking_thread is not None:
self.networking_thread = NetworkingThread(self) if not self.networking_thread.interrupt:
self.networking_thread.start() return
self.networking_thread.join()
self.networking_thread = NetworkingThread(self)
self.networking_thread.start()
def write_packet(self, packet, force=False): def write_packet(self, packet, force=False):
"""Writes a packet to the server. """Writes a packet to the server.
@ -170,11 +174,33 @@ class Connection(object):
packet.write(self.socket) packet.write(self.socket)
return True return True
def status(self): def status(self, handle_status=None, handle_ping=False):
"""Issue a status request to the server and then disconnect.
:param handle_status: a function to be called with the status
dictionary None for the default behaviour of
printing the dictionary to standard output, or
False to ignore the result.
:param handle_ping: a function to be called with the measured latency
in milliseconds, None for the default handler,
which prints the latency to standard outout, or
False, to prevent measurement of the latency.
"""
self._connect() self._connect()
self._handshake(1) self._handshake(1)
self._start_network_thread() self._start_network_thread()
self.reactor = StatusReactor(self)
self.reactor = StatusReactor(self, do_ping=handle_ping is not False)
if handle_status is False:
self.reactor.handle_status = lambda *args, **kwds: None
elif handle_status is not None:
self.reactor.handle_status = handle_status
if handle_ping is False:
self.reactor.handle_ping = lambda *args, **kwds: None
elif handle_ping is not None:
self.reactor.handle_ping = handle_ping
request_packet = packets.RequestPacket() request_packet = packets.RequestPacket()
self.write_packet(request_packet) self.write_packet(request_packet)
@ -190,8 +216,6 @@ class Connection(object):
# automatic version negotiation. # automatic version negotiation.
self.spawned = False self.spawned = False
self._outgoing_packet_queue = deque()
self._connect() self._connect()
self._handshake() self._handshake()
login_start_packet = packets.LoginStartPacket() login_start_packet = packets.LoginStartPacket()
@ -211,10 +235,33 @@ class Connection(object):
# since it's "guaranteed" to read the number of bytes specified, # since it's "guaranteed" to read the number of bytes specified,
# the socket itself will mostly be used to write data upstream to # the socket itself will mostly be used to write data upstream to
# the server. # the server.
self._outgoing_packet_queue = deque()
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.connect((self.options.address, self.options.port)) self.socket.connect((self.options.address, self.options.port))
self.file_object = self.socket.makefile("rb", 0) self.file_object = self.socket.makefile("rb", 0)
def disconnect(self):
""" Terminate the existing server connection, if there is one. """
if self.networking_thread is None:
self._disconnect()
else:
# The networking thread will call _disconnect() later.
self.networking_thread.interrupt = True
def _disconnect(self):
if self.socket is not None:
if hasattr(self.socket, 'actual_socket'):
# pylint: disable=no-member
actual_socket = self.socket.actual_socket
else:
actual_socket = self.socket
try:
actual_socket.shutdown(socket.SHUT_RDWR)
finally:
actual_socket.close()
self.socket = None
def _handshake(self, next_state=2): def _handshake(self, next_state=2):
handshake = packets.HandShakePacket() handshake = packets.HandShakePacket()
handshake.protocol_version = self.context.protocol_version handshake.protocol_version = self.context.protocol_version
@ -236,14 +283,16 @@ class NetworkingThread(threading.Thread):
def run(self): def run(self):
try: try:
self._run() self._run()
except: except Exception as e:
ty, ex, tb = sys.exc_info() e.exc_info = sys.exc_info()
ex.exc_info = ty, ex, tb self.connection.exception = e
self.connection.exception = ex finally:
self.connection.networking_thread = None
def _run(self): def _run(self):
while True: while True:
if self.interrupt: if self.interrupt:
self.connection._disconnect()
break break
# Attempt to write out as many as 300 packets as possible every # Attempt to write out as many as 300 packets as possible every
@ -285,6 +334,10 @@ class NetworkingThread(threading.Thread):
if num_packets >= 50: if num_packets >= 50:
break break
if self.interrupt:
self.connection._disconnect()
break
packet = self.connection.reactor.read_packet( packet = self.connection.reactor.read_packet(
self.connection.file_object) self.connection.file_object)
@ -473,22 +526,36 @@ class PlayingReactor(PacketReactor):
''' '''
self.connection.spawned = True self.connection.spawned = True
'''
if packet.packet_name == "disconnect": if packet.packet_name == "disconnect":
print(packet.json_data) # TODO: handle propagating this back self.connection.disconnect()
'''
class StatusReactor(PacketReactor): class StatusReactor(PacketReactor):
get_clientbound_packets = staticmethod(packets.state_status_clientbound) get_clientbound_packets = staticmethod(packets.state_status_clientbound)
def __init__(self, connection, do_ping=False):
super(StatusReactor, self).__init__(connection)
self.do_ping = do_ping
def react(self, packet): def react(self, packet):
if packet.packet_name == "response": if packet.packet_name == "response":
print(json.loads(packet.json_response)) if self.do_ping:
ping_packet = packets.PingPacket()
# NOTE: it may be better to depend on the `monotonic' package
# or something similar for more accurate time measurement.
ping_packet.time = int(1000 * timeit.default_timer())
self.connection.write_packet(ping_packet)
else:
self.connection.disconnect()
self.handle_status(json.loads(packet.json_response))
ping_packet = packets.PingPacket() elif packet.packet_name == "ping" and self.do_ping:
ping_packet.time = int(time.time()) now = int(1000 * timeit.default_timer())
self.connection.write_packet(ping_packet) self.connection.disconnect()
self.handle_ping(now - packet.time)
self.connection.networking_thread.interrupt = True def handle_status(self, status_dict):
# TODO: More shutdown? idk print(status_dict)
def handle_ping(self, latency_ms):
print('Ping: %d ms' % latency_ms)

View File

@ -86,7 +86,7 @@ class VarInt(Type):
for i in range(5): for i in range(5):
byte = file_object.read(1) byte = file_object.read(1)
if len(byte) < 1: if len(byte) < 1:
raise RuntimeError("Unexpected end of message.") raise EOFError("Unexpected end of message.")
byte = ord(byte) byte = ord(byte)
number |= (byte & 0x7F) << 7 * i number |= (byte & 0x7F) << 7 * i
if not byte & 0x80: if not byte & 0x80:

View File

@ -19,24 +19,27 @@ THREAD_TIMEOUT_S = 5
class _ConnectTest(unittest.TestCase): class _ConnectTest(unittest.TestCase):
def _test_connect(self, client_version, server_version): def _test_connect(self, client_version=None, server_version=None):
server = FakeServer(minecraft_version=server_version) server = FakeServer(minecraft_version=server_version)
addr, port = server.listen_socket.getsockname() addr, port = server.listen_socket.getsockname()
client = connection.Connection( client = connection.Connection(
addr, port, username='User', initial_version=client_version) addr, port, username='User', initial_version=client_version)
client.register_packet_listener(
lambda packet: logging.debug('[ ->C] %s' % packet), packets.Packet)
cond = threading.Condition() cond = threading.Condition()
try: try:
with cond: with cond:
server_thread = threading.Thread( server_thread = threading.Thread(
name='test_connection server', name='_ConnectTest server',
target=self._test_connect_server, target=self._test_connect_server,
args=(server, cond)) args=(server, cond))
server_thread.daemon = True server_thread.daemon = True
server_thread.start() server_thread.start()
client_thread = threading.Thread( client_thread = threading.Thread(
name='test_connection client', name='_ConnectTest client',
target=self._test_connect_client, target=self._test_connect_client,
args=(client, cond)) args=(client, cond))
client_thread.daemon = True client_thread.daemon = True
@ -58,10 +61,6 @@ class _ConnectTest(unittest.TestCase):
break break
def _test_connect_client(self, client, cond): def _test_connect_client(self, client, cond):
def handle_packet(packet):
logging.debug('[ ->C] %s' % packet)
client.register_packet_listener(handle_packet, packets.Packet)
client.connect() client.connect()
client.networking_thread.join() client.networking_thread.join()
if getattr(client, 'exception', None) is not None: if getattr(client, 'exception', None) is not None:
@ -100,13 +99,41 @@ class ConnectNewToNewTest(_ConnectTest):
self._test_connect(VERSIONS[-1][1], VERSIONS[-1][0]) self._test_connect(VERSIONS[-1][1], VERSIONS[-1][0])
class PingTest(_ConnectTest):
def runTest(self):
self._test_connect()
def _test_connect_client(self, client, cond):
def handle_ping(latency_ms):
assert 0 <= latency_ms < 60000
with cond:
cond.exc_info = None
cond.notify_all()
client.status(handle_status=False, handle_ping=handle_ping)
client.networking_thread.join()
if getattr(client, 'exception', None) is not None:
with cond:
cond.exc_info = client.exception.exc_info
cond.notify_all()
def _test_connect_server(self, server, cond):
try:
server.run()
except:
with cond:
cond.exc_info = sys.exc_info()
cond.notify_all()
class FakeServer(threading.Thread): class FakeServer(threading.Thread):
__slots__ = 'context', 'minecraft_version', 'listen_socket', \ __slots__ = 'context', 'minecraft_version', 'listen_socket', \
'packets_login', 'packets_playing', 'packets', 'packets_login', 'packets_playing', 'packets_status', \
'packets',
def __init__(self, minecraft_version=None): def __init__(self, minecraft_version=None):
if minecraft_version is None: if minecraft_version is None:
minecraft_version = SUPPORTED_MINECRAFT_VERSIONS.keys()[-1] minecraft_version = VERSIONS[-1][0]
self.minecraft_version = minecraft_version self.minecraft_version = minecraft_version
protocol_version = SUPPORTED_MINECRAFT_VERSIONS[minecraft_version] protocol_version = SUPPORTED_MINECRAFT_VERSIONS[minecraft_version]
self.context = connection.ConnectionContext( self.context = connection.ConnectionContext(
@ -115,13 +142,19 @@ class FakeServer(threading.Thread):
self.packets_handshake = { self.packets_handshake = {
p.get_id(self.context): p for p in p.get_id(self.context): p for p in
packets.state_handshake_serverbound(self.context)} packets.state_handshake_serverbound(self.context)}
self.packets_login = { self.packets_login = {
p.get_id(self.context): p for p in p.get_id(self.context): p for p in
packets.state_login_serverbound(self.context)} packets.state_login_serverbound(self.context)}
self.packets_playing = { self.packets_playing = {
p.get_id(self.context): p for p in p.get_id(self.context): p for p in
packets.state_playing_serverbound(self.context)} packets.state_playing_serverbound(self.context)}
self.packets_status = {
p.get_id(self.context): p for p in
packets.state_status_serverbound(self.context)}
self.listen_socket = socket.socket() self.listen_socket = socket.socket()
self.listen_socket.bind(('0.0.0.0', 0)) self.listen_socket.bind(('0.0.0.0', 0))
self.listen_socket.listen(0) self.listen_socket.listen(0)
@ -150,9 +183,20 @@ class FakeServer(threading.Thread):
self.packets = self.packets_handshake self.packets = self.packets_handshake
packet = self.read_packet_filtered(client_file) packet = self.read_packet_filtered(client_file)
assert isinstance(packet, packets.HandShakePacket) assert isinstance(packet, packets.HandShakePacket)
if packet.next_state == 1:
return self.run_handshake_status(
packet, client_socket, client_file)
elif packet.next_state == 2:
return self.run_handshake_play(
packet, client_socket, client_file)
else:
raise AssertionError('Unknown state: %s' % packet.next_state)
def run_handshake_status(self, packet, client_socket, client_file):
self.run_status(client_socket, client_file)
def run_handshake_play(self, packet, client_socket, client_file):
if packet.protocol_version == self.context.protocol_version: if packet.protocol_version == self.context.protocol_version:
assert packet.next_state == 2
self.run_login(client_socket, client_file) self.run_login(client_socket, client_file)
else: else:
if packet.protocol_version < self.context.protocol_version: if packet.protocol_version < self.context.protocol_version:
@ -198,6 +242,36 @@ class FakeServer(threading.Thread):
self.write_packet(packet, client_socket) self.write_packet(packet, client_socket)
return False return False
def run_status(self, client_socket, client_file):
self.packets = self.packets_status
packet = self.read_packet(client_file)
assert isinstance(packet, packets.RequestPacket)
packet = packets.ResponsePacket(self.context)
packet.json_response = json.dumps({
'version': {
'name': self.minecraft_version,
'protocol': self.context.protocol_version},
'players': {
'max': 1,
'online': 0,
'sample': []},
'description': {
'text': 'FakeServer'}})
self.write_packet(packet, client_socket)
try:
packet = self.read_packet(client_file)
except EOFError:
return False
assert isinstance(packet, packets.PingPacket)
res_packet = packets.PingPacketResponse(self.context)
res_packet.time = packet.time
self.write_packet(res_packet, client_socket)
return False
def read_packet_filtered(self, client_file): def read_packet_filtered(self, client_file):
while True: while True:
packet = self.read_packet(client_file) packet = self.read_packet(client_file)