Hold strong references to voice assistant tasks (#424)

This commit is contained in:
J. Nick Koston 2023-04-20 14:30:28 -10:00 committed by GitHub
parent 032e921cb3
commit 3f29ac92ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -235,6 +235,7 @@ class APIClient:
) )
self._connection: Optional[APIConnection] = None self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None self._cached_name: Optional[str] = None
self._background_tasks: set[asyncio.Task[Any]] = set()
@property @property
def expected_name(self) -> Optional[str]: def expected_name(self) -> Optional[str]:
@ -1254,7 +1255,7 @@ class APIClient:
"""Subscribes to voice assistant messages from the device. """Subscribes to voice assistant messages from the device.
handle_start: called when the devices requests a server to send audio data to. handle_start: called when the devices requests a server to send audio data to.
This callback is asyncronous and returns the port number the server is started on. This callback is asynchronous and returns the port number the server is started on.
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed. handle_stop: called when the device has stopped sending audio data and the pipeline should be closed.
@ -1262,7 +1263,7 @@ class APIClient:
""" """
self._check_authenticated() self._check_authenticated()
t: Optional[asyncio.Task[Optional[int]]] = None start_task: Optional[asyncio.Task[Optional[int]]] = None
def _started(fut: asyncio.Task[Optional[int]]) -> None: def _started(fut: asyncio.Task[Optional[int]]) -> None:
if self._connection is not None and not fut.cancelled(): if self._connection is not None and not fut.cancelled():
@ -1275,12 +1276,15 @@ class APIClient:
def on_msg(msg: VoiceAssistantRequest) -> None: def on_msg(msg: VoiceAssistantRequest) -> None:
command = VoiceAssistantCommand.from_pb(msg) command = VoiceAssistantCommand.from_pb(msg)
loop = asyncio.get_running_loop()
if command.start: if command.start:
t = loop.create_task(handle_start()) start_task = asyncio.create_task(handle_start())
t.add_done_callback(_started) start_task.add_done_callback(_started)
# We hold a reference to the start_task in unsub function
# so we don't need to add it to the background tasks.
else: else:
loop.create_task(handle_stop()) stop_task = asyncio.create_task(handle_stop())
self._background_tasks.add(stop_task)
stop_task.add_done_callback(self._background_tasks.discard)
assert self._connection is not None assert self._connection is not None
@ -1297,8 +1301,8 @@ class APIClient:
SubscribeVoiceAssistantRequest(subscribe=False) SubscribeVoiceAssistantRequest(subscribe=False)
) )
if t is not None and not t.cancelled(): if start_task is not None and not start_task.cancelled():
t.cancel() start_task.cancel()
return unsub return unsub