diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 01e4c3c..2069601 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -92,6 +92,7 @@ CONNECT_AND_SETUP_TIMEOUT = 120.0 # the esp device DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0 + in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar( "in_do_connect" ) @@ -119,6 +120,9 @@ class ConnectionState(enum.Enum): CLOSED = 3 +OPEN_STATES = {ConnectionState.SOCKET_OPENED, ConnectionState.CONNECTED} + + class APIConnection: """This class represents _one_ connection to a remote native API device. @@ -132,7 +136,7 @@ class APIConnection: "_on_stop_task", "_socket", "_frame_helper", - "_api_version", + "api_version", "_connection_state", "_connect_complete", "_message_handlers", @@ -149,6 +153,7 @@ class APIConnection: "_send_pending_ping", "is_connected", "is_authenticated", + "_is_socket_open", ) def __init__( @@ -162,7 +167,7 @@ class APIConnection: self._on_stop_task: Optional[asyncio.Task[None]] = None self._socket: Optional[socket.socket] = None self._frame_helper: Optional[APIFrameHelper] = None - self._api_version: Optional[APIVersion] = None + self.api_version: Optional[APIVersion] = None self._connection_state = ConnectionState.INITIALIZED # Store whether connect() has completed @@ -189,6 +194,7 @@ class APIConnection: self._loop = asyncio.get_event_loop() self.is_connected = False self.is_authenticated = False + self._is_socket_open = False @property def connection_state(self) -> ConnectionState: @@ -278,13 +284,14 @@ class APIConnection: err, ) - _LOGGER.debug( - "%s: Connecting to %s:%s (%s)", - self.log_name, - self._params.address, - self._params.port, - addr, - ) + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug( + "%s: Connecting to %s:%s (%s)", + self.log_name, + self._params.address, + self._params.port, + addr, + ) sockaddr = astuple(addr.sockaddr) try: @@ -361,15 +368,16 @@ class APIConnection: resp.api_version_major, resp.api_version_minor, ) - self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor) - if self._api_version.major > 2: + api_version = APIVersion(resp.api_version_major, resp.api_version_minor) + if api_version.major > 2: _LOGGER.error( "%s: Incompatible version %s! Closing connection", self.log_name, - self._api_version.major, + api_version.major, ) raise APIConnectionError("Incompatible API version.") + self.api_version = api_version if ( self._params.expected_name is not None and resp.name != "" @@ -440,26 +448,25 @@ class APIConnection: ) ) + async def _do_connect(self, login: bool) -> None: + """Do the actual connect process.""" + in_do_connect.set(True) + addr = await self._connect_resolve_host() + await self._connect_socket_connect(addr) + await self._connect_init_frame_helper() + await self._connect_hello() + if login: + await self.login(check_connected=False) + self._async_schedule_keep_alive() + async def connect(self, *, login: bool) -> None: if self._connection_state != ConnectionState.INITIALIZED: raise ValueError( "Connection can only be used once, connection is not in init state" ) - - async def _do_connect() -> None: - in_do_connect.set(True) - addr = await self._connect_resolve_host() - await self._connect_socket_connect(addr) - await self._connect_init_frame_helper() - await self._connect_hello() - self._async_schedule_keep_alive() - if login: - await self.login(check_connected=False) - self._connect_task = asyncio.create_task( - _do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect" + self._do_connect(login), name=f"{self.log_name}: aioesphomeapi do_connect" ) - try: # Allow 2 minutes for connect and setup; this is only as a last measure # to protect from issues if some part of the connect process mistakenly @@ -486,10 +493,11 @@ class APIConnection: """Set the connection state and log the change.""" self._connection_state = state self.is_connected = state == ConnectionState.CONNECTED + self._is_socket_open = state in OPEN_STATES async def login(self, check_connected: bool = True) -> None: """Send a login (ConnectRequest) and await the response.""" - if check_connected and self._connection_state != ConnectionState.CONNECTED: + if check_connected and self.is_connected: # On first connect, we don't want to check if we're connected # because we don't set the connection state until after login # is complete @@ -517,19 +525,9 @@ class APIConnection: self.is_authenticated = True - @property - def _is_socket_open(self) -> bool: - return self._connection_state in ( - ConnectionState.SOCKET_OPENED, - ConnectionState.CONNECTED, - ) - def send_message(self, msg: message.Message) -> None: """Send a protobuf message to the remote.""" - if self._connection_state not in ( - ConnectionState.SOCKET_OPENED, - ConnectionState.CONNECTED, - ): + if not self._is_socket_open: if in_do_connect.get(False): # If we are in the do_connect task, we can't raise an error # because it would obscure the original exception (ie encrypt error). @@ -720,6 +718,7 @@ class APIConnection: is_enabled_for = _LOGGER.isEnabledFor logging_debug = logging.DEBUG message_handlers = self._message_handlers + internal_message_types = INTERNAL_MESSAGE_TYPES def _process_packet(msg_type_proto: int, data: bytes) -> None: """Process a packet from the socket.""" @@ -785,7 +784,7 @@ class APIConnection: # Pre-check the message type to avoid awaiting # since most messages are not internal messages - if msg_type not in INTERNAL_MESSAGE_TYPES: + if msg_type not in internal_message_types: return if msg_type is DisconnectRequest: @@ -830,7 +829,3 @@ class APIConnection: self._set_connection_state(ConnectionState.CLOSED) self._expected_disconnect = True self._cleanup() - - @property - def api_version(self) -> Optional[APIVersion]: - return self._api_version