Added more tests, fixed compat.py to run on py3 as well

This commit is contained in:
Ammar Askar 2015-04-12 07:26:12 +05:00
parent b2ccc754f4
commit a5a76a8e1c
5 changed files with 101 additions and 59 deletions

View File

@ -15,5 +15,5 @@ both Python2 and Python3 while using the same codebase.
try: try:
input = raw_input input = raw_input
except NameError: except NameError:
pass input = input
# pylint: enable=undefined-variable,redefined-builtin,invalid-name # pylint: enable=undefined-variable,redefined-builtin,invalid-name

View File

@ -54,9 +54,14 @@ class Packet(object):
id = -0x01 id = -0x01
definition = [] definition = []
def __init__(self): def __init__(self, **kwargs):
pass pass
def set_values(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
return self
def read(self, file_object): def read(self, file_object):
for field in self.definition: for field in self.definition:
for var_name, data_type in field.items(): for var_name, data_type in field.items():

View File

@ -8,11 +8,11 @@ import struct
class Type(object): class Type(object):
@staticmethod @staticmethod
def read(file_object): def read(file_object):
pass raise NotImplementedError("Base data type not serializable")
@staticmethod @staticmethod
def send(value, socket): def send(value, socket):
pass raise NotImplementedError("Base data type not serializable")
# ========================================================= # =========================================================
@ -84,7 +84,7 @@ class VarInt(Type):
number = 0 number = 0
for i in range(5): for i in range(5):
byte = socket.recv(1) byte = socket.recv(1)
if byte == "": if byte == "" or len(byte) == 0:
raise RuntimeError("Socket disconnected") raise RuntimeError("Socket disconnected")
byte = ord(byte) byte = ord(byte)
number |= (byte & 0x7F) << 7 * i number |= (byte & 0x7F) << 7 * i

78
tests/test_packets.py Normal file
View File

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
import unittest
import string
from zlib import decompress
from random import choice
from minecraft.networking.types import VarInt
from minecraft.networking.packets import (
PacketBuffer, ChatPacket, KeepAlivePacket, PacketListener)
class PacketSerializatonTest(unittest.TestCase):
def test_packet(self):
packet = ChatPacket()
packet.message = u"κόσμε"
packet_buffer = PacketBuffer()
packet.write(packet_buffer)
packet_buffer.reset_cursor()
# Read the length and packet id
VarInt.read(packet_buffer)
packet_id = VarInt.read(packet_buffer)
self.assertEqual(packet_id, packet.id)
deserialized = ChatPacket()
deserialized.read(packet_buffer)
self.assertEqual(packet.message, deserialized.message)
def test_compressed_packet(self):
msg = ''.join(choice(string.ascii_lowercase) for i in range(500))
packet = ChatPacket()
packet.message = msg
self.write_read_packet(packet, 20)
self.write_read_packet(packet, -1)
def write_read_packet(self, packet, compression_threshold):
packet_buffer = PacketBuffer()
packet.write(packet_buffer, compression_threshold)
packet_buffer.reset_cursor()
VarInt.read(packet_buffer)
compressed_size = VarInt.read(packet_buffer)
if compressed_size > 0:
decompressed = decompress(packet_buffer.read(compressed_size))
packet_buffer.reset()
packet_buffer.send(decompressed)
packet_buffer.reset_cursor()
packet_id = VarInt.read(packet_buffer)
self.assertEqual(packet_id, packet.id)
deserialized = ChatPacket()
deserialized.read(packet_buffer)
self.assertEqual(packet.message, deserialized.message)
class PacketListenerTest(unittest.TestCase):
def test_listener(self):
message = "hello world"
def test_packet(chat_packet):
self.assertEqual(chat_packet.message, message)
listener = PacketListener(test_packet, ChatPacket)
packet = ChatPacket().set_values(message=message)
uncalled_packet = KeepAlivePacket().set_values(keep_alive_id=0)
listener.call_packet(packet)
listener.call_packet(uncalled_packet)

View File

@ -1,15 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import unittest import unittest
import string
from random import choice
from zlib import decompress
from minecraft.networking.types import ( from minecraft.networking.types import (
Type, Boolean, UnsignedByte, Byte, Short, UnsignedShort, Type, Boolean, UnsignedByte, Byte, Short, UnsignedShort,
Integer, VarInt, Long, Float, Double, ShortPrefixedByteArray, Integer, VarInt, Long, Float, Double, ShortPrefixedByteArray,
VarIntPrefixedByteArray, String as StringType VarIntPrefixedByteArray, String as StringType
) )
from minecraft.networking.packets import PacketBuffer, ChatPacket from minecraft.networking.packets import PacketBuffer
TEST_DATA = { TEST_DATA = {
@ -47,6 +44,18 @@ class SerializationTest(unittest.TestCase):
else: else:
self.assertEqual(test_data, deserialized) self.assertEqual(test_data, deserialized)
def test_exceptions(self):
base_type = Type()
with self.assertRaises(NotImplementedError):
base_type.read(None)
with self.assertRaises(NotImplementedError):
base_type.send(None, None)
empty_socket = PacketBuffer()
with self.assertRaises(RuntimeError):
VarInt.read_socket(empty_socket)
def test_varint(self): def test_varint(self):
self.assertEqual(VarInt.size(2), 1) self.assertEqual(VarInt.size(2), 1)
self.assertEqual(VarInt.size(1250), 2) self.assertEqual(VarInt.size(1250), 2)
@ -56,53 +65,3 @@ class SerializationTest(unittest.TestCase):
packet_buffer.reset_cursor() packet_buffer.reset_cursor()
self.assertEqual(VarInt.read_socket(packet_buffer), 50000) self.assertEqual(VarInt.read_socket(packet_buffer), 50000)
def test_packet(self):
packet = ChatPacket()
packet.message = u"κόσμε"
packet_buffer = PacketBuffer()
packet.write(packet_buffer)
packet_buffer.reset_cursor()
# Read the length and packet id
VarInt.read(packet_buffer)
packet_id = VarInt.read(packet_buffer)
self.assertEqual(packet_id, packet.id)
deserialized = ChatPacket()
deserialized.read(packet_buffer)
self.assertEqual(packet.message, deserialized.message)
def test_compressed_packet(self):
msg = ''.join(choice(string.ascii_lowercase) for i in range(500))
packet = ChatPacket()
packet.message = msg
self.write_read_packet(packet, 20)
self.write_read_packet(packet, -1)
def write_read_packet(self, packet, compression_threshold):
packet_buffer = PacketBuffer()
packet.write(packet_buffer, compression_threshold)
packet_buffer.reset_cursor()
VarInt.read(packet_buffer)
compressed_size = VarInt.read(packet_buffer)
if compressed_size > 0:
decompressed = decompress(packet_buffer.read(compressed_size))
packet_buffer.reset()
packet_buffer.send(decompressed)
packet_buffer.reset_cursor()
packet_id = VarInt.read(packet_buffer)
self.assertEqual(packet_id, packet.id)
deserialized = ChatPacket()
deserialized.read(packet_buffer)
self.assertEqual(packet.message, deserialized.message)