Upgrade python code to 3.9 (#496)

This commit is contained in:
J. Nick Koston 2023-07-19 15:33:28 -05:00 committed by GitHub
parent 3a50305cf8
commit d63b9bbf5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 232 additions and 246 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from .noise import APINoiseFrameHelper
from .plain_text import APIPlaintextFrameHelper

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
import logging
from abc import abstractmethod
from functools import partial
from typing import Callable, Optional, Union, cast
from typing import Callable, cast
from ..core import SocketClosedAPIError
@ -45,10 +47,8 @@ class APIFrameHelper(asyncio.Protocol):
"""Initialize the API frame helper."""
self._on_pkt = on_pkt
self._on_error = on_error
self._transport: Optional[asyncio.Transport] = None
self._writer: Optional[
Callable[[Union[bytes, bytearray, memoryview]], None]
] = None
self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._connected_event = asyncio.Event()
self._buffer = bytearray()
self._buffer_len = 0
@ -57,7 +57,7 @@ class APIFrameHelper(asyncio.Protocol):
self._log_name = log_name
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
def _read_exactly(self, length: int) -> Optional[bytearray]:
def _read_exactly(self, length: int) -> bytearray | None:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos
new_pos = original_pos + length
@ -87,13 +87,13 @@ class APIFrameHelper(asyncio.Protocol):
def _handle_error(self, exc: Exception) -> None:
self._on_error(exc)
def connection_lost(self, exc: Optional[Exception]) -> None:
def connection_lost(self, exc: Exception | None) -> None:
self._handle_error(
exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
)
return super().connection_lost(exc)
def eof_received(self) -> Optional[bool]:
def eof_received(self) -> bool | None:
self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
return super().eof_received()

View File

@ -1,10 +1,12 @@
from __future__ import annotations
import asyncio
import base64
import logging
from enum import Enum
from functools import partial
from struct import Struct
from typing import TYPE_CHECKING, Any, Callable, Optional, Type
from typing import TYPE_CHECKING, Any, Callable
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
from cryptography.exceptions import InvalidTag
@ -34,7 +36,7 @@ class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
format_nonce = PACK_NONCE
@property
def klass(self) -> Type[ChaCha20Poly1305Reusable]:
def klass(self) -> type[ChaCha20Poly1305Reusable]:
return ChaCha20Poly1305Reusable
@ -81,7 +83,7 @@ class APINoiseFrameHelper(APIFrameHelper):
on_pkt: Callable[[int, bytes], None],
on_error: Callable[[Exception], None],
noise_psk: str,
expected_name: Optional[str],
expected_name: str | None,
client_info: str,
log_name: str,
) -> None:
@ -93,9 +95,9 @@ class APINoiseFrameHelper(APIFrameHelper):
self._noise_psk = noise_psk
self._expected_name = expected_name
self._set_state(NoiseConnectionState.HELLO)
self._server_name: Optional[str] = None
self._decrypt: Optional[Callable[[bytes], bytes]] = None
self._encrypt: Optional[Callable[[bytes], bytes]] = None
self._server_name: str | None = None
self._decrypt: Callable[[bytes], bytes] | None = None
self._encrypt: Callable[[bytes], bytes] | None = None
self._setup_proto()
self._is_ready = False

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
from ..util import bytes_to_varuint, varuint_to_bytes
@ -45,8 +47,8 @@ class APIPlaintextFrameHelper(APIFrameHelper):
init_bytes = self._read_exactly(3)
if init_bytes is None:
return
msg_type_int: Optional[int] = None
length_int: Optional[int] = None
msg_type_int: int | None = None
length_int: int | None = None
preamble, length_high, maybe_msg_type = init_bytes
if preamble != 0x00:
if preamble == 0x01:

View File

