Remove in_do_connect contextvar (#652)

This commit is contained in:
J. Nick Koston 2023-11-21 15:36:43 +01:00 committed by GitHub
parent f88b15e33b
commit 0afa8c6832
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 18 deletions

View File

@ -42,7 +42,7 @@ cdef object PingFailedAPIError
cdef object ReadFailedAPIError
cdef object TimeoutAPIError
cdef object in_do_connect, astuple
cdef object astuple
@cython.dataclasses.dataclass

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import contextvars
import enum
import logging
import socket
@ -95,11 +94,6 @@ TCP_CONNECT_TIMEOUT = 60.0
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar(
"in_do_connect"
)
_int = int
_bytes = bytes
_float = float
@ -236,11 +230,19 @@ class APIConnection:
# If we are being called from do_connect we
# need to make sure we don't cancel the task
# that called us
if self._start_connect_task is not None and not in_do_connect.get(False):
current_task = asyncio.current_task()
if (
self._start_connect_task is not None
and self._start_connect_task is not current_task
):
self._start_connect_task.cancel("Connection cleanup")
self._start_connect_task = None
if self._finish_connect_task is not None and not in_do_connect.get(False):
if (
self._finish_connect_task is not None
and self._finish_connect_task is not current_task
):
self._finish_connect_task.cancel("Connection cleanup")
self._finish_connect_task = None
@ -512,7 +514,6 @@ class APIConnection:
async def _do_connect(self) -> None:
"""Do the actual connect process."""
in_do_connect.set(True)
self.resolved_addr_info = await self._connect_resolve_host()
await self._connect_socket_connect(self.resolved_addr_info)
@ -522,7 +523,7 @@ class APIConnection:
This part of the process establishes the socket connection but
does not initialize the frame helper or send the hello message.
"""
if self.connection_state != ConnectionState.INITIALIZED:
if self.connection_state is not ConnectionState.INITIALIZED:
raise ValueError(
"Connection can only be used once, connection is not in init state"
)
@ -567,7 +568,6 @@ class APIConnection:
async def _do_finish_connect(self, login: bool) -> None:
"""Finish the connection process."""
in_do_connect.set(True)
await self._connect_init_frame_helper()
self._register_internal_message_handlers()
await self._connect_hello_login(login)
@ -579,7 +579,7 @@ class APIConnection:
This part of the process initializes the frame helper and sends the hello message
than starts the keep alive process.
"""
if self.connection_state != ConnectionState.SOCKET_OPENED:
if self.connection_state is not ConnectionState.SOCKET_OPENED:
raise ValueError(
"Connection must be in SOCKET_OPENED state to finish connection"
)
@ -619,11 +619,6 @@ class APIConnection:
def send_messages(self, msgs: tuple[message.Message, ...]) -> None:
"""Send a protobuf message to the remote."""
if not self._handshake_complete:
if in_do_connect.get(False):
# If we are in the do_connect task, we can't raise an error
# because it would obscure the original exception (ie encrypt error).
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
return
raise ConnectionNotEstablishedAPIError(
f"Connection isn't established yet ({self.connection_state})"
)

View File

@ -20,6 +20,7 @@ from aioesphomeapi.api_pb2 import (
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import (
APIConnectionError,
ConnectionNotEstablishedAPIError,
HandshakeAPIError,
InvalidAuthAPIError,
RequiresEncryptionAPIError,
@ -609,3 +610,9 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
# We should disconnect if we are getting ping responses
assert conn.is_connected is True
def test_raise_during_send_messages_when_not_yet_connected(conn: APIConnection) -> None:
"""Test that we raise when sending messages before we are connected."""
with pytest.raises(ConnectionNotEstablishedAPIError):
conn.send_message(PingRequest())