Remove in_do_connect contextvar (#652)

This commit is contained in:
J. Nick Koston 2023-11-21 15:36:43 +01:00 committed by GitHub
parent f88b15e33b
commit 0afa8c6832
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 18 deletions

View File

@ -42,7 +42,7 @@ cdef object PingFailedAPIError
cdef object ReadFailedAPIError cdef object ReadFailedAPIError
cdef object TimeoutAPIError cdef object TimeoutAPIError
cdef object in_do_connect, astuple cdef object astuple
@cython.dataclasses.dataclass @cython.dataclasses.dataclass

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextvars
import enum import enum
import logging import logging
import socket import socket
@ -95,11 +94,6 @@ TCP_CONNECT_TIMEOUT = 60.0
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0 DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar(
"in_do_connect"
)
_int = int _int = int
_bytes = bytes _bytes = bytes
_float = float _float = float
@ -236,11 +230,19 @@ class APIConnection:
# If we are being called from do_connect we # If we are being called from do_connect we
# need to make sure we don't cancel the task # need to make sure we don't cancel the task
# that called us # that called us
if self._start_connect_task is not None and not in_do_connect.get(False): current_task = asyncio.current_task()
if (
self._start_connect_task is not None
and self._start_connect_task is not current_task
):
self._start_connect_task.cancel("Connection cleanup") self._start_connect_task.cancel("Connection cleanup")
self._start_connect_task = None self._start_connect_task = None
if self._finish_connect_task is not None and not in_do_connect.get(False): if (
self._finish_connect_task is not None
and self._finish_connect_task is not current_task
):
self._finish_connect_task.cancel("Connection cleanup") self._finish_connect_task.cancel("Connection cleanup")
self._finish_connect_task = None self._finish_connect_task = None
@ -512,7 +514,6 @@ class APIConnection:
async def _do_connect(self) -> None: async def _do_connect(self) -> None:
"""Do the actual connect process.""" """Do the actual connect process."""
in_do_connect.set(True)
self.resolved_addr_info = await self._connect_resolve_host() self.resolved_addr_info = await self._connect_resolve_host()
await self._connect_socket_connect(self.resolved_addr_info) await self._connect_socket_connect(self.resolved_addr_info)
@ -522,7 +523,7 @@ class APIConnection:
This part of the process establishes the socket connection but This part of the process establishes the socket connection but
does not initialize the frame helper or send the hello message. does not initialize the frame helper or send the hello message.
""" """
if self.connection_state != ConnectionState.INITIALIZED: if self.connection_state is not ConnectionState.INITIALIZED:
raise ValueError( raise ValueError(
"Connection can only be used once, connection is not in init state" "Connection can only be used once, connection is not in init state"
) )
@ -567,7 +568,6 @@ class APIConnection:
async def _do_finish_connect(self, login: bool) -> None: async def _do_finish_connect(self, login: bool) -> None:
"""Finish the connection process.""" """Finish the connection process."""
in_do_connect.set(True)
await self._connect_init_frame_helper() await self._connect_init_frame_helper()
self._register_internal_message_handlers() self._register_internal_message_handlers()
await self._connect_hello_login(login) await self._connect_hello_login(login)
@ -579,7 +579,7 @@ class APIConnection:
This part of the process initializes the frame helper and sends the hello message This part of the process initializes the frame helper and sends the hello message
than starts the keep alive process. than starts the keep alive process.
""" """
if self.connection_state != ConnectionState.SOCKET_OPENED: if self.connection_state is not ConnectionState.SOCKET_OPENED:
raise ValueError( raise ValueError(
"Connection must be in SOCKET_OPENED state to finish connection" "Connection must be in SOCKET_OPENED state to finish connection"
) )
@ -619,11 +619,6 @@ class APIConnection:
def send_messages(self, msgs: tuple[message.Message, ...]) -> None: def send_messages(self, msgs: tuple[message.Message, ...]) -> None:
"""Send a protobuf message to the remote.""" """Send a protobuf message to the remote."""
if not self._handshake_complete: if not self._handshake_complete:
if in_do_connect.get(False):
# If we are in the do_connect task, we can't raise an error
# because it would obscure the original exception (ie encrypt error).
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
return
raise ConnectionNotEstablishedAPIError( raise ConnectionNotEstablishedAPIError(
f"Connection isn't established yet ({self.connection_state})" f"Connection isn't established yet ({self.connection_state})"
) )

View File

@ -20,6 +20,7 @@ from aioesphomeapi.api_pb2 import (
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import ( from aioesphomeapi.core import (
APIConnectionError, APIConnectionError,
ConnectionNotEstablishedAPIError,
HandshakeAPIError, HandshakeAPIError,
InvalidAuthAPIError, InvalidAuthAPIError,
RequiresEncryptionAPIError, RequiresEncryptionAPIError,
@ -609,3 +610,9 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
# We should disconnect if we are getting ping responses # We should disconnect if we are getting ping responses
assert conn.is_connected is True assert conn.is_connected is True
def test_raise_during_send_messages_when_not_yet_connected(conn: APIConnection) -> None:
"""Test that we raise when sending messages before we are connected."""
with pytest.raises(ConnectionNotEstablishedAPIError):
conn.send_message(PingRequest())