Change version negotiator to use a status query.

This commit is contained in:
joo 2016-11-22 12:13:09 +00:00
parent d72f05c8b0
commit bf17f99083
3 changed files with 187 additions and 135 deletions

View File

@ -7,3 +7,7 @@ class YggdrasilError(Exception):
"""
Base `Exception` for the Yggdrasil authentication service.
"""
class VersionMismatch(Exception):
pass

View File

@ -10,16 +10,19 @@ import timeit
import select
import sys
import json
import re
from future.utils import raise_
from ..compat import unicode
from .types import VarInt
from . import packets
from . import encryption
from .. import SUPPORTED_PROTOCOL_VERSIONS
from .. import SUPPORTED_MINECRAFT_VERSIONS
from ..exceptions import VersionMismatch
STATE_STATUS = 1
STATE_PLAYING = 2
class ConnectionContext(object):
@ -68,8 +71,10 @@ class Connection(object):
server is assumed to be running in offline mode.
:param username: Username string; only applicable in offline mode.
:param initial_version: A Minecraft version string or protocol version
number to use as a first guess when connecting
to the server.
number to use if the server's protocol version
cannot be determined. (Although it is now
somewhat inaccurate, this name is retained for
backward compatibility.)
:param allowed_versions: A set of versions, each being a Minecraft
version string or protocol version number,
restricting the versions that the client may
@ -105,15 +110,16 @@ class Connection(object):
if allowed_versions is None:
self.allowed_proto_versions = set(SUPPORTED_PROTOCOL_VERSIONS)
else:
allowed_version = set(map(proto_version, allowed_versions))
self.allowed_proto_versions = allowed_version
allowed_versions = set(map(proto_version, allowed_versions))
self.allowed_proto_versions = allowed_versions
if initial_version is None:
initial_proto_version = max(self.allowed_proto_versions)
self.default_proto_version = max(self.allowed_proto_versions)
else:
initial_proto_version = proto_version(initial_version)
self.default_proto_version = proto_version(initial_version)
self.context = ConnectionContext(
protocol_version=initial_proto_version)
protocol_version=max(self.allowed_proto_versions))
self.options = _ConnectionOptions()
self.options.address = address
@ -130,12 +136,13 @@ class Connection(object):
def _start_network_thread(self):
"""May safely be called multiple times."""
if self.networking_thread is not None:
if not self.networking_thread.interrupt:
return
self.networking_thread.join()
self.networking_thread = NetworkingThread(self)
self.networking_thread.start()
if self.networking_thread is None:
self.networking_thread = NetworkingThread(self)
self.networking_thread.start()
elif self.networking_thread.interrupt:
# This thread will wait until the previous thread exits, and then
# set `networking_thread' to itself.
NetworkingThread(self, previous=self.networking_thread).start()
def write_packet(self, packet, force=False):
"""Writes a packet to the server.
@ -202,7 +209,7 @@ class Connection(object):
False, to prevent measurement of the latency.
"""
self._connect()
self._handshake(1)
self._handshake(next_state=STATE_STATUS)
self._start_network_thread()
self.reactor = StatusReactor(self, do_ping=handle_ping is not False)
@ -225,22 +232,35 @@ class Connection(object):
Attempt to begin connecting to the server.
May safely be called multiple times after the first, i.e. to reconnect.
"""
# Hold the lock throughout, in case connect() is called from the
# networking thread while another connection is in progress.
with self._write_lock:
# We hold the lock throughout, as connect() may be called by both
# the network thread and a parent thread simultaneously, during
# automatic version negotiation.
# It is important that this is set correctly even when connecting
# in status mode, as some servers, e.g. SpigotMC with the
# ProtocolSupport plugin, use it to determine the correct response.
self.context.protocol_version = max(self.allowed_proto_versions)
self.spawned = False
self._connect()
self._handshake()
login_start_packet = packets.LoginStartPacket()
if self.auth_token:
login_start_packet.name = self.auth_token.profile.name
if len(self.allowed_proto_versions) == 1:
# There is exactly one allowed protocol version, so skip the
# process of determining the server's version, and immediately
# connect.
self._handshake(next_state=STATE_PLAYING)
login_start_packet = packets.LoginStartPacket()
if self.auth_token:
login_start_packet.name = self.auth_token.profile.name
else:
login_start_packet.name = self.username
self.write_packet(login_start_packet)
self.reactor = LoginReactor(self)
else:
login_start_packet.name = self.username
self.write_packet(login_start_packet)
self.reactor = LoginReactor(self)
# Determine the server's protocol version by first performing a
# status query.
self._handshake(next_state=STATE_STATUS)
self.write_packet(packets.RequestPacket())
self.reactor = PlayingStatusReactor(self)
self._start_network_thread()
def _connect(self):
@ -257,13 +277,10 @@ class Connection(object):
def disconnect(self):
""" Terminate the existing server connection, if there is one. """
if self.networking_thread is None:
self._disconnect()
else:
# The networking thread will call _disconnect() later.
self.networking_thread.interrupt = True
if self.networking_thread is not None:
with self._write_lock:
self.networking_thread.interrupt = True
def _disconnect(self):
if self.socket is not None:
if hasattr(self.socket, 'actual_socket'):
# pylint: disable=no-member
@ -273,11 +290,13 @@ class Connection(object):
try:
actual_socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
finally:
actual_socket.close()
self.socket = None
def _handshake(self, next_state=2):
def _handshake(self, next_state=STATE_PLAYING):
handshake = packets.HandShakePacket()
handshake.protocol_version = self.context.protocol_version
handshake.server_address = self.options.address
@ -291,21 +310,30 @@ class Connection(object):
exc.exc_info = exc_info # For backward compatibility.
except (TypeError, AttributeError):
pass
if self.reactor.handle_exception(exc, exc_info):
return
self.exception, self.exc_info = exc, exc_info
if self.handle_exception is None:
raise_(*exc_info)
elif self.handle_exception is not False:
self.handle_exception(exc, exc_info)
def _react(self, packet):
self.reactor.react(packet)
class NetworkingThread(threading.Thread):
def __init__(self, connection):
def __init__(self, connection, previous=None):
threading.Thread.__init__(self)
self.interrupt = False
self.connection = connection
self.name = "Networking Thread"
self.daemon = True
self.previous_thread = previous
def run(self):
try:
self._run()
@ -315,11 +343,13 @@ class NetworkingThread(threading.Thread):
self.connection.networking_thread = None
def _run(self):
while True:
if self.interrupt:
self.connection._disconnect()
break
if self.previous_thread is not None:
if self.previous_thread.is_alive():
self.previous_thread.join()
self.previous_thread = None
self.connection.networking_thread = self
while not self.interrupt:
# Attempt to write out as many as 300 packets as possible every
# 0.05 seconds (20 ticks per second)
num_packets = 0
@ -336,36 +366,27 @@ class NetworkingThread(threading.Thread):
# Read and react to as many as 50 packets
num_packets = 0
packet = self.connection.reactor.read_packet(
self.connection.file_object)
while packet:
while num_packets < 50 and not self.interrupt:
packet = self.connection.reactor.read_packet(
self.connection.file_object)
if not packet:
break
num_packets += 1
# Do not raise an IOError if it occurred while a disconnect
# packet was received, as this may be part of an orderly
# disconnection.
if packet.packet_name == 'disconnect' and \
exc_info is not None and \
isinstance(exc_info[1], IOError):
exc_info is not None and isinstance(exc_info[1], IOError):
exc_info = None
try:
self.connection.reactor.react(packet)
self.connection._react(packet)
for listener in self.connection.packet_listeners:
listener.call_packet(packet)
except IgnorePacket:
pass
if num_packets >= 50:
break
if self.interrupt:
self.connection._disconnect()
break
packet = self.connection.reactor.read_packet(
self.connection.file_object)
if exc_info is not None:
raise_(*exc_info)
@ -440,6 +461,13 @@ class PacketReactor(object):
def react(self, packet):
raise NotImplementedError("Call to base reactor")
""" Called when an exception is raised in the networking thread. If this
method returns True, the default action will be prevented and the
exception ignored (but the networking thread will still terminate).
"""
def handle_exception(self, exc, exc_info):
return False
class LoginReactor(PacketReactor):
get_clientbound_packets = staticmethod(packets.state_login_clientbound)
@ -477,42 +505,7 @@ class LoginReactor(PacketReactor):
self.connection.file_object, decryptor)
if packet.packet_name == "disconnect":
# Test for a disconnect packet indicating a version mismatch.
# (Note: it is known how the disconnect messages are formatted for
# official servers within SUPPORTED_MINECRAFT_VERSIONS, but in case
# new versions are added, this section may need to be updated.)
try:
data = json.loads(packet.json_data)
except ValueError:
pass
if isinstance(data, dict) and 'text' in data:
data = data['text']
if not isinstance(data, (str, unicode)):
return
match = re.match(
r"(Outdated client! Please use"
r"|Outdated server! I'm still on) (?P<version>.*)", data)
if not match:
return
self.connection.allowed_proto_versions.remove(
self.connection.context.protocol_version)
version = match.group('version')
if version in SUPPORTED_MINECRAFT_VERSIONS:
new_version = SUPPORTED_MINECRAFT_VERSIONS[version]
elif data.startswith('Outdated client!'):
new_version = max(SUPPORTED_PROTOCOL_VERSIONS)
elif data.startswith('Outdated server!'):
new_version = min(SUPPORTED_PROTOCOL_VERSIONS)
if new_version in self.connection.allowed_proto_versions:
# Ignore this disconnect packet and reconnect with the new
# protocol version, making it appear (on the client side) as if
# the client had initially connected with the (hopefully)
# correct version.
self.connection.context.protocol_version = new_version
self.connection.connect()
raise IgnorePacket
self.connection.disconnect()
if packet.packet_name == "login success":
self.connection.reactor = PlayingReactor(self.connection)
@ -536,19 +529,19 @@ class PlayingReactor(PacketReactor):
self.connection.write_packet(keep_alive_packet)
if packet.packet_name == "player position and look":
teleport_confirm = packets.TeleportConfirmPacket()
teleport_confirm.teleport_id = packet.teleport_id
self.connection.write_packet(teleport_confirm)
'''
position_response = packets.PositionAndLookPacket()
position_response.x = packet.x
position_response.feet_y = packet.y
position_response.z = packet.z
position_response.yaw = packet.yaw
position_response.pitch = packet.pitch
position_response.on_ground = True
self.connection.write_packet(position_response)
'''
if self.connection.context.protocol_version >= 107:
teleport_confirm = packets.TeleportConfirmPacket()
teleport_confirm.teleport_id = packet.teleport_id
self.connection.write_packet(teleport_confirm)
else:
position_response = packets.PositionAndLookPacket()
position_response.x = packet.x
position_response.feet_y = packet.y
position_response.z = packet.z
position_response.yaw = packet.yaw
position_response.pitch = packet.pitch
position_response.on_ground = True
self.connection.write_packet(position_response)
self.connection.spawned = True
if packet.packet_name == "disconnect":
@ -564,6 +557,7 @@ class StatusReactor(PacketReactor):
def react(self, packet):
if packet.packet_name == "response":
status_dict = json.loads(packet.json_response)
if self.do_ping:
ping_packet = packets.PingPacket()
# NOTE: it may be better to depend on the `monotonic' package
@ -572,7 +566,7 @@ class StatusReactor(PacketReactor):
self.connection.write_packet(ping_packet)
else:
self.connection.disconnect()
self.handle_status(json.loads(packet.json_response))
self.handle_status(status_dict)
elif packet.packet_name == "ping" and self.do_ping:
now = int(1000 * timeit.default_timer())
@ -584,3 +578,42 @@ class StatusReactor(PacketReactor):
def handle_ping(self, latency_ms):
print('Ping: %d ms' % latency_ms)
class PlayingStatusReactor(StatusReactor):
def __init__(self, connection):
super(PlayingStatusReactor, self).__init__(connection, do_ping=False)
def handle_status(self, status):
if status == {}:
# This can occur when we connect to a Mojang server while it is
# still initialising, so it must not cause the client to connect
# with the default version.
raise IOError('Invalid server status.')
elif 'version' not in status or 'protocol' not in status['version']:
return self.handle_failure()
proto = status['version']['protocol']
if proto not in self.connection.allowed_proto_versions:
vstr = ('%d (%s)' % (proto, status['version']['name'])) \
if 'name' in status['version'] else str(proto)
sstr = 'supported, but not allowed for this connection' \
if proto in SUPPORTED_PROTOCOL_VERSIONS else 'not supported'
raise VersionMismatch("Server's protocol version of %s is %s."
% (vstr, sstr))
self.handle_proto_version(proto)
def handle_proto_version(self, proto_version):
self.connection.allowed_proto_versions = {proto_version}
self.connection.connect()
def handle_failure(self):
self.handle_proto_version(self.connection.default_proto_version)
def handle_exception(self, exc, exc_info):
if isinstance(exc, EOFError):
# An exception of this type may indicate that the server does not
# properly support status queries, so we treat it as non-fatal.
self.handle_failure()
return True

View File

@ -23,13 +23,33 @@ class _ConnectTest(unittest.TestCase):
server = FakeServer(minecraft_version=server_version)
addr, port = server.listen_socket.getsockname()
cond = threading.Condition()
def handle_client_exception(exc, exc_info):
with cond:
cond.exc_info = exc_info
cond.notify_all()
def client_write(packet, *args, **kwds):
def packet_write(*args, **kwds):
logging.debug('[C-> ] %s' % packet)
return real_packet_write(*args, **kwds)
real_packet_write = packet.write
packet.write = packet_write
return real_client_write(packet, *args, **kwds)
def client_react(packet, *args, **kwds):
logging.debug('[ ->C] %s' % packet)
return real_client_react(packet, *args, **kwds)
client = connection.Connection(
addr, port, username='User', initial_version=client_version,
handle_exception=False)
client.register_packet_listener(
lambda packet: logging.debug('[ ->C] %s' % packet), packets.Packet)
handle_exception=handle_client_exception)
real_client_react = client._react
real_client_write = client.write_packet
client.write_packet = client_write
client._react = client_react
cond = threading.Condition()
try:
with cond:
server_thread = threading.Thread(
@ -39,22 +59,20 @@ class _ConnectTest(unittest.TestCase):
server_thread.daemon = True
server_thread.start()
client_thread = threading.Thread(
name='_ConnectTest client',
target=self._test_connect_client,
args=(client, cond))
client_thread.daemon = True
client_thread.start()
self._test_connect_client(client, cond)
cond.wait()
if cond.exc_info is not None:
cond.exc_info = Ellipsis
cond.wait(THREAD_TIMEOUT_S)
if cond.exc_info is Ellipsis:
self.fail('Timed out.')
elif cond.exc_info is not None:
raise_(*cond.exc_info)
finally:
# Wait for all threads to exit.
for thread in server_thread, client_thread:
if thread.is_alive():
for thread in server_thread, client.networking_thread:
if thread is not None and thread.is_alive():
thread.join(THREAD_TIMEOUT_S)
if thread.is_alive():
if thread is not None and thread.is_alive():
if cond.exc_info is None:
self.fail('Thread "%s" timed out.' % thread.name)
else:
@ -63,11 +81,6 @@ class _ConnectTest(unittest.TestCase):
def _test_connect_client(self, client, cond):
client.connect()
client.networking_thread.join()
if getattr(client, 'exception', None) is not None:
with cond:
cond.exc_info = client.exception.exc_info
cond.notify_all()
def _test_connect_server(self, server, cond):
try:
@ -112,14 +125,9 @@ class PingTest(_ConnectTest):
cond.notify_all()
client.status(handle_status=False, handle_ping=handle_ping)
client.networking_thread.join()
if getattr(client, 'exception', None) is not None:
with cond:
cond.exc_info = client.exception.exc_info
cond.notify_all()
def _test_connect_server(self, server, cond):
try:
server.continue_after_status = False
server.run()
except:
with cond:
@ -132,10 +140,11 @@ class FakeServer(threading.Thread):
'packets_login', 'packets_playing', 'packets_status', \
'packets',
def __init__(self, minecraft_version=None):
def __init__(self, minecraft_version=None, continue_after_status=True):
if minecraft_version is None:
minecraft_version = VERSIONS[-1][0]
self.minecraft_version = minecraft_version
self.continue_after_status = continue_after_status
protocol_version = SUPPORTED_MINECRAFT_VERSIONS[minecraft_version]
self.context = connection.ConnectionContext(
protocol_version=protocol_version)
@ -172,11 +181,16 @@ class FakeServer(threading.Thread):
running = True
while running:
client_socket, addr = self.listen_socket.accept()
logging.debug('[ ++ ] Client %s connected to server.' % (addr,))
client_file = client_socket.makefile('rb', 0)
try:
running = self.run_handshake(client_socket, client_file)
finally:
except:
raise
else:
client_socket.shutdown(socket.SHUT_RDWR)
logging.debug('[ -- ] Client %s disconnected.' % (addr,))
finally:
client_socket.close()
client_file.close()
@ -195,6 +209,7 @@ class FakeServer(threading.Thread):
def run_handshake_status(self, packet, client_socket, client_file):
self.run_status(client_socket, client_file)
return self.continue_after_status
def run_handshake_play(self, packet, client_socket, client_file):
if packet.protocol_version == self.context.protocol_version:
@ -226,7 +241,7 @@ class FakeServer(threading.Thread):
packet = packets.JoinGamePacket(
self.context, entity_id=0, game_mode=0, dimension=0, difficulty=2,
max_players=255, level_type='default', reduced_debug_info=False)
max_players=1, level_type='default', reduced_debug_info=False)
self.write_packet(packet, client_socket)
keep_alive_id = 1076048782
@ -290,7 +305,7 @@ class FakeServer(threading.Thread):
packet.read(buffer)
else:
packet = packets.Packet(self.context, id=packet_id)
logging.debug('[C->S] %s' % packet)
logging.debug('[ ->S] %s' % packet)
return packet
def read_packet_buffer(self, client_file):