mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-04 09:19:37 +01:00
Cleanups to connect process (#485)
This commit is contained in:
parent
92ec96469d
commit
0c1f710869
@ -92,6 +92,7 @@ CONNECT_AND_SETUP_TIMEOUT = 120.0
|
||||
# the esp device
|
||||
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
|
||||
|
||||
|
||||
in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar(
|
||||
"in_do_connect"
|
||||
)
|
||||
@ -119,6 +120,9 @@ class ConnectionState(enum.Enum):
|
||||
CLOSED = 3
|
||||
|
||||
|
||||
OPEN_STATES = {ConnectionState.SOCKET_OPENED, ConnectionState.CONNECTED}
|
||||
|
||||
|
||||
class APIConnection:
|
||||
"""This class represents _one_ connection to a remote native API device.
|
||||
|
||||
@ -132,7 +136,7 @@ class APIConnection:
|
||||
"_on_stop_task",
|
||||
"_socket",
|
||||
"_frame_helper",
|
||||
"_api_version",
|
||||
"api_version",
|
||||
"_connection_state",
|
||||
"_connect_complete",
|
||||
"_message_handlers",
|
||||
@ -149,6 +153,7 @@ class APIConnection:
|
||||
"_send_pending_ping",
|
||||
"is_connected",
|
||||
"is_authenticated",
|
||||
"_is_socket_open",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -162,7 +167,7 @@ class APIConnection:
|
||||
self._on_stop_task: Optional[asyncio.Task[None]] = None
|
||||
self._socket: Optional[socket.socket] = None
|
||||
self._frame_helper: Optional[APIFrameHelper] = None
|
||||
self._api_version: Optional[APIVersion] = None
|
||||
self.api_version: Optional[APIVersion] = None
|
||||
|
||||
self._connection_state = ConnectionState.INITIALIZED
|
||||
# Store whether connect() has completed
|
||||
@ -189,6 +194,7 @@ class APIConnection:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self.is_connected = False
|
||||
self.is_authenticated = False
|
||||
self._is_socket_open = False
|
||||
|
||||
@property
|
||||
def connection_state(self) -> ConnectionState:
|
||||
@ -278,13 +284,14 @@ class APIConnection:
|
||||
err,
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s: Connecting to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
addr,
|
||||
)
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"%s: Connecting to %s:%s (%s)",
|
||||
self.log_name,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
addr,
|
||||
)
|
||||
sockaddr = astuple(addr.sockaddr)
|
||||
|
||||
try:
|
||||
@ -361,15 +368,16 @@ class APIConnection:
|
||||
resp.api_version_major,
|
||||
resp.api_version_minor,
|
||||
)
|
||||
self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
|
||||
if self._api_version.major > 2:
|
||||
api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
|
||||
if api_version.major > 2:
|
||||
_LOGGER.error(
|
||||
"%s: Incompatible version %s! Closing connection",
|
||||
self.log_name,
|
||||
self._api_version.major,
|
||||
api_version.major,
|
||||
)
|
||||
raise APIConnectionError("Incompatible API version.")
|
||||
|
||||
self.api_version = api_version
|
||||
if (
|
||||
self._params.expected_name is not None
|
||||
and resp.name != ""
|
||||
@ -440,26 +448,25 @@ class APIConnection:
|
||||
)
|
||||
)
|
||||
|
||||
async def _do_connect(self, login: bool) -> None:
|
||||
"""Do the actual connect process."""
|
||||
in_do_connect.set(True)
|
||||
addr = await self._connect_resolve_host()
|
||||
await self._connect_socket_connect(addr)
|
||||
await self._connect_init_frame_helper()
|
||||
await self._connect_hello()
|
||||
if login:
|
||||
await self.login(check_connected=False)
|
||||
self._async_schedule_keep_alive()
|
||||
|
||||
async def connect(self, *, login: bool) -> None:
|
||||
if self._connection_state != ConnectionState.INITIALIZED:
|
||||
raise ValueError(
|
||||
"Connection can only be used once, connection is not in init state"
|
||||
)
|
||||
|
||||
async def _do_connect() -> None:
|
||||
in_do_connect.set(True)
|
||||
addr = await self._connect_resolve_host()
|
||||
await self._connect_socket_connect(addr)
|
||||
await self._connect_init_frame_helper()
|
||||
await self._connect_hello()
|
||||
self._async_schedule_keep_alive()
|
||||
if login:
|
||||
await self.login(check_connected=False)
|
||||
|
||||
self._connect_task = asyncio.create_task(
|
||||
_do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
|
||||
self._do_connect(login), name=f"{self.log_name}: aioesphomeapi do_connect"
|
||||
)
|
||||
|
||||
try:
|
||||
# Allow 2 minutes for connect and setup; this is only as a last measure
|
||||
# to protect from issues if some part of the connect process mistakenly
|
||||
@ -486,10 +493,11 @@ class APIConnection:
|
||||
"""Set the connection state and log the change."""
|
||||
self._connection_state = state
|
||||
self.is_connected = state == ConnectionState.CONNECTED
|
||||
self._is_socket_open = state in OPEN_STATES
|
||||
|
||||
async def login(self, check_connected: bool = True) -> None:
|
||||
"""Send a login (ConnectRequest) and await the response."""
|
||||
if check_connected and self._connection_state != ConnectionState.CONNECTED:
|
||||
if check_connected and self.is_connected:
|
||||
# On first connect, we don't want to check if we're connected
|
||||
# because we don't set the connection state until after login
|
||||
# is complete
|
||||
@ -517,19 +525,9 @@ class APIConnection:
|
||||
|
||||
self.is_authenticated = True
|
||||
|
||||
@property
|
||||
def _is_socket_open(self) -> bool:
|
||||
return self._connection_state in (
|
||||
ConnectionState.SOCKET_OPENED,
|
||||
ConnectionState.CONNECTED,
|
||||
)
|
||||
|
||||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a protobuf message to the remote."""
|
||||
if self._connection_state not in (
|
||||
ConnectionState.SOCKET_OPENED,
|
||||
ConnectionState.CONNECTED,
|
||||
):
|
||||
if not self._is_socket_open:
|
||||
if in_do_connect.get(False):
|
||||
# If we are in the do_connect task, we can't raise an error
|
||||
# because it would obscure the original exception (ie encrypt error).
|
||||
@ -720,6 +718,7 @@ class APIConnection:
|
||||
is_enabled_for = _LOGGER.isEnabledFor
|
||||
logging_debug = logging.DEBUG
|
||||
message_handlers = self._message_handlers
|
||||
internal_message_types = INTERNAL_MESSAGE_TYPES
|
||||
|
||||
def _process_packet(msg_type_proto: int, data: bytes) -> None:
|
||||
"""Process a packet from the socket."""
|
||||
@ -785,7 +784,7 @@ class APIConnection:
|
||||
|
||||
# Pre-check the message type to avoid awaiting
|
||||
# since most messages are not internal messages
|
||||
if msg_type not in INTERNAL_MESSAGE_TYPES:
|
||||
if msg_type not in internal_message_types:
|
||||
return
|
||||
|
||||
if msg_type is DisconnectRequest:
|
||||
@ -830,7 +829,3 @@ class APIConnection:
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._expected_disconnect = True
|
||||
self._cleanup()
|
||||
|
||||
@property
|
||||
def api_version(self) -> Optional[APIVersion]:
|
||||
return self._api_version
|
||||
|
Loading…
Reference in New Issue
Block a user