Cleanups to connect process (#485)

This commit is contained in:
J. Nick Koston 2023-07-17 10:11:34 -10:00 committed by GitHub
parent 92ec96469d
commit 0c1f710869
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -92,6 +92,7 @@ CONNECT_AND_SETUP_TIMEOUT = 120.0
# the esp device
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar(
"in_do_connect"
)
@ -119,6 +120,9 @@ class ConnectionState(enum.Enum):
CLOSED = 3
OPEN_STATES = {ConnectionState.SOCKET_OPENED, ConnectionState.CONNECTED}
class APIConnection:
"""This class represents _one_ connection to a remote native API device.
@ -132,7 +136,7 @@ class APIConnection:
"_on_stop_task",
"_socket",
"_frame_helper",
"_api_version",
"api_version",
"_connection_state",
"_connect_complete",
"_message_handlers",
@ -149,6 +153,7 @@ class APIConnection:
"_send_pending_ping",
"is_connected",
"is_authenticated",
"_is_socket_open",
)
def __init__(
@ -162,7 +167,7 @@ class APIConnection:
self._on_stop_task: Optional[asyncio.Task[None]] = None
self._socket: Optional[socket.socket] = None
self._frame_helper: Optional[APIFrameHelper] = None
self._api_version: Optional[APIVersion] = None
self.api_version: Optional[APIVersion] = None
self._connection_state = ConnectionState.INITIALIZED
# Store whether connect() has completed
@ -189,6 +194,7 @@ class APIConnection:
self._loop = asyncio.get_event_loop()
self.is_connected = False
self.is_authenticated = False
self._is_socket_open = False
@property
def connection_state(self) -> ConnectionState:
@ -278,13 +284,14 @@ class APIConnection:
err,
)
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
)
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self.log_name,
self._params.address,
self._params.port,
addr,
)
sockaddr = astuple(addr.sockaddr)
try:
@ -361,15 +368,16 @@ class APIConnection:
resp.api_version_major,
resp.api_version_minor,
)
self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
if self._api_version.major > 2:
api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
if api_version.major > 2:
_LOGGER.error(
"%s: Incompatible version %s! Closing connection",
self.log_name,
self._api_version.major,
api_version.major,
)
raise APIConnectionError("Incompatible API version.")
self.api_version = api_version
if (
self._params.expected_name is not None
and resp.name != ""
@ -440,26 +448,25 @@ class APIConnection:
)
)
async def _do_connect(self, login: bool) -> None:
"""Do the actual connect process."""
in_do_connect.set(True)
addr = await self._connect_resolve_host()
await self._connect_socket_connect(addr)
await self._connect_init_frame_helper()
await self._connect_hello()
if login:
await self.login(check_connected=False)
self._async_schedule_keep_alive()
async def connect(self, *, login: bool) -> None:
if self._connection_state != ConnectionState.INITIALIZED:
raise ValueError(
"Connection can only be used once, connection is not in init state"
)
async def _do_connect() -> None:
in_do_connect.set(True)
addr = await self._connect_resolve_host()
await self._connect_socket_connect(addr)
await self._connect_init_frame_helper()
await self._connect_hello()
self._async_schedule_keep_alive()
if login:
await self.login(check_connected=False)
self._connect_task = asyncio.create_task(
_do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
self._do_connect(login), name=f"{self.log_name}: aioesphomeapi do_connect"
)
try:
# Allow 2 minutes for connect and setup; this is only as a last measure
# to protect from issues if some part of the connect process mistakenly
@ -486,10 +493,11 @@ class APIConnection:
"""Set the connection state and log the change."""
self._connection_state = state
self.is_connected = state == ConnectionState.CONNECTED
self._is_socket_open = state in OPEN_STATES
async def login(self, check_connected: bool = True) -> None:
"""Send a login (ConnectRequest) and await the response."""
if check_connected and self._connection_state != ConnectionState.CONNECTED:
if check_connected and self.is_connected:
# On first connect, we don't want to check if we're connected
# because we don't set the connection state until after login
# is complete
@ -517,19 +525,9 @@ class APIConnection:
self.is_authenticated = True
@property
def _is_socket_open(self) -> bool:
return self._connection_state in (
ConnectionState.SOCKET_OPENED,
ConnectionState.CONNECTED,
)
def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote."""
if self._connection_state not in (
ConnectionState.SOCKET_OPENED,
ConnectionState.CONNECTED,
):
if not self._is_socket_open:
if in_do_connect.get(False):
# If we are in the do_connect task, we can't raise an error
# because it would obscure the original exception (ie encrypt error).
@ -720,6 +718,7 @@ class APIConnection:
is_enabled_for = _LOGGER.isEnabledFor
logging_debug = logging.DEBUG
message_handlers = self._message_handlers
internal_message_types = INTERNAL_MESSAGE_TYPES
def _process_packet(msg_type_proto: int, data: bytes) -> None:
"""Process a packet from the socket."""
@ -785,7 +784,7 @@ class APIConnection:
# Pre-check the message type to avoid awaiting
# since most messages are not internal messages
if msg_type not in INTERNAL_MESSAGE_TYPES:
if msg_type not in internal_message_types:
return
if msg_type is DisconnectRequest:
@ -830,7 +829,3 @@ class APIConnection:
self._set_connection_state(ConnectionState.CLOSED)
self._expected_disconnect = True
self._cleanup()
@property
def api_version(self) -> Optional[APIVersion]:
return self._api_version