From 0afa8c683222758d676acbd699c3fe6cd1e4da9e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 21 Nov 2023 15:36:43 +0100 Subject: [PATCH] Remove in_do_connect contextvar (#652) --- aioesphomeapi/connection.pxd | 2 +- aioesphomeapi/connection.py | 29 ++++++++++++----------------- tests/test_connection.py | 7 +++++++ 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index dfb8694..c1a1005 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -42,7 +42,7 @@ cdef object PingFailedAPIError cdef object ReadFailedAPIError cdef object TimeoutAPIError -cdef object in_do_connect, astuple +cdef object astuple @cython.dataclasses.dataclass diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 232550c..f62b60e 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextvars import enum import logging import socket @@ -95,11 +94,6 @@ TCP_CONNECT_TIMEOUT = 60.0 DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0 -in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar( - "in_do_connect" -) - - _int = int _bytes = bytes _float = float @@ -236,11 +230,19 @@ class APIConnection: # If we are being called from do_connect we # need to make sure we don't cancel the task # that called us - if self._start_connect_task is not None and not in_do_connect.get(False): + current_task = asyncio.current_task() + + if ( + self._start_connect_task is not None + and self._start_connect_task is not current_task + ): self._start_connect_task.cancel("Connection cleanup") self._start_connect_task = None - if self._finish_connect_task is not None and not in_do_connect.get(False): + if ( + self._finish_connect_task is not None + and self._finish_connect_task is not current_task + ): self._finish_connect_task.cancel("Connection cleanup") self._finish_connect_task = None @@ -512,7 +514,6 @@ class APIConnection: async def _do_connect(self) -> None: """Do the actual connect process.""" - in_do_connect.set(True) self.resolved_addr_info = await self._connect_resolve_host() await self._connect_socket_connect(self.resolved_addr_info) @@ -522,7 +523,7 @@ class APIConnection: This part of the process establishes the socket connection but does not initialize the frame helper or send the hello message. """ - if self.connection_state != ConnectionState.INITIALIZED: + if self.connection_state is not ConnectionState.INITIALIZED: raise ValueError( "Connection can only be used once, connection is not in init state" ) @@ -567,7 +568,6 @@ class APIConnection: async def _do_finish_connect(self, login: bool) -> None: """Finish the connection process.""" - in_do_connect.set(True) await self._connect_init_frame_helper() self._register_internal_message_handlers() await self._connect_hello_login(login) @@ -579,7 +579,7 @@ class APIConnection: This part of the process initializes the frame helper and sends the hello message than starts the keep alive process. """ - if self.connection_state != ConnectionState.SOCKET_OPENED: + if self.connection_state is not ConnectionState.SOCKET_OPENED: raise ValueError( "Connection must be in SOCKET_OPENED state to finish connection" ) @@ -619,11 +619,6 @@ class APIConnection: def send_messages(self, msgs: tuple[message.Message, ...]) -> None: """Send a protobuf message to the remote.""" if not self._handshake_complete: - if in_do_connect.get(False): - # If we are in the do_connect task, we can't raise an error - # because it would obscure the original exception (ie encrypt error). - _LOGGER.debug("%s: Connection isn't established yet", self.log_name) - return raise ConnectionNotEstablishedAPIError( f"Connection isn't established yet ({self.connection_state})" ) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6fe816e..8475115 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -20,6 +20,7 @@ from aioesphomeapi.api_pb2 import ( from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState from aioesphomeapi.core import ( APIConnectionError, + ConnectionNotEstablishedAPIError, HandshakeAPIError, InvalidAuthAPIError, RequiresEncryptionAPIError, @@ -609,3 +610,9 @@ async def test_ping_does_not_disconnect_if_we_get_responses( # We should disconnect if we are getting ping responses assert conn.is_connected is True + + +def test_raise_during_send_messages_when_not_yet_connected(conn: APIConnection) -> None: + """Test that we raise when sending messages before we are connected.""" + with pytest.raises(ConnectionNotEstablishedAPIError): + conn.send_message(PingRequest())