diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 1706be8..ebdda1c 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -65,7 +65,6 @@ from .api_pb2 import ( # type: ignore SwitchCommandRequest, TextCommandRequest, UnsubscribeBluetoothLEAdvertisementsRequest, - VoiceAssistantAudioSettings, VoiceAssistantEventData, VoiceAssistantEventResponse, VoiceAssistantRequest, @@ -118,9 +117,9 @@ from .model import ( MediaPlayerCommand, UserService, UserServiceArgType, - VoiceAssistantCommand, - VoiceAssistantEventType, ) +from .model import VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel +from .model import VoiceAssistantCommand, VoiceAssistantEventType from .model_conversions import ( LIST_ENTITIES_SERVICES_RESPONSE_TYPES, SUBSCRIBE_STATES_RESPONSE_TYPES, @@ -1240,7 +1239,8 @@ class APIClient: async def subscribe_voice_assistant( self, handle_start: Callable[ - [str, int, VoiceAssistantAudioSettings], Coroutine[Any, Any, int | None] + [str, int, VoiceAssistantAudioSettingsModel], + Coroutine[Any, Any, int | None], ], handle_stop: Callable[[], Coroutine[Any, Any, None]], ) -> Callable[[], None]: diff --git a/tests/test_client.py b/tests/test_client.py index 141a498..3943083 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -55,10 +55,14 @@ from aioesphomeapi.api_pb2 import ( SirenCommandRequest, SubscribeHomeAssistantStateResponse, SubscribeLogsResponse, + SubscribeVoiceAssistantRequest, SwitchCommandRequest, TextCommandRequest, + VoiceAssistantAudioSettings, VoiceAssistantEventData, VoiceAssistantEventResponse, + VoiceAssistantRequest, + VoiceAssistantResponse, ) from aioesphomeapi.client import APIClient from aioesphomeapi.connection import APIConnection @@ -96,6 +100,9 @@ from aioesphomeapi.model import ( UserServiceArg, UserServiceArgType, ) +from aioesphomeapi.model import ( + VoiceAssistantAudioSettings as VoiceAssistantAudioSettingsModel, +) from aioesphomeapi.model import VoiceAssistantEventType as VoiceAssistantEventModelType from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState @@ -1699,3 +1706,126 @@ async def test_send_voice_assistant_event(auth_client: APIClient) -> None: data=[], ) ) + + +@pytest.mark.asyncio +async def test_subscribe_voice_assistant( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test subscribe_voice_assistant.""" + client, connection, transport, protocol = api_client + send = patch_send(client) + starts = [] + stops = [] + + async def handle_start( + conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings + ) -> int | None: + starts.append((conversation_id, flags, audio_settings)) + return 42 + + async def handle_stop() -> None: + stops.append(True) + + await client.subscribe_voice_assistant(handle_start, handle_stop) + send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True)) + send.reset_mock() + audio_settings = VoiceAssistantAudioSettings( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ) + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=True, + flags=42, + audio_settings=audio_settings, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + await asyncio.sleep(0) + assert starts == [ + ( + "theone", + 42, + VoiceAssistantAudioSettingsModel( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ), + ) + ] + assert stops == [] + send.assert_called_once_with(VoiceAssistantResponse(port=42)) + send.reset_mock() + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=False, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + assert stops == [True] + + +@pytest.mark.asyncio +async def test_subscribe_voice_assistant_failure( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test subscribe_voice_assistant.""" + client, connection, transport, protocol = api_client + send = patch_send(client) + starts = [] + stops = [] + + async def handle_start( + conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings + ) -> int | None: + starts.append((conversation_id, flags, audio_settings)) + # Return None to indicate failure + return None + + async def handle_stop() -> None: + stops.append(True) + + await client.subscribe_voice_assistant(handle_start, handle_stop) + send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=True)) + send.reset_mock() + audio_settings = VoiceAssistantAudioSettings( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ) + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=True, + flags=42, + audio_settings=audio_settings, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + await asyncio.sleep(0) + assert starts == [ + ( + "theone", + 42, + VoiceAssistantAudioSettingsModel( + noise_suppression_level=42, + auto_gain=42, + volume_multiplier=42, + ), + ) + ] + assert stops == [] + send.assert_called_once_with(VoiceAssistantResponse(error=True)) + send.reset_mock() + response: message.Message = VoiceAssistantRequest( + conversation_id="theone", + start=False, + ) + mock_data_received(protocol, generate_plaintext_packet(response)) + await asyncio.sleep(0) + assert stops == [True]