diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 365de5a..720dafa 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -135,6 +135,8 @@ class APIConnection: "_connect_task", "_fatal_exception", "_expected_disconnect", + "_loop", + "_send_pending_ping", ) def __init__( @@ -172,6 +174,8 @@ class APIConnection: self._connect_task: Optional[asyncio.Task[None]] = None self._fatal_exception: Optional[Exception] = None self._expected_disconnect = False + self._send_pending_ping = False + self._loop = asyncio.get_event_loop() @property def connection_state(self) -> ConnectionState: @@ -271,7 +275,7 @@ class APIConnection: sockaddr = astuple(addr.sockaddr) try: - coro = asyncio.get_event_loop().sock_connect(self._socket, sockaddr) + coro = self._loop.sock_connect(self._socket, sockaddr) async with async_timeout.timeout(TCP_CONNECT_TIMEOUT): await coro except OSError as err: @@ -290,7 +294,7 @@ class APIConnection: async def _connect_init_frame_helper(self) -> None: """Step 3 in connect process: initialize the frame helper and init read loop.""" fh: Union[APIPlaintextFrameHelper, APINoiseFrameHelper] - loop = asyncio.get_event_loop() + loop = self._loop if self._params.noise_psk is None: _, fh = await loop.create_connection( @@ -357,13 +361,10 @@ class APIConnection: f"Server sent a different name '{resp.name}'", resp.name ) - async def _connect_start_ping(self) -> None: - """Step 5 in connect process: start the ping loop.""" - self._async_schedule_keep_alive(asyncio.get_running_loop()) - - def _async_schedule_keep_alive(self, loop: asyncio.AbstractEventLoop) -> None: + def _async_schedule_keep_alive(self) -> None: """Start the keep alive task.""" - self._ping_timer = loop.call_later( + self._send_pending_ping = True + self._ping_timer = self._loop.call_later( self._keep_alive_interval, self._async_send_keep_alive ) @@ -372,14 +373,14 @@ class APIConnection: if not self._is_socket_open: return - loop = asyncio.get_running_loop() - self.send_message(PING_REQUEST_MESSAGE) + if self._send_pending_ping: + self.send_message(PING_REQUEST_MESSAGE) if self._pong_timer is None: # Do not reset the timer if it's already set # since the only thing we want to reset the timer # is if we receive a pong. - self._pong_timer = loop.call_later( + self._pong_timer = self._loop.call_later( self._keep_alive_timeout, self._async_pong_not_received ) else: @@ -399,7 +400,7 @@ class APIConnection: self._keep_alive_interval, ) - self._async_schedule_keep_alive(loop) + self._async_schedule_keep_alive() def _async_cancel_pong_timer(self) -> None: """Cancel the pong timer.""" @@ -434,7 +435,7 @@ class APIConnection: await self._connect_socket_connect(addr) await self._connect_init_frame_helper() await self._connect_hello() - await self._connect_start_ping() + self._async_schedule_keep_alive() if login: await self.login(check_connected=False) @@ -599,7 +600,7 @@ class APIConnection: :raises TimeoutAPIError: if a timeout occured """ - fut = asyncio.get_event_loop().create_future() + fut = self._loop.create_future() responses = [] def on_message(resp: message.Message) -> None: @@ -725,6 +726,11 @@ class APIConnection: # as we know the connection is still alive self._async_cancel_pong_timer() + if self._send_pending_ping: + # Any valid message from the remove cancels the pending ping + # since we know the connection is still alive + self._send_pending_ping = False + for handler in self._message_handlers.get(msg_type, [])[:]: handler(msg)