This commit is contained in:
J. Nick Koston 2023-11-27 22:19:29 -06:00
parent 361ddebeaf
commit 8ea12a299c
No known key found for this signature in database
7 changed files with 24 additions and 24 deletions

View File

@ -13,7 +13,7 @@ cdef class APIFrameHelper:
cdef APIConnection _connection
cdef object _transport
cdef public object _writer
cdef public object _ready_future
cdef public object ready_future
cdef bytes _buffer
cdef unsigned int _buffer_len
cdef unsigned int _pos

View File

@ -34,7 +34,7 @@ class APIFrameHelper:
"_connection",
"_transport",
"_writer",
"_ready_future",
"ready_future",
"_buffer",
"_buffer_len",
"_pos",
@ -54,7 +54,7 @@ class APIFrameHelper:
self._connection = connection
self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._ready_future = self._loop.create_future()
self.ready_future = self._loop.create_future()
self._buffer: bytes | None = None
self._buffer_len = 0
self._pos = 0
@ -65,9 +65,9 @@ class APIFrameHelper:
"""Set the log name."""
self._log_name = log_name
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
if not self._ready_future.done():
self._ready_future.set_exception(exc)
def _setready_future_exception(self, exc: Exception | type[Exception]) -> None:
if not self.ready_future.done():
self.ready_future.set_exception(exc)
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
"""Add data to the buffer."""
@ -138,7 +138,7 @@ class APIFrameHelper:
def get_handshake_future(self) -> None:
"""Get the handshake future."""
return self._ready_future
return self.ready_future
@abstractmethod
def write_packets(
@ -159,7 +159,7 @@ class APIFrameHelper:
self.close()
def _handle_error(self, exc: Exception) -> None:
self._set_ready_future_exception(exc)
self._setready_future_exception(exc)
self._connection.report_fatal_error(exc)
def connection_lost(self, exc: Exception | None) -> None:

View File

@ -104,7 +104,7 @@ class APINoiseFrameHelper(APIFrameHelper):
# Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we
# are waiting for the handshake to complete.
self._set_ready_future_exception(
self._setready_future_exception(
APIConnectionError(f"{self._log_name}: Connection closed")
)
self._state = NOISE_STATE_CLOSED
@ -279,7 +279,7 @@ class APINoiseFrameHelper(APIFrameHelper):
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
None,
)
self._ready_future.set_result(None)
self.ready_future.set_result(None)
def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool

View File

@ -39,7 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
super().connection_made(transport)
self._ready_future.set_result(None)
self.ready_future.set_result(None)
def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool

View File

@ -400,7 +400,7 @@ class APIConnection:
# Set the frame helper right away to ensure
# the socket gets closed if we fail to handshake
self._frame_helper = fh
future = self._frame_helper.get_handshake_future()
future = self._frame_helper.ready_future
handshake_handle = self._loop.call_at(
self._loop.time() + HANDSHAKE_TIMEOUT, handle_timeout, future
)

View File

@ -312,7 +312,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30)
await helper.ready_future
@pytest.mark.asyncio
@ -343,7 +343,7 @@ async def test_noise_frame_helper_incorrect_key():
mock_data_received(helper, bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30)
await helper.ready_future
@pytest.mark.asyncio
@ -376,7 +376,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
mock_data_received(helper, in_pkt[i : i + 1])
with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30)
await helper.ready_future
@pytest.mark.asyncio
@ -407,7 +407,7 @@ async def test_noise_incorrect_name():
mock_data_received(helper, bytes.fromhex(pkt))
with pytest.raises(BadNameAPIError):
await helper.perform_handshake(30)
await helper.ready_future
@pytest.mark.asyncio
@ -431,7 +431,7 @@ async def test_noise_timeout():
for pkt in outgoing_packets:
helper.mock_write_frame(bytes.fromhex(pkt))
task = asyncio.create_task(helper.perform_handshake(30))
task = asyncio.create_task(helper.ready_future)
await asyncio.sleep(0)
async_fire_time_changed(utcnow() + timedelta(seconds=60))
await asyncio.sleep(0)
@ -478,7 +478,7 @@ async def test_noise_frame_helper_handshake_failure():
proto = _mock_responder_proto(psk_bytes)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
handshake_task = asyncio.create_task(helper.ready_future)
await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1
@ -528,7 +528,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
proto = _mock_responder_proto(psk_bytes)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
handshake_task = asyncio.create_task(helper.ready_future)
await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1
@ -591,7 +591,7 @@ async def test_noise_frame_helper_bad_encryption(
proto = _mock_responder_proto(psk_bytes)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
handshake_task = asyncio.create_task(helper.ready_future)
await asyncio.sleep(0) # let the task run to read the hello packet
assert len(writes) == 1
@ -638,7 +638,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
conn._socket = MagicMock()
await conn._connect_init_frame_helper()
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
conn.connection_state = ConnectionState.CONNECTED
task = asyncio.create_task(conn._connect_hello_login(login=True))
@ -687,7 +687,7 @@ async def test_noise_frame_helper_empty_hello():
log_name="test",
)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
handshake_task = asyncio.create_task(helper.ready_future)
hello_pkt_with_header = _make_noise_hello_pkt(b"")
mock_data_received(helper, hello_pkt_with_header)
@ -708,7 +708,7 @@ async def test_noise_frame_helper_wrong_protocol():
log_name="test",
)
handshake_task = asyncio.create_task(helper.perform_handshake(30))
handshake_task = asyncio.create_task(helper.ready_future)
# wrong protocol 5 instead of 1
hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0")

View File

@ -161,7 +161,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
conn._socket = MagicMock()
await conn._connect_init_frame_helper()
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
conn.connection_state = ConnectionState.CONNECTED
with pytest.raises(RequiresEncryptionAPIError):