Upgrade python code to 3.9 (#496)

This commit is contained in:
J. Nick Koston 2023-07-19 15:33:28 -05:00 committed by GitHub
parent 3a50305cf8
commit d63b9bbf5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 232 additions and 246 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from .noise import APINoiseFrameHelper from .noise import APINoiseFrameHelper
from .plain_text import APIPlaintextFrameHelper from .plain_text import APIPlaintextFrameHelper

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from functools import partial from functools import partial
from typing import Callable, Optional, Union, cast from typing import Callable, cast
from ..core import SocketClosedAPIError from ..core import SocketClosedAPIError
@ -45,10 +47,8 @@ class APIFrameHelper(asyncio.Protocol):
"""Initialize the API frame helper.""" """Initialize the API frame helper."""
self._on_pkt = on_pkt self._on_pkt = on_pkt
self._on_error = on_error self._on_error = on_error
self._transport: Optional[asyncio.Transport] = None self._transport: asyncio.Transport | None = None
self._writer: Optional[ self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
Callable[[Union[bytes, bytearray, memoryview]], None]
] = None
self._connected_event = asyncio.Event() self._connected_event = asyncio.Event()
self._buffer = bytearray() self._buffer = bytearray()
self._buffer_len = 0 self._buffer_len = 0
@ -57,7 +57,7 @@ class APIFrameHelper(asyncio.Protocol):
self._log_name = log_name self._log_name = log_name
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG) self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
def _read_exactly(self, length: int) -> Optional[bytearray]: def _read_exactly(self, length: int) -> bytearray | None:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" """Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos original_pos = self._pos
new_pos = original_pos + length new_pos = original_pos + length
@ -87,13 +87,13 @@ class APIFrameHelper(asyncio.Protocol):
def _handle_error(self, exc: Exception) -> None: def _handle_error(self, exc: Exception) -> None:
self._on_error(exc) self._on_error(exc)
def connection_lost(self, exc: Optional[Exception]) -> None: def connection_lost(self, exc: Exception | None) -> None:
self._handle_error( self._handle_error(
exc or SocketClosedAPIError(f"{self._log_name}: Connection lost") exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
) )
return super().connection_lost(exc) return super().connection_lost(exc)
def eof_received(self) -> Optional[bool]: def eof_received(self) -> bool | None:
self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received")) self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
return super().eof_received() return super().eof_received()

View File

@ -1,10 +1,12 @@
from __future__ import annotations
import asyncio import asyncio
import base64 import base64
import logging import logging
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from struct import Struct from struct import Struct
from typing import TYPE_CHECKING, Any, Callable, Optional, Type from typing import TYPE_CHECKING, Any, Callable
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
from cryptography.exceptions import InvalidTag from cryptography.exceptions import InvalidTag
@ -34,7 +36,7 @@ class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
format_nonce = PACK_NONCE format_nonce = PACK_NONCE
@property @property
def klass(self) -> Type[ChaCha20Poly1305Reusable]: def klass(self) -> type[ChaCha20Poly1305Reusable]:
return ChaCha20Poly1305Reusable return ChaCha20Poly1305Reusable
@ -81,7 +83,7 @@ class APINoiseFrameHelper(APIFrameHelper):
on_pkt: Callable[[int, bytes], None], on_pkt: Callable[[int, bytes], None],
on_error: Callable[[Exception], None], on_error: Callable[[Exception], None],
noise_psk: str, noise_psk: str,
expected_name: Optional[str], expected_name: str | None,
client_info: str, client_info: str,
log_name: str, log_name: str,
) -> None: ) -> None:
@ -93,9 +95,9 @@ class APINoiseFrameHelper(APIFrameHelper):
self._noise_psk = noise_psk self._noise_psk = noise_psk
self._expected_name = expected_name self._expected_name = expected_name
self._set_state(NoiseConnectionState.HELLO) self._set_state(NoiseConnectionState.HELLO)
self._server_name: Optional[str] = None self._server_name: str | None = None
self._decrypt: Optional[Callable[[bytes], bytes]] = None self._decrypt: Callable[[bytes], bytes] | None = None
self._encrypt: Optional[Callable[[bytes], bytes]] = None self._encrypt: Callable[[bytes], bytes] | None = None
self._setup_proto() self._setup_proto()
self._is_ready = False self._is_ready = False

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
from ..util import bytes_to_varuint, varuint_to_bytes from ..util import bytes_to_varuint, varuint_to_bytes
@ -45,8 +47,8 @@ class APIPlaintextFrameHelper(APIFrameHelper):
init_bytes = self._read_exactly(3) init_bytes = self._read_exactly(3)
if init_bytes is None: if init_bytes is None:
return return
msg_type_int: Optional[int] = None msg_type_int: int | None = None
length_int: Optional[int] = None length_int: int | None = None
preamble, length_high, maybe_msg_type = init_bytes preamble, length_high, maybe_msg_type = init_bytes
if preamble != 0x00: if preamble != 0x00:
if preamble == 0x01: if preamble == 0x01:

View File

@ -1,20 +1,10 @@
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from collections.abc import Awaitable, Coroutine
from functools import partial from functools import partial
from typing import ( from typing import TYPE_CHECKING, Any, Callable, Union, cast
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Coroutine,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from google.protobuf import message from google.protobuf import message
@ -200,8 +190,8 @@ DEFAULT_BLE_DISCONNECT_TIMEOUT = 5.0
# connection is poor. # connection is poor.
KEEP_ALIVE_FREQUENCY = 20.0 KEEP_ALIVE_FREQUENCY = 20.0
ExecuteServiceDataType = Dict[ ExecuteServiceDataType = dict[
str, Union[bool, int, float, str, List[bool], List[int], List[float], List[str]] str, Union[bool, int, float, str, list[bool], list[int], list[float], list[str]]
] ]
@ -213,13 +203,13 @@ class APIClient:
self, self,
address: str, address: str,
port: int, port: int,
password: Optional[str], password: str | None,
*, *,
client_info: str = "aioesphomeapi", client_info: str = "aioesphomeapi",
keepalive: float = KEEP_ALIVE_FREQUENCY, keepalive: float = KEEP_ALIVE_FREQUENCY,
zeroconf_instance: ZeroconfInstanceType = None, zeroconf_instance: ZeroconfInstanceType = None,
noise_psk: Optional[str] = None, noise_psk: str | None = None,
expected_name: Optional[str] = None, expected_name: str | None = None,
): ):
"""Create a client, this object is shared across sessions. """Create a client, this object is shared across sessions.
@ -247,17 +237,17 @@ class APIClient:
noise_psk=noise_psk or None, noise_psk=noise_psk or None,
expected_name=expected_name, expected_name=expected_name,
) )
self._connection: Optional[APIConnection] = None self._connection: APIConnection | None = None
self._cached_name: Optional[str] = None self._cached_name: str | None = None
self._background_tasks: set[asyncio.Task[Any]] = set() self._background_tasks: set[asyncio.Task[Any]] = set()
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
@property @property
def expected_name(self) -> Optional[str]: def expected_name(self) -> str | None:
return self._params.expected_name return self._params.expected_name
@expected_name.setter @expected_name.setter
def expected_name(self, value: Optional[str]) -> None: def expected_name(self, value: str | None) -> None:
self._params.expected_name = value self._params.expected_name = value
@property @property
@ -277,7 +267,7 @@ class APIClient:
async def connect( async def connect(
self, self,
on_stop: Optional[Callable[[bool], Awaitable[None]]] = None, on_stop: Callable[[bool], Awaitable[None]] | None = None,
login: bool = False, login: bool = False,
) -> None: ) -> None:
if self._connection is not None: if self._connection is not None:
@ -343,9 +333,9 @@ class APIClient:
async def list_entities_services( async def list_entities_services(
self, self,
) -> Tuple[List[EntityInfo], List[UserService]]: ) -> tuple[list[EntityInfo], list[UserService]]:
self._check_authenticated() self._check_authenticated()
response_types: Dict[Any, Optional[Type[EntityInfo]]] = { response_types: dict[Any, type[EntityInfo] | None] = {
ListEntitiesBinarySensorResponse: BinarySensorInfo, ListEntitiesBinarySensorResponse: BinarySensorInfo,
ListEntitiesButtonResponse: ButtonInfo, ListEntitiesButtonResponse: ButtonInfo,
ListEntitiesCoverResponse: CoverInfo, ListEntitiesCoverResponse: CoverInfo,
@ -376,8 +366,8 @@ class APIClient:
resp = await self._connection.send_message_await_response_complex( resp = await self._connection.send_message_await_response_complex(
ListEntitiesRequest(), do_append, do_stop, msg_types, timeout=60 ListEntitiesRequest(), do_append, do_stop, msg_types, timeout=60
) )
entities: List[EntityInfo] = [] entities: list[EntityInfo] = []
services: List[UserService] = [] services: list[UserService] = []
for msg in resp: for msg in resp:
if isinstance(msg, ListEntitiesServicesResponse): if isinstance(msg, ListEntitiesServicesResponse):
services.append(UserService.from_pb(msg)) services.append(UserService.from_pb(msg))
@ -389,8 +379,8 @@ class APIClient:
async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None: async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None:
self._check_authenticated() self._check_authenticated()
image_stream: Dict[int, list[bytes]] = {} image_stream: dict[int, list[bytes]] = {}
response_types: Dict[Any, Type[EntityState]] = { response_types: dict[Any, type[EntityState]] = {
BinarySensorStateResponse: BinarySensorState, BinarySensorStateResponse: BinarySensorState,
CoverStateResponse: CoverState, CoverStateResponse: CoverState,
FanStateResponse: FanState, FanStateResponse: FanState,
@ -417,7 +407,7 @@ class APIClient:
if TYPE_CHECKING: if TYPE_CHECKING:
assert isinstance(msg, CameraImageResponse) assert isinstance(msg, CameraImageResponse)
msg_key = msg.key msg_key = msg.key
data_parts: Optional[List[bytes]] = image_stream.get(msg_key) data_parts: list[bytes] | None = image_stream.get(msg_key)
if not data_parts: if not data_parts:
data_parts = [] data_parts = []
image_stream[msg_key] = data_parts image_stream[msg_key] = data_parts
@ -425,7 +415,7 @@ class APIClient:
data_parts.append(msg.data) data_parts.append(msg.data)
if msg.done: if msg.done:
# Return CameraState with the merged data # Return CameraState with the merged data
image_data = bytes().join(data_parts) image_data = b"".join(data_parts)
del image_stream[msg_key] del image_stream[msg_key]
on_state(CameraState(key=msg.key, data=image_data)) # type: ignore[call-arg] on_state(CameraState(key=msg.key, data=image_data)) # type: ignore[call-arg]
@ -437,8 +427,8 @@ class APIClient:
async def subscribe_logs( async def subscribe_logs(
self, self,
on_log: Callable[[SubscribeLogsResponse], None], on_log: Callable[[SubscribeLogsResponse], None],
log_level: Optional[LogLevel] = None, log_level: LogLevel | None = None,
dump_config: Optional[bool] = None, dump_config: bool | None = None,
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
req = SubscribeLogsRequest() req = SubscribeLogsRequest()
@ -492,11 +482,11 @@ class APIClient:
address: int, address: int,
handle: int, handle: int,
request: message.Message, request: message.Message,
response_type: Union[ response_type: (
Type[BluetoothGATTNotifyResponse], type[BluetoothGATTNotifyResponse]
Type[BluetoothGATTReadResponse], | type[BluetoothGATTReadResponse]
Type[BluetoothGATTWriteResponse], | type[BluetoothGATTWriteResponse]
], ),
timeout: float = 10.0, timeout: float = 10.0,
) -> message.Message: ) -> message.Message:
self._check_authenticated() self._check_authenticated()
@ -541,7 +531,7 @@ class APIClient:
return unsub return unsub
async def subscribe_bluetooth_le_raw_advertisements( async def subscribe_bluetooth_le_raw_advertisements(
self, on_advertisements: Callable[[List[BluetoothLERawAdvertisement]], None] self, on_advertisements: Callable[[list[BluetoothLERawAdvertisement]], None]
) -> Callable[[], None]: ) -> Callable[[], None]:
self._check_authenticated() self._check_authenticated()
msg_types = (BluetoothLERawAdvertisementsResponse,) msg_types = (BluetoothLERawAdvertisementsResponse,)
@ -615,7 +605,7 @@ class APIClient:
disconnect_timeout: float = DEFAULT_BLE_DISCONNECT_TIMEOUT, disconnect_timeout: float = DEFAULT_BLE_DISCONNECT_TIMEOUT,
feature_flags: int = 0, feature_flags: int = 0,
has_cache: bool = False, has_cache: bool = False,
address_type: Optional[int] = None, address_type: int | None = None,
) -> Callable[[], None]: ) -> Callable[[], None]:
self._check_authenticated() self._check_authenticated()
msg_types = (BluetoothDeviceConnectionResponse,) msg_types = (BluetoothDeviceConnectionResponse,)
@ -955,7 +945,7 @@ class APIClient:
address: int, address: int,
handle: int, handle: int,
on_bluetooth_gatt_notify: Callable[[int, bytearray], None], on_bluetooth_gatt_notify: Callable[[int, bytearray], None],
) -> Tuple[Callable[[], Coroutine[Any, Any, None]], Callable[[], None]]: ) -> tuple[Callable[[], Coroutine[Any, Any, None]], Callable[[], None]]:
"""Start a notify session for a GATT characteristic. """Start a notify session for a GATT characteristic.
Returns two functions that can be used to stop the notify. Returns two functions that can be used to stop the notify.
@ -1001,7 +991,7 @@ class APIClient:
return stop_notify, remove_callback return stop_notify, remove_callback
async def subscribe_home_assistant_states( async def subscribe_home_assistant_states(
self, on_state_sub: Callable[[str, Optional[str]], None] self, on_state_sub: Callable[[str, str | None], None]
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1018,7 +1008,7 @@ class APIClient:
) )
async def send_home_assistant_state( async def send_home_assistant_state(
self, entity_id: str, attribute: Optional[str], state: str self, entity_id: str, attribute: str | None, state: str
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1034,8 +1024,8 @@ class APIClient:
async def cover_command( async def cover_command(
self, self,
key: int, key: int,
position: Optional[float] = None, position: float | None = None,
tilt: Optional[float] = None, tilt: float | None = None,
stop: bool = False, stop: bool = False,
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1068,11 +1058,11 @@ class APIClient:
async def fan_command( async def fan_command(
self, self,
key: int, key: int,
state: Optional[bool] = None, state: bool | None = None,
speed: Optional[FanSpeed] = None, speed: FanSpeed | None = None,
speed_level: Optional[int] = None, speed_level: int | None = None,
oscillating: Optional[bool] = None, oscillating: bool | None = None,
direction: Optional[FanDirection] = None, direction: FanDirection | None = None,
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1099,18 +1089,18 @@ class APIClient:
async def light_command( async def light_command(
self, self,
key: int, key: int,
state: Optional[bool] = None, state: bool | None = None,
brightness: Optional[float] = None, brightness: float | None = None,
color_mode: Optional[int] = None, color_mode: int | None = None,
color_brightness: Optional[float] = None, color_brightness: float | None = None,
rgb: Optional[Tuple[float, float, float]] = None, rgb: tuple[float, float, float] | None = None,
white: Optional[float] = None, white: float | None = None,
color_temperature: Optional[float] = None, color_temperature: float | None = None,
cold_white: Optional[float] = None, cold_white: float | None = None,
warm_white: Optional[float] = None, warm_white: float | None = None,
transition_length: Optional[float] = None, transition_length: float | None = None,
flash_length: Optional[float] = None, flash_length: float | None = None,
effect: Optional[str] = None, effect: str | None = None,
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1169,15 +1159,15 @@ class APIClient:
async def climate_command( async def climate_command(
self, self,
key: int, key: int,
mode: Optional[ClimateMode] = None, mode: ClimateMode | None = None,
target_temperature: Optional[float] = None, target_temperature: float | None = None,
target_temperature_low: Optional[float] = None, target_temperature_low: float | None = None,
target_temperature_high: Optional[float] = None, target_temperature_high: float | None = None,
fan_mode: Optional[ClimateFanMode] = None, fan_mode: ClimateFanMode | None = None,
swing_mode: Optional[ClimateSwingMode] = None, swing_mode: ClimateSwingMode | None = None,
custom_fan_mode: Optional[str] = None, custom_fan_mode: str | None = None,
preset: Optional[ClimatePreset] = None, preset: ClimatePreset | None = None,
custom_preset: Optional[str] = None, custom_preset: str | None = None,
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1239,10 +1229,10 @@ class APIClient:
async def siren_command( async def siren_command(
self, self,
key: int, key: int,
state: Optional[bool] = None, state: bool | None = None,
tone: Optional[str] = None, tone: str | None = None,
volume: Optional[float] = None, volume: float | None = None,
duration: Optional[int] = None, duration: int | None = None,
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1275,7 +1265,7 @@ class APIClient:
self, self,
key: int, key: int,
command: LockCommand, command: LockCommand,
code: Optional[str] = None, code: str | None = None,
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1291,9 +1281,9 @@ class APIClient:
self, self,
key: int, key: int,
*, *,
command: Optional[MediaPlayerCommand] = None, command: MediaPlayerCommand | None = None,
volume: Optional[float] = None, volume: float | None = None,
media_url: Optional[str] = None, media_url: str | None = None,
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1365,14 +1355,14 @@ class APIClient:
await self._request_image(stream=True) await self._request_image(stream=True)
@property @property
def api_version(self) -> Optional[APIVersion]: def api_version(self) -> APIVersion | None:
if self._connection is None: if self._connection is None:
return None return None
return self._connection.api_version return self._connection.api_version
async def subscribe_voice_assistant( async def subscribe_voice_assistant(
self, self,
handle_start: Callable[[str, bool], Coroutine[Any, Any, Optional[int]]], handle_start: Callable[[str, bool], Coroutine[Any, Any, int | None]],
handle_stop: Callable[[], Coroutine[Any, Any, None]], handle_stop: Callable[[], Coroutine[Any, Any, None]],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Subscribes to voice assistant messages from the device. """Subscribes to voice assistant messages from the device.
@ -1386,9 +1376,9 @@ class APIClient:
""" """
self._check_authenticated() self._check_authenticated()
start_task: Optional[asyncio.Task[Optional[int]]] = None start_task: asyncio.Task[int | None] | None = None
def _started(fut: asyncio.Task[Optional[int]]) -> None: def _started(fut: asyncio.Task[int | None]) -> None:
if self._connection is not None and not fut.cancelled(): if self._connection is not None and not fut.cancelled():
port = fut.result() port = fut.result()
if port is not None: if port is not None:
@ -1432,7 +1422,7 @@ class APIClient:
return unsub return unsub
def send_voice_assistant_event( def send_voice_assistant_event(
self, event_type: VoiceAssistantEventType, data: Optional[dict[str, str]] self, event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
@ -1457,7 +1447,7 @@ class APIClient:
self, self,
key: int, key: int,
command: AlarmControlPanelCommand, command: AlarmControlPanelCommand,
code: Optional[str] = None, code: str | None = None,
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()

View File

@ -1,24 +1,15 @@
from __future__ import annotations
import asyncio import asyncio
import contextvars import contextvars
import enum import enum
import logging import logging
import socket import socket
import time import time
from collections.abc import Coroutine, Iterable
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from functools import partial from functools import partial
from typing import ( from typing import TYPE_CHECKING, Any, Callable
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Set,
Type,
Union,
)
import async_timeout import async_timeout
from google.protobuf import message from google.protobuf import message
@ -94,7 +85,7 @@ CONNECT_AND_SETUP_TIMEOUT = 120.0
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0 DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar( in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar(
"in_do_connect" "in_do_connect"
) )
@ -103,12 +94,12 @@ in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar(
class ConnectionParams: class ConnectionParams:
address: str address: str
port: int port: int
password: Optional[str] password: str | None
client_info: str client_info: str
keepalive: float keepalive: float
zeroconf_instance: hr.ZeroconfInstanceType zeroconf_instance: hr.ZeroconfInstanceType
noise_psk: Optional[str] noise_psk: str | None
expected_name: Optional[str] expected_name: str | None
class ConnectionState(enum.Enum): class ConnectionState(enum.Enum):
@ -162,16 +153,16 @@ class APIConnection:
self, self,
params: ConnectionParams, params: ConnectionParams,
on_stop: Callable[[bool], Coroutine[Any, Any, None]], on_stop: Callable[[bool], Coroutine[Any, Any, None]],
log_name: Optional[str] = None, log_name: str | None = None,
) -> None: ) -> None:
self._params = params self._params = params
self.on_stop: Optional[Callable[[bool], Coroutine[Any, Any, None]]] = on_stop self.on_stop: Callable[[bool], Coroutine[Any, Any, None]] | None = on_stop
self._on_stop_task: Optional[asyncio.Task[None]] = None self._on_stop_task: asyncio.Task[None] | None = None
self._socket: Optional[socket.socket] = None self._socket: socket.socket | None = None
self._frame_helper: Optional[ self._frame_helper: None | (
Union[APINoiseFrameHelper, APIPlaintextFrameHelper] APINoiseFrameHelper | APIPlaintextFrameHelper
] = None ) = None
self.api_version: Optional[APIVersion] = None self.api_version: APIVersion | None = None
self._connection_state = ConnectionState.INITIALIZED self._connection_state = ConnectionState.INITIALIZED
# Store whether connect() has completed # Store whether connect() has completed
@ -179,20 +170,20 @@ class APIConnection:
self._connect_complete = False self._connect_complete = False
# Message handlers currently subscribed to incoming messages # Message handlers currently subscribed to incoming messages
self._message_handlers: Dict[Any, Set[Callable[[message.Message], None]]] = {} self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs # The friendly name to show for this connection in the logs
self.log_name = log_name or params.address self.log_name = log_name or params.address
# futures currently subscribed to exceptions in the read task # futures currently subscribed to exceptions in the read task
self._read_exception_futures: Set[asyncio.Future[None]] = set() self._read_exception_futures: set[asyncio.Future[None]] = set()
self._ping_timer: Optional[asyncio.TimerHandle] = None self._ping_timer: asyncio.TimerHandle | None = None
self._pong_timer: Optional[asyncio.TimerHandle] = None self._pong_timer: asyncio.TimerHandle | None = None
self._keep_alive_interval = params.keepalive self._keep_alive_interval = params.keepalive
self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO
self._connect_task: Optional[asyncio.Task[None]] = None self._connect_task: asyncio.Task[None] | None = None
self._fatal_exception: Optional[Exception] = None self._fatal_exception: Exception | None = None
self._expected_disconnect = False self._expected_disconnect = False
self._send_pending_ping = False self._send_pending_ping = False
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
@ -318,7 +309,7 @@ class APIConnection:
async def _connect_init_frame_helper(self) -> None: async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop.""" """Step 3 in connect process: initialize the frame helper and init read loop."""
fh: Union[APIPlaintextFrameHelper, APINoiseFrameHelper] fh: APIPlaintextFrameHelper | APINoiseFrameHelper
loop = self._loop loop = self._loop
process_packet = self._process_packet_factory() process_packet = self._process_packet_factory()
@ -563,7 +554,7 @@ class APIConnection:
raise raise
def add_message_callback( def add_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]] self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Add a message callback.""" """Add a message callback."""
message_handlers = self._message_handlers message_handlers = self._message_handlers
@ -572,7 +563,7 @@ class APIConnection:
return partial(self._remove_message_callback, on_message, msg_types) return partial(self._remove_message_callback, on_message, msg_types)
def _remove_message_callback( def _remove_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]] self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
) -> None: ) -> None:
"""Remove a message callback.""" """Remove a message callback."""
message_handlers = self._message_handlers message_handlers = self._message_handlers
@ -583,7 +574,7 @@ class APIConnection:
self, self,
send_msg: message.Message, send_msg: message.Message,
on_message: Callable[[Any], None], on_message: Callable[[Any], None],
msg_types: Iterable[Type[Any]], msg_types: Iterable[type[Any]],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Send a message to the remote and register the given message handler.""" """Send a message to the remote and register the given message handler."""
self.send_message(send_msg) self.send_message(send_msg)
@ -604,9 +595,9 @@ class APIConnection:
def _handle_complex_message( def _handle_complex_message(
self, self,
fut: asyncio.Future[None], fut: asyncio.Future[None],
responses: List[message.Message], responses: list[message.Message],
do_append: Optional[Callable[[message.Message], bool]], do_append: Callable[[message.Message], bool] | None,
do_stop: Optional[Callable[[message.Message], bool]], do_stop: Callable[[message.Message], bool] | None,
resp: message.Message, resp: message.Message,
) -> None: ) -> None:
"""Handle a message that is part of a response.""" """Handle a message that is part of a response."""
@ -620,11 +611,11 @@ class APIConnection:
async def send_message_await_response_complex( async def send_message_await_response_complex(
self, self,
send_msg: message.Message, send_msg: message.Message,
do_append: Optional[Callable[[message.Message], bool]], do_append: Callable[[message.Message], bool] | None,
do_stop: Optional[Callable[[message.Message], bool]], do_stop: Callable[[message.Message], bool] | None,
msg_types: Iterable[Type[Any]], msg_types: Iterable[type[Any]],
timeout: float = 10.0, timeout: float = 10.0,
) -> List[message.Message]: ) -> list[message.Message]:
"""Send a message to the remote and build up a list response. """Send a message to the remote and build up a list response.
:param send_msg: The message (request) to send. :param send_msg: The message (request) to send.
@ -641,7 +632,7 @@ class APIConnection:
self.send_message(send_msg) self.send_message(send_msg)
# Unsafe to await between sending the message and registering the handler # Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = self._loop.create_future() fut: asyncio.Future[None] = self._loop.create_future()
responses: List[message.Message] = [] responses: list[message.Message] = []
on_message = partial( on_message = partial(
self._handle_complex_message, fut, responses, do_append, do_stop self._handle_complex_message, fut, responses, do_append, do_stop
) )

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import re import re
from typing import Optional
from aioesphomeapi.model import BluetoothGATTError from aioesphomeapi.model import BluetoothGATTError
@ -198,7 +199,7 @@ class BadNameAPIError(APIConnectionError):
class InvalidEncryptionKeyAPIError(HandshakeAPIError): class InvalidEncryptionKeyAPIError(HandshakeAPIError):
def __init__( def __init__(
self, msg: Optional[str] = None, received_name: Optional[str] = None self, msg: str | None = None, received_name: str | None = None
) -> None: ) -> None:
super().__init__(f"{msg}: received_name={received_name}") super().__init__(f"{msg}: received_name={received_name}")
self.received_name = received_name self.received_name = received_name

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import socket import socket
from dataclasses import dataclass from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
from typing import List, Optional, Tuple, Union, cast from typing import Union, cast
import zeroconf import zeroconf
import zeroconf.asyncio import zeroconf.asyncio
@ -45,7 +47,7 @@ async def _async_zeroconf_get_service_info(
service_type: str, service_type: str,
service_name: str, service_name: str,
timeout: float, timeout: float,
) -> Optional["zeroconf.ServiceInfo"]: ) -> "zeroconf.ServiceInfo" | None:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf # Use or create zeroconf instance, ensure it's an AsyncZeroconf
if zeroconf_instance is None: if zeroconf_instance is None:
try: try:
@ -87,7 +89,7 @@ async def _async_resolve_host_zeroconf(
*, *,
timeout: float = 3.0, timeout: float = 3.0,
zeroconf_instance: ZeroconfInstanceType = None, zeroconf_instance: ZeroconfInstanceType = None,
) -> List[AddrInfo]: ) -> list[AddrInfo]:
service_type = "_esphomelib._tcp.local." service_type = "_esphomelib._tcp.local."
service_name = f"{host}.{service_type}" service_name = f"{host}.{service_type}"
@ -98,7 +100,7 @@ async def _async_resolve_host_zeroconf(
if info is None: if info is None:
return [] return []
addrs: List[AddrInfo] = [] addrs: list[AddrInfo] = []
for raw in info.addresses_by_version(zeroconf.IPVersion.All): for raw in info.addresses_by_version(zeroconf.IPVersion.All):
is_ipv6 = len(raw) == 16 is_ipv6 = len(raw) == 16
sockaddr: Sockaddr sockaddr: Sockaddr
@ -126,7 +128,7 @@ async def _async_resolve_host_zeroconf(
return addrs return addrs
async def _async_resolve_host_getaddrinfo(host: str, port: int) -> List[AddrInfo]: async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo]:
try: try:
# Limit to TCP IP protocol and SOCK_STREAM # Limit to TCP IP protocol and SOCK_STREAM
res = await asyncio.get_event_loop().getaddrinfo( res = await asyncio.get_event_loop().getaddrinfo(
@ -135,15 +137,15 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> List[AddrInfo
except OSError as err: except OSError as err:
raise APIConnectionError(f"Error resolving IP address: {err}") raise APIConnectionError(f"Error resolving IP address: {err}")
addrs: List[AddrInfo] = [] addrs: list[AddrInfo] = []
for family, type_, proto, _, raw in res: for family, type_, proto, _, raw in res:
sockaddr: Sockaddr sockaddr: Sockaddr
if family == socket.AF_INET: if family == socket.AF_INET:
raw = cast(Tuple[str, int], raw) raw = cast(tuple[str, int], raw)
address, port = raw address, port = raw
sockaddr = IPv4Sockaddr(address=address, port=port) sockaddr = IPv4Sockaddr(address=address, port=port)
elif family == socket.AF_INET6: elif family == socket.AF_INET6:
raw = cast(Tuple[str, int, int, int], raw) raw = cast(tuple[str, int, int, int], raw)
address, port, flowinfo, scope_id = raw address, port, flowinfo, scope_id = raw
sockaddr = IPv6Sockaddr( sockaddr = IPv6Sockaddr(
address=address, port=port, flowinfo=flowinfo, scope_id=scope_id address=address, port=port, flowinfo=flowinfo, scope_id=scope_id
@ -158,7 +160,7 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> List[AddrInfo
return addrs return addrs
def _async_ip_address_to_addrs(host: str, port: int) -> List[AddrInfo]: def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]:
"""Convert an ipaddress to AddrInfo.""" """Convert an ipaddress to AddrInfo."""
with contextlib.suppress(ValueError): with contextlib.suppress(ValueError):
return [ return [
@ -193,7 +195,7 @@ async def async_resolve_host(
port: int, port: int,
zeroconf_instance: ZeroconfInstanceType = None, zeroconf_instance: ZeroconfInstanceType = None,
) -> AddrInfo: ) -> AddrInfo:
addrs: List[AddrInfo] = [] addrs: list[AddrInfo] = []
zc_error = None zc_error = None
if host.endswith(".local"): if host.endswith(".local"):

View File

@ -1,10 +1,11 @@
from __future__ import annotations
# Helper script and aioesphomeapi to view logs from an esphome device # Helper script and aioesphomeapi to view logs from an esphome device
import argparse import argparse
import asyncio import asyncio
import logging import logging
import sys import sys
from datetime import datetime from datetime import datetime
from typing import List
import zeroconf import zeroconf
@ -17,7 +18,7 @@ from aioesphomeapi.reconnect_logic import ReconnectLogic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def main(argv: List[str]) -> None: async def main(argv: list[str]) -> None:
parser = argparse.ArgumentParser("aioesphomeapi-logs") parser = argparse.ArgumentParser("aioesphomeapi-logs")
parser.add_argument("--port", type=int, default=6053) parser.add_argument("--port", type=int, default=6053)
parser.add_argument("--password", type=str) parser.add_argument("--password", type=str)

View File

@ -1,20 +1,11 @@
from __future__ import annotations
import enum import enum
import sys import sys
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from functools import cache, lru_cache, partial from functools import cache, lru_cache, partial
from typing import ( from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
from uuid import UUID from uuid import UUID
from .util import fix_float_single_double_conversion from .util import fix_float_single_double_conversion
@ -48,14 +39,14 @@ class APIIntEnum(enum.IntEnum):
"""Base class for int enum values in API model.""" """Base class for int enum values in API model."""
@classmethod @classmethod
def convert(cls: Type[_T], value: int) -> Optional[_T]: def convert(cls: type[_T], value: int) -> _T | None:
try: try:
return cls(value) return cls(value)
except ValueError: except ValueError:
return None return None
@classmethod @classmethod
def convert_list(cls: Type[_T], value: List[int]) -> List[_T]: def convert_list(cls: type[_T], value: list[int]) -> list[_T]:
ret = [] ret = []
for x in value: for x in value:
try: try:
@ -81,12 +72,12 @@ class APIModelBase:
# use this setattr to prevent FrozenInstanceError # use this setattr to prevent FrozenInstanceError
object.__setattr__(self, field_.name, convert(val)) object.__setattr__(self, field_.name, convert(val))
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) # type: ignore[no-any-return, call-overload] return asdict(self) # type: ignore[no-any-return, call-overload]
@classmethod @classmethod
def from_dict( def from_dict(
cls: Type[_V], data: Dict[str, Any], *, ignore_missing: bool = True cls: type[_V], data: dict[str, Any], *, ignore_missing: bool = True
) -> _V: ) -> _V:
init_args = { init_args = {
f.name: data[f.name] f.name: data[f.name]
@ -96,7 +87,7 @@ class APIModelBase:
return cls(**init_args) return cls(**init_args)
@classmethod @classmethod
def from_pb(cls: Type[_V], data: Any) -> _V: def from_pb(cls: type[_V], data: Any) -> _V:
return cls(**{f.name: getattr(data, f.name) for f in cached_fields(cls)}) # type: ignore[arg-type] return cls(**{f.name: getattr(data, f.name) for f in cached_fields(cls)}) # type: ignore[arg-type]
@ -174,7 +165,7 @@ class EntityInfo(APIModelBase):
unique_id: str = "" unique_id: str = ""
disabled_by_default: bool = False disabled_by_default: bool = False
icon: str = "" icon: str = ""
entity_category: Optional[EntityCategory] = converter_field( entity_category: EntityCategory | None = converter_field(
default=EntityCategory.NONE, converter=EntityCategory.convert default=EntityCategory.NONE, converter=EntityCategory.convert
) )
@ -226,7 +217,7 @@ class CoverOperation(APIIntEnum):
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class CoverState(EntityState): class CoverState(EntityState):
legacy_state: Optional[LegacyCoverState] = converter_field( legacy_state: LegacyCoverState | None = converter_field(
default=LegacyCoverState.OPEN, converter=LegacyCoverState.convert default=LegacyCoverState.OPEN, converter=LegacyCoverState.convert
) )
position: float = converter_field( position: float = converter_field(
@ -235,7 +226,7 @@ class CoverState(EntityState):
tilt: float = converter_field( tilt: float = converter_field(
default=0.0, converter=fix_float_single_double_conversion default=0.0, converter=fix_float_single_double_conversion
) )
current_operation: Optional[CoverOperation] = converter_field( current_operation: CoverOperation | None = converter_field(
default=CoverOperation.IDLE, converter=CoverOperation.convert default=CoverOperation.IDLE, converter=CoverOperation.convert
) )
@ -269,11 +260,11 @@ class FanDirection(APIIntEnum):
class FanState(EntityState): class FanState(EntityState):
state: bool = False state: bool = False
oscillating: bool = False oscillating: bool = False
speed: Optional[FanSpeed] = converter_field( speed: FanSpeed | None = converter_field(
default=FanSpeed.LOW, converter=FanSpeed.convert default=FanSpeed.LOW, converter=FanSpeed.convert
) )
speed_level: int = 0 speed_level: int = 0
direction: Optional[FanDirection] = converter_field( direction: FanDirection | None = converter_field(
default=FanDirection.FORWARD, converter=FanDirection.convert default=FanDirection.FORWARD, converter=FanDirection.convert
) )
@ -290,7 +281,7 @@ class LightColorCapability(enum.IntFlag):
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class LightInfo(EntityInfo): class LightInfo(EntityInfo):
supported_color_modes: List[int] = converter_field( supported_color_modes: list[int] = converter_field(
default_factory=list, converter=list default_factory=list, converter=list
) )
min_mireds: float = converter_field( min_mireds: float = converter_field(
@ -299,7 +290,7 @@ class LightInfo(EntityInfo):
max_mireds: float = converter_field( max_mireds: float = converter_field(
default=0.0, converter=fix_float_single_double_conversion default=0.0, converter=fix_float_single_double_conversion
) )
effects: List[str] = converter_field(default_factory=list, converter=list) effects: list[str] = converter_field(default_factory=list, converter=list)
# deprecated, do not use # deprecated, do not use
legacy_supports_brightness: bool = False legacy_supports_brightness: bool = False
@ -307,7 +298,7 @@ class LightInfo(EntityInfo):
legacy_supports_white_value: bool = False legacy_supports_white_value: bool = False
legacy_supports_color_temperature: bool = False legacy_supports_color_temperature: bool = False
def supported_color_modes_compat(self, api_version: APIVersion) -> List[int]: def supported_color_modes_compat(self, api_version: APIVersion) -> list[int]:
if api_version < APIVersion(1, 6): if api_version < APIVersion(1, 6):
key = ( key = (
self.legacy_supports_brightness, self.legacy_supports_brightness,
@ -353,7 +344,7 @@ class LightInfo(EntityInfo):
], ],
} }
return cast(List[int], modes_map[key]) if key in modes_map else [] return cast(list[int], modes_map[key]) if key in modes_map else []
return self.supported_color_modes return self.supported_color_modes
@ -412,10 +403,10 @@ class SensorInfo(EntityInfo):
unit_of_measurement: str = "" unit_of_measurement: str = ""
accuracy_decimals: int = 0 accuracy_decimals: int = 0
force_update: bool = False force_update: bool = False
state_class: Optional[SensorStateClass] = converter_field( state_class: SensorStateClass | None = converter_field(
default=SensorStateClass.NONE, converter=SensorStateClass.convert default=SensorStateClass.NONE, converter=SensorStateClass.convert
) )
last_reset_type: Optional[LastResetType] = converter_field( last_reset_type: LastResetType | None = converter_field(
default=LastResetType.NONE, converter=LastResetType.convert default=LastResetType.NONE, converter=LastResetType.convert
) )
@ -516,7 +507,7 @@ class ClimatePreset(APIIntEnum):
class ClimateInfo(EntityInfo): class ClimateInfo(EntityInfo):
supports_current_temperature: bool = False supports_current_temperature: bool = False
supports_two_point_target_temperature: bool = False supports_two_point_target_temperature: bool = False
supported_modes: List[ClimateMode] = converter_field( supported_modes: list[ClimateMode] = converter_field(
default_factory=list, converter=ClimateMode.convert_list default_factory=list, converter=ClimateMode.convert_list
) )
visual_min_temperature: float = converter_field( visual_min_temperature: float = converter_field(
@ -533,23 +524,23 @@ class ClimateInfo(EntityInfo):
) )
legacy_supports_away: bool = False legacy_supports_away: bool = False
supports_action: bool = False supports_action: bool = False
supported_fan_modes: List[ClimateFanMode] = converter_field( supported_fan_modes: list[ClimateFanMode] = converter_field(
default_factory=list, converter=ClimateFanMode.convert_list default_factory=list, converter=ClimateFanMode.convert_list
) )
supported_swing_modes: List[ClimateSwingMode] = converter_field( supported_swing_modes: list[ClimateSwingMode] = converter_field(
default_factory=list, converter=ClimateSwingMode.convert_list default_factory=list, converter=ClimateSwingMode.convert_list
) )
supported_custom_fan_modes: List[str] = converter_field( supported_custom_fan_modes: list[str] = converter_field(
default_factory=list, converter=list default_factory=list, converter=list
) )
supported_presets: List[ClimatePreset] = converter_field( supported_presets: list[ClimatePreset] = converter_field(
default_factory=list, converter=ClimatePreset.convert_list default_factory=list, converter=ClimatePreset.convert_list
) )
supported_custom_presets: List[str] = converter_field( supported_custom_presets: list[str] = converter_field(
default_factory=list, converter=list default_factory=list, converter=list
) )
def supported_presets_compat(self, api_version: APIVersion) -> List[ClimatePreset]: def supported_presets_compat(self, api_version: APIVersion) -> list[ClimatePreset]:
if api_version < APIVersion(1, 5): if api_version < APIVersion(1, 5):
return ( return (
[ClimatePreset.HOME, ClimatePreset.AWAY] [ClimatePreset.HOME, ClimatePreset.AWAY]
@ -561,10 +552,10 @@ class ClimateInfo(EntityInfo):
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class ClimateState(EntityState): class ClimateState(EntityState):
mode: Optional[ClimateMode] = converter_field( mode: ClimateMode | None = converter_field(
default=ClimateMode.OFF, converter=ClimateMode.convert default=ClimateMode.OFF, converter=ClimateMode.convert
) )
action: Optional[ClimateAction] = converter_field( action: ClimateAction | None = converter_field(
default=ClimateAction.OFF, converter=ClimateAction.convert default=ClimateAction.OFF, converter=ClimateAction.convert
) )
current_temperature: float = converter_field( current_temperature: float = converter_field(
@ -580,19 +571,19 @@ class ClimateState(EntityState):
default=0.0, converter=fix_float_single_double_conversion default=0.0, converter=fix_float_single_double_conversion
) )
legacy_away: bool = False legacy_away: bool = False
fan_mode: Optional[ClimateFanMode] = converter_field( fan_mode: ClimateFanMode | None = converter_field(
default=ClimateFanMode.ON, converter=ClimateFanMode.convert default=ClimateFanMode.ON, converter=ClimateFanMode.convert
) )
swing_mode: Optional[ClimateSwingMode] = converter_field( swing_mode: ClimateSwingMode | None = converter_field(
default=ClimateSwingMode.OFF, converter=ClimateSwingMode.convert default=ClimateSwingMode.OFF, converter=ClimateSwingMode.convert
) )
custom_fan_mode: str = "" custom_fan_mode: str = ""
preset: Optional[ClimatePreset] = converter_field( preset: ClimatePreset | None = converter_field(
default=ClimatePreset.NONE, converter=ClimatePreset.convert default=ClimatePreset.NONE, converter=ClimatePreset.convert
) )
custom_preset: str = "" custom_preset: str = ""
def preset_compat(self, api_version: APIVersion) -> Optional[ClimatePreset]: def preset_compat(self, api_version: APIVersion) -> ClimatePreset | None:
if api_version < APIVersion(1, 5): if api_version < APIVersion(1, 5):
return ClimatePreset.AWAY if self.legacy_away else ClimatePreset.HOME return ClimatePreset.AWAY if self.legacy_away else ClimatePreset.HOME
return self.preset return self.preset
@ -617,7 +608,7 @@ class NumberInfo(EntityInfo):
default=0.0, converter=fix_float_single_double_conversion default=0.0, converter=fix_float_single_double_conversion
) )
unit_of_measurement: str = "" unit_of_measurement: str = ""
mode: Optional[NumberMode] = converter_field( mode: NumberMode | None = converter_field(
default=NumberMode.AUTO, converter=NumberMode.convert default=NumberMode.AUTO, converter=NumberMode.convert
) )
device_class: str = "" device_class: str = ""
@ -634,7 +625,7 @@ class NumberState(EntityState):
# ==================== SELECT ==================== # ==================== SELECT ====================
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class SelectInfo(EntityInfo): class SelectInfo(EntityInfo):
options: List[str] = converter_field(default_factory=list, converter=list) options: list[str] = converter_field(default_factory=list, converter=list)
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
@ -646,7 +637,7 @@ class SelectState(EntityState):
# ==================== SIREN ==================== # ==================== SIREN ====================
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class SirenInfo(EntityInfo): class SirenInfo(EntityInfo):
tones: List[str] = converter_field(default_factory=list, converter=list) tones: list[str] = converter_field(default_factory=list, converter=list)
supports_volume: bool = False supports_volume: bool = False
supports_duration: bool = False supports_duration: bool = False
@ -689,7 +680,7 @@ class LockInfo(EntityInfo):
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class LockEntityState(EntityState): class LockEntityState(EntityState):
state: Optional[LockState] = converter_field( state: LockState | None = converter_field(
default=LockState.NONE, converter=LockState.convert default=LockState.NONE, converter=LockState.convert
) )
@ -717,7 +708,7 @@ class MediaPlayerInfo(EntityInfo):
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class MediaPlayerEntityState(EntityState): class MediaPlayerEntityState(EntityState):
state: Optional[MediaPlayerState] = converter_field( state: MediaPlayerState | None = converter_field(
default=MediaPlayerState.NONE, converter=MediaPlayerState.convert default=MediaPlayerState.NONE, converter=MediaPlayerState.convert
) )
volume: float = converter_field( volume: float = converter_field(
@ -759,7 +750,7 @@ class AlarmControlPanelInfo(EntityInfo):
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class AlarmControlPanelEntityState(EntityState): class AlarmControlPanelEntityState(EntityState):
state: Optional[AlarmControlPanelState] = converter_field( state: AlarmControlPanelState | None = converter_field(
default=AlarmControlPanelState.DISARMED, default=AlarmControlPanelState.DISARMED,
converter=AlarmControlPanelState.convert, converter=AlarmControlPanelState.convert,
) )
@ -767,7 +758,7 @@ class AlarmControlPanelEntityState(EntityState):
# ==================== INFO MAP ==================== # ==================== INFO MAP ====================
COMPONENT_TYPE_TO_INFO: Dict[str, Type[EntityInfo]] = { COMPONENT_TYPE_TO_INFO: dict[str, type[EntityInfo]] = {
"binary_sensor": BinarySensorInfo, "binary_sensor": BinarySensorInfo,
"cover": CoverInfo, "cover": CoverInfo,
"fan": FanInfo, "fan": FanInfo,
@ -789,8 +780,8 @@ COMPONENT_TYPE_TO_INFO: Dict[str, Type[EntityInfo]] = {
# ==================== USER-DEFINED SERVICES ==================== # ==================== USER-DEFINED SERVICES ====================
def _convert_homeassistant_service_map( def _convert_homeassistant_service_map(
value: Union[Dict[str, str], Iterable["HomeassistantServiceMap"]], value: dict[str, str] | Iterable[HomeassistantServiceMap],
) -> Dict[str, str]: ) -> dict[str, str]:
if isinstance(value, dict): if isinstance(value, dict):
# already a dict, don't convert # already a dict, don't convert
return value return value
@ -801,13 +792,13 @@ def _convert_homeassistant_service_map(
class HomeassistantServiceCall(APIModelBase): class HomeassistantServiceCall(APIModelBase):
service: str = "" service: str = ""
is_event: bool = False is_event: bool = False
data: Dict[str, str] = converter_field( data: dict[str, str] = converter_field(
default_factory=dict, converter=_convert_homeassistant_service_map default_factory=dict, converter=_convert_homeassistant_service_map
) )
data_template: Dict[str, str] = converter_field( data_template: dict[str, str] = converter_field(
default_factory=dict, converter=_convert_homeassistant_service_map default_factory=dict, converter=_convert_homeassistant_service_map
) )
variables: Dict[str, str] = converter_field( variables: dict[str, str] = converter_field(
default_factory=dict, converter=_convert_homeassistant_service_map default_factory=dict, converter=_convert_homeassistant_service_map
) )
@ -826,12 +817,12 @@ class UserServiceArgType(APIIntEnum):
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class UserServiceArg(APIModelBase): class UserServiceArg(APIModelBase):
name: str = "" name: str = ""
type: Optional[UserServiceArgType] = converter_field( type: UserServiceArgType | None = converter_field(
default=UserServiceArgType.BOOL, converter=UserServiceArgType.convert default=UserServiceArgType.BOOL, converter=UserServiceArgType.convert
) )
@classmethod @classmethod
def convert_list(cls, value: List[Any]) -> List["UserServiceArg"]: def convert_list(cls, value: list[Any]) -> list[UserServiceArg]:
ret = [] ret = []
for x in value: for x in value:
if isinstance(x, dict): if isinstance(x, dict):
@ -847,7 +838,7 @@ class UserServiceArg(APIModelBase):
class UserService(APIModelBase): class UserService(APIModelBase):
name: str = "" name: str = ""
key: int = 0 key: int = 0
args: List[UserServiceArg] = converter_field( args: list[UserServiceArg] = converter_field(
default_factory=list, converter=UserServiceArg.convert_list default_factory=list, converter=UserServiceArg.convert_list
) )
@ -855,7 +846,7 @@ class UserService(APIModelBase):
# ==================== BLUETOOTH ==================== # ==================== BLUETOOTH ====================
def _join_split_uuid(value: List[int]) -> str: def _join_split_uuid(value: list[int]) -> str:
"""Convert a high/low uuid into a single string.""" """Convert a high/low uuid into a single string."""
return str(UUID(int=(value[0] << 64) | value[1])) return str(UUID(int=(value[0] << 64) | value[1]))
@ -877,14 +868,14 @@ class BluetoothLEAdvertisement:
rssi: int rssi: int
address_type: int address_type: int
name: str name: str
service_uuids: List[str] service_uuids: list[str]
service_data: Dict[str, bytes] service_data: dict[str, bytes]
manufacturer_data: Dict[int, bytes] manufacturer_data: dict[int, bytes]
@classmethod @classmethod
def from_pb( # type: ignore[misc] def from_pb( # type: ignore[misc]
cls: "BluetoothLEAdvertisement", data: "BluetoothLEAdvertisementResponse" cls: BluetoothLEAdvertisement, data: BluetoothLEAdvertisementResponse
) -> "BluetoothLEAdvertisement": ) -> BluetoothLEAdvertisement:
_uuid_convert = _cached_uuid_converter _uuid_convert = _cached_uuid_converter
if raw_manufacturer_data := data.manufacturer_data: if raw_manufacturer_data := data.manufacturer_data:
@ -937,12 +928,12 @@ class BluetoothLERawAdvertisement:
def make_ble_raw_advertisement_processor( def make_ble_raw_advertisement_processor(
on_advertisements: Callable[[List[BluetoothLERawAdvertisement]], None] on_advertisements: Callable[[list[BluetoothLERawAdvertisement]], None]
) -> Callable[["BluetoothLERawAdvertisementsResponse"], None]: ) -> Callable[[BluetoothLERawAdvertisementsResponse], None]:
"""Make a processor for BluetoothLERawAdvertisementResponse.""" """Make a processor for BluetoothLERawAdvertisementResponse."""
def _on_ble_raw_advertisement_response( def _on_ble_raw_advertisement_response(
data: "BluetoothLERawAdvertisementsResponse", data: BluetoothLERawAdvertisementsResponse,
) -> None: ) -> None:
on_advertisements( on_advertisements(
[ [
@ -999,7 +990,7 @@ class BluetoothGATTDescriptor(APIModelBase):
handle: int = 0 handle: int = 0
@classmethod @classmethod
def convert_list(cls, value: List[Any]) -> List["BluetoothGATTDescriptor"]: def convert_list(cls, value: list[Any]) -> list[BluetoothGATTDescriptor]:
ret = [] ret = []
for x in value: for x in value:
if isinstance(x, dict): if isinstance(x, dict):
@ -1015,12 +1006,12 @@ class BluetoothGATTCharacteristic(APIModelBase):
handle: int = 0 handle: int = 0
properties: int = 0 properties: int = 0
descriptors: List[BluetoothGATTDescriptor] = converter_field( descriptors: list[BluetoothGATTDescriptor] = converter_field(
default_factory=list, converter=BluetoothGATTDescriptor.convert_list default_factory=list, converter=BluetoothGATTDescriptor.convert_list
) )
@classmethod @classmethod
def convert_list(cls, value: List[Any]) -> List["BluetoothGATTCharacteristic"]: def convert_list(cls, value: list[Any]) -> list[BluetoothGATTCharacteristic]:
ret = [] ret = []
for x in value: for x in value:
if isinstance(x, dict): if isinstance(x, dict):
@ -1034,12 +1025,12 @@ class BluetoothGATTCharacteristic(APIModelBase):
class BluetoothGATTService(APIModelBase): class BluetoothGATTService(APIModelBase):
uuid: str = converter_field(default="", converter=_join_split_uuid) uuid: str = converter_field(default="", converter=_join_split_uuid)
handle: int = 0 handle: int = 0
characteristics: List[BluetoothGATTCharacteristic] = converter_field( characteristics: list[BluetoothGATTCharacteristic] = converter_field(
default_factory=list, converter=BluetoothGATTCharacteristic.convert_list default_factory=list, converter=BluetoothGATTCharacteristic.convert_list
) )
@classmethod @classmethod
def convert_list(cls, value: List[Any]) -> List["BluetoothGATTService"]: def convert_list(cls, value: list[Any]) -> list[BluetoothGATTService]:
ret = [] ret = []
for x in value: for x in value:
if isinstance(x, dict): if isinstance(x, dict):
@ -1052,7 +1043,7 @@ class BluetoothGATTService(APIModelBase):
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class BluetoothGATTServices(APIModelBase): class BluetoothGATTServices(APIModelBase):
address: int = 0 address: int = 0
services: List[BluetoothGATTService] = converter_field( services: list[BluetoothGATTService] = converter_field(
default_factory=list, converter=BluetoothGATTService.convert_list default_factory=list, converter=BluetoothGATTService.convert_list
) )
@ -1060,7 +1051,7 @@ class BluetoothGATTServices(APIModelBase):
@_frozen_dataclass_decorator @_frozen_dataclass_decorator
class ESPHomeBluetoothGATTServices: class ESPHomeBluetoothGATTServices:
address: int = 0 address: int = 0
services: List[BluetoothGATTService] = field(default_factory=list) services: list[BluetoothGATTService] = field(default_factory=list)
@_frozen_dataclass_decorator @_frozen_dataclass_decorator

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from typing import Awaitable, Callable, List, Optional from collections.abc import Awaitable
from typing import Callable
import zeroconf import zeroconf
@ -34,9 +37,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
client: APIClient, client: APIClient,
on_connect: Callable[[], Awaitable[None]], on_connect: Callable[[], Awaitable[None]],
on_disconnect: Callable[[bool], Awaitable[None]], on_disconnect: Callable[[bool], Awaitable[None]],
zeroconf_instance: "zeroconf.Zeroconf", zeroconf_instance: zeroconf.Zeroconf,
name: Optional[str] = None, name: str | None = None,
on_connect_error: Optional[Callable[[Exception], Awaitable[None]]] = None, on_connect_error: Callable[[Exception], Awaitable[None]] | None = None,
) -> None: ) -> None:
"""Initialize ReconnectingLogic. """Initialize ReconnectingLogic.
@ -51,7 +54,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._on_disconnect_cb = on_disconnect self._on_disconnect_cb = on_disconnect
self._on_connect_error_cb = on_connect_error self._on_connect_error_cb = on_connect_error
self._zc = zeroconf_instance self._zc = zeroconf_instance
self._filter_alias: Optional[str] = None self._filter_alias: str | None = None
# Flag to check if the device is connected # Flag to check if the device is connected
self._connected = False self._connected = False
self._connected_lock = asyncio.Lock() self._connected_lock = asyncio.Lock()
@ -60,9 +63,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# How many connect attempts have there been already, used for exponential wait time # How many connect attempts have there been already, used for exponential wait time
self._tries = 0 self._tries = 0
# Event for tracking when logic should stop # Event for tracking when logic should stop
self._connect_task: Optional[asyncio.Task[None]] = None self._connect_task: asyncio.Task[None] | None = None
self._connect_timer: Optional[asyncio.TimerHandle] = None self._connect_timer: asyncio.TimerHandle | None = None
self._stop_task: Optional[asyncio.Task[None]] = None self._stop_task: asyncio.Task[None] | None = None
@property @property
def _log_name(self) -> str: def _log_name(self) -> str:
@ -244,9 +247,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
def async_update_records( def async_update_records(
self, self,
zc: "zeroconf.Zeroconf", # pylint: disable=unused-argument zc: zeroconf.Zeroconf, # pylint: disable=unused-argument
now: float, # pylint: disable=unused-argument now: float, # pylint: disable=unused-argument
records: List["zeroconf.RecordUpdate"], records: list[zeroconf.RecordUpdate],
) -> None: ) -> None:
"""Listen to zeroconf updated mDNS records. This must be called from the eventloop. """Listen to zeroconf updated mDNS records. This must be called from the eventloop.

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import math import math
from functools import lru_cache from functools import lru_cache
from typing import Optional
@lru_cache(maxsize=1024) @lru_cache(maxsize=1024)
@ -8,7 +9,7 @@ def varuint_to_bytes(value: int) -> bytes:
if value <= 0x7F: if value <= 0x7F:
return bytes([value]) return bytes([value])
ret = bytes() ret = b""
while value: while value:
temp = value & 0x7F temp = value & 0x7F
value >>= 7 value >>= 7
@ -21,7 +22,7 @@ def varuint_to_bytes(value: int) -> bytes:
@lru_cache(maxsize=1024) @lru_cache(maxsize=1024)
def bytes_to_varuint(value: bytes) -> Optional[int]: def bytes_to_varuint(value: bytes) -> int | None:
result = 0 result = 0
bitpos = 0 bitpos = 0
for val in value: for val in value: