From ac4374719605aefc8d57f7e97787792b4994cbed Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Nov 2023 22:22:10 +0100 Subject: [PATCH] Refactor connection checks to return APIConnection to avoid many asserts (#660) --- aioesphomeapi/client.py | 234 ++++++++++++---------------------------- 1 file changed, 67 insertions(+), 167 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index fd30f70..ead3211 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -411,7 +411,7 @@ class APIClient: else: await self._connection.disconnect() - def _check_authenticated(self) -> None: + def _get_authenticated_connection(self) -> APIConnection: connection = self._connection if not connection: raise APIConnectionError(f"Not connected to {self.log_name}!") @@ -420,13 +420,10 @@ class APIClient: f"Authenticated connection not ready yet for {self.log_name}; " f"current state is {connection.connection_state}!" ) + return connection async def device_info(self) -> DeviceInfo: - self._check_authenticated() - connection = self._connection - if TYPE_CHECKING: - assert connection is not None - resp = await connection.send_message_await_response( + resp = await self._get_authenticated_connection().send_message_await_response( DeviceInfoRequest(), DeviceInfoResponse ) info = DeviceInfo.from_pb(resp) @@ -441,7 +438,6 @@ class APIClient: async def list_entities_services( self, ) -> tuple[list[EntityInfo], list[UserService]]: - self._check_authenticated() response_types = LIST_ENTITIES_SERVICES_RESPONSE_TYPES msg_types = LIST_ENTITIES_MSG_TYPES @@ -451,9 +447,7 @@ class APIClient: def do_stop(msg: message.Message) -> bool: return isinstance(msg, ListEntitiesDoneResponse) - if TYPE_CHECKING: - assert self._connection is not None - resp = await self._connection.send_messages_await_response_complex( + resp = await self._get_authenticated_connection().send_messages_await_response_complex( (ListEntitiesRequest(),), do_append, do_stop, msg_types, 60 ) entities: list[EntityInfo] = [] @@ -468,7 +462,6 @@ class APIClient: return entities, services async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None: - self._check_authenticated() image_stream: dict[int, list[bytes]] = {} response_types = SUBSCRIBE_STATES_RESPONSE_TYPES msg_types = SUBSCRIBE_STATES_MSG_TYPES @@ -494,9 +487,7 @@ class APIClient: del image_stream[msg_key] on_state(CameraState(key=msg.key, data=image_data)) # type: ignore[call-arg] - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message_callback_response( + self._get_authenticated_connection().send_message_callback_response( SubscribeStatesRequest(), _on_state_msg, msg_types ) @@ -506,31 +497,24 @@ class APIClient: log_level: LogLevel | None = None, dump_config: bool | None = None, ) -> None: - self._check_authenticated() req = SubscribeLogsRequest() if log_level is not None: req.level = log_level if dump_config is not None: req.dump_config = dump_config - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message_callback_response( + self._get_authenticated_connection().send_message_callback_response( req, on_log, (SubscribeLogsResponse,) ) async def subscribe_service_calls( self, on_service_call: Callable[[HomeassistantServiceCall], None] ) -> None: - self._check_authenticated() - def _on_home_assistant_service_response( msg: HomeassistantServiceResponse, ) -> None: on_service_call(HomeassistantServiceCall.from_pb(msg)) - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message_callback_response( + self._get_authenticated_connection().send_message_callback_response( SubscribeHomeassistantServicesRequest(), _on_home_assistant_service_response, (HomeassistantServiceResponse,), @@ -567,13 +551,10 @@ class APIClient: ), timeout: float = 10.0, ) -> message.Message: - self._check_authenticated() msg_types = (response_type, BluetoothGATTErrorResponse) - if TYPE_CHECKING: - assert self._connection is not None message_filter = partial(self._filter_bluetooth_message, address, handle) - resp = await self._connection.send_messages_await_response_complex( + resp = await self._get_authenticated_connection().send_messages_await_response_complex( (request,), message_filter, message_filter, msg_types, timeout ) @@ -585,7 +566,6 @@ class APIClient: async def subscribe_bluetooth_le_advertisements( self, on_bluetooth_le_advertisement: Callable[[BluetoothLEAdvertisement], None] ) -> Callable[[], None]: - self._check_authenticated() msg_types = (BluetoothLEAdvertisementResponse,) def _on_bluetooth_le_advertising_response( @@ -593,12 +573,12 @@ class APIClient: ) -> None: on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc] - if TYPE_CHECKING: - assert self._connection is not None - unsub_callback = self._connection.send_message_callback_response( - SubscribeBluetoothLEAdvertisementsRequest(flags=0), - _on_bluetooth_le_advertising_response, - msg_types, + unsub_callback = ( + self._get_authenticated_connection().send_message_callback_response( + SubscribeBluetoothLEAdvertisementsRequest(flags=0), + _on_bluetooth_le_advertising_response, + msg_types, + ) ) def unsub() -> None: @@ -613,23 +593,21 @@ class APIClient: async def subscribe_bluetooth_le_raw_advertisements( self, on_advertisements: Callable[[list[BluetoothLERawAdvertisement]], None] ) -> Callable[[], None]: - self._check_authenticated() msg_types = (BluetoothLERawAdvertisementsResponse,) - if TYPE_CHECKING: - assert self._connection is not None - def _on_ble_raw_advertisement_response( data: BluetoothLERawAdvertisementsResponse, ) -> None: on_advertisements(data.advertisements) - unsub_callback = self._connection.send_message_callback_response( - SubscribeBluetoothLEAdvertisementsRequest( - flags=BluetoothProxySubscriptionFlag.RAW_ADVERTISEMENTS - ), - _on_ble_raw_advertisement_response, - msg_types, + unsub_callback = ( + self._get_authenticated_connection().send_message_callback_response( + SubscribeBluetoothLEAdvertisementsRequest( + flags=BluetoothProxySubscriptionFlag.RAW_ADVERTISEMENTS + ), + _on_ble_raw_advertisement_response, + msg_types, + ) ) def unsub() -> None: @@ -644,7 +622,6 @@ class APIClient: async def subscribe_bluetooth_connections_free( self, on_bluetooth_connections_free_update: Callable[[int, int], None] ) -> Callable[[], None]: - self._check_authenticated() msg_types = (BluetoothConnectionsFreeResponse,) def _on_bluetooth_connections_free_response( @@ -652,9 +629,7 @@ class APIClient: ) -> None: on_bluetooth_connections_free_update(msg.free, msg.limit) - if TYPE_CHECKING: - assert self._connection is not None - return self._connection.send_message_callback_response( + return self._get_authenticated_connection().send_message_callback_response( SubscribeBluetoothConnectionsFreeRequest(), _on_bluetooth_connections_free_response, msg_types, @@ -692,14 +667,10 @@ class APIClient: has_cache: bool = False, address_type: int | None = None, ) -> Callable[[], None]: - self._check_authenticated() msg_types = (BluetoothDeviceConnectionResponse,) debug = _LOGGER.isEnabledFor(logging.DEBUG) connect_future: asyncio.Future[None] = self._loop.create_future() - if TYPE_CHECKING: - assert self._connection is not None - if has_cache: # REMOTE_CACHING feature with cache: requestor has services and mtu cached request_type = BluetoothDeviceRequestType.CONNECT_V3_WITH_CACHE @@ -714,7 +685,7 @@ class APIClient: if debug: _LOGGER.debug("%s: Using connection version %s", address, request_type) - unsub = self._connection.send_message_callback_response( + unsub = self._get_authenticated_connection().send_message_callback_response( BluetoothDeviceRequest( address=address, request_type=request_type, @@ -865,10 +836,9 @@ class APIClient: msg_types: tuple[type[message.Message], ...], timeout: float, ) -> message.Message: - self._check_authenticated() - if TYPE_CHECKING: - assert self._connection is not None - [response] = await self._connection.send_messages_await_response_complex( + [ + response + ] = await self._get_authenticated_connection().send_messages_await_response_complex( ( BluetoothDeviceRequest( address=address, @@ -885,7 +855,6 @@ class APIClient: async def bluetooth_gatt_get_services( self, address: int ) -> ESPHomeBluetoothGATTServices: - self._check_authenticated() msg_types = ( BluetoothGATTGetServicesResponse, BluetoothGATTGetServicesDoneResponse, @@ -900,9 +869,7 @@ class APIClient: def do_stop(msg: message.Message) -> bool: return isinstance(msg, stop_types) and msg.address == address - if TYPE_CHECKING: - assert self._connection is not None - resp = await self._connection.send_messages_await_response_complex( + resp = await self._get_authenticated_connection().send_messages_await_response_complex( (BluetoothGATTGetServicesRequest(address=address),), do_append, do_stop, @@ -946,9 +913,7 @@ class APIClient: req.data = data if not response: - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + self._get_authenticated_connection().send_message(req) return await self._send_bluetooth_message_await_response( @@ -1008,9 +973,7 @@ class APIClient: req.data = data if not wait_for_response: - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + self._get_authenticated_connection().send_message(req) return await self._send_bluetooth_message_await_response( @@ -1037,7 +1000,6 @@ class APIClient: callbacks without stopping the notify session on the remote device, which should be used when the connection is lost. """ - await self._send_bluetooth_message_await_response( address, handle, @@ -1051,9 +1013,7 @@ class APIClient: if address == msg.address and handle == msg.handle: on_bluetooth_gatt_notify(handle, bytearray(msg.data)) - if TYPE_CHECKING: - assert self._connection is not None - remove_callback = self._connection.add_message_callback( + remove_callback = self._get_authenticated_connection().add_message_callback( _on_bluetooth_gatt_notify_data_response, (BluetoothGATTNotifyDataResponse,) ) @@ -1063,8 +1023,6 @@ class APIClient: remove_callback() - self._check_authenticated() - self._connection.send_message( BluetoothGATTNotifyRequest(address=address, handle=handle, enable=False) ) @@ -1074,16 +1032,12 @@ class APIClient: async def subscribe_home_assistant_states( self, on_state_sub: Callable[[str, str | None], None] ) -> None: - self._check_authenticated() - def _on_subscribe_home_assistant_state_response( msg: SubscribeHomeAssistantStateResponse, ) -> None: on_state_sub(msg.entity_id, msg.attribute) - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message_callback_response( + self._get_authenticated_connection().send_message_callback_response( SubscribeHomeAssistantStatesRequest(), _on_subscribe_home_assistant_state_response, (SubscribeHomeAssistantStateResponse,), @@ -1092,11 +1046,7 @@ class APIClient: async def send_home_assistant_state( self, entity_id: str, attribute: str | None, state: str ) -> None: - self._check_authenticated() - - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message( + self._get_authenticated_connection().send_message( HomeAssistantStateResponse( entity_id=entity_id, state=state, @@ -1111,8 +1061,6 @@ class APIClient: tilt: float | None = None, stop: bool = False, ) -> None: - self._check_authenticated() - req = CoverCommandRequest() req.key = key apiv = cast(APIVersion, self.api_version) @@ -1135,9 +1083,8 @@ class APIClient: elif position == 0.0: req.legacy_command = LegacyCoverCommand.CLOSE req.has_legacy_command = True - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def fan_command( self, @@ -1148,8 +1095,6 @@ class APIClient: oscillating: bool | None = None, direction: FanDirection | None = None, ) -> None: - self._check_authenticated() - req = FanCommandRequest() req.key = key if state is not None: @@ -1167,9 +1112,8 @@ class APIClient: if direction is not None: req.has_direction = True req.direction = direction - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def light_command( # pylint: disable=too-many-branches self, @@ -1187,8 +1131,6 @@ class APIClient: flash_length: float | None = None, effect: str | None = None, ) -> None: - self._check_authenticated() - req = LightCommandRequest() req.key = key if state is not None: @@ -1229,19 +1171,15 @@ class APIClient: if effect is not None: req.has_effect = True req.effect = effect - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def switch_command(self, key: int, state: bool) -> None: - self._check_authenticated() - req = SwitchCommandRequest() req.key = key req.state = state - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def climate_command( self, @@ -1256,8 +1194,6 @@ class APIClient: preset: ClimatePreset | None = None, custom_preset: str | None = None, ) -> None: - self._check_authenticated() - req = ClimateCommandRequest() req.key = key if mode is not None: @@ -1292,29 +1228,22 @@ class APIClient: if custom_preset is not None: req.has_custom_preset = True req.custom_preset = custom_preset - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def number_command(self, key: int, state: float) -> None: - self._check_authenticated() - req = NumberCommandRequest() req.key = key req.state = state - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def select_command(self, key: int, state: str) -> None: - self._check_authenticated() - req = SelectCommandRequest() req.key = key req.state = state - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def siren_command( self, @@ -1324,8 +1253,6 @@ class APIClient: volume: float | None = None, duration: int | None = None, ) -> None: - self._check_authenticated() - req = SirenCommandRequest() req.key = key if state is not None: @@ -1340,18 +1267,14 @@ class APIClient: if duration is not None: req.duration = duration req.has_duration = True - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def button_command(self, key: int) -> None: - self._check_authenticated() - req = ButtonCommandRequest() req.key = key - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def lock_command( self, @@ -1359,16 +1282,13 @@ class APIClient: command: LockCommand, code: str | None = None, ) -> None: - self._check_authenticated() - req = LockCommandRequest() req.key = key req.command = command if code is not None: req.code = code - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def media_player_command( self, @@ -1378,8 +1298,6 @@ class APIClient: volume: float | None = None, media_url: str | None = None, ) -> None: - self._check_authenticated() - req = MediaPlayerCommandRequest() req.key = key if command is not None: @@ -1391,25 +1309,19 @@ class APIClient: if media_url is not None: req.media_url = media_url req.has_media_url = True - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def text_command(self, key: int, state: str) -> None: - self._check_authenticated() - req = TextCommandRequest() req.key = key req.state = state - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def execute_service( self, service: UserService, data: ExecuteServiceDataType ) -> None: - self._check_authenticated() - req = ExecuteServiceRequest() req.key = service.key args = [] @@ -1440,9 +1352,8 @@ class APIClient: args.append(arg) # pylint: disable=no-member req.args.extend(args) - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def _request_image( self, *, single: bool = False, stream: bool = False @@ -1450,9 +1361,8 @@ class APIClient: req = CameraImageRequest() req.single = single req.stream = stream - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req) async def request_single_image(self) -> None: await self._request_image(single=True) @@ -1482,7 +1392,7 @@ class APIClient: Returns a callback to unsubscribe. """ - self._check_authenticated() + connection = self._get_authenticated_connection() start_task: asyncio.Task[int | None] | None = None @@ -1511,12 +1421,9 @@ class APIClient: self._background_tasks.add(stop_task) stop_task.add_done_callback(self._background_tasks.discard) - if TYPE_CHECKING: - assert self._connection is not None + connection.send_message(SubscribeVoiceAssistantRequest(subscribe=True)) - self._connection.send_message(SubscribeVoiceAssistantRequest(subscribe=True)) - - remove_callback = self._connection.add_message_callback( + remove_callback = connection.add_message_callback( _on_voice_assistant_request, (VoiceAssistantRequest,) ) @@ -1535,8 +1442,6 @@ class APIClient: def send_voice_assistant_event( self, event_type: VoiceAssistantEventType, data: dict[str, str] | None ) -> None: - self._check_authenticated() - req = VoiceAssistantEventResponse() req.event_type = event_type @@ -1551,9 +1456,7 @@ class APIClient: # pylint: disable=no-member req.data.extend(data_args) - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + self._get_authenticated_connection().send_message(req) async def alarm_control_panel_command( self, @@ -1561,13 +1464,10 @@ class APIClient: command: AlarmControlPanelCommand, code: str | None = None, ) -> None: - self._check_authenticated() - req = AlarmControlPanelCommandRequest() req.key = key req.command = command if code is not None: req.code = code - if TYPE_CHECKING: - assert self._connection is not None - self._connection.send_message(req) + + self._get_authenticated_connection().send_message(req)