diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 2ada885..4f22254 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -461,34 +461,38 @@ class APIClient: entities.append(cls.from_pb(msg)) return entities, services + def _on_state_msg( + self, + on_state: Callable[[EntityState], None], + image_stream: dict[int, list[bytes]], + msg: message.Message, + ) -> None: + """Handle a state message.""" + msg_type = type(msg) + if cls := SUBSCRIBE_STATES_RESPONSE_TYPES.get(msg_type): + on_state(cls.from_pb(msg)) + elif msg_type is CameraImageResponse: + if TYPE_CHECKING: + assert isinstance(msg, CameraImageResponse) + msg_key = msg.key + data_parts: list[bytes] | None = image_stream.get(msg_key) + if not data_parts: + data_parts = [] + image_stream[msg_key] = data_parts + + data_parts.append(msg.data) + if msg.done: + # Return CameraState with the merged data + image_data = b"".join(data_parts) + del image_stream[msg_key] + on_state(CameraState(key=msg.key, data=image_data)) # type: ignore[call-arg] + async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None: - image_stream: dict[int, list[bytes]] = {} - response_types = SUBSCRIBE_STATES_RESPONSE_TYPES - msg_types = SUBSCRIBE_STATES_MSG_TYPES - - def _on_state_msg(msg: message.Message) -> None: - msg_type = type(msg) - cls = response_types.get(msg_type) - if cls: - on_state(cls.from_pb(msg)) - elif msg_type is CameraImageResponse: - if TYPE_CHECKING: - assert isinstance(msg, CameraImageResponse) - msg_key = msg.key - data_parts: list[bytes] | None = image_stream.get(msg_key) - if not data_parts: - data_parts = [] - image_stream[msg_key] = data_parts - - data_parts.append(msg.data) - if msg.done: - # Return CameraState with the merged data - image_data = b"".join(data_parts) - del image_stream[msg_key] - on_state(CameraState(key=msg.key, data=image_data)) # type: ignore[call-arg] - + """Subscribe to state updates.""" self._get_connection().send_message_callback_response( - SubscribeStatesRequest(), _on_state_msg, msg_types + SubscribeStatesRequest(), + partial(self._on_state_msg, on_state, {}), + SUBSCRIBE_STATES_MSG_TYPES, ) async def subscribe_logs( @@ -506,17 +510,19 @@ class APIClient: req, on_log, (SubscribeLogsResponse,) ) + def _on_home_assistant_service_response( + self, + on_service_call: Callable[[HomeassistantServiceCall], None], + msg: HomeassistantServiceResponse, + ) -> None: + on_service_call(HomeassistantServiceCall.from_pb(msg)) + async def subscribe_service_calls( self, on_service_call: Callable[[HomeassistantServiceCall], None] ) -> None: - def _on_home_assistant_service_response( - msg: HomeassistantServiceResponse, - ) -> None: - on_service_call(HomeassistantServiceCall.from_pb(msg)) - self._get_connection().send_message_callback_response( SubscribeHomeassistantServicesRequest(), - _on_home_assistant_service_response, + partial(self._on_home_assistant_service_response, on_service_call), (HomeassistantServiceResponse,), )