Remove in_do_connect contextvar (#652)
This commit is contained in:
parent
f88b15e33b
commit
0afa8c6832
|
@ -42,7 +42,7 @@ cdef object PingFailedAPIError
|
|||
cdef object ReadFailedAPIError
|
||||
cdef object TimeoutAPIError
|
||||
|
||||
cdef object in_do_connect, astuple
|
||||
cdef object astuple
|
||||
|
||||
|
||||
@cython.dataclasses.dataclass
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
|
@ -95,11 +94,6 @@ TCP_CONNECT_TIMEOUT = 60.0
|
|||
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
|
||||
|
||||
|
||||
in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar(
|
||||
"in_do_connect"
|
||||
)
|
||||
|
||||
|
||||
_int = int
|
||||
_bytes = bytes
|
||||
_float = float
|
||||
|
@ -236,11 +230,19 @@ class APIConnection:
|
|||
# If we are being called from do_connect we
|
||||
# need to make sure we don't cancel the task
|
||||
# that called us
|
||||
if self._start_connect_task is not None and not in_do_connect.get(False):
|
||||
current_task = asyncio.current_task()
|
||||
|
||||
if (
|
||||
self._start_connect_task is not None
|
||||
and self._start_connect_task is not current_task
|
||||
):
|
||||
self._start_connect_task.cancel("Connection cleanup")
|
||||
self._start_connect_task = None
|
||||
|
||||
if self._finish_connect_task is not None and not in_do_connect.get(False):
|
||||
if (
|
||||
self._finish_connect_task is not None
|
||||
and self._finish_connect_task is not current_task
|
||||
):
|
||||
self._finish_connect_task.cancel("Connection cleanup")
|
||||
self._finish_connect_task = None
|
||||
|
||||
|
@ -512,7 +514,6 @@ class APIConnection:
|
|||
|
||||
async def _do_connect(self) -> None:
|
||||
"""Do the actual connect process."""
|
||||
in_do_connect.set(True)
|
||||
self.resolved_addr_info = await self._connect_resolve_host()
|
||||
await self._connect_socket_connect(self.resolved_addr_info)
|
||||
|
||||
|
@ -522,7 +523,7 @@ class APIConnection:
|
|||
This part of the process establishes the socket connection but
|
||||
does not initialize the frame helper or send the hello message.
|
||||
"""
|
||||
if self.connection_state != ConnectionState.INITIALIZED:
|
||||
if self.connection_state is not ConnectionState.INITIALIZED:
|
||||
raise ValueError(
|
||||
"Connection can only be used once, connection is not in init state"
|
||||
)
|
||||
|
@ -567,7 +568,6 @@ class APIConnection:
|
|||
|
||||
async def _do_finish_connect(self, login: bool) -> None:
|
||||
"""Finish the connection process."""
|
||||
in_do_connect.set(True)
|
||||
await self._connect_init_frame_helper()
|
||||
self._register_internal_message_handlers()
|
||||
await self._connect_hello_login(login)
|
||||
|
@ -579,7 +579,7 @@ class APIConnection:
|
|||
This part of the process initializes the frame helper and sends the hello message
|
||||
than starts the keep alive process.
|
||||
"""
|
||||
if self.connection_state != ConnectionState.SOCKET_OPENED:
|
||||
if self.connection_state is not ConnectionState.SOCKET_OPENED:
|
||||
raise ValueError(
|
||||
"Connection must be in SOCKET_OPENED state to finish connection"
|
||||
)
|
||||
|
@ -619,11 +619,6 @@ class APIConnection:
|
|||
def send_messages(self, msgs: tuple[message.Message, ...]) -> None:
|
||||
"""Send a protobuf message to the remote."""
|
||||
if not self._handshake_complete:
|
||||
if in_do_connect.get(False):
|
||||
# If we are in the do_connect task, we can't raise an error
|
||||
# because it would obscure the original exception (ie encrypt error).
|
||||
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
|
||||
return
|
||||
raise ConnectionNotEstablishedAPIError(
|
||||
f"Connection isn't established yet ({self.connection_state})"
|
||||
)
|
||||
|
|
|
@ -20,6 +20,7 @@ from aioesphomeapi.api_pb2 import (
|
|||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||
from aioesphomeapi.core import (
|
||||
APIConnectionError,
|
||||
ConnectionNotEstablishedAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
|
@ -609,3 +610,9 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
|
|||
|
||||
# We should disconnect if we are getting ping responses
|
||||
assert conn.is_connected is True
|
||||
|
||||
|
||||
def test_raise_during_send_messages_when_not_yet_connected(conn: APIConnection) -> None:
|
||||
"""Test that we raise when sending messages before we are connected."""
|
||||
with pytest.raises(ConnectionNotEstablishedAPIError):
|
||||
conn.send_message(PingRequest())
|
||||
|
|
Loading…
Reference in New Issue