Improve error reporting for authenticated vs non-authenticated requests (#481)

This commit is contained in:
J. Nick Koston 2023-07-15 10:34:46 -10:00 committed by GitHub
parent 24cddc22a8
commit ed0a611994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 36 deletions

View File

@ -316,30 +316,33 @@ class APIClient:
else: else:
await self._connection.disconnect() await self._connection.disconnect()
def _check_connected(self) -> None:
if self._connection is None:
raise APIConnectionError(f"Not connected to {self._log_name}!")
if not self._connection.is_connected:
raise APIConnectionError(
f"Connection not done for {self._log_name}; "
f"current state is {self._connection.connection_state}!"
)
def _check_authenticated(self) -> None: def _check_authenticated(self) -> None:
self._check_connected() connection = self._connection
assert self._connection is not None if not connection:
if not self._connection.is_authenticated: raise APIConnectionError(f"Not connected to {self._log_name}!")
if not connection.is_connected:
raise APIConnectionError(
f"Authenticated connection not ready yet for {self._log_name}; "
f"current state is {connection.connection_state}!"
)
if not connection.is_authenticated:
raise APIConnectionError(f"Not authenticated for {self._log_name}!") raise APIConnectionError(f"Not authenticated for {self._log_name}!")
async def device_info(self) -> DeviceInfo: async def device_info(self) -> DeviceInfo:
self._check_connected() connection = self._connection
assert self._connection is not None if not connection:
resp = await self._connection.send_message_await_response( raise APIConnectionError(f"Not connected to {self._log_name}!")
if not connection or not connection.is_connected:
raise APIConnectionError(
f"Connection not ready yet for {self._log_name}; "
f"current state is {connection.connection_state}!"
)
resp = await connection.send_message_await_response(
DeviceInfoRequest(), DeviceInfoResponse DeviceInfoRequest(), DeviceInfoResponse
) )
info = DeviceInfo.from_pb(resp) info = DeviceInfo.from_pb(resp)
self._cached_name = info.name self._cached_name = info.name
self._connection.set_log_name(self._log_name) connection.set_log_name(self._log_name)
return info return info
async def list_entities_services( async def list_entities_services(

View File

@ -134,7 +134,6 @@ class APIConnection:
"_frame_helper", "_frame_helper",
"_api_version", "_api_version",
"_connection_state", "_connection_state",
"_is_authenticated",
"_connect_complete", "_connect_complete",
"_message_handlers", "_message_handlers",
"log_name", "log_name",
@ -148,6 +147,8 @@ class APIConnection:
"_expected_disconnect", "_expected_disconnect",
"_loop", "_loop",
"_send_pending_ping", "_send_pending_ping",
"is_connected",
"is_authenticated",
) )
def __init__( def __init__(
@ -164,7 +165,6 @@ class APIConnection:
self._api_version: Optional[APIVersion] = None self._api_version: Optional[APIVersion] = None
self._connection_state = ConnectionState.INITIALIZED self._connection_state = ConnectionState.INITIALIZED
self._is_authenticated = False
# Store whether connect() has completed # Store whether connect() has completed
# Used so that on_stop is _not_ called if an error occurs during connect() # Used so that on_stop is _not_ called if an error occurs during connect()
self._connect_complete = False self._connect_complete = False
@ -187,6 +187,8 @@ class APIConnection:
self._expected_disconnect = False self._expected_disconnect = False
self._send_pending_ping = False self._send_pending_ping = False
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self.is_connected = False
self.is_authenticated = False
@property @property
def connection_state(self) -> ConnectionState: def connection_state(self) -> ConnectionState:
@ -330,7 +332,7 @@ class APIConnection:
) )
self._frame_helper = fh self._frame_helper = fh
self._connection_state = ConnectionState.SOCKET_OPENED self._set_connection_state(ConnectionState.SOCKET_OPENED)
try: try:
async with async_timeout.timeout(HANDSHAKE_TIMEOUT): async with async_timeout.timeout(HANDSHAKE_TIMEOUT):
await fh.perform_handshake() await fh.perform_handshake()
@ -466,19 +468,24 @@ class APIConnection:
except asyncio.CancelledError: except asyncio.CancelledError:
# If the task was cancelled, we need to clean up the connection # If the task was cancelled, we need to clean up the connection
# and raise the CancelledError # and raise the CancelledError
self._connection_state = ConnectionState.CLOSED self._set_connection_state(ConnectionState.CLOSED)
self._cleanup() self._cleanup()
raise self._fatal_exception or APIConnectionError("Connection cancelled") raise self._fatal_exception or APIConnectionError("Connection cancelled")
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
# Always clean up the connection if an error occurred during connect # Always clean up the connection if an error occurred during connect
self._connection_state = ConnectionState.CLOSED self._set_connection_state(ConnectionState.CLOSED)
self._cleanup() self._cleanup()
raise raise
self._connect_task = None self._connect_task = None
self._connection_state = ConnectionState.CONNECTED self._set_connection_state(ConnectionState.CONNECTED)
self._connect_complete = True self._connect_complete = True
def _set_connection_state(self, state: ConnectionState) -> None:
"""Set the connection state and log the change."""
self._connection_state = state
self.is_connected = state == ConnectionState.CONNECTED
async def login(self, check_connected: bool = True) -> None: async def login(self, check_connected: bool = True) -> None:
"""Send a login (ConnectRequest) and await the response.""" """Send a login (ConnectRequest) and await the response."""
if check_connected and self._connection_state != ConnectionState.CONNECTED: if check_connected and self._connection_state != ConnectionState.CONNECTED:
@ -486,7 +493,7 @@ class APIConnection:
# because we don't set the connection state until after login # because we don't set the connection state until after login
# is complete # is complete
raise APIConnectionError("Must be connected!") raise APIConnectionError("Must be connected!")
if self._is_authenticated: if self.is_authenticated:
raise APIConnectionError("Already logged in!") raise APIConnectionError("Already logged in!")
connect = ConnectRequest() connect = ConnectRequest()
@ -507,7 +514,7 @@ class APIConnection:
if resp.invalid_password: if resp.invalid_password:
raise InvalidAuthAPIError("Invalid password!") raise InvalidAuthAPIError("Invalid password!")
self._is_authenticated = True self.is_authenticated = True
@property @property
def _is_socket_open(self) -> bool: def _is_socket_open(self) -> bool:
@ -516,14 +523,6 @@ class APIConnection:
ConnectionState.CONNECTED, ConnectionState.CONNECTED,
) )
@property
def is_connected(self) -> bool:
return self._connection_state == ConnectionState.CONNECTED
@property
def is_authenticated(self) -> bool:
return self.is_connected and self._is_authenticated
def send_message(self, msg: message.Message) -> None: def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote.""" """Send a protobuf message to the remote."""
if self._connection_state not in ( if self._connection_state not in (
@ -702,7 +701,7 @@ class APIConnection:
exc_info=not str(err), # Log the full stack on empty error string exc_info=not str(err), # Log the full stack on empty error string
) )
self._fatal_exception = err self._fatal_exception = err
self._connection_state = ConnectionState.CLOSED self._set_connection_state(ConnectionState.CLOSED)
for fut in self._read_exception_futures: for fut in self._read_exception_futures:
if fut.done(): if fut.done():
continue continue
@ -785,7 +784,7 @@ class APIConnection:
if msg_type is DisconnectRequest: if msg_type is DisconnectRequest:
self.send_message(DisconnectResponse()) self.send_message(DisconnectResponse())
self._connection_state = ConnectionState.CLOSED self._set_connection_state(ConnectionState.CLOSED)
self._expected_disconnect = True self._expected_disconnect = True
self._cleanup() self._cleanup()
elif msg_type is PingRequest: elif msg_type is PingRequest:
@ -818,11 +817,11 @@ class APIConnection:
except APIConnectionError: except APIConnectionError:
pass pass
self._connection_state = ConnectionState.CLOSED self._set_connection_state(ConnectionState.CLOSED)
self._cleanup() self._cleanup()
async def force_disconnect(self) -> None: async def force_disconnect(self) -> None:
self._connection_state = ConnectionState.CLOSED self._set_connection_state(ConnectionState.CLOSED)
self._expected_disconnect = True self._expected_disconnect = True
self._cleanup() self._cleanup()