tweak
This commit is contained in:
parent
361ddebeaf
commit
8ea12a299c
|
@ -13,7 +13,7 @@ cdef class APIFrameHelper:
|
|||
cdef APIConnection _connection
|
||||
cdef object _transport
|
||||
cdef public object _writer
|
||||
cdef public object _ready_future
|
||||
cdef public object ready_future
|
||||
cdef bytes _buffer
|
||||
cdef unsigned int _buffer_len
|
||||
cdef unsigned int _pos
|
||||
|
|
|
@ -34,7 +34,7 @@ class APIFrameHelper:
|
|||
"_connection",
|
||||
"_transport",
|
||||
"_writer",
|
||||
"_ready_future",
|
||||
"ready_future",
|
||||
"_buffer",
|
||||
"_buffer_len",
|
||||
"_pos",
|
||||
|
@ -54,7 +54,7 @@ class APIFrameHelper:
|
|||
self._connection = connection
|
||||
self._transport: asyncio.Transport | None = None
|
||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
||||
self._ready_future = self._loop.create_future()
|
||||
self.ready_future = self._loop.create_future()
|
||||
self._buffer: bytes | None = None
|
||||
self._buffer_len = 0
|
||||
self._pos = 0
|
||||
|
@ -65,9 +65,9 @@ class APIFrameHelper:
|
|||
"""Set the log name."""
|
||||
self._log_name = log_name
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
|
||||
if not self._ready_future.done():
|
||||
self._ready_future.set_exception(exc)
|
||||
def _setready_future_exception(self, exc: Exception | type[Exception]) -> None:
|
||||
if not self.ready_future.done():
|
||||
self.ready_future.set_exception(exc)
|
||||
|
||||
def _add_to_buffer(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Add data to the buffer."""
|
||||
|
@ -138,7 +138,7 @@ class APIFrameHelper:
|
|||
|
||||
def get_handshake_future(self) -> None:
|
||||
"""Get the handshake future."""
|
||||
return self._ready_future
|
||||
return self.ready_future
|
||||
|
||||
@abstractmethod
|
||||
def write_packets(
|
||||
|
@ -159,7 +159,7 @@ class APIFrameHelper:
|
|||
self.close()
|
||||
|
||||
def _handle_error(self, exc: Exception) -> None:
|
||||
self._set_ready_future_exception(exc)
|
||||
self._setready_future_exception(exc)
|
||||
self._connection.report_fatal_error(exc)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
|
|
|
@ -104,7 +104,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
# Make sure we set the ready event if its not already set
|
||||
# so that we don't block forever on the ready event if we
|
||||
# are waiting for the handshake to complete.
|
||||
self._set_ready_future_exception(
|
||||
self._setready_future_exception(
|
||||
APIConnectionError(f"{self._log_name}: Connection closed")
|
||||
)
|
||||
self._state = NOISE_STATE_CLOSED
|
||||
|
@ -279,7 +279,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
|
||||
None,
|
||||
)
|
||||
self._ready_future.set_result(None)
|
||||
self.ready_future.set_result(None)
|
||||
|
||||
def write_packets(
|
||||
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||
|
|
|
@ -39,7 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Handle a new connection."""
|
||||
super().connection_made(transport)
|
||||
self._ready_future.set_result(None)
|
||||
self.ready_future.set_result(None)
|
||||
|
||||
def write_packets(
|
||||
self, packets: list[tuple[int, bytes]], debug_enabled: bool
|
||||
|
|
|
@ -400,7 +400,7 @@ class APIConnection:
|
|||
# Set the frame helper right away to ensure
|
||||
# the socket gets closed if we fail to handshake
|
||||
self._frame_helper = fh
|
||||
future = self._frame_helper.get_handshake_future()
|
||||
future = self._frame_helper.ready_future
|
||||
handshake_handle = self._loop.call_at(
|
||||
self._loop.time() + HANDSHAKE_TIMEOUT, handle_timeout, future
|
||||
)
|
||||
|
|
|
@ -312,7 +312,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
|
|||
mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -343,7 +343,7 @@ async def test_noise_frame_helper_incorrect_key():
|
|||
mock_data_received(helper, bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -376,7 +376,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
|||
mock_data_received(helper, in_pkt[i : i + 1])
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -407,7 +407,7 @@ async def test_noise_incorrect_name():
|
|||
mock_data_received(helper, bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(BadNameAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
await helper.ready_future
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -431,7 +431,7 @@ async def test_noise_timeout():
|
|||
for pkt in outgoing_packets:
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
||||
task = asyncio.create_task(helper.perform_handshake(30))
|
||||
task = asyncio.create_task(helper.ready_future)
|
||||
await asyncio.sleep(0)
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=60))
|
||||
await asyncio.sleep(0)
|
||||
|
@ -478,7 +478,7 @@ async def test_noise_frame_helper_handshake_failure():
|
|||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
handshake_task = asyncio.create_task(helper.ready_future)
|
||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||
|
||||
assert len(writes) == 1
|
||||
|
@ -528,7 +528,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
handshake_task = asyncio.create_task(helper.ready_future)
|
||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||
|
||||
assert len(writes) == 1
|
||||
|
@ -591,7 +591,7 @@ async def test_noise_frame_helper_bad_encryption(
|
|||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
handshake_task = asyncio.create_task(helper.ready_future)
|
||||
await asyncio.sleep(0) # let the task run to read the hello packet
|
||||
|
||||
assert len(writes) == 1
|
||||
|
@ -638,7 +638,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
|
|||
|
||||
conn._socket = MagicMock()
|
||||
await conn._connect_init_frame_helper()
|
||||
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
|
||||
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
|
||||
conn.connection_state = ConnectionState.CONNECTED
|
||||
|
||||
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
||||
|
@ -687,7 +687,7 @@ async def test_noise_frame_helper_empty_hello():
|
|||
log_name="test",
|
||||
)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
handshake_task = asyncio.create_task(helper.ready_future)
|
||||
hello_pkt_with_header = _make_noise_hello_pkt(b"")
|
||||
|
||||
mock_data_received(helper, hello_pkt_with_header)
|
||||
|
@ -708,7 +708,7 @@ async def test_noise_frame_helper_wrong_protocol():
|
|||
log_name="test",
|
||||
)
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
handshake_task = asyncio.create_task(helper.ready_future)
|
||||
# wrong protocol 5 instead of 1
|
||||
hello_pkt_with_header = _make_noise_hello_pkt(b"\x05servicetest\0")
|
||||
|
||||
|
|
|
@ -161,7 +161,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
|
|||
|
||||
conn._socket = MagicMock()
|
||||
await conn._connect_init_frame_helper()
|
||||
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
|
||||
loop.call_soon(conn._frame_helper.ready_future.set_result, None)
|
||||
conn.connection_state = ConnectionState.CONNECTED
|
||||
|
||||
with pytest.raises(RequiresEncryptionAPIError):
|
||||
|
|
Loading…
Reference in New Issue