@ -1,20 +1,10 @@
from __future__ import annotations
import asyncio
import logging
from collections.abc import Awaitable, Coroutine
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Coroutine,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Union, cast
from google.protobuf import message
@ -200,8 +190,8 @@ DEFAULT_BLE_DISCONNECT_TIMEOUT = 5.0
# connection is poor.
KEEP_ALIVE_FREQUENCY = 20.0
ExecuteServiceDataType = Dict[
str, Union[bool, int, float, str, List[bool], List[int], List[float], List[str]]
ExecuteServiceDataType = dict[
str, Union[bool, int, float, str, list[bool], list[int], list[float], list[str]]
]
@ -213,13 +203,13 @@ class APIClient:
self,
address: str,
port: int,
password: Optional[str],
password: str | None,
*,
client_info: str = "aioesphomeapi",
keepalive: float = KEEP_ALIVE_FREQUENCY,
zeroconf_instance: ZeroconfInstanceType = None,
noise_psk: Optional[str] = None,
expected_name: Optional[str] = None,
noise_psk: str | None = None,
expected_name: str | None = None,
):
"""Create a client, this object is shared across sessions.
@ -247,17 +237,17 @@ class APIClient:
noise_psk=noise_psk or None,
expected_name=expected_name,
)
self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None
self._connection: APIConnection | None = None
self._cached_name: str | None = None
self._background_tasks: set[asyncio.Task[Any]] = set()
self._loop = asyncio.get_event_loop()
@property
def expected_name(self) -> Optional[str]:
def expected_name(self) -> str | None:
return self._params.expected_name
@expected_name.setter
def expected_name(self, value: Optional[str]) -> None:
def expected_name(self, value: str | None) -> None:
self._params.expected_name = value
@property
@ -277,7 +267,7 @@ class APIClient:
async def connect(
self,
on_stop: Optional[Callable[[bool], Awaitable[None]]] = None,
on_stop: Callable[[bool], Awaitable[None]] | None = None,
login: bool = False,
) -> None:
if self._connection is not None:
@ -343,9 +333,9 @@ class APIClient:
async def list_entities_services(
self,
) -> Tuple[List[EntityInfo], List[UserService]]:
) -> tuple[list[EntityInfo], list[UserService]]:
self._check_authenticated()
response_types: Dict[Any, Optional[Type[EntityInfo]]] = {
response_types: dict[Any, type[EntityInfo] | None] = {
ListEntitiesBinarySensorResponse: BinarySensorInfo,
ListEntitiesButtonResponse: ButtonInfo,
ListEntitiesCoverResponse: CoverInfo,
@ -376,8 +366,8 @@ class APIClient:
resp = await self._connection.send_message_await_response_complex(
ListEntitiesRequest(), do_append, do_stop, msg_types, timeout=60
)
entities: List[EntityInfo] = []
services: List[UserService] = []
entities: list[EntityInfo] = []
services: list[UserService] = []
for msg in resp:
if isinstance(msg, ListEntitiesServicesResponse):
services.append(UserService.from_pb(msg))
@ -389,8 +379,8 @@ class APIClient:
async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None:
self._check_authenticated()
image_stream: Dict[int, list[bytes]] = {}
response_types: Dict[Any, Type[EntityState]] = {
image_stream: dict[int, list[bytes]] = {}
response_types: dict[Any, type[EntityState]] = {
BinarySensorStateResponse: BinarySensorState,
CoverStateResponse: CoverState,
FanStateResponse: FanState,
@ -417,7 +407,7 @@ class APIClient:
if TYPE_CHECKING:
assert isinstance(msg, CameraImageResponse)
msg_key = msg.key
data_parts: Optional[List[bytes]] = image_stream.get(msg_key)
data_parts: list[bytes] | None = image_stream.get(msg_key)
if not data_parts:
data_parts = []
image_stream[msg_key] = data_parts
@ -425,7 +415,7 @@ class APIClient:
data_parts.append(msg.data)
if msg.done:
# Return CameraState with the merged data
image_data = bytes().join(data_parts)
image_data = b"".join(data_parts)
del image_stream[msg_key]
on_state(CameraState(key=msg.key, data=image_data)) # type: ignore[call-arg]
@ -437,8 +427,8 @@ class APIClient:
async def subscribe_logs(
self,
on_log: Callable[[SubscribeLogsResponse], None],
log_level: Optional[LogLevel] = None,
dump_config: Optional[bool] = None,
log_level: LogLevel | None = None,
dump_config: bool | None = None,
) -> None:
self._check_authenticated()
req = SubscribeLogsRequest()
@ -492,11 +482,11 @@ class APIClient:
address: int,
handle: int,
request: message.Message,
response_type: Union[
Type[BluetoothGATTNotifyResponse],
Type[BluetoothGATTReadResponse],
Type[BluetoothGATTWriteResponse],
],
response_type: (
type[BluetoothGATTNotifyResponse]
| type[BluetoothGATTReadResponse]
| type[BluetoothGATTWriteResponse]
),
timeout: float = 10.0,
) -> message.Message:
self._check_authenticated()
@ -541,7 +531,7 @@ class APIClient:
return unsub
async def subscribe_bluetooth_le_raw_advertisements(
self, on_advertisements: Callable[[List[BluetoothLERawAdvertisement]], None]
self, on_advertisements: Callable[[list[BluetoothLERawAdvertisement]], None]
) -> Callable[[], None]:
self._check_authenticated()
msg_types = (BluetoothLERawAdvertisementsResponse,)
@ -615,7 +605,7 @@ class APIClient:
disconnect_timeout: float = DEFAULT_BLE_DISCONNECT_TIMEOUT,
feature_flags: int = 0,
has_cache: bool = False,
address_type: Optional[int] = None,
address_type: int | None = None,
) -> Callable[[], None]:
self._check_authenticated()
msg_types = (BluetoothDeviceConnectionResponse,)
@ -955,7 +945,7 @@ class APIClient:
address: int,
handle: int,
on_bluetooth_gatt_notify: Callable[[int, bytearray], None],
) -> Tuple[Callable[[], Coroutine[Any, Any, None]], Callable[[], None]]:
) -> tuple[Callable[[], Coroutine[Any, Any, None]], Callable[[], None]]:
"""Start a notify session for a GATT characteristic.
Returns two functions that can be used to stop the notify.
@ -1001,7 +991,7 @@ class APIClient:
return stop_notify, remove_callback
async def subscribe_home_assistant_states(
self, on_state_sub: Callable[[str, Optional[str]], None]
self, on_state_sub: Callable[[str, str | None], None]
) -> None:
self._check_authenticated()
@ -1018,7 +1008,7 @@ class APIClient:
)
async def send_home_assistant_state(
self, entity_id: str, attribute: Optional[str], state: str
self, entity_id: str, attribute: str | None, state: str
) -> None:
self._check_authenticated()
@ -1034,8 +1024,8 @@ class APIClient:
async def cover_command(
self,
key: int,
position: Optional[float] = None,
tilt: Optional[float] = None,
position: float | None = None,
tilt: float | None = None,
stop: bool = False,
) -> None:
self._check_authenticated()
@ -1068,11 +1058,11 @@ class APIClient:
async def fan_command(
self,
key: int,
state: Optional[bool] = None,
speed: Optional[FanSpeed] = None,
speed_level: Optional[int] = None,
oscillating: Optional[bool] = None,
direction: Optional[FanDirection] = None,
state: bool | None = None,
speed: FanSpeed | None = None,
speed_level: int | None = None,
oscillating: bool | None = None,
direction: FanDirection | None = None,
) -> None:
self._check_authenticated()
@ -1099,18 +1089,18 @@ class APIClient:
async def light_command(
self,
key: int,
state: Optional[bool] = None,
brightness: Optional[float] = None,
color_mode: Optional[int] = None,
color_brightness: Optional[float] = None,
rgb: Optional[Tuple[float, float, float]] = None,
white: Optional[float] = None,
color_temperature: Optional[float] = None,
cold_white: Optional[float] = None,
warm_white: Optional[float] = None,
transition_length: Optional[float] = None,
flash_length: Optional[float] = None,
effect: Optional[str] = None,
state: bool | None = None,
brightness: float | None = None,
color_mode: int | None = None,
color_brightness: float | None = None,
rgb: tuple[float, float, float] | None = None,
white: float | None = None,
color_temperature: float | None = None,
cold_white: float | None = None,
warm_white: float | None = None,
transition_length: float | None = None,
flash_length: float | None = None,
effect: str | None = None,
) -> None:
self._check_authenticated()
@ -1169,15 +1159,15 @@ class APIClient:
async def climate_command(
self,
key: int,
mode: Optional[ClimateMode] = None,
target_temperature: Optional[float] = None,
target_temperature_low: Optional[float] = None,
target_temperature_high: Optional[float] = None,
fan_mode: Optional[ClimateFanMode] = None,
swing_mode: Optional[ClimateSwingMode] = None,
custom_fan_mode: Optional[str] = None,
preset: Optional[ClimatePreset] = None,
custom_preset: Optional[str] = None,
mode: ClimateMode | None = None,
target_temperature: float | None = None,
target_temperature_low: float | None = None,
target_temperature_high: float | None = None,
fan_mode: ClimateFanMode | None = None,
swing_mode: ClimateSwingMode | None = None,
custom_fan_mode: str | None = None,
preset: ClimatePreset | None = None,
custom_preset: str | None = None,
) -> None:
self._check_authenticated()
@ -1239,10 +1229,10 @@ class APIClient:
async def siren_command(
self,
key: int,
state: Optional[bool] = None,
tone: Optional[str] = None,
volume: Optional[float] = None,
duration: Optional[int] = None,
state: bool | None = None,
tone: str | None = None,
volume: float | None = None,
duration: int | None = None,
) -> None:
self._check_authenticated()
@ -1275,7 +1265,7 @@ class APIClient:
self,
key: int,
command: LockCommand,
code: Optional[str] = None,
code: str | None = None,
) -> None:
self._check_authenticated()
@ -1291,9 +1281,9 @@ class APIClient:
self,
key: int,
*,
command: Optional[MediaPlayerCommand] = None,
volume: Optional[float] = None,
media_url: Optional[str] = None,
command: MediaPlayerCommand | None = None,
volume: float | None = None,
media_url: str | None = None,
) -> None:
self._check_authenticated()
@ -1365,14 +1355,14 @@ class APIClient:
await self._request_image(stream=True)
@property
def api_version(self) -> Optional[APIVersion]:
def api_version(self) -> APIVersion | None:
if self._connection is None:
return None
return self._connection.api_version
async def subscribe_voice_assistant(
self,
handle_start: Callable[[str, bool], Coroutine[Any, Any, Optional[int]]],
handle_start: Callable[[str, bool], Coroutine[Any, Any, int | None]],
handle_stop: Callable[[], Coroutine[Any, Any, None]],
) -> Callable[[], None]:
"""Subscribes to voice assistant messages from the device.
@ -1386,9 +1376,9 @@ class APIClient:
"""
self._check_authenticated()
start_task: Optional[asyncio.Task[Optional[int]]] = None
start_task: asyncio.Task[int | None] | None = None
def _started(fut: asyncio.Task[Optional[int]]) -> None:
def _started(fut: asyncio.Task[int | None]) -> None:
if self._connection is not None and not fut.cancelled():
port = fut.result()
if port is not None:
@ -1432,7 +1422,7 @@ class APIClient:
return unsub
def send_voice_assistant_event(
self, event_type: VoiceAssistantEventType, data: Optional[dict[str, str]]
self, event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
self._check_authenticated()
@ -1457,7 +1447,7 @@ class APIClient:
self,
key: int,
command: AlarmControlPanelCommand,
code: Optional[str] = None,
code: str | None = None,
) -> None:
self._check_authenticated()

