mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-12 20:10:42 +01:00
Improve error reporting for authenticated vs non-authenticated requests (#481)
This commit is contained in:
parent
24cddc22a8
commit
ed0a611994
@ -316,30 +316,33 @@ class APIClient:
|
||||
else:
|
||||
await self._connection.disconnect()
|
||||
|
||||
def _check_connected(self) -> None:
|
||||
if self._connection is None:
|
||||
raise APIConnectionError(f"Not connected to {self._log_name}!")
|
||||
if not self._connection.is_connected:
|
||||
raise APIConnectionError(
|
||||
f"Connection not done for {self._log_name}; "
|
||||
f"current state is {self._connection.connection_state}!"
|
||||
)
|
||||
|
||||
def _check_authenticated(self) -> None:
|
||||
self._check_connected()
|
||||
assert self._connection is not None
|
||||
if not self._connection.is_authenticated:
|
||||
connection = self._connection
|
||||
if not connection:
|
||||
raise APIConnectionError(f"Not connected to {self._log_name}!")
|
||||
if not connection.is_connected:
|
||||
raise APIConnectionError(
|
||||
f"Authenticated connection not ready yet for {self._log_name}; "
|
||||
f"current state is {connection.connection_state}!"
|
||||
)
|
||||
if not connection.is_authenticated:
|
||||
raise APIConnectionError(f"Not authenticated for {self._log_name}!")
|
||||
|
||||
async def device_info(self) -> DeviceInfo:
|
||||
self._check_connected()
|
||||
assert self._connection is not None
|
||||
resp = await self._connection.send_message_await_response(
|
||||
connection = self._connection
|
||||
if not connection:
|
||||
raise APIConnectionError(f"Not connected to {self._log_name}!")
|
||||
if not connection or not connection.is_connected:
|
||||
raise APIConnectionError(
|
||||
f"Connection not ready yet for {self._log_name}; "
|
||||
f"current state is {connection.connection_state}!"
|
||||
)
|
||||
resp = await connection.send_message_await_response(
|
||||
DeviceInfoRequest(), DeviceInfoResponse
|
||||
)
|
||||
info = DeviceInfo.from_pb(resp)
|
||||
self._cached_name = info.name
|
||||
self._connection.set_log_name(self._log_name)
|
||||
connection.set_log_name(self._log_name)
|
||||
return info
|
||||
|
||||
async def list_entities_services(
|
||||
|
@ -134,7 +134,6 @@ class APIConnection:
|
||||
"_frame_helper",
|
||||
"_api_version",
|
||||
"_connection_state",
|
||||
"_is_authenticated",
|
||||
"_connect_complete",
|
||||
"_message_handlers",
|
||||
"log_name",
|
||||
@ -148,6 +147,8 @@ class APIConnection:
|
||||
"_expected_disconnect",
|
||||
"_loop",
|
||||
"_send_pending_ping",
|
||||
"is_connected",
|
||||
"is_authenticated",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -164,7 +165,6 @@ class APIConnection:
|
||||
self._api_version: Optional[APIVersion] = None
|
||||
|
||||
self._connection_state = ConnectionState.INITIALIZED
|
||||
self._is_authenticated = False
|
||||
# Store whether connect() has completed
|
||||
# Used so that on_stop is _not_ called if an error occurs during connect()
|
||||
self._connect_complete = False
|
||||
@ -187,6 +187,8 @@ class APIConnection:
|
||||
self._expected_disconnect = False
|
||||
self._send_pending_ping = False
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self.is_connected = False
|
||||
self.is_authenticated = False
|
||||
|
||||
@property
|
||||
def connection_state(self) -> ConnectionState:
|
||||
@ -330,7 +332,7 @@ class APIConnection:
|
||||
)
|
||||
|
||||
self._frame_helper = fh
|
||||
self._connection_state = ConnectionState.SOCKET_OPENED
|
||||
self._set_connection_state(ConnectionState.SOCKET_OPENED)
|
||||
try:
|
||||
async with async_timeout.timeout(HANDSHAKE_TIMEOUT):
|
||||
await fh.perform_handshake()
|
||||
@ -466,19 +468,24 @@ class APIConnection:
|
||||
except asyncio.CancelledError:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._cleanup()
|
||||
raise self._fatal_exception or APIConnectionError("Connection cancelled")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Always clean up the connection if an error occurred during connect
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._cleanup()
|
||||
raise
|
||||
|
||||
self._connect_task = None
|
||||
self._connection_state = ConnectionState.CONNECTED
|
||||
self._set_connection_state(ConnectionState.CONNECTED)
|
||||
self._connect_complete = True
|
||||
|
||||
def _set_connection_state(self, state: ConnectionState) -> None:
|
||||
"""Set the connection state and log the change."""
|
||||
self._connection_state = state
|
||||
self.is_connected = state == ConnectionState.CONNECTED
|
||||
|
||||
async def login(self, check_connected: bool = True) -> None:
|
||||
"""Send a login (ConnectRequest) and await the response."""
|
||||
if check_connected and self._connection_state != ConnectionState.CONNECTED:
|
||||
@ -486,7 +493,7 @@ class APIConnection:
|
||||
# because we don't set the connection state until after login
|
||||
# is complete
|
||||
raise APIConnectionError("Must be connected!")
|
||||
if self._is_authenticated:
|
||||
if self.is_authenticated:
|
||||
raise APIConnectionError("Already logged in!")
|
||||
|
||||
connect = ConnectRequest()
|
||||
@ -507,7 +514,7 @@ class APIConnection:
|
||||
if resp.invalid_password:
|
||||
raise InvalidAuthAPIError("Invalid password!")
|
||||
|
||||
self._is_authenticated = True
|
||||
self.is_authenticated = True
|
||||
|
||||
@property
|
||||
def _is_socket_open(self) -> bool:
|
||||
@ -516,14 +523,6 @@ class APIConnection:
|
||||
ConnectionState.CONNECTED,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection_state == ConnectionState.CONNECTED
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return self.is_connected and self._is_authenticated
|
||||
|
||||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a protobuf message to the remote."""
|
||||
if self._connection_state not in (
|
||||
@ -702,7 +701,7 @@ class APIConnection:
|
||||
exc_info=not str(err), # Log the full stack on empty error string
|
||||
)
|
||||
self._fatal_exception = err
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
for fut in self._read_exception_futures:
|
||||
if fut.done():
|
||||
continue
|
||||
@ -785,7 +784,7 @@ class APIConnection:
|
||||
|
||||
if msg_type is DisconnectRequest:
|
||||
self.send_message(DisconnectResponse())
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._expected_disconnect = True
|
||||
self._cleanup()
|
||||
elif msg_type is PingRequest:
|
||||
@ -818,11 +817,11 @@ class APIConnection:
|
||||
except APIConnectionError:
|
||||
pass
|
||||
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._cleanup()
|
||||
|
||||
async def force_disconnect(self) -> None:
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._expected_disconnect = True
|
||||
self._cleanup()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user