mirror of
https://github.com/ammaraskar/pyCraft.git
synced 2025-03-25 04:49:03 +01:00
Fully test the encryption package, and add a test for uncompressed packets when packet compression is enabled
This commit is contained in:
parent
89f788f3ea
commit
ea11461e76
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
||||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
@ -28,9 +27,6 @@ def encrypt_token_and_secret(pubkey, verification_token, shared_secret):
|
|||||||
"""
|
"""
|
||||||
pubkey = load_der_public_key(pubkey, default_backend())
|
pubkey = load_der_public_key(pubkey, default_backend())
|
||||||
|
|
||||||
if not isinstance(pubkey, rsa.RSAPublicKey):
|
|
||||||
raise RuntimeError("Public key provided by server not an RSA key")
|
|
||||||
|
|
||||||
encrypted_token = pubkey.encrypt(verification_token, PKCS1v15())
|
encrypted_token = pubkey.encrypt(verification_token, PKCS1v15())
|
||||||
encrypted_secret = pubkey.encrypt(shared_secret, PKCS1v15())
|
encrypted_secret = pubkey.encrypt(shared_secret, PKCS1v15())
|
||||||
return encrypted_token, encrypted_secret
|
return encrypted_token, encrypted_secret
|
||||||
@ -56,15 +52,14 @@ def minecraft_sha1_hash_digest(sha1_hash):
|
|||||||
def _number_from_bytes(b, signed=False):
|
def _number_from_bytes(b, signed=False):
|
||||||
try:
|
try:
|
||||||
return int.from_bytes(b, byteorder='big', signed=signed)
|
return int.from_bytes(b, byteorder='big', signed=signed)
|
||||||
except AttributeError:
|
except AttributeError: # pragma: no cover
|
||||||
pass
|
# py-2 compatibility
|
||||||
|
if len(b) == 0:
|
||||||
if len(b) == 0:
|
b = b'\x00'
|
||||||
b = b'\x00'
|
num = int(str(b).encode('hex'), 16)
|
||||||
num = int(str(b).encode('hex'), 16)
|
if signed and (ord(b[0]) & 0x80):
|
||||||
if signed and (ord(b[0]) & 0x80):
|
num -= 2 ** (len(b) * 8)
|
||||||
num -= 2 ** (len(b) * 8)
|
return num
|
||||||
return num
|
|
||||||
|
|
||||||
|
|
||||||
class EncryptedFileObjectWrapper(object):
|
class EncryptedFileObjectWrapper(object):
|
||||||
|
@ -77,7 +77,7 @@ class Packet(object):
|
|||||||
|
|
||||||
# compression_threshold of None means compression is disabled
|
# compression_threshold of None means compression is disabled
|
||||||
if compression_threshold is not None:
|
if compression_threshold is not None:
|
||||||
if len(packet_buffer.get_writable()) > compression_threshold:
|
if len(packet_buffer.get_writable()) > compression_threshold != -1:
|
||||||
# compress the current payload
|
# compress the current payload
|
||||||
compressed_data = compress(packet_buffer.get_writable())
|
compressed_data = compress(packet_buffer.get_writable())
|
||||||
packet_buffer.reset()
|
packet_buffer.reset()
|
||||||
@ -87,8 +87,10 @@ class Packet(object):
|
|||||||
packet_buffer.send(compressed_data)
|
packet_buffer.send(compressed_data)
|
||||||
else:
|
else:
|
||||||
# write out a 0 to indicate uncompressed data
|
# write out a 0 to indicate uncompressed data
|
||||||
|
packet_data = packet_buffer.get_writable()
|
||||||
packet_buffer.reset()
|
packet_buffer.reset()
|
||||||
VarInt.send(0, packet_buffer)
|
VarInt.send(0, packet_buffer)
|
||||||
|
packet_buffer.send(packet_data)
|
||||||
|
|
||||||
VarInt.send(len(packet_buffer.get_writable()), socket) # Packet Size
|
VarInt.send(len(packet_buffer.get_writable()), socket) # Packet Size
|
||||||
socket.send(packet_buffer.get_writable()) # Packet Payload
|
socket.send(packet_buffer.get_writable()) # Packet Payload
|
||||||
|
@ -9,6 +9,7 @@ from minecraft.networking.encryption import (
|
|||||||
generate_verification_hash,
|
generate_verification_hash,
|
||||||
create_AES_cipher,
|
create_AES_cipher,
|
||||||
EncryptedFileObjectWrapper,
|
EncryptedFileObjectWrapper,
|
||||||
|
EncryptedSocketWrapper
|
||||||
)
|
)
|
||||||
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
@ -55,7 +56,7 @@ class Encryption(unittest.TestCase):
|
|||||||
self.assertEquals(self.token, decrypted_token)
|
self.assertEquals(self.token, decrypted_token)
|
||||||
self.assertEquals(secret, decrypted_secret)
|
self.assertEquals(secret, decrypted_secret)
|
||||||
|
|
||||||
def generate_hash_test(self):
|
def test_generate_hash(self):
|
||||||
verification_hash = generate_verification_hash(
|
verification_hash = generate_verification_hash(
|
||||||
b"", "secret".encode('utf-8'), self.public_key)
|
b"", "secret".encode('utf-8'), self.public_key)
|
||||||
self.assertEquals("1f142e737a84a974a5f2a22f6174a78d80fd97f5",
|
self.assertEquals("1f142e737a84a974a5f2a22f6174a78d80fd97f5",
|
||||||
@ -76,4 +77,50 @@ class Encryption(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(test_data, decrypted_data)
|
self.assertEqual(test_data, decrypted_data)
|
||||||
|
|
||||||
# TODO: test for the socket wrapper
|
def test_socket_wrapper(self):
|
||||||
|
secret = generate_shared_secret()
|
||||||
|
|
||||||
|
cipher = create_AES_cipher(secret)
|
||||||
|
encryptor = cipher.encryptor()
|
||||||
|
decryptor = cipher.decryptor()
|
||||||
|
|
||||||
|
server_cipher = create_AES_cipher(secret)
|
||||||
|
server_encryptor = server_cipher.encryptor()
|
||||||
|
server_decryptor = server_cipher.decryptor()
|
||||||
|
|
||||||
|
mock_socket = MockSocket(server_encryptor, server_decryptor)
|
||||||
|
wrapper = EncryptedSocketWrapper(mock_socket, encryptor, decryptor)
|
||||||
|
|
||||||
|
self.assertEqual(wrapper.fileno(), 0)
|
||||||
|
|
||||||
|
# Ensure that the 12 bytes we receive are the same as the 12 bytes
|
||||||
|
# sent by the server, after undergoing encryption
|
||||||
|
self.assertEqual(wrapper.recv(12), mock_socket.raw_data[:12])
|
||||||
|
|
||||||
|
# Ensure that hello reaches the server properly after undergoing
|
||||||
|
# encryption
|
||||||
|
test_data = "hello".encode('utf-8')
|
||||||
|
wrapper.send(test_data)
|
||||||
|
self.assertEqual(test_data, mock_socket.received)
|
||||||
|
|
||||||
|
|
||||||
|
class MockSocket(object):
|
||||||
|
|
||||||
|
def __init__(self, encryptor, decryptor):
|
||||||
|
self.raw_data = os.urandom(100)
|
||||||
|
self.encryptor = encryptor
|
||||||
|
self.decryptor = decryptor
|
||||||
|
self.received = None
|
||||||
|
|
||||||
|
# when we receive data from the server
|
||||||
|
# it'll be encrypted
|
||||||
|
def recv(self, length):
|
||||||
|
return self.encryptor.update(self.raw_data[:length])
|
||||||
|
|
||||||
|
# decrypt the data as it reaches
|
||||||
|
# the server side
|
||||||
|
def send(self, data):
|
||||||
|
self.received = self.decryptor.update(data)
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return 0
|
||||||
|
@ -77,12 +77,16 @@ class SerializationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_compressed_packet(self):
|
def test_compressed_packet(self):
|
||||||
msg = ''.join(choice(string.ascii_lowercase) for i in range(500))
|
msg = ''.join(choice(string.ascii_lowercase) for i in range(500))
|
||||||
|
|
||||||
packet = ChatPacket()
|
packet = ChatPacket()
|
||||||
packet.message = msg
|
packet.message = msg
|
||||||
|
|
||||||
|
self.write_read_packet(packet, 20)
|
||||||
|
self.write_read_packet(packet, -1)
|
||||||
|
|
||||||
|
def write_read_packet(self, packet, compression_threshold):
|
||||||
|
|
||||||
packet_buffer = PacketBuffer()
|
packet_buffer = PacketBuffer()
|
||||||
packet.write(packet_buffer, compression_threshold=20)
|
packet.write(packet_buffer, compression_threshold)
|
||||||
|
|
||||||
packet_buffer.reset_cursor()
|
packet_buffer.reset_cursor()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user