This commit is contained in:
J. Nick Koston 2023-11-27 23:42:34 -06:00
parent 783a52e95b
commit 8702ff1578
No known key found for this signature in database
2 changed files with 17 additions and 18 deletions

View File

@ -14,9 +14,9 @@ cdef object HANDSHAKE_TIMEOUT
cdef bint TYPE_CHECKING
cdef object DISCONNECT_REQUEST_MESSAGE
cdef object DISCONNECT_RESPONSE_MESSAGE
cdef object PING_REQUEST_MESSAGE
cdef object PING_RESPONSE_MESSAGE
cdef tuple DISCONNECT_RESPONSE_MESSAGES
cdef tuple PING_REQUEST_MESSAGES
cdef tuple PING_RESPONSE_MESSAGES
cdef object NO_PASSWORD_CONNECT_REQUEST
cdef object asyncio_timeout

View File

@ -63,9 +63,9 @@ _LOGGER = logging.getLogger(__name__)
BUFFER_SIZE = 1024 * 1024 * 2 # Set buffer limit to 2MB
DISCONNECT_REQUEST_MESSAGE = DisconnectRequest()
DISCONNECT_RESPONSE_MESSAGE = DisconnectResponse()
PING_REQUEST_MESSAGE = PingRequest()
PING_RESPONSE_MESSAGE = PingResponse()
DISCONNECT_RESPONSE_MESSAGES = (DisconnectResponse(),)
PING_REQUEST_MESSAGES = (PingRequest(),)
PING_RESPONSE_MESSAGES = (PingResponse(),)
NO_PASSWORD_CONNECT_REQUEST = ConnectRequest()
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
@ -370,12 +370,11 @@ class APIConnection:
async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop."""
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
loop = self._loop
if TYPE_CHECKING:
assert self._socket is not None
if (noise_psk := self._params.noise_psk) is None:
_, fh = await loop.create_connection( # type: ignore[type-var]
_, fh = await self._loop.create_connection( # type: ignore[type-var]
lambda: APIPlaintextFrameHelper(
connection=self,
client_info=self._params.client_info,
@ -384,7 +383,7 @@ class APIConnection:
sock=self._socket,
)
else:
_, fh = await loop.create_connection( # type: ignore[type-var]
_, fh = await self._loop.create_connection( # type: ignore[type-var]
lambda: APINoiseFrameHelper(
noise_psk=noise_psk,
expected_name=self._params.expected_name,
@ -486,16 +485,15 @@ class APIConnection:
def _async_send_keep_alive(self) -> None:
"""Send a keep alive message."""
loop = self._loop
now = loop.time()
now = self._loop.time()
if self._send_pending_ping:
self.send_messages((PING_REQUEST_MESSAGE,))
self.send_messages(PING_REQUEST_MESSAGES)
if self._pong_timer is None:
# Do not reset the timer if it's already set
# since the only thing we want to reset the timer
# is if we receive a pong.
self._pong_timer = loop.call_at(
self._pong_timer = self._loop.call_at(
now + self._keep_alive_timeout, self._async_pong_not_received
)
elif self._debug_enabled:
@ -738,9 +736,8 @@ class APIConnection:
# This is safe because we are not awaiting between
# sending the message and registering the handler
self.send_messages(messages)
loop = self._loop
# Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = loop.create_future()
fut: asyncio.Future[None] = self._loop.create_future()
responses: list[message.Message] = []
on_message = partial(
_handle_complex_message, fut, responses, do_append, do_stop
@ -753,7 +750,9 @@ class APIConnection:
# We must not await without a finally or
# the message could fail to be removed if the
# the await is cancelled
timeout_handle = loop.call_at(loop.time() + timeout, _handle_timeout, fut)
timeout_handle = self._loop.call_at(
self._loop.time() + timeout, _handle_timeout, fut
)
timeout_expired = False
try:
await fut
@ -893,14 +892,14 @@ class APIConnection:
# the response if for some reason sending the response
# fails we will still mark the disconnect as expected
self._expected_disconnect = True
self.send_messages((DISCONNECT_RESPONSE_MESSAGE,))
self.send_messages(DISCONNECT_RESPONSE_MESSAGES)
self._cleanup()
def _handle_ping_request_internal( # pylint: disable=unused-argument
self, _msg: PingRequest
) -> None:
"""Handle a PingRequest."""
self.send_messages((PING_RESPONSE_MESSAGE,))
self.send_messages(PING_RESPONSE_MESSAGES)
def _handle_get_time_request_internal( # pylint: disable=unused-argument
self, _msg: GetTimeRequest