View File

@ -1,24 +1,15 @@
from __future__ import annotations
import asyncio
import contextvars
import enum
import logging
import socket
import time
from collections.abc import Coroutine, Iterable
from dataclasses import astuple, dataclass
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Set,
Type,
Union,
)
from typing import TYPE_CHECKING, Any, Callable
import async_timeout
from google.protobuf import message
@ -94,7 +85,7 @@ CONNECT_AND_SETUP_TIMEOUT = 120.0
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar(
in_do_connect: contextvars.ContextVar[bool | None] = contextvars.ContextVar(
"in_do_connect"
)
@ -103,12 +94,12 @@ in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar(
class ConnectionParams:
address: str
port: int
password: Optional[str]
password: str | None
client_info: str
keepalive: float
zeroconf_instance: hr.ZeroconfInstanceType
noise_psk: Optional[str]
expected_name: Optional[str]
noise_psk: str | None
expected_name: str | None
class ConnectionState(enum.Enum):
@ -162,16 +153,16 @@ class APIConnection:
self,
params: ConnectionParams,
on_stop: Callable[[bool], Coroutine[Any, Any, None]],
log_name: Optional[str] = None,
log_name: str | None = None,
) -> None:
self._params = params
self.on_stop: Optional[Callable[[bool], Coroutine[Any, Any, None]]] = on_stop
self._on_stop_task: Optional[asyncio.Task[None]] = None
self._socket: Optional[socket.socket] = None
self._frame_helper: Optional[
Union[APINoiseFrameHelper, APIPlaintextFrameHelper]
] = None
self.api_version: Optional[APIVersion] = None
self.on_stop: Callable[[bool], Coroutine[Any, Any, None]] | None = on_stop
self._on_stop_task: asyncio.Task[None] | None = None
self._socket: socket.socket | None = None
self._frame_helper: None | (
APINoiseFrameHelper | APIPlaintextFrameHelper
) = None
self.api_version: APIVersion | None = None
self._connection_state = ConnectionState.INITIALIZED
# Store whether connect() has completed
@ -179,20 +170,20 @@ class APIConnection:
self._connect_complete = False
# Message handlers currently subscribed to incoming messages
self._message_handlers: Dict[Any, Set[Callable[[message.Message], None]]] = {}
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs
self.log_name = log_name or params.address
# futures currently subscribed to exceptions in the read task
self._read_exception_futures: Set[asyncio.Future[None]] = set()
self._read_exception_futures: set[asyncio.Future[None]] = set()
self._ping_timer: Optional[asyncio.TimerHandle] = None
self._pong_timer: Optional[asyncio.TimerHandle] = None
self._ping_timer: asyncio.TimerHandle | None = None
self._pong_timer: asyncio.TimerHandle | None = None
self._keep_alive_interval = params.keepalive
self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO
self._connect_task: Optional[asyncio.Task[None]] = None
self._fatal_exception: Optional[Exception] = None
self._connect_task: asyncio.Task[None] | None = None
self._fatal_exception: Exception | None = None
self._expected_disconnect = False
self._send_pending_ping = False
self._loop = asyncio.get_event_loop()
@ -318,7 +309,7 @@ class APIConnection:
async def _connect_init_frame_helper(self) -> None:
"""Step 3 in connect process: initialize the frame helper and init read loop."""
fh: Union[APIPlaintextFrameHelper, APINoiseFrameHelper]
fh: APIPlaintextFrameHelper | APINoiseFrameHelper
loop = self._loop
process_packet = self._process_packet_factory()
@ -563,7 +554,7 @@ class APIConnection:
raise
def add_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
) -> Callable[[], None]:
"""Add a message callback."""
message_handlers = self._message_handlers
@ -572,7 +563,7 @@ class APIConnection:
return partial(self._remove_message_callback, on_message, msg_types)
def _remove_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[Type[Any]]
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
) -> None:
"""Remove a message callback."""
message_handlers = self._message_handlers
@ -583,7 +574,7 @@ class APIConnection:
self,
send_msg: message.Message,
on_message: Callable[[Any], None],
msg_types: Iterable[Type[Any]],
msg_types: Iterable[type[Any]],
) -> Callable[[], None]:
"""Send a message to the remote and register the given message handler."""
self.send_message(send_msg)
@ -604,9 +595,9 @@ class APIConnection:
def _handle_complex_message(
self,
fut: asyncio.Future[None],
responses: List[message.Message],
do_append: Optional[Callable[[message.Message], bool]],
do_stop: Optional[Callable[[message.Message], bool]],
responses: list[message.Message],
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
resp: message.Message,
) -> None:
"""Handle a message that is part of a response."""
@ -620,11 +611,11 @@ class APIConnection:
async def send_message_await_response_complex(
self,
send_msg: message.Message,
do_append: Optional[Callable[[message.Message], bool]],
do_stop: Optional[Callable[[message.Message], bool]],
msg_types: Iterable[Type[Any]],
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
msg_types: Iterable[type[Any]],
timeout: float = 10.0,
) -> List[message.Message]:
) -> list[message.Message]:
"""Send a message to the remote and build up a list response.
:param send_msg: The message (request) to send.
@ -641,7 +632,7 @@ class APIConnection:
self.send_message(send_msg)
# Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = self._loop.create_future()
responses: List[message.Message] = []
responses: list[message.Message] = []
on_message = partial(
self._handle_complex_message, fut, responses, do_append, do_stop
)

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import re
from typing import Optional
from aioesphomeapi.model import BluetoothGATTError
@ -198,7 +199,7 @@ class BadNameAPIError(APIConnectionError):
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
def __init__(
self, msg: Optional[str] = None, received_name: Optional[str] = None
self, msg: str | None = None, received_name: str | None = None
) -> None:
super().__init__(f"{msg}: received_name={received_name}")
self.received_name = received_name

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import asyncio
import contextlib
import socket
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
from typing import List, Optional, Tuple, Union, cast
from typing import Union, cast
import zeroconf
import zeroconf.asyncio
@ -45,7 +47,7 @@ async def _async_zeroconf_get_service_info(
service_type: str,
service_name: str,
timeout: float,
) -> Optional["zeroconf.ServiceInfo"]:
) -> "zeroconf.ServiceInfo" | None:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
if zeroconf_instance is None:
try:
@ -87,7 +89,7 @@ async def _async_resolve_host_zeroconf(
*,
timeout: float = 3.0,
zeroconf_instance: ZeroconfInstanceType = None,
) -> List[AddrInfo]:
) -> list[AddrInfo]:
service_type = "_esphomelib._tcp.local."
service_name = f"{host}.{service_type}"
@ -98,7 +100,7 @@ async def _async_resolve_host_zeroconf(
if info is None:
return []
addrs: List[AddrInfo] = []
addrs: list[AddrInfo] = []
for raw in info.addresses_by_version(zeroconf.IPVersion.All):
is_ipv6 = len(raw) == 16
sockaddr: Sockaddr
@ -126,7 +128,7 @@ async def _async_resolve_host_zeroconf(
return addrs
async def _async_resolve_host_getaddrinfo(host: str, port: int) -> List[AddrInfo]:
async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo]:
try:
# Limit to TCP IP protocol and SOCK_STREAM
res = await asyncio.get_event_loop().getaddrinfo(
@ -135,15 +137,15 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> List[AddrInfo
except OSError as err:
raise APIConnectionError(f"Error resolving IP address: {err}")
addrs: List[AddrInfo] = []
addrs: list[AddrInfo] = []
for family, type_, proto, _, raw in res:
sockaddr: Sockaddr
if family == socket.AF_INET:
raw = cast(Tuple[str, int], raw)
raw = cast(tuple[str, int], raw)
address, port = raw
sockaddr = IPv4Sockaddr(address=address, port=port)
elif family == socket.AF_INET6:
raw = cast(Tuple[str, int, int, int], raw)
raw = cast(tuple[str, int, int, int], raw)
address, port, flowinfo, scope_id = raw
sockaddr = IPv6Sockaddr(
address=address, port=port, flowinfo=flowinfo, scope_id=scope_id
@ -158,7 +160,7 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> List[AddrInfo
return addrs
def _async_ip_address_to_addrs(host: str, port: int) -> List[AddrInfo]:
def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]:
"""Convert an ipaddress to AddrInfo."""
with contextlib.suppress(ValueError):
return [
@ -193,7 +195,7 @@ async def async_resolve_host(
port: int,
zeroconf_instance: ZeroconfInstanceType = None,
) -> AddrInfo:
addrs: List[AddrInfo] = []
addrs: list[AddrInfo] = []
zc_error = None
if host.endswith(".local"):

View File

@ -1,10 +1,11 @@
from __future__ import annotations
# Helper script and aioesphomeapi to view logs from an esphome device
import argparse
import asyncio
import logging
import sys
from datetime import datetime
from typing import List
import zeroconf
@ -17,7 +18,7 @@ from aioesphomeapi.reconnect_logic import ReconnectLogic
_LOGGER = logging.getLogger(__name__)
async def main(argv: List[str]) -> None:
async def main(argv: list[str]) -> None:
parser = argparse.ArgumentParser("aioesphomeapi-logs")
parser.add_argument("--port", type=int, default=6053)
parser.add_argument("--password", type=str)

View File

@ -1,20 +1,11 @@
from __future__ import annotations
import enum
import sys
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field, fields
from functools import cache, lru_cache, partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from uuid import UUID
from .util import fix_float_single_double_conversion
@ -48,14 +39,14 @@ class APIIntEnum(enum.IntEnum):
"""Base class for int enum values in API model."""
@classmethod
def convert(cls: Type[_T], value: int) -> Optional[_T]:
def convert(cls: type[_T], value: int) -> _T | None:
try:
return cls(value)
except ValueError:
return None
@classmethod
def convert_list(cls: Type[_T], value: List[int]) -> List[_T]:
def convert_list(cls: type[_T], value: list[int]) -> list[_T]:
ret = []
for x in value:
try:
@ -81,12 +72,12 @@ class APIModelBase:
# use this setattr to prevent FrozenInstanceError
object.__setattr__(self, field_.name, convert(val))
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self) # type: ignore[no-any-return, call-overload]
@classmethod
def from_dict(
cls: Type[_V], data: Dict[str, Any], *, ignore_missing: bool = True
cls: type[_V], data: dict[str, Any], *, ignore_missing: bool = True
) -> _V:
init_args = {
f.name: data[f.name]
@ -96,7 +87,7 @@ class APIModelBase:
return cls(**init_args)
@classmethod
def from_pb(cls: Type[_V], data: Any) -> _V:
def from_pb(cls: type[_V], data: Any) -> _V:
return cls(**{f.name: getattr(data, f.name) for f in cached_fields(cls)}) # type: ignore[arg-type]
@ -174,7 +165,7 @@ class EntityInfo(APIModelBase):
unique_id: str = ""
disabled_by_default: bool = False
icon: str = ""
entity_category: Optional[EntityCategory] = converter_field(
entity_category: EntityCategory | None = converter_field(
default=EntityCategory.NONE, converter=EntityCategory.convert
)
@ -226,7 +217,7 @@ class CoverOperation(APIIntEnum):
@_frozen_dataclass_decorator
class CoverState(EntityState):
legacy_state: Optional[LegacyCoverState] = converter_field(
legacy_state: LegacyCoverState | None = converter_field(
default=LegacyCoverState.OPEN, converter=LegacyCoverState.convert
)
position: float = converter_field(
@ -235,7 +226,7 @@ class CoverState(EntityState):
tilt: float = converter_field(
default=0.0, converter=fix_float_single_double_conversion
)
current_operation: Optional[CoverOperation] = converter_field(
current_operation: CoverOperation | None = converter_field(
default=CoverOperation.IDLE, converter=CoverOperation.convert
)
@ -269,11 +260,11 @@ class FanDirection(APIIntEnum):
class FanState(EntityState):
state: bool = False
oscillating: bool = False
speed: Optional[FanSpeed] = converter_field(
speed: FanSpeed | None = converter_field(
default=FanSpeed.LOW, converter=FanSpeed.convert
)
speed_level: int = 0
direction: Optional[FanDirection] = converter_field(
direction: FanDirection | None = converter_field(
default=FanDirection.FORWARD, converter=FanDirection.convert
)
@ -290,7 +281,7 @@ class LightColorCapability(enum.IntFlag):
@_frozen_dataclass_decorator
class LightInfo(EntityInfo):
supported_color_modes: List[int] = converter_field(
supported_color_modes: list[int] = converter_field(
default_factory=list, converter=list
)
min_mireds: float = converter_field(
@ -299,7 +290,7 @@ class LightInfo(EntityInfo):
max_mireds: float = converter_field(
default=0.0, converter=fix_float_single_double_conversion
)
effects: List[str] = converter_field(default_factory=list, converter=list)
effects: list[str] = converter_field(default_factory=list, converter=list)
# deprecated, do not use
legacy_supports_brightness: bool = False
@ -307,7 +298,7 @@ class LightInfo(EntityInfo):
legacy_supports_white_value: bool = False
legacy_supports_color_temperature: bool = False
def supported_color_modes_compat(self, api_version: APIVersion) -> List[int]:
def supported_color_modes_compat(self, api_version: APIVersion) -> list[int]:
if api_version < APIVersion(1, 6):
key = (
self.legacy_supports_brightness,
@ -353,7 +344,7 @@ class LightInfo(EntityInfo):
],
}
return cast(List[int], modes_map[key]) if key in modes_map else []
return cast(list[int], modes_map[key]) if key in modes_map else []
return self.supported_color_modes
@ -412,10 +403,10 @@ class SensorInfo(EntityInfo):
unit_of_measurement: str = ""
accuracy_decimals: int = 0
force_update: bool = False
state_class: Optional[SensorStateClass] = converter_field(
state_class: SensorStateClass | None = converter_field(
default=SensorStateClass.NONE, converter=SensorStateClass.convert
)
last_reset_type: Optional[LastResetType] = converter_field(
last_reset_type: LastResetType | None = converter_field(
default=LastResetType.NONE, converter=LastResetType.convert
)
@ -516,7 +507,7 @@ class ClimatePreset(APIIntEnum):
class ClimateInfo(EntityInfo):
supports_current_temperature: bool = False
supports_two_point_target_temperature: bool = False
supported_modes: List[ClimateMode] = converter_field(
supported_modes: list[ClimateMode] = converter_field(
default_factory=list, converter=ClimateMode.convert_list
)
visual_min_temperature: float = converter_field(
@ -533,23 +524,23 @@ class ClimateInfo(EntityInfo):
)
legacy_supports_away: bool = False
supports_action: bool = False
supported_fan_modes: List[ClimateFanMode] = converter_field(
supported_fan_modes: list[ClimateFanMode] = converter_field(
default_factory=list, converter=ClimateFanMode.convert_list
)
supported_swing_modes: List[ClimateSwingMode] = converter_field(
supported_swing_modes: list[ClimateSwingMode] = converter_field(
default_factory=list, converter=ClimateSwingMode.convert_list
)
supported_custom_fan_modes: List[str] = converter_field(
supported_custom_fan_modes: list[str] = converter_field(
default_factory=list, converter=list
)
supported_presets: List[ClimatePreset] = converter_field(
supported_presets: list[ClimatePreset] = converter_field(
default_factory=list, converter=ClimatePreset.convert_list
)
supported_custom_presets: List[str] = converter_field(
supported_custom_presets: list[str] = converter_field(
default_factory=list, converter=list
)
def supported_presets_compat(self, api_version: APIVersion) -> List[ClimatePreset]:
def supported_presets_compat(self, api_version: APIVersion) -> list[ClimatePreset]:
if api_version < APIVersion(1, 5):
return (
[ClimatePreset.HOME, ClimatePreset.AWAY]
@ -561,10 +552,10 @@ class ClimateInfo(EntityInfo):
@_frozen_dataclass_decorator
class ClimateState(EntityState):
mode: Optional[ClimateMode] = converter_field(
mode: ClimateMode | None = converter_field(
default=ClimateMode.OFF, converter=ClimateMode.convert
)
action: Optional[ClimateAction] = converter_field(
action: ClimateAction | None = converter_field(
default=ClimateAction.OFF, converter=ClimateAction.convert
)
current_temperature: float = converter_field(
@ -580,19 +571,19 @@ class ClimateState(EntityState):
default=0.0, converter=fix_float_single_double_conversion
)
legacy_away: bool = False
fan_mode: Optional[ClimateFanMode] = converter_field(
fan_mode: ClimateFanMode | None = converter_field(
default=ClimateFanMode.ON, converter=ClimateFanMode.convert
)
swing_mode: Optional[ClimateSwingMode] = converter_field(
swing_mode: ClimateSwingMode | None = converter_field(
default=ClimateSwingMode.OFF, converter=ClimateSwingMode.convert
)
custom_fan_mode: str = ""
preset: Optional[ClimatePreset] = converter_field(
preset: ClimatePreset | None = converter_field(
default=ClimatePreset.NONE, converter=ClimatePreset.convert
)
custom_preset: str = ""
def preset_compat(self, api_version: APIVersion) -> Optional[ClimatePreset]:
def preset_compat(self, api_version: APIVersion) -> ClimatePreset | None:
if api_version < APIVersion(1, 5):
return ClimatePreset.AWAY if self.legacy_away else ClimatePreset.HOME
return self.preset
@ -617,7 +608,7 @@ class NumberInfo(EntityInfo):
default=0.0, converter=fix_float_single_double_conversion
)
unit_of_measurement: str = ""
mode: Optional[NumberMode] = converter_field(
mode: NumberMode | None = converter_field(
default=NumberMode.AUTO, converter=NumberMode.convert
)
device_class: str = ""
@ -634,7 +625,7 @@ class NumberState(EntityState):
# ==================== SELECT ====================
@_frozen_dataclass_decorator
class SelectInfo(EntityInfo):
options: List[str] = converter_field(default_factory=list, converter=list)
options: list[str] = converter_field(default_factory=list, converter=list)
@_frozen_dataclass_decorator
@ -646,7 +637,7 @@ class SelectState(EntityState):
# ==================== SIREN ====================
@_frozen_dataclass_decorator
class SirenInfo(EntityInfo):
tones: List[str] = converter_field(default_factory=list, converter=list)
tones: list[str] = converter_field(default_factory=list, converter=list)
supports_volume: bool = False
supports_duration: bool = False
@ -689,7 +680,7 @@ class LockInfo(EntityInfo):
@_frozen_dataclass_decorator
class LockEntityState(EntityState):
state: Optional[LockState] = converter_field(
state: LockState | None = converter_field(
default=LockState.NONE, converter=LockState.convert
)
@ -717,7 +708,7 @@ class MediaPlayerInfo(EntityInfo):
@_frozen_dataclass_decorator
class MediaPlayerEntityState(EntityState):
state: Optional[MediaPlayerState] = converter_field(
state: MediaPlayerState | None = converter_field(
default=MediaPlayerState.NONE, converter=MediaPlayerState.convert
)
volume: float = converter_field(
@ -759,7 +750,7 @@ class AlarmControlPanelInfo(EntityInfo):
@_frozen_dataclass_decorator
class AlarmControlPanelEntityState(EntityState):
state: Optional[AlarmControlPanelState] = converter_field(
state: AlarmControlPanelState | None = converter_field(
default=AlarmControlPanelState.DISARMED,
converter=AlarmControlPanelState.convert,
)
@ -767,7 +758,7 @@ class AlarmControlPanelEntityState(EntityState):
# ==================== INFO MAP ====================
COMPONENT_TYPE_TO_INFO: Dict[str, Type[EntityInfo]] = {
COMPONENT_TYPE_TO_INFO: dict[str, type[EntityInfo]] = {
"binary_sensor": BinarySensorInfo,
"cover": CoverInfo,
"fan": FanInfo,
@ -789,8 +780,8 @@ COMPONENT_TYPE_TO_INFO: Dict[str, Type[EntityInfo]] = {
# ==================== USER-DEFINED SERVICES ====================
def _convert_homeassistant_service_map(
value: Union[Dict[str, str], Iterable["HomeassistantServiceMap"]],
) -> Dict[str, str]:
value: dict[str, str] | Iterable[HomeassistantServiceMap],
) -> dict[str, str]:
if isinstance(value, dict):
# already a dict, don't convert
return value
@ -801,13 +792,13 @@ def _convert_homeassistant_service_map(
class HomeassistantServiceCall(APIModelBase):
service: str = ""
is_event: bool = False
data: Dict[str, str] = converter_field(
data: dict[str, str] = converter_field(
default_factory=dict, converter=_convert_homeassistant_service_map
)
data_template: Dict[str, str] = converter_field(
data_template: dict[str, str] = converter_field(
default_factory=dict, converter=_convert_homeassistant_service_map
)
variables: Dict[str, str] = converter_field(
variables: dict[str, str] = converter_field(
default_factory=dict, converter=_convert_homeassistant_service_map
)
@ -826,12 +817,12 @@ class UserServiceArgType(APIIntEnum):
@_frozen_dataclass_decorator
class UserServiceArg(APIModelBase):
name: str = ""
type: Optional[UserServiceArgType] = converter_field(
type: UserServiceArgType | None = converter_field(
default=UserServiceArgType.BOOL, converter=UserServiceArgType.convert
)
@classmethod
def convert_list(cls, value: List[Any]) -> List["UserServiceArg"]:
def convert_list(cls, value: list[Any]) -> list[UserServiceArg]:
ret = []
for x in value:
if isinstance(x, dict):
@ -847,7 +838,7 @@ class UserServiceArg(APIModelBase):
class UserService(APIModelBase):
name: str = ""
key: int = 0
args: List[UserServiceArg] = converter_field(
args: list[UserServiceArg] = converter_field(
default_factory=list, converter=UserServiceArg.convert_list
)
@ -855,7 +846,7 @@ class UserService(APIModelBase):
# ==================== BLUETOOTH ====================
def _join_split_uuid(value: List[int]) -> str:
def _join_split_uuid(value: list[int]) -> str:
"""Convert a high/low uuid into a single string."""
return str(UUID(int=(value[0] << 64) | value[1]))
@ -877,14 +868,14 @@ class BluetoothLEAdvertisement:
rssi: int
address_type: int
name: str
service_uuids: List[str]
service_data: Dict[str, bytes]
manufacturer_data: Dict[int, bytes]
service_uuids: list[str]
service_data: dict[str, bytes]
manufacturer_data: dict[int, bytes]
@classmethod
def from_pb( # type: ignore[misc]
cls: "BluetoothLEAdvertisement", data: "BluetoothLEAdvertisementResponse"
) -> "BluetoothLEAdvertisement":
cls: BluetoothLEAdvertisement, data: BluetoothLEAdvertisementResponse
) -> BluetoothLEAdvertisement:
_uuid_convert = _cached_uuid_converter
if raw_manufacturer_data := data.manufacturer_data:
@ -937,12 +928,12 @@ class BluetoothLERawAdvertisement:
def make_ble_raw_advertisement_processor(
on_advertisements: Callable[[List[BluetoothLERawAdvertisement]], None]
) -> Callable[["BluetoothLERawAdvertisementsResponse"], None]:
on_advertisements: Callable[[list[BluetoothLERawAdvertisement]], None]
) -> Callable[[BluetoothLERawAdvertisementsResponse], None]:
"""Make a processor for BluetoothLERawAdvertisementResponse."""
def _on_ble_raw_advertisement_response(
data: "BluetoothLERawAdvertisementsResponse",
data: BluetoothLERawAdvertisementsResponse,
) -> None:
on_advertisements(
[
@ -999,7 +990,7 @@ class BluetoothGATTDescriptor(APIModelBase):
handle: int = 0
@classmethod
def convert_list(cls, value: List[Any]) -> List["BluetoothGATTDescriptor"]:
def convert_list(cls, value: list[Any]) -> list[BluetoothGATTDescriptor]:
ret = []
for x in value:
if isinstance(x, dict):
@ -1015,12 +1006,12 @@ class BluetoothGATTCharacteristic(APIModelBase):
handle: int = 0
properties: int = 0
descriptors: List[BluetoothGATTDescriptor] = converter_field(
descriptors: list[BluetoothGATTDescriptor] = converter_field(
default_factory=list, converter=BluetoothGATTDescriptor.convert_list
)
@classmethod
def convert_list(cls, value: List[Any]) -> List["BluetoothGATTCharacteristic"]:
def convert_list(cls, value: list[Any]) -> list[BluetoothGATTCharacteristic]:
ret = []
for x in value:
if isinstance(x, dict):
@ -1034,12 +1025,12 @@ class BluetoothGATTCharacteristic(APIModelBase):
class BluetoothGATTService(APIModelBase):
uuid: str = converter_field(default="", converter=_join_split_uuid)
handle: int = 0
characteristics: List[BluetoothGATTCharacteristic] = converter_field(
characteristics: list[BluetoothGATTCharacteristic] = converter_field(
default_factory=list, converter=BluetoothGATTCharacteristic.convert_list
)
@classmethod
def convert_list(cls, value: List[Any]) -> List["BluetoothGATTService"]:
def convert_list(cls, value: list[Any]) -> list[BluetoothGATTService]:
ret = []
for x in value:
if isinstance(x, dict):
@ -1052,7 +1043,7 @@ class BluetoothGATTService(APIModelBase):
@_frozen_dataclass_decorator
class BluetoothGATTServices(APIModelBase):
address: int = 0
services: List[BluetoothGATTService] = converter_field(
services: list[BluetoothGATTService] = converter_field(
default_factory=list, converter=BluetoothGATTService.convert_list
)
@ -1060,7 +1051,7 @@ class BluetoothGATTServices(APIModelBase):
@_frozen_dataclass_decorator
class ESPHomeBluetoothGATTServices:
address: int = 0
services: List[BluetoothGATTService] = field(default_factory=list)
services: list[BluetoothGATTService] = field(default_factory=list)
@_frozen_dataclass_decorator

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import asyncio
import logging
from typing import Awaitable, Callable, List, Optional
from collections.abc import Awaitable
from typing import Callable
import zeroconf
@ -34,9 +37,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
client: APIClient,
on_connect: Callable[[], Awaitable[None]],
on_disconnect: Callable[[bool], Awaitable[None]],
zeroconf_instance: "zeroconf.Zeroconf",
name: Optional[str] = None,
on_connect_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
zeroconf_instance: zeroconf.Zeroconf,
name: str | None = None,
on_connect_error: Callable[[Exception], Awaitable[None]] | None = None,
) -> None:
"""Initialize ReconnectingLogic.
@ -51,7 +54,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._on_disconnect_cb = on_disconnect
self._on_connect_error_cb = on_connect_error
self._zc = zeroconf_instance
self._filter_alias: Optional[str] = None
self._filter_alias: str | None = None
# Flag to check if the device is connected
self._connected = False
self._connected_lock = asyncio.Lock()
@ -60,9 +63,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# How many connect attempts have there been already, used for exponential wait time
self._tries = 0
# Event for tracking when logic should stop
self._connect_task: Optional[asyncio.Task[None]] = None
self._connect_timer: Optional[asyncio.TimerHandle] = None
self._stop_task: Optional[asyncio.Task[None]] = None
self._connect_task: asyncio.Task[None] | None = None
self._connect_timer: asyncio.TimerHandle | None = None
self._stop_task: asyncio.Task[None] | None = None
@property
def _log_name(self) -> str:
@ -244,9 +247,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
def async_update_records(
self,
zc: "zeroconf.Zeroconf", # pylint: disable=unused-argument
zc: zeroconf.Zeroconf, # pylint: disable=unused-argument
now: float, # pylint: disable=unused-argument
records: List["zeroconf.RecordUpdate"],
records: list[zeroconf.RecordUpdate],
) -> None:
"""Listen to zeroconf updated mDNS records. This must be called from the eventloop.

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import math
from functools import lru_cache
from typing import Optional
@lru_cache(maxsize=1024)
@ -8,7 +9,7 @@ def varuint_to_bytes(value: int) -> bytes:
if value <= 0x7F:
return bytes([value])
ret = bytes()
ret = b""
while value:
temp = value & 0x7F
value >>= 7
@ -21,7 +22,7 @@ def varuint_to_bytes(value: int) -> bytes:
@lru_cache(maxsize=1024)
def bytes_to_varuint(value: bytes) -> Optional[int]:
def bytes_to_varuint(value: bytes) -> int | None:
result = 0
bitpos = 0
for val in value: