Remove in_do_connect contextvar (#652)
This commit is contained in:
parent
f88b15e33b
commit
0afa8c6832
|
@ -42,7 +42,7 @@ cdef object PingFailedAPIError
|
||||||
cdef object ReadFailedAPIError
|
cdef object ReadFailedAPIError
|
||||||
cdef object TimeoutAPIError
|
cdef object TimeoutAPIError
|
||||||
|
|
||||||
cdef object in_do_connect, astuple
|
cdef object astuple
|
||||||
|
|
||||||
|
|
||||||
@cython.dataclasses.dataclass
|
@cython.dataclasses.dataclass
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
|
@ -95,11 +94,6 @@ TCP_CONNECT_TIMEOUT = 60.0
|
||||||
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
|
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
|
||||||
|
|
||||||
|
|
||||||
in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar(
|
|
||||||
"in_do_connect"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_int = int
|
_int = int
|
||||||
_bytes = bytes
|
_bytes = bytes
|
||||||
_float = float
|
_float = float
|
||||||
|
@ -236,11 +230,19 @@ class APIConnection:
|
||||||
# If we are being called from do_connect we
|
# If we are being called from do_connect we
|
||||||
# need to make sure we don't cancel the task
|
# need to make sure we don't cancel the task
|
||||||
# that called us
|
# that called us
|
||||||
if self._start_connect_task is not None and not in_do_connect.get(False):
|
current_task = asyncio.current_task()
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._start_connect_task is not None
|
||||||
|
and self._start_connect_task is not current_task
|
||||||
|
):
|
||||||
self._start_connect_task.cancel("Connection cleanup")
|
self._start_connect_task.cancel("Connection cleanup")
|
||||||
self._start_connect_task = None
|
self._start_connect_task = None
|
||||||
|
|
||||||
if self._finish_connect_task is not None and not in_do_connect.get(False):
|
if (
|
||||||
|
self._finish_connect_task is not None
|
||||||
|
and self._finish_connect_task is not current_task
|
||||||
|
):
|
||||||
self._finish_connect_task.cancel("Connection cleanup")
|
self._finish_connect_task.cancel("Connection cleanup")
|
||||||
self._finish_connect_task = None
|
self._finish_connect_task = None
|
||||||
|
|
||||||
|
@ -512,7 +514,6 @@ class APIConnection:
|
||||||
|
|
||||||
async def _do_connect(self) -> None:
|
async def _do_connect(self) -> None:
|
||||||
"""Do the actual connect process."""
|
"""Do the actual connect process."""
|
||||||
in_do_connect.set(True)
|
|
||||||
self.resolved_addr_info = await self._connect_resolve_host()
|
self.resolved_addr_info = await self._connect_resolve_host()
|
||||||
await self._connect_socket_connect(self.resolved_addr_info)
|
await self._connect_socket_connect(self.resolved_addr_info)
|
||||||
|
|
||||||
|
@ -522,7 +523,7 @@ class APIConnection:
|
||||||
This part of the process establishes the socket connection but
|
This part of the process establishes the socket connection but
|
||||||
does not initialize the frame helper or send the hello message.
|
does not initialize the frame helper or send the hello message.
|
||||||
"""
|
"""
|
||||||
if self.connection_state != ConnectionState.INITIALIZED:
|
if self.connection_state is not ConnectionState.INITIALIZED:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Connection can only be used once, connection is not in init state"
|
"Connection can only be used once, connection is not in init state"
|
||||||
)
|
)
|
||||||
|
@ -567,7 +568,6 @@ class APIConnection:
|
||||||
|
|
||||||
async def _do_finish_connect(self, login: bool) -> None:
|
async def _do_finish_connect(self, login: bool) -> None:
|
||||||
"""Finish the connection process."""
|
"""Finish the connection process."""
|
||||||
in_do_connect.set(True)
|
|
||||||
await self._connect_init_frame_helper()
|
await self._connect_init_frame_helper()
|
||||||
self._register_internal_message_handlers()
|
self._register_internal_message_handlers()
|
||||||
await self._connect_hello_login(login)
|
await self._connect_hello_login(login)
|
||||||
|
@ -579,7 +579,7 @@ class APIConnection:
|
||||||
This part of the process initializes the frame helper and sends the hello message
|
This part of the process initializes the frame helper and sends the hello message
|
||||||
than starts the keep alive process.
|
than starts the keep alive process.
|
||||||
"""
|
"""
|
||||||
if self.connection_state != ConnectionState.SOCKET_OPENED:
|
if self.connection_state is not ConnectionState.SOCKET_OPENED:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Connection must be in SOCKET_OPENED state to finish connection"
|
"Connection must be in SOCKET_OPENED state to finish connection"
|
||||||
)
|
)
|
||||||
|
@ -619,11 +619,6 @@ class APIConnection:
|
||||||
def send_messages(self, msgs: tuple[message.Message, ...]) -> None:
|
def send_messages(self, msgs: tuple[message.Message, ...]) -> None:
|
||||||
"""Send a protobuf message to the remote."""
|
"""Send a protobuf message to the remote."""
|
||||||
if not self._handshake_complete:
|
if not self._handshake_complete:
|
||||||
if in_do_connect.get(False):
|
|
||||||
# If we are in the do_connect task, we can't raise an error
|
|
||||||
# because it would obscure the original exception (ie encrypt error).
|
|
||||||
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
|
|
||||||
return
|
|
||||||
raise ConnectionNotEstablishedAPIError(
|
raise ConnectionNotEstablishedAPIError(
|
||||||
f"Connection isn't established yet ({self.connection_state})"
|
f"Connection isn't established yet ({self.connection_state})"
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from aioesphomeapi.api_pb2 import (
|
||||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||||
from aioesphomeapi.core import (
|
from aioesphomeapi.core import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
|
ConnectionNotEstablishedAPIError,
|
||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidAuthAPIError,
|
InvalidAuthAPIError,
|
||||||
RequiresEncryptionAPIError,
|
RequiresEncryptionAPIError,
|
||||||
|
@ -609,3 +610,9 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
|
||||||
|
|
||||||
# We should disconnect if we are getting ping responses
|
# We should disconnect if we are getting ping responses
|
||||||
assert conn.is_connected is True
|
assert conn.is_connected is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_raise_during_send_messages_when_not_yet_connected(conn: APIConnection) -> None:
|
||||||
|
"""Test that we raise when sending messages before we are connected."""
|
||||||
|
with pytest.raises(ConnectionNotEstablishedAPIError):
|
||||||
|
conn.send_message(PingRequest())
|
||||||
|
|
Loading…
Reference in New Issue