mirror of
https://github.com/ammaraskar/pyCraft.git
synced 2025-03-23 12:00:12 +01:00
Fully test the encryption package, and add a test for uncompressed packets when packet compression is enabled
This commit is contained in:
parent
89f788f3ea
commit
ea11461e76
@ -1,7 +1,6 @@
|
||||
import os
|
||||
from hashlib import sha1
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
@ -28,9 +27,6 @@ def encrypt_token_and_secret(pubkey, verification_token, shared_secret):
|
||||
"""
|
||||
pubkey = load_der_public_key(pubkey, default_backend())
|
||||
|
||||
if not isinstance(pubkey, rsa.RSAPublicKey):
|
||||
raise RuntimeError("Public key provided by server not an RSA key")
|
||||
|
||||
encrypted_token = pubkey.encrypt(verification_token, PKCS1v15())
|
||||
encrypted_secret = pubkey.encrypt(shared_secret, PKCS1v15())
|
||||
return encrypted_token, encrypted_secret
|
||||
@ -56,15 +52,14 @@ def minecraft_sha1_hash_digest(sha1_hash):
|
||||
def _number_from_bytes(b, signed=False):
|
||||
try:
|
||||
return int.from_bytes(b, byteorder='big', signed=signed)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if len(b) == 0:
|
||||
b = b'\x00'
|
||||
num = int(str(b).encode('hex'), 16)
|
||||
if signed and (ord(b[0]) & 0x80):
|
||||
num -= 2 ** (len(b) * 8)
|
||||
return num
|
||||
except AttributeError: # pragma: no cover
|
||||
# py-2 compatibility
|
||||
if len(b) == 0:
|
||||
b = b'\x00'
|
||||
num = int(str(b).encode('hex'), 16)
|
||||
if signed and (ord(b[0]) & 0x80):
|
||||
num -= 2 ** (len(b) * 8)
|
||||
return num
|
||||
|
||||
|
||||
class EncryptedFileObjectWrapper(object):
|
||||
|
@ -77,7 +77,7 @@ class Packet(object):
|
||||
|
||||
# compression_threshold of None means compression is disabled
|
||||
if compression_threshold is not None:
|
||||
if len(packet_buffer.get_writable()) > compression_threshold:
|
||||
if len(packet_buffer.get_writable()) > compression_threshold != -1:
|
||||
# compress the current payload
|
||||
compressed_data = compress(packet_buffer.get_writable())
|
||||
packet_buffer.reset()
|
||||
@ -87,8 +87,10 @@ class Packet(object):
|
||||
packet_buffer.send(compressed_data)
|
||||
else:
|
||||
# write out a 0 to indicate uncompressed data
|
||||
packet_data = packet_buffer.get_writable()
|
||||
packet_buffer.reset()
|
||||
VarInt.send(0, packet_buffer)
|
||||
packet_buffer.send(packet_data)
|
||||
|
||||
VarInt.send(len(packet_buffer.get_writable()), socket) # Packet Size
|
||||
socket.send(packet_buffer.get_writable()) # Packet Payload
|
||||
|
@ -9,6 +9,7 @@ from minecraft.networking.encryption import (
|
||||
generate_verification_hash,
|
||||
create_AES_cipher,
|
||||
EncryptedFileObjectWrapper,
|
||||
EncryptedSocketWrapper
|
||||
)
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
@ -55,7 +56,7 @@ class Encryption(unittest.TestCase):
|
||||
self.assertEquals(self.token, decrypted_token)
|
||||
self.assertEquals(secret, decrypted_secret)
|
||||
|
||||
def generate_hash_test(self):
|
||||
def test_generate_hash(self):
|
||||
verification_hash = generate_verification_hash(
|
||||
b"", "secret".encode('utf-8'), self.public_key)
|
||||
self.assertEquals("1f142e737a84a974a5f2a22f6174a78d80fd97f5",
|
||||
@ -76,4 +77,50 @@ class Encryption(unittest.TestCase):
|
||||
|
||||
self.assertEqual(test_data, decrypted_data)
|
||||
|
||||
# TODO: test for the socket wrapper
|
||||
def test_socket_wrapper(self):
|
||||
secret = generate_shared_secret()
|
||||
|
||||
cipher = create_AES_cipher(secret)
|
||||
encryptor = cipher.encryptor()
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
server_cipher = create_AES_cipher(secret)
|
||||
server_encryptor = server_cipher.encryptor()
|
||||
server_decryptor = server_cipher.decryptor()
|
||||
|
||||
mock_socket = MockSocket(server_encryptor, server_decryptor)
|
||||
wrapper = EncryptedSocketWrapper(mock_socket, encryptor, decryptor)
|
||||
|
||||
self.assertEqual(wrapper.fileno(), 0)
|
||||
|
||||
# Ensure that the 12 bytes we receive are the same as the 12 bytes
|
||||
# sent by the server, after undergoing encryption
|
||||
self.assertEqual(wrapper.recv(12), mock_socket.raw_data[:12])
|
||||
|
||||
# Ensure that hello reaches the server properly after undergoing
|
||||
# encryption
|
||||
test_data = "hello".encode('utf-8')
|
||||
wrapper.send(test_data)
|
||||
self.assertEqual(test_data, mock_socket.received)
|
||||
|
||||
|
||||
class MockSocket(object):
|
||||
|
||||
def __init__(self, encryptor, decryptor):
|
||||
self.raw_data = os.urandom(100)
|
||||
self.encryptor = encryptor
|
||||
self.decryptor = decryptor
|
||||
self.received = None
|
||||
|
||||
# when we receive data from the server
|
||||
# it'll be encrypted
|
||||
def recv(self, length):
|
||||
return self.encryptor.update(self.raw_data[:length])
|
||||
|
||||
# decrypt the data as it reaches
|
||||
# the server side
|
||||
def send(self, data):
|
||||
self.received = self.decryptor.update(data)
|
||||
|
||||
def fileno(self):
|
||||
return 0
|
||||
|
@ -77,12 +77,16 @@ class SerializationTest(unittest.TestCase):
|
||||
|
||||
def test_compressed_packet(self):
|
||||
msg = ''.join(choice(string.ascii_lowercase) for i in range(500))
|
||||
|
||||
packet = ChatPacket()
|
||||
packet.message = msg
|
||||
|
||||
self.write_read_packet(packet, 20)
|
||||
self.write_read_packet(packet, -1)
|
||||
|
||||
def write_read_packet(self, packet, compression_threshold):
|
||||
|
||||
packet_buffer = PacketBuffer()
|
||||
packet.write(packet_buffer, compression_threshold=20)
|
||||
packet.write(packet_buffer, compression_threshold)
|
||||
|
||||
packet_buffer.reset_cursor()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user