diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 23a1620..62a1a02 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: args: - --fix - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.11.0 + rev: 24.1.1 hooks: - id: black args: diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index c5e8660..19491fe 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -776,8 +776,9 @@ class APIClient: async def _bluetooth_gatt_read( self, - req_type: type[BluetoothGATTReadDescriptorRequest] - | type[BluetoothGATTReadRequest], + req_type: ( + type[BluetoothGATTReadDescriptorRequest] | type[BluetoothGATTReadRequest] + ), address: int, handle: int, timeout: float, diff --git a/aioesphomeapi/client_callbacks.py b/aioesphomeapi/client_callbacks.py index 210696d..18e7632 100644 --- a/aioesphomeapi/client_callbacks.py +++ b/aioesphomeapi/client_callbacks.py @@ -114,11 +114,13 @@ def on_bluetooth_device_connection_response( def on_bluetooth_handle_message( address: int, handle: int, - msg: BluetoothGATTErrorResponse - | BluetoothGATTNotifyResponse - | BluetoothGATTReadResponse - | BluetoothGATTWriteResponse - | BluetoothDeviceConnectionResponse, + msg: ( + BluetoothGATTErrorResponse + | BluetoothGATTNotifyResponse + | BluetoothGATTReadResponse + | BluetoothGATTWriteResponse + | BluetoothDeviceConnectionResponse + ), ) -> bool: """Filter a Bluetooth message for an address and handle.""" if type(msg) is BluetoothDeviceConnectionResponse: @@ -129,14 +131,16 @@ def on_bluetooth_handle_message( def on_bluetooth_message_types( address: int, msg_types: tuple[type[message.Message]], - msg: BluetoothGATTErrorResponse - | BluetoothGATTNotifyResponse - | BluetoothGATTReadResponse - | BluetoothGATTWriteResponse - | BluetoothDeviceConnectionResponse - | BluetoothGATTGetServicesResponse - | BluetoothGATTGetServicesDoneResponse - | BluetoothGATTErrorResponse, + msg: ( + BluetoothGATTErrorResponse + | BluetoothGATTNotifyResponse + | BluetoothGATTReadResponse + | BluetoothGATTWriteResponse + | BluetoothDeviceConnectionResponse + | BluetoothGATTGetServicesResponse + | BluetoothGATTGetServicesDoneResponse + | BluetoothGATTErrorResponse + ), ) -> bool: """Filter Bluetooth messages of a specific type and address.""" return type(msg) in msg_types and bool(msg.address == address) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 14403b7..9f08f53 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -221,9 +221,9 @@ class APIConnection: self._params = params self.on_stop: Callable[[bool], None] | None = on_stop self._socket: socket.socket | None = None - self._frame_helper: None | ( - APINoiseFrameHelper | APIPlaintextFrameHelper - ) = None + self._frame_helper: None | (APINoiseFrameHelper | APIPlaintextFrameHelper) = ( + None + ) self.api_version: APIVersion | None = None self.connection_state = CONNECTION_STATE_INITIALIZED diff --git a/requirements_test.txt b/requirements_test.txt index 9ad5b52..333c7f0 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,5 +1,5 @@ pylint==3.0.3 -black==23.12.1 +black==24.1.1 flake8==7.0.0 isort==5.13.2 mypy==1.8.0 diff --git a/tests/__init__.py b/tests/__init__.py index 2c99c93..9d8d653 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,5 @@ """Init tests.""" + from __future__ import annotations import logging diff --git a/tests/conftest.py b/tests/conftest.py index 33815da..47220e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ """Test fixtures.""" + from __future__ import annotations import asyncio @@ -217,11 +218,14 @@ async def api_client( password=None, ) - with patch.object( - event_loop, - "create_connection", - side_effect=partial(_create_mock_transport_protocol, transport, connected), - ), patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection): + with ( + patch.object( + event_loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), + ), + patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), + ): connect_task = asyncio.create_task(connect_client(client, login=False)) await connected.wait() conn = client._connection diff --git a/tests/test_client.py b/tests/test_client.py index cbd86c6..fff3d08 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -186,9 +186,10 @@ async def test_connect_backwards_compat() -> None: pass cli = PatchableApiClient("host", 1234, None) - with patch.object(cli, "start_connection") as mock_start_connection, patch.object( - cli, "finish_connection" - ) as mock_finish_connection: + with ( + patch.object(cli, "start_connection") as mock_start_connection, + patch.object(cli, "finish_connection") as mock_finish_connection, + ): await cli.connect() assert mock_start_connection.mock_calls == [call(None)] @@ -244,9 +245,12 @@ async def test_connection_released_if_connecting_is_cancelled() -> None: mock_socket.getpeername.return_value = ("4.3.3.3", 323) return mock_socket - with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection", - _start_connection_without_delay, + with ( + patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection), + patch( + "aioesphomeapi.connection.aiohappyeyeballs.start_connection", + _start_connection_without_delay, + ), ): await cli.start_connection() await asyncio.sleep(0) @@ -268,10 +272,13 @@ async def test_request_while_handshaking() -> None: pass cli = PatchableApiClient("host", 1234, None) - with patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection", - side_effect=partial(asyncio.sleep, 1), - ), patch.object(cli, "finish_connection"): + with ( + patch( + "aioesphomeapi.connection.aiohappyeyeballs.start_connection", + side_effect=partial(asyncio.sleep, 1), + ), + patch.object(cli, "finish_connection"), + ): connect_task = asyncio.create_task(cli.connect()) await asyncio.sleep(0) @@ -1500,11 +1507,14 @@ async def test_bluetooth_gatt_start_notify_fails( handlers_before = len(list(itertools.chain(*connection._message_handlers.values()))) - with patch.object( - connection, - "send_messages_await_response_complex", - side_effect=APIConnectionError, - ), pytest.raises(APIConnectionError): + with ( + patch.object( + connection, + "send_messages_await_response_complex", + side_effect=APIConnectionError, + ), + pytest.raises(APIConnectionError), + ): await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify) assert ( diff --git a/tests/test_connection.py b/tests/test_connection.py index 8c31ab7..2b84748 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -135,8 +135,9 @@ async def test_disconnect_when_not_fully_connected( await asyncio.sleep(0) transport.reset_mock() - with patch("aioesphomeapi.connection.DISCONNECT_CONNECT_TIMEOUT", 0.0), patch( - "aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0 + with ( + patch("aioesphomeapi.connection.DISCONNECT_CONNECT_TIMEOUT", 0.0), + patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0), ): await conn.disconnect() @@ -344,10 +345,13 @@ async def test_start_connection_times_out( async def _mock_socket_connect(*args, **kwargs): await asyncio.sleep(500) - with patch( - "aioesphomeapi.connection.aiohappyeyeballs.start_connection", - side_effect=_mock_socket_connect, - ), patch("aioesphomeapi.connection.TCP_CONNECT_TIMEOUT", 0.0): + with ( + patch( + "aioesphomeapi.connection.aiohappyeyeballs.start_connection", + side_effect=_mock_socket_connect, + ), + patch("aioesphomeapi.connection.TCP_CONNECT_TIMEOUT", 0.0), + ): connect_task = asyncio.create_task(connect(conn, login=False)) await asyncio.sleep(0) @@ -497,14 +501,17 @@ async def test_plaintext_connection_fails_handshake( remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse)) transport = MagicMock() - with patch( - "aioesphomeapi.connection.APIPlaintextFrameHelper", - APIPlaintextFrameHelperHandshakeException, - ), patch.object( - loop, - "create_connection", - side_effect=partial( - _create_failing_mock_transport_protocol, transport, connected + with ( + patch( + "aioesphomeapi.connection.APIPlaintextFrameHelper", + APIPlaintextFrameHelperHandshakeException, + ), + patch.object( + loop, + "create_connection", + side_effect=partial( + _create_failing_mock_transport_protocol, transport, connected + ), ), ): connect_task = asyncio.create_task(connect(conn, login=False)) @@ -537,12 +544,10 @@ async def test_plaintext_connection_fails_handshake( def _frame_helper_close_call(): call_order.append("frame_helper_close") - with patch.object( - conn._socket, "close", side_effect=_socket_close_call - ), patch.object( - conn._frame_helper, "close", side_effect=_frame_helper_close_call - ), pytest.raises( - raised_exception + with ( + patch.object(conn._socket, "close", side_effect=_socket_close_call), + patch.object(conn._frame_helper, "close", side_effect=_frame_helper_close_call), + pytest.raises(raised_exception), ): await asyncio.sleep(0) await connect_task @@ -658,16 +663,20 @@ async def test_connect_resolver_times_out( connected = asyncio.Event() event_loop = asyncio.get_running_loop() - with patch( - "aioesphomeapi.host_resolver.async_resolve_host", - side_effect=asyncio.TimeoutError, - ), patch.object( - event_loop, - "create_connection", - side_effect=partial(_create_mock_transport_protocol, transport, connected), - ), pytest.raises( - ResolveAPIError, - match="Timeout while resolving IP address for fake.address", + with ( + patch( + "aioesphomeapi.host_resolver.async_resolve_host", + side_effect=asyncio.TimeoutError, + ), + patch.object( + event_loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), + ), + pytest.raises( + ResolveAPIError, + match="Timeout while resolving IP address for fake.address", + ), ): await connect(conn, login=False) diff --git a/tests/test_host_resolver.py b/tests/test_host_resolver.py index 5b333e2..30a7882 100644 --- a/tests/test_host_resolver.py +++ b/tests/test_host_resolver.py @@ -45,9 +45,10 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos): [ipv6], ] info.async_request = AsyncMock(return_value=True) - with patch( - "aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info - ), patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf): + with ( + patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info), + patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf), + ): ret = await hr._async_resolve_host_zeroconf("asdf", 6052) info.async_request.assert_called_once() @@ -86,20 +87,26 @@ async def test_resolve_host_zeroconf_empty(async_zeroconf: AsyncZeroconf): @pytest.mark.asyncio async def test_resolve_host_zeroconf_fails(async_zeroconf: AsyncZeroconf): - with patch( - "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", - side_effect=Exception("no buffers"), - ), pytest.raises(ResolveAPIError, match="no buffers"): + with ( + patch( + "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", + side_effect=Exception("no buffers"), + ), + pytest.raises(ResolveAPIError, match="no buffers"), + ): await hr._async_resolve_host_zeroconf("asdf.local", 6052) @pytest.mark.asyncio @patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo", return_value=[]) async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroconf): - with patch( - "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", - side_effect=Exception("no buffers"), - ), pytest.raises(ResolveAPIError, match="no buffers"): + with ( + patch( + "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", + side_effect=Exception("no buffers"), + ), + pytest.raises(ResolveAPIError, match="no buffers"), + ): await hr.async_resolve_host(["asdf.local"], 6052) @@ -226,13 +233,13 @@ async def test_resolve_host_zeroconf_service_info_oserror( ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), ] info.async_request = AsyncMock(return_value=True) - with patch( - "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", - side_effect=OSError("out of buffers"), - ), patch( - "aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf - ), pytest.raises( - ResolveAPIError, match="out of buffers" + with ( + patch( + "aioesphomeapi.host_resolver.AsyncServiceInfo.async_request", + side_effect=OSError("out of buffers"), + ), + patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf), + pytest.raises(ResolveAPIError, match="out of buffers"), ): await hr._async_resolve_host_zeroconf("asdf", 6052) @@ -247,9 +254,13 @@ async def test_resolve_host_create_zeroconf_oserror( ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"), ] info.async_request = AsyncMock(return_value=True) - with patch( - "aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers") - ), pytest.raises(ResolveAPIError, match="out of buffers"): + with ( + patch( + "aioesphomeapi.zeroconf.AsyncZeroconf", + side_effect=OSError("out of buffers"), + ), + pytest.raises(ResolveAPIError, match="out of buffers"), + ): await hr._async_resolve_host_zeroconf("asdf", 6052) diff --git a/tests/test_log_runner.py b/tests/test_log_runner.py index c6725b4..e95f5c5 100644 --- a/tests/test_log_runner.py +++ b/tests/test_log_runner.py @@ -72,9 +72,12 @@ async def test_log_runner( await original_subscribe_logs(*args, **kwargs) subscribed.set() - with patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol - ), patch.object(cli, "subscribe_logs", _wait_subscribe_cli): + with ( + patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ), + patch.object(cli, "subscribe_logs", _wait_subscribe_cli), + ): stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf) await connected.wait() protocol = cli._connection._frame_helper @@ -138,9 +141,12 @@ async def test_log_runner_reconnects_on_disconnect( await original_subscribe_logs(*args, **kwargs) subscribed.set() - with patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol - ), patch.object(cli, "subscribe_logs", _wait_subscribe_cli): + with ( + patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ), + patch.object(cli, "subscribe_logs", _wait_subscribe_cli), + ): stop = await async_run(cli, on_log, aio_zeroconf_instance=async_zeroconf) await connected.wait() protocol = cli._connection._frame_helper @@ -214,9 +220,10 @@ async def test_log_runner_reconnects_on_subscribe_failure( subscribed.set() raise APIConnectionError("subscribed force to fail") - with patch.object( - cli, "disconnect", partial(cli.disconnect, force=True) - ), patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli): + with ( + patch.object(cli, "disconnect", partial(cli.disconnect, force=True)), + patch.object(cli, "subscribe_logs", _wait_and_fail_subscribe_cli), + ): with patch.object( loop, "create_connection", side_effect=_create_mock_transport_protocol ): @@ -230,9 +237,12 @@ async def test_log_runner_reconnects_on_subscribe_failure( assert cli._connection is None - with patch.object( - loop, "create_connection", side_effect=_create_mock_transport_protocol - ), patch.object(cli, "subscribe_logs"): + with ( + patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ), + patch.object(cli, "subscribe_logs"), + ): connected.clear() await asyncio.sleep(0) async_fire_time_changed( diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index 29b39b5..4554563 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -215,8 +215,9 @@ async def test_reconnect_logic_state(patchable_api_client: APIClient): assert rl._connection_state is ReconnectLogicState.DISCONNECTED assert rl._tries == 1 - with patch.object(cli, "start_connection"), patch.object( - cli, "finish_connection", side_effect=RequiresEncryptionAPIError + with ( + patch.object(cli, "start_connection"), + patch.object(cli, "finish_connection", side_effect=RequiresEncryptionAPIError), ): await rl.start() await asyncio.sleep(0) @@ -428,8 +429,9 @@ async def test_reconnect_zeroconf( assert not rl._is_stopped caplog.clear() - with patch.object(cli, "start_connection") as mock_start_connection, patch.object( - cli, "finish_connection" + with ( + patch.object(cli, "start_connection") as mock_start_connection, + patch.object(cli, "finish_connection"), ): assert rl._zc_listening is True rl.async_update_records( @@ -482,9 +484,12 @@ async def test_reconnect_zeroconf_not_while_handshaking( assert mock_start_connection.call_count == 1 - with patch.object(cli, "start_connection") as mock_start_connection, patch.object( - cli, "finish_connection", side_effect=slow_connect_fail - ) as mock_finish_connection: + with ( + patch.object(cli, "start_connection") as mock_start_connection, + patch.object( + cli, "finish_connection", side_effect=slow_connect_fail + ) as mock_finish_connection, + ): assert rl._connection_state is ReconnectLogicState.DISCONNECTED assert rl._accept_zeroconf_records is True assert not rl._is_stopped @@ -536,9 +541,12 @@ async def test_connect_task_not_cancelled_while_handshaking( assert mock_start_connection.call_count == 1 - with patch.object(cli, "start_connection") as mock_start_connection, patch.object( - cli, "finish_connection", side_effect=slow_connect_fail - ) as mock_finish_connection: + with ( + patch.object(cli, "start_connection") as mock_start_connection, + patch.object( + cli, "finish_connection", side_effect=slow_connect_fail + ) as mock_finish_connection, + ): assert rl._connection_state is ReconnectLogicState.DISCONNECTED assert rl._accept_zeroconf_records is True assert not rl._is_stopped @@ -647,8 +655,9 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake( ) assert rl._connection_state is ReconnectLogicState.DISCONNECTED - with patch.object(cli, "start_connection"), patch.object( - cli, "finish_connection", side_effect=slow_connect_fail + with ( + patch.object(cli, "start_connection"), + patch.object(cli, "finish_connection", side_effect=slow_connect_fail), ): await rl.start() for _ in range(3):