Add full connection tests with encryption enabled.

This commit is contained in:
joo 2018-05-27 02:35:13 +01:00
parent ebee077303
commit ab9ca6dfee
3 changed files with 112 additions and 28 deletions

View File

@ -73,6 +73,9 @@ class EncryptedFileObjectWrapper(object):
def fileno(self):
return self.actual_file_object.fileno()
def close(self):
self.actual_file_object.close()
class EncryptedSocketWrapper(object):
def __init__(self, socket, encryptor, decryptor):
@ -88,3 +91,9 @@ class EncryptedSocketWrapper(object):
def fileno(self):
return self.actual_socket.fileno()
def close(self):
return self.actual_socket.close()
def shutdown(self, *args, **kwds):
return self.actual_socket.shutdown(*args, **kwds)

View File

@ -6,7 +6,11 @@ from minecraft.networking import types
from minecraft.networking import packets
from minecraft.networking.packets import clientbound
from minecraft.networking.packets import serverbound
from minecraft.networking.encryption import (
create_AES_cipher, EncryptedFileObjectWrapper, EncryptedSocketWrapper
)
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from future.utils import raise_
import unittest
@ -155,6 +159,9 @@ class FakeClientHandler(object):
packet = self.read_packet()
assert isinstance(packet, serverbound.login.LoginStartPacket)
if self.server.private_key is not None:
self._run_login_encryption()
if self.server.compression_threshold is not None:
self.write_packet(clientbound.login.SetCompressionPacket(
threshold=self.server.compression_threshold))
@ -163,12 +170,30 @@ class FakeClientHandler(object):
self.user_name = packet.name
self.user_uuid = uuid.UUID(bytes=hashlib.md5(
('OfflinePlayer:%s' % self.user_name).encode('utf8')).digest())
self.write_packet(clientbound.login.LoginSuccessPacket(
UUID=str(self.user_uuid), Username=self.user_name))
self._run_playing()
def _run_login_encryption(self):
# Set up protocol encryption with the client, then return.
server_token = b'\x89\x82\x9a\x01' # Guaranteed to be random.
self.write_packet(clientbound.login.EncryptionRequestPacket(
server_id='', verify_token=server_token,
public_key=self.server.public_key_bytes))
packet = self.read_packet()
assert isinstance(packet, serverbound.login.EncryptionResponsePacket)
private_key = self.server.private_key
client_token = private_key.decrypt(packet.verify_token, PKCS1v15())
assert client_token == server_token
shared_secret = private_key.decrypt(packet.shared_secret, PKCS1v15())
cipher = create_AES_cipher(shared_secret)
enc, dec = cipher.encryptor(), cipher.decryptor()
self.socket = EncryptedSocketWrapper(self.socket, enc, dec)
self.socket_file = EncryptedFileObjectWrapper(self.socket_file, dec)
def _run_playing(self):
# Enter the playing state of the connection.
self.packets = self.server.packets_playing
@ -247,27 +272,37 @@ class FakeServer(object):
The server listens on a local TCP socket and accepts client connections
in serial, in a single-threaded manner. It responds to status queries,
performs handshake and login, and, by default, echoes any chat messages
back to the client until it disconnects.1~
back to the client until it disconnects.
The behaviour of the server can be customised by writing subclasses of
FakeClientHandler, overriding its public methods of the form
'handle_*', and providing the class to the FakeServer as its
'client_handler_type'.
If 'private_key' is not None, it must be an instance of
'cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey',
'public_key_bytes' must be the corresponding public key serialised in
DER format with PKCS1 encoding, and encryption will be enabled for all
client sessions; otherwise, if it is None, encryption is disabled.
"""
__slots__ = 'listen_socket', 'compression_threshold', 'context', \
'minecraft_version', 'client_handler_type', \
'packets_handshake', 'packets_login', 'packets_playing', \
'packets_status', 'lock', 'stopping'
'packets_status', 'lock', 'stopping', 'private_key', \
'public_key_bytes',
def __init__(self, minecraft_version=None, compression_threshold=None,
client_handler_type=FakeClientHandler):
client_handler_type=FakeClientHandler, private_key=None,
public_key_bytes=None):
if minecraft_version is None:
minecraft_version = VERSIONS[-1][0]
self.minecraft_version = minecraft_version
self.compression_threshold = compression_threshold
self.client_handler_type = client_handler_type
self.private_key = private_key
self.public_key_bytes = public_key_bytes
protocol_version = SUPPORTED_MINECRAFT_VERSIONS[minecraft_version]
self.context = connection.ConnectionContext(
@ -329,11 +364,11 @@ class _FakeServerTest(unittest.TestCase):
if a 'JoinGamePacket' is received before a 'DisconnectPacket'.
Customise by making subclasses that:
1. Overrides the attributes present in this class, where desired, so
that they will apply to all tests; and/or
1. Override the attributes present in this class, where desired, so
that they will apply to all tests; and
2. Define tests (or override 'runTest') to call '_test_connect' with
the arguments specified as necessary to override class attributes.
3. Overrides '_start_client' in order to set event listeners and
the necessary arguments to override class attributes; and
3. Override '_start_client' in order to set event listeners and
change the connection mode, if necessary.
To terminate the test and indicate that it finished successfully, a
client packet handler or a handler method of the 'FakeClientHandler'
@ -354,6 +389,12 @@ class _FakeServerTest(unittest.TestCase):
# The compression threshold that the server will dictate.
# If None, compression is disabled.
private_key = None
# The RSA private key used by the server: see 'FakeServer'.
public_key_bytes = None
# The serialised RSA public key used by the server: see 'FakeServer'.
def _start_client(self, client):
game_joined = [False]
@ -371,7 +412,8 @@ class _FakeServerTest(unittest.TestCase):
client.connect()
def _test_connect(self, client_versions=None, server_version=None,
client_handler_type=None, compression_threshold=None):
client_handler_type=None, compression_threshold=None,
private_key=None, public_key_bytes=None):
if client_versions is None:
client_versions = self.client_versions
if server_version is None:
@ -380,10 +422,16 @@ class _FakeServerTest(unittest.TestCase):
compression_threshold = self.compression_threshold
if client_handler_type is None:
client_handler_type = self.client_handler_type
if private_key is None:
private_key = self.private_key
if public_key_bytes is None:
public_key_bytes = self.public_key_bytes
server = FakeServer(minecraft_version=server_version,
compression_threshold=compression_threshold,
client_handler_type=client_handler_type)
client_handler_type=client_handler_type,
private_key=private_key,
public_key_bytes=public_key_bytes)
addr = "localhost"
port = server.listen_socket.getsockname()[1]

View File

@ -11,6 +11,8 @@ from minecraft.networking.encryption import (
EncryptedFileObjectWrapper,
EncryptedSocketWrapper
)
from minecraft.networking.packets import clientbound
from tests import test_connection
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
@ -20,6 +22,24 @@ KEY_LOCATION = os.path.join(os.path.dirname(os.path.realpath(__file__)),
"encryption")
def setUpModule():
global private_key, public_key, token
with open(os.path.join(KEY_LOCATION, "priv_key.bin"), "rb") as f:
private_key = f.read()
private_key = load_der_private_key(private_key, None, default_backend())
with open(os.path.join(KEY_LOCATION, "pub_key.bin"), "rb") as f:
public_key = f.read()
token = generate_shared_secret()
def tearDownModule():
global private_key, public_key, token
del private_key, public_key, token
class Hashing(unittest.TestCase):
test_data = {'Notch': '4ed1f46bbe04bc756bcb17c0c7ce3e4632f06a48',
'jeb_': '-7c9d5b0044c130109a5d7b5fb5c317c02b4e28c1',
@ -34,31 +54,19 @@ class Hashing(unittest.TestCase):
class Encryption(unittest.TestCase):
def setUp(self):
with open(os.path.join(KEY_LOCATION, "priv_key.bin"), "rb") as f:
self.private_key = f.read()
self.private_key = load_der_private_key(self.private_key, None,
default_backend())
with open(os.path.join(KEY_LOCATION, "pub_key.bin"), "rb") as f:
self.public_key = f.read()
self.token = generate_shared_secret()
def test_token_secret_encryption(self):
secret = generate_shared_secret()
token, encrypted_secret = encrypt_token_and_secret(self.public_key,
self.token, secret)
decrypted_token = self.private_key.decrypt(token,
PKCS1v15())
decrypted_secret = self.private_key.decrypt(encrypted_secret,
PKCS1v15())
encrypted_token, encrypted_secret = \
encrypt_token_and_secret(public_key, token, secret)
decrypted_token = private_key.decrypt(encrypted_token, PKCS1v15())
decrypted_secret = private_key.decrypt(encrypted_secret, PKCS1v15())
self.assertEquals(self.token, decrypted_token)
self.assertEquals(token, decrypted_token)
self.assertEquals(secret, decrypted_secret)
def test_generate_hash(self):
verification_hash = generate_verification_hash(
u"", "secret".encode('utf-8'), self.public_key)
u"", "secret".encode('utf-8'), public_key)
self.assertEquals("1f142e737a84a974a5f2a22f6174a78d80fd97f5",
verification_hash)
@ -104,6 +112,25 @@ class Encryption(unittest.TestCase):
self.assertEqual(test_data, mock_socket.received)
class EncryptedConnection(test_connection.ConnectTest):
def test_connect(self):
self._test_connect(private_key=private_key,
public_key_bytes=public_key)
def _start_client(self, client):
def handle_login_success(_packet):
assert isinstance(client.socket, EncryptedSocketWrapper)
assert isinstance(client.file_object, EncryptedFileObjectWrapper)
client.register_packet_listener(
handle_login_success, clientbound.login.LoginSuccessPacket)
super(EncryptedConnection, self)._start_client(client)
class EncryptedCompressedConnection(EncryptedConnection,
test_connection.ConnectCompressionLowTest):
pass
class MockSocket(object):
def __init__(self, encryptor, decryptor):