Add flake8, black, isort and mypy linting (#39)

This commit is contained in:
Otto Winter 2021-06-18 17:57:02 +02:00 committed by GitHub
parent 41d0d335e9
commit 52cf01e11a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1677 additions and 985 deletions

View File

@ -18,5 +18,68 @@ jobs:
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- name: Register problem matcher
run: |
echo "::add-matcher::.github/workflows/matchers/pylint.json"
- run: pylint aioesphomeapi
flake8:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- name: Register problem matcher
run: |
echo "::add-matcher::.github/workflows/matchers/flake8.json"
- run: flake8 aioesphomeapi
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- run: black --safe --exclude 'api_pb2.py|api_options_pb2.py' aioesphomeapi
isort:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- name: Register problem matcher
run: |
echo "::add-matcher::.github/workflows/matchers/isort.json"
- run: isort --check aioesphomeapi
mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- name: Register problem matcher
run: |
echo "::add-matcher::.github/workflows/matchers/mypy.json"
- run: mypy --strict aioesphomeapi

30
.github/workflows/matchers/flake8.json vendored Normal file
View File

@ -0,0 +1,30 @@
{
"problemMatcher": [
{
"owner": "flake8-error",
"severity": "error",
"pattern": [
{
"regexp": "^(.*):(\\d+):(\\d+):\\s([EF]\\d{3}\\s.*)$",
"file": 1,
"line": 2,
"column": 3,
"message": 4
}
]
},
{
"owner": "flake8-warning",
"severity": "warning",
"pattern": [
{
"regexp": "^(.*):(\\d+):(\\d+):\\s([CDNW]\\d{3}\\s.*)$",
"file": 1,
"line": 2,
"column": 3,
"message": 4
}
]
}
]
}

14
.github/workflows/matchers/isort.json vendored Normal file
View File

@ -0,0 +1,14 @@
{
"problemMatcher": [
{
"owner": "isort",
"pattern": [
{
"regexp": "^ERROR:\\s+(.+)\\s+(.+)$",
"file": 1,
"message": 2
}
]
}
]
}

16
.github/workflows/matchers/mypy.json vendored Normal file
View File

@ -0,0 +1,16 @@
{
"problemMatcher": [
{
"owner": "mypy",
"pattern": [
{
"regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$",
"file": 1,
"line": 2,
"severity": 3,
"message": 4
}
]
}
]
}

32
.github/workflows/matchers/pylint.json vendored Normal file
View File

@ -0,0 +1,32 @@
{
"problemMatcher": [
{
"owner": "pylint-error",
"severity": "error",
"pattern": [
{
"regexp": "^(.+):(\\d+):(\\d+):\\s(([EF]\\d{4}):\\s.+)$",
"file": 1,
"line": 2,
"column": 3,
"message": 4,
"code": 5
}
]
},
{
"owner": "pylint-warning",
"severity": "warning",
"pattern": [
{
"regexp": "^(.+):(\\d+):(\\d+):\\s(([CRW]\\d{4}):\\s.+)$",
"file": 1,
"line": 2,
"column": 3,
"message": 4,
"code": 5
}
]
}
]
}

37
.github/workflows/protoc-update.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Update protobuf generated files
on:
push:
branches: [master]
jobs:
protoc-update:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Install protoc
run: |
sudo apt-get install protobuf-compiler
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- name: Generate protoc
run: |
script/gen-protoc
# github actions email from here: https://github.community/t/github-actions-bot-email-address/17204
- name: Commit changes
run: |
if git diff-index --quiet HEAD --; then
echo "No changes detected, protobuf files are up to date!"
else
git config --global user.name "github-actions[bot]"
git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
git commit -am "Update protobuf files"
git push
fi

View File

@ -1,5 +1,6 @@
# flake8: noqa
from .client import APIClient
from .connection import ConnectionParams, APIConnection
from .core import APIConnectionError, MESSAGE_TYPE_TO_PROTO
from .connection import APIConnection, ConnectionParams
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .model import *
from .util import resolve_ip_address_getaddrinfo, resolve_ip_address
from .util import resolve_ip_address, resolve_ip_address_getaddrinfo

View File

@ -1,8 +1,8 @@
# type: ignore
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: api_options.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
@ -21,7 +21,8 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='',
syntax='proto2',
serialized_options=None,
serialized_pb=_b('\n\x11\x61pi_options.proto\x1a google/protobuf/descriptor.proto\"\x06\n\x04void*F\n\rAPISourceType\x12\x0f\n\x0bSOURCE_BOTH\x10\x00\x12\x11\n\rSOURCE_SERVER\x10\x01\x12\x11\n\rSOURCE_CLIENT\x10\x02:E\n\x16needs_setup_connection\x12\x1e.google.protobuf.MethodOptions\x18\x8e\x08 \x01(\x08:\x04true:C\n\x14needs_authentication\x12\x1e.google.protobuf.MethodOptions\x18\x8f\x08 \x01(\x08:\x04true:/\n\x02id\x12\x1f.google.protobuf.MessageOptions\x18\x8c\x08 \x01(\r:\x01\x30:M\n\x06source\x12\x1f.google.protobuf.MessageOptions\x18\x8d\x08 \x01(\x0e\x32\x0e.APISourceType:\x0bSOURCE_BOTH:/\n\x05ifdef\x12\x1f.google.protobuf.MessageOptions\x18\x8e\x08 \x01(\t:3\n\x03log\x12\x1f.google.protobuf.MessageOptions\x18\x8f\x08 \x01(\x08:\x04true:9\n\x08no_delay\x12\x1f.google.protobuf.MessageOptions\x18\x90\x08 \x01(\x08:\x05\x66\x61lse')
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x11\x61pi_options.proto\x1a google/protobuf/descriptor.proto\"\x06\n\x04void*F\n\rAPISourceType\x12\x0f\n\x0bSOURCE_BOTH\x10\x00\x12\x11\n\rSOURCE_SERVER\x10\x01\x12\x11\n\rSOURCE_CLIENT\x10\x02:E\n\x16needs_setup_connection\x12\x1e.google.protobuf.MethodOptions\x18\x8e\x08 \x01(\x08:\x04true:C\n\x14needs_authentication\x12\x1e.google.protobuf.MethodOptions\x18\x8f\x08 \x01(\x08:\x04true:/\n\x02id\x12\x1f.google.protobuf.MessageOptions\x18\x8c\x08 \x01(\r:\x01\x30:M\n\x06source\x12\x1f.google.protobuf.MessageOptions\x18\x8d\x08 \x01(\x0e\x32\x0e.APISourceType:\x0bSOURCE_BOTH:/\n\x05ifdef\x12\x1f.google.protobuf.MessageOptions\x18\x8e\x08 \x01(\t:3\n\x03log\x12\x1f.google.protobuf.MessageOptions\x18\x8f\x08 \x01(\x08:\x04true:9\n\x08no_delay\x12\x1f.google.protobuf.MessageOptions\x18\x90\x08 \x01(\x08:\x05\x66\x61lse'
,
dependencies=[google_dot_protobuf_dot_descriptor__pb2.DESCRIPTOR,])
@ -30,19 +31,23 @@ _APISOURCETYPE = _descriptor.EnumDescriptor(
full_name='APISourceType',
filename=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
values=[
_descriptor.EnumValueDescriptor(
name='SOURCE_BOTH', index=0, number=0,
serialized_options=None,
type=None),
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='SOURCE_SERVER', index=1, number=1,
serialized_options=None,
type=None),
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='SOURCE_CLIENT', index=2, number=2,
serialized_options=None,
type=None),
type=None,
create_key=_descriptor._internal_create_key),
],
containing_type=None,
serialized_options=None,
@ -63,7 +68,7 @@ needs_setup_connection = _descriptor.FieldDescriptor(
has_default_value=True, default_value=True,
message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR)
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
NEEDS_AUTHENTICATION_FIELD_NUMBER = 1039
needs_authentication = _descriptor.FieldDescriptor(
name='needs_authentication', full_name='needs_authentication', index=1,
@ -71,7 +76,7 @@ needs_authentication = _descriptor.FieldDescriptor(
has_default_value=True, default_value=True,
message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR)
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
ID_FIELD_NUMBER = 1036
id = _descriptor.FieldDescriptor(
name='id', full_name='id', index=2,
@ -79,7 +84,7 @@ id = _descriptor.FieldDescriptor(
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR)
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
SOURCE_FIELD_NUMBER = 1037
source = _descriptor.FieldDescriptor(
name='source', full_name='source', index=3,
@ -87,15 +92,15 @@ source = _descriptor.FieldDescriptor(
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR)
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
IFDEF_FIELD_NUMBER = 1038
ifdef = _descriptor.FieldDescriptor(
name='ifdef', full_name='ifdef', index=4,
number=1038, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR)
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
LOG_FIELD_NUMBER = 1039
log = _descriptor.FieldDescriptor(
name='log', full_name='log', index=5,
@ -103,7 +108,7 @@ log = _descriptor.FieldDescriptor(
has_default_value=True, default_value=True,
message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR)
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
NO_DELAY_FIELD_NUMBER = 1040
no_delay = _descriptor.FieldDescriptor(
name='no_delay', full_name='no_delay', index=6,
@ -111,7 +116,7 @@ no_delay = _descriptor.FieldDescriptor(
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR)
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
_VOID = _descriptor.Descriptor(
@ -120,6 +125,7 @@ _VOID = _descriptor.Descriptor(
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
],
extensions=[
@ -148,11 +154,11 @@ DESCRIPTOR.extensions_by_name['log'] = log
DESCRIPTOR.extensions_by_name['no_delay'] = no_delay
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
void = _reflection.GeneratedProtocolMessageType('void', (_message.Message,), dict(
DESCRIPTOR = _VOID,
__module__ = 'api_options_pb2'
void = _reflection.GeneratedProtocolMessageType('void', (_message.Message,), {
'DESCRIPTOR' : _VOID,
'__module__' : 'api_options_pb2'
# @@protoc_insertion_point(class_scope:void)
))
})
_sym_db.RegisterMessage(void)
google_dot_protobuf_dot_descriptor__pb2.MethodOptions.RegisterExtension(needs_setup_connection)

File diff suppressed because one or more lines are too long

View File

@ -1,19 +1,108 @@
import asyncio
import logging
from typing import Any, Callable, Optional, Tuple
import zeroconf
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union, cast
import aioesphomeapi.api_pb2 as pb
import attr
import zeroconf
from google.protobuf import message
from aioesphomeapi.api_pb2 import ( # type: ignore
BinarySensorStateResponse,
CameraImageRequest,
CameraImageResponse,
ClimateCommandRequest,
ClimateStateResponse,
CoverCommandRequest,
CoverStateResponse,
DeviceInfoRequest,
DeviceInfoResponse,
ExecuteServiceArgument,
ExecuteServiceRequest,
FanCommandRequest,
FanStateResponse,
HomeassistantServiceResponse,
HomeAssistantStateResponse,
LightCommandRequest,
LightStateResponse,
ListEntitiesBinarySensorResponse,
ListEntitiesCameraResponse,
ListEntitiesClimateResponse,
ListEntitiesCoverResponse,
ListEntitiesDoneResponse,
ListEntitiesFanResponse,
ListEntitiesLightResponse,
ListEntitiesRequest,
ListEntitiesSensorResponse,
ListEntitiesServicesResponse,
ListEntitiesSwitchResponse,
ListEntitiesTextSensorResponse,
LogLevel,
SensorStateResponse,
SubscribeHomeassistantServicesRequest,
SubscribeHomeAssistantStateResponse,
SubscribeHomeAssistantStatesRequest,
SubscribeLogsRequest,
SubscribeLogsResponse,
SubscribeStatesRequest,
SwitchCommandRequest,
SwitchStateResponse,
TextSensorStateResponse,
)
from aioesphomeapi.connection import APIConnection, ConnectionParams
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.model import *
from aioesphomeapi.model import (
APIVersion,
BinarySensorInfo,
BinarySensorState,
CameraInfo,
CameraState,
ClimateFanMode,
ClimateInfo,
ClimateMode,
ClimateState,
ClimateSwingMode,
CoverInfo,
CoverState,
DeviceInfo,
EntityInfo,
FanDirection,
FanInfo,
FanSpeed,
FanState,
HomeassistantServiceCall,
LegacyCoverCommand,
LightInfo,
LightState,
SensorInfo,
SensorState,
SwitchInfo,
SwitchState,
TextSensorInfo,
TextSensorState,
UserService,
UserServiceArg,
UserServiceArgType,
)
_LOGGER = logging.getLogger(__name__)
ExecuteServiceDataType = Dict[
str, Union[bool, int, float, str, List[bool], List[int], List[float], List[str]]
]
class APIClient:
def __init__(self, eventloop, address: str, port: int, password: str, *,
client_info: str = 'aioesphomeapi', keepalive: float = 15.0,
zeroconf_instance: zeroconf.Zeroconf = None):
def __init__(
self,
eventloop: asyncio.AbstractEventLoop,
address: str,
port: int,
password: str,
*,
client_info: str = "aioesphomeapi",
keepalive: float = 15.0,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None
):
self._params = ConnectionParams(
eventloop=eventloop,
address=address,
@ -21,18 +110,22 @@ class APIClient:
password=password,
client_info=client_info,
keepalive=keepalive,
zeroconf_instance=zeroconf_instance
zeroconf_instance=zeroconf_instance,
)
self._connection = None # type: Optional[APIConnection]
async def connect(self, on_stop=None, login=False):
async def connect(
self,
on_stop: Optional[Callable[[], Awaitable[None]]] = None,
login: bool = False,
) -> None:
if self._connection is not None:
raise APIConnectionError("Already connected!")
connected = False
stopped = False
async def _on_stop():
async def _on_stop() -> None:
nonlocal stopped
if stopped:
@ -53,31 +146,33 @@ class APIClient:
raise
except Exception as e:
await _on_stop()
raise APIConnectionError(
"Unexpected error while connecting: {}".format(e))
raise APIConnectionError("Unexpected error while connecting: {}".format(e))
connected = True
async def disconnect(self, force=False):
async def disconnect(self, force: bool = False) -> None:
if self._connection is None:
return
await self._connection.stop(force=force)
def _check_connected(self):
def _check_connected(self) -> None:
if self._connection is None:
raise APIConnectionError("Not connected!")
if not self._connection.is_connected:
raise APIConnectionError("Connection not done!")
def _check_authenticated(self):
def _check_authenticated(self) -> None:
self._check_connected()
assert self._connection is not None
if not self._connection.is_authenticated:
raise APIConnectionError("Not authenticated!")
async def device_info(self) -> DeviceInfo:
self._check_connected()
assert self._connection is not None
resp = await self._connection.send_message_await_response(
pb.DeviceInfoRequest(), pb.DeviceInfoResponse)
DeviceInfoRequest(), DeviceInfoResponse
)
return DeviceInfo(
uses_password=resp.uses_password,
name=resp.name,
@ -88,49 +183,60 @@ class APIClient:
has_deep_sleep=resp.has_deep_sleep,
)
async def list_entities_services(self) -> Tuple[List[Any], List[UserService]]:
async def list_entities_services(
self,
) -> Tuple[List[EntityInfo], List[UserService]]:
self._check_authenticated()
response_types = {
pb.ListEntitiesBinarySensorResponse: BinarySensorInfo,
pb.ListEntitiesCoverResponse: CoverInfo,
pb.ListEntitiesFanResponse: FanInfo,
pb.ListEntitiesLightResponse: LightInfo,
pb.ListEntitiesSensorResponse: SensorInfo,
pb.ListEntitiesSwitchResponse: SwitchInfo,
pb.ListEntitiesTextSensorResponse: TextSensorInfo,
pb.ListEntitiesServicesResponse: None,
pb.ListEntitiesCameraResponse: CameraInfo,
pb.ListEntitiesClimateResponse: ClimateInfo,
ListEntitiesBinarySensorResponse: BinarySensorInfo,
ListEntitiesCoverResponse: CoverInfo,
ListEntitiesFanResponse: FanInfo,
ListEntitiesLightResponse: LightInfo,
ListEntitiesSensorResponse: SensorInfo,
ListEntitiesSwitchResponse: SwitchInfo,
ListEntitiesTextSensorResponse: TextSensorInfo,
ListEntitiesServicesResponse: None,
ListEntitiesCameraResponse: CameraInfo,
ListEntitiesClimateResponse: ClimateInfo,
}
def do_append(msg):
def do_append(msg: message.Message) -> bool:
return isinstance(msg, tuple(response_types.keys()))
def do_stop(msg):
return isinstance(msg, pb.ListEntitiesDoneResponse)
def do_stop(msg: message.Message) -> bool:
return isinstance(msg, ListEntitiesDoneResponse)
assert self._connection is not None
resp = await self._connection.send_message_await_response_complex(
pb.ListEntitiesRequest(), do_append, do_stop, timeout=5)
entities = []
services = []
ListEntitiesRequest(), do_append, do_stop, timeout=5
)
entities: List[EntityInfo] = []
services: List[UserService] = []
for msg in resp:
if isinstance(msg, pb.ListEntitiesServicesResponse):
if isinstance(msg, ListEntitiesServicesResponse):
args = []
for arg in msg.args:
args.append(UserServiceArg(
args.append(
UserServiceArg(
name=arg.name,
type_=arg.type,
))
services.append(UserService(
)
)
services.append(
UserService(
name=msg.name,
key=msg.key,
args=args,
))
args=args, # type: ignore
)
)
continue
cls = None
for resp_type, cls in response_types.items():
if isinstance(msg, resp_type):
break
else:
continue
cls = cast(type, cls)
kwargs = {}
for key, _ in attr.fields_dict(cls).items():
kwargs[key] = getattr(msg, key)
@ -141,20 +247,20 @@ class APIClient:
self._check_authenticated()
response_types = {
pb.BinarySensorStateResponse: BinarySensorState,
pb.CoverStateResponse: CoverState,
pb.FanStateResponse: FanState,
pb.LightStateResponse: LightState,
pb.SensorStateResponse: SensorState,
pb.SwitchStateResponse: SwitchState,
pb.TextSensorStateResponse: TextSensorState,
pb.ClimateStateResponse: ClimateState,
BinarySensorStateResponse: BinarySensorState,
CoverStateResponse: CoverState,
FanStateResponse: FanState,
LightStateResponse: LightState,
SensorStateResponse: SensorState,
SwitchStateResponse: SwitchState,
TextSensorStateResponse: TextSensorState,
ClimateStateResponse: ClimateState,
}
image_stream = {}
image_stream: Dict[int, bytes] = {}
def on_msg(msg):
if isinstance(msg, pb.CameraImageResponse):
def on_msg(msg: message.Message) -> None:
if isinstance(msg, CameraImageResponse):
data = image_stream.pop(msg.key, bytes()) + msg.data
if msg.done:
on_state(CameraState(key=msg.key, image=data))
@ -174,33 +280,43 @@ class APIClient:
kwargs[key] = getattr(msg, key)
on_state(cls(**kwargs))
await self._connection.send_message_callback_response(pb.SubscribeStatesRequest(), on_msg)
assert self._connection is not None
await self._connection.send_message_callback_response(
SubscribeStatesRequest(), on_msg
)
async def subscribe_logs(self, on_log: Callable[[pb.SubscribeLogsResponse], None],
log_level=None) -> None:
async def subscribe_logs(
self,
on_log: Callable[[SubscribeLogsResponse], None],
log_level: Optional[LogLevel] = None,
) -> None:
self._check_authenticated()
def on_msg(msg):
if isinstance(msg, pb.SubscribeLogsResponse):
def on_msg(msg: message.Message) -> None:
if isinstance(msg, SubscribeLogsResponse):
on_log(msg)
req = pb.SubscribeLogsRequest()
req = SubscribeLogsRequest()
if log_level is not None:
req.level = log_level
assert self._connection is not None
await self._connection.send_message_callback_response(req, on_msg)
async def subscribe_service_calls(self, on_service_call: Callable[[HomeassistantServiceCall], None]) -> None:
async def subscribe_service_calls(
self, on_service_call: Callable[[HomeassistantServiceCall], None]
) -> None:
self._check_authenticated()
def on_msg(msg):
if isinstance(msg, pb.HomeassistantServiceResponse):
def on_msg(msg: message.Message) -> None:
if isinstance(msg, HomeassistantServiceResponse):
kwargs = {}
for key, _ in attr.fields_dict(HomeassistantServiceCall).items():
kwargs[key] = getattr(msg, key)
on_service_call(HomeassistantServiceCall(**kwargs))
assert self._connection is not None
await self._connection.send_message_callback_response(
pb.SubscribeHomeassistantServicesRequest(), on_msg
SubscribeHomeassistantServicesRequest(), on_msg
)
async def subscribe_home_assistant_states(
@ -208,12 +324,13 @@ class APIClient:
) -> None:
self._check_authenticated()
def on_msg(msg):
if isinstance(msg, pb.SubscribeHomeAssistantStateResponse):
def on_msg(msg: message.Message) -> None:
if isinstance(msg, SubscribeHomeAssistantStateResponse):
on_state_sub(msg.entity_id, msg.attribute)
assert self._connection is not None
await self._connection.send_message_callback_response(
pb.SubscribeHomeAssistantStatesRequest(), on_msg
SubscribeHomeAssistantStatesRequest(), on_msg
)
async def send_home_assistant_state(
@ -221,15 +338,17 @@ class APIClient:
) -> None:
self._check_authenticated()
assert self._connection is not None
await self._connection.send_message(
pb.HomeAssistantStateResponse(
HomeAssistantStateResponse(
entity_id=entity_id,
state=state,
attribute=attribute,
)
)
async def cover_command(self,
async def cover_command(
self,
key: int,
position: Optional[float] = None,
tilt: Optional[float] = None,
@ -237,9 +356,10 @@ class APIClient:
) -> None:
self._check_authenticated()
req = pb.CoverCommandRequest()
req = CoverCommandRequest()
req.key = key
if self.api_version >= APIVersion(1, 1):
apiv = cast(APIVersion, self.api_version)
if apiv >= APIVersion(1, 1):
if position is not None:
req.has_position = True
req.position = position
@ -256,19 +376,21 @@ class APIClient:
req.legacy_command = LegacyCoverCommand.OPEN
else:
req.legacy_command = LegacyCoverCommand.CLOSE
assert self._connection is not None
await self._connection.send_message(req)
async def fan_command(self,
async def fan_command(
self,
key: int,
state: Optional[bool] = None,
speed: Optional[FanSpeed] = None,
speed_level: Optional[int] = None,
oscillating: Optional[bool] = None,
direction: Optional[FanDirection] = None
direction: Optional[FanDirection] = None,
) -> None:
self._check_authenticated()
req = pb.FanCommandRequest()
req = FanCommandRequest()
req.key = key
if state is not None:
req.has_state = True
@ -285,9 +407,11 @@ class APIClient:
if direction is not None:
req.has_direction = True
req.direction = direction
assert self._connection is not None
await self._connection.send_message(req)
async def light_command(self,
async def light_command(
self,
key: int,
state: Optional[bool] = None,
brightness: Optional[float] = None,
@ -297,10 +421,10 @@ class APIClient:
transition_length: Optional[float] = None,
flash_length: Optional[float] = None,
effect: Optional[str] = None,
):
) -> None:
self._check_authenticated()
req = pb.LightCommandRequest()
req = LightCommandRequest()
req.key = key
if state is not None:
req.has_state = True
@ -328,20 +452,20 @@ class APIClient:
if effect is not None:
req.has_effect = True
req.effect = effect
assert self._connection is not None
await self._connection.send_message(req)
async def switch_command(self,
key: int,
state: bool
) -> None:
async def switch_command(self, key: int, state: bool) -> None:
self._check_authenticated()
req = pb.SwitchCommandRequest()
req = SwitchCommandRequest()
req.key = key
req.state = state
assert self._connection is not None
await self._connection.send_message(req)
async def climate_command(self,
async def climate_command(
self,
key: int,
mode: Optional[ClimateMode] = None,
target_temperature: Optional[float] = None,
@ -353,7 +477,7 @@ class APIClient:
) -> None:
self._check_authenticated()
req = pb.ClimateCommandRequest()
req = ClimateCommandRequest()
req.key = key
if mode is not None:
req.has_mode = True
@ -376,30 +500,33 @@ class APIClient:
if swing_mode is not None:
req.has_swing_mode = True
req.swing_mode = swing_mode
assert self._connection is not None
await self._connection.send_message(req)
async def execute_service(self, service: UserService, data: dict):
async def execute_service(
self, service: UserService, data: ExecuteServiceDataType
) -> None:
self._check_authenticated()
req = pb.ExecuteServiceRequest()
req = ExecuteServiceRequest()
req.key = service.key
args = []
for arg_desc in service.args:
arg = pb.ExecuteServiceArgument()
arg = ExecuteServiceArgument()
val = data[arg_desc.name]
int_type = 'int_' if self.api_version >= APIVersion(
1, 3) else 'legacy_int'
apiv = cast(APIVersion, self.api_version)
int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
map_single = {
UserServiceArgType.BOOL: 'bool_',
UserServiceArgType.BOOL: "bool_",
UserServiceArgType.INT: int_type,
UserServiceArgType.FLOAT: 'float_',
UserServiceArgType.STRING: 'string_',
UserServiceArgType.FLOAT: "float_",
UserServiceArgType.STRING: "string_",
}
map_array = {
UserServiceArgType.BOOL_ARRAY: 'bool_array',
UserServiceArgType.INT_ARRAY: 'int_array',
UserServiceArgType.FLOAT_ARRAY: 'float_array',
UserServiceArgType.STRING_ARRAY: 'string_array',
UserServiceArgType.BOOL_ARRAY: "bool_array",
UserServiceArgType.INT_ARRAY: "int_array",
UserServiceArgType.FLOAT_ARRAY: "float_array",
UserServiceArgType.STRING_ARRAY: "string_array",
}
# pylint: disable=redefined-outer-name
if arg_desc.type_ in map_array:
@ -411,18 +538,22 @@ class APIClient:
args.append(arg)
# pylint: disable=no-member
req.args.extend(args)
assert self._connection is not None
await self._connection.send_message(req)
async def _request_image(self, *, single=False, stream=False):
req = pb.CameraImageRequest()
async def _request_image(
self, *, single: bool = False, stream: bool = False
) -> None:
req = CameraImageRequest()
req.single = single
req.stream = stream
assert self._connection is not None
await self._connection.send_message(req)
async def request_single_image(self):
async def request_single_image(self) -> None:
await self._request_image(single=True)
async def request_image_stream(self):
async def request_image_stream(self) -> None:
await self._request_image(stream=True)
@property

View File

@ -2,14 +2,25 @@ import asyncio
import logging
import socket
import time
from typing import Any, Callable, List, Optional, cast
from typing import Any, Awaitable, Callable, List, Optional, cast
import attr
import zeroconf
from google.protobuf import message
import aioesphomeapi.api_pb2 as pb
from aioesphomeapi.core import APIConnectionError, MESSAGE_TYPE_TO_PROTO
from aioesphomeapi.api_pb2 import ( # type: ignore
ConnectRequest,
ConnectResponse,
DisconnectRequest,
DisconnectResponse,
GetTimeRequest,
GetTimeResponse,
HelloRequest,
HelloResponse,
PingRequest,
PingResponse,
)
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from aioesphomeapi.model import APIVersion
from aioesphomeapi.util import _bytes_to_varuint, _varuint_to_bytes, resolve_ip_address
@ -24,26 +35,27 @@ class ConnectionParams:
password = attr.ib(type=Optional[str])
client_info = attr.ib(type=str)
keepalive = attr.ib(type=float)
zeroconf_instance = attr.ib(type=zeroconf.Zeroconf)
zeroconf_instance = attr.ib(type=Optional[zeroconf.Zeroconf])
class APIConnection:
def __init__(self, params: ConnectionParams, on_stop):
def __init__(
self, params: ConnectionParams, on_stop: Callable[[], Awaitable[None]]
):
self._params = params
self.on_stop = on_stop
self._stopped = False
self._socket = None # type: Optional[socket.socket]
self._socket_reader = None # type: Optional[asyncio.StreamReader]
self._socket_writer = None # type: Optional[asyncio.StreamWriter]
self._socket: Optional[socket.socket] = None
self._socket_reader: Optional[asyncio.StreamReader] = None
self._socket_writer: Optional[asyncio.StreamWriter] = None
self._write_lock = asyncio.Lock()
self._connected = False
self._authenticated = False
self._socket_connected = False
self._state_lock = asyncio.Lock()
self._api_version = None # type: Optional[APIVersion]
self._api_version: Optional[APIVersion] = None
self._message_handlers = [] # type: List[Callable[[message], None]]
self._running_task = None # type: Optional[asyncio.Task]
self._message_handlers: List[Callable[[message.Message], None]] = []
def _start_ping(self) -> None:
async def func() -> None:
@ -66,6 +78,7 @@ class APIConnection:
if not self._socket_connected:
return
async with self._write_lock:
if self._socket_writer is not None:
self._socket_writer.close()
self._socket_writer = None
self._socket_reader = None
@ -85,8 +98,6 @@ class APIConnection:
except APIConnectionError:
pass
self._stopped = True
if self._running_task is not None:
self._running_task.cancel()
await self._close_socket()
await self.on_stop()
@ -100,8 +111,12 @@ class APIConnection:
raise APIConnectionError("Already connected!")
try:
coro = resolve_ip_address(self._params.eventloop, self._params.address,
self._params.port, self._params.zeroconf_instance)
coro = resolve_ip_address(
self._params.eventloop,
self._params.address,
self._params.port,
self._params.zeroconf_instance,
)
sockaddr = await asyncio.wait_for(coro, 30.0)
except APIConnectionError as err:
await self._on_error()
@ -114,40 +129,51 @@ class APIConnection:
self._socket.setblocking(False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
_LOGGER.debug("%s: Connecting to %s:%s (%s)", self._params.address,
self._params.address, self._params.port, sockaddr)
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self._params.address,
self._params.address,
self._params.port,
sockaddr,
)
try:
coro = self._params.eventloop.sock_connect(self._socket, sockaddr)
await asyncio.wait_for(coro, 30.0)
coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
await asyncio.wait_for(coro2, 30.0)
except OSError as err:
await self._on_error()
raise APIConnectionError(
"Error connecting to {}: {}".format(sockaddr, err))
raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err))
except asyncio.TimeoutError:
await self._on_error()
raise APIConnectionError(
"Timeout while connecting to {}".format(sockaddr))
raise APIConnectionError("Timeout while connecting to {}".format(sockaddr))
_LOGGER.debug("%s: Opened socket for", self._params.address)
self._socket_reader, self._socket_writer = await asyncio.open_connection(sock=self._socket)
self._socket_reader, self._socket_writer = await asyncio.open_connection(
sock=self._socket
)
self._socket_connected = True
self._params.eventloop.create_task(self.run_forever())
hello = pb.HelloRequest()
hello = HelloRequest()
hello.client_info = self._params.client_info
try:
resp = await self.send_message_await_response(hello, pb.HelloResponse)
resp = await self.send_message_await_response(hello, HelloResponse)
except APIConnectionError as err:
await self._on_error()
raise err
_LOGGER.debug("%s: Successfully connected ('%s' API=%s.%s)",
self._params.address, resp.server_info, resp.api_version_major,
resp.api_version_minor)
self._api_version = APIVersion(
resp.api_version_major, resp.api_version_minor)
_LOGGER.debug(
"%s: Successfully connected ('%s' API=%s.%s)",
self._params.address,
resp.server_info,
resp.api_version_major,
resp.api_version_minor,
)
self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
if self._api_version.major > 2:
_LOGGER.error("%s: Incompatible version %s! Closing connection",
self._params.address, self._api_version.major)
_LOGGER.error(
"%s: Incompatible version %s! Closing connection",
self._params.address,
self._api_version.major,
)
await self._on_error()
raise APIConnectionError("Incompatible API version.")
self._connected = True
@ -159,10 +185,10 @@ class APIConnection:
if self._authenticated:
raise APIConnectionError("Already logged in!")
connect = pb.ConnectRequest()
connect = ConnectRequest()
if self._params.password is not None:
connect.password = self._params.password
resp = await self.send_message_await_response(connect, pb.ConnectResponse)
resp = await self.send_message_await_response(connect, ConnectResponse)
if resp.invalid_password:
raise APIConnectionError("Invalid password!")
@ -187,12 +213,12 @@ class APIConnection:
raise APIConnectionError("Socket is not connected")
try:
async with self._write_lock:
if self._socket_writer is not None:
self._socket_writer.write(data)
await self._socket_writer.drain()
except OSError as err:
await self._on_error()
raise APIConnectionError(
"Error while writing data: {}".format(err))
raise APIConnectionError("Error while writing data: {}".format(err))
async def send_message(self, msg: message.Message) -> None:
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
@ -202,8 +228,7 @@ class APIConnection:
raise ValueError
encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s",
self._params.address, type(msg), str(msg))
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
req = bytes([0])
req += _varuint_to_bytes(len(encoded))
# pylint: disable=undefined-loop-variable
@ -211,19 +236,23 @@ class APIConnection:
req += encoded
await self._write(req)
async def send_message_callback_response(self, send_msg: message.Message,
on_message: Callable[[Any], None]) -> None:
async def send_message_callback_response(
self, send_msg: message.Message, on_message: Callable[[Any], None]
) -> None:
self._message_handlers.append(on_message)
await self.send_message(send_msg)
async def send_message_await_response_complex(self, send_msg: message.Message,
async def send_message_await_response_complex(
self,
send_msg: message.Message,
do_append: Callable[[Any], bool],
do_stop: Callable[[Any], bool],
timeout: float = 5.0) -> List[Any]:
timeout: float = 5.0,
) -> List[Any]:
fut = self._params.eventloop.create_future()
responses = []
def on_message(resp):
def on_message(resp: message.Message) -> None:
if fut.done():
return
if do_append(resp):
@ -238,8 +267,7 @@ class APIConnection:
await asyncio.wait_for(fut, timeout)
except asyncio.TimeoutError:
if self._stopped:
raise APIConnectionError(
"Disconnected while waiting for API response!")
raise APIConnectionError("Disconnected while waiting for API response!")
raise APIConnectionError("Timeout while waiting for API response!")
try:
@ -249,17 +277,17 @@ class APIConnection:
return responses
async def send_message_await_response(self,
send_msg: message.Message,
response_type: Any, timeout: float = 5.0) -> Any:
def is_response(msg):
async def send_message_await_response(
self, send_msg: message.Message, response_type: Any, timeout: float = 5.0
) -> Any:
def is_response(msg: message.Message) -> bool:
return isinstance(msg, response_type)
res = await self.send_message_await_response_complex(
send_msg, is_response, is_response, timeout=timeout)
send_msg, is_response, is_response, timeout=timeout
)
if len(res) != 1:
raise APIConnectionError(
"Expected one result, got {}".format(len(res)))
raise APIConnectionError("Expected one result, got {}".format(len(res)))
return res[0]
@ -268,10 +296,10 @@ class APIConnection:
return bytes()
try:
assert self._socket_reader is not None
ret = await self._socket_reader.readexactly(amount)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
raise APIConnectionError(
"Error while receiving data: {}".format(err))
raise APIConnectionError("Error while receiving data: {}".format(err))
return ret
@ -291,8 +319,9 @@ class APIConnection:
raw_msg = await self._recv(length)
if msg_type not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug("%s: Skipping message type %s",
self._params.address, msg_type)
_LOGGER.debug(
"%s: Skipping message type %s", self._params.address, msg_type
)
return
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
@ -300,8 +329,9 @@ class APIConnection:
msg.ParseFromString(raw_msg)
except Exception as e:
raise APIConnectionError("Invalid protobuf message: {}".format(e))
_LOGGER.debug("%s: Got message of type %s: %s",
self._params.address, type(msg), msg)
_LOGGER.debug(
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
)
for msg_handler in self._message_handlers[:]:
msg_handler(msg)
await self._handle_internal_messages(msg)
@ -311,36 +341,44 @@ class APIConnection:
try:
await self._run_once()
except APIConnectionError as err:
_LOGGER.info("%s: Error while reading incoming messages: %s",
self._params.address, err)
_LOGGER.info(
"%s: Error while reading incoming messages: %s",
self._params.address,
err,
)
await self._on_error()
break
except Exception as err: # pylint: disable=broad-except
_LOGGER.info("%s: Unexpected error while reading incoming messages: %s",
self._params.address, err)
_LOGGER.info(
"%s: Unexpected error while reading incoming messages: %s",
self._params.address,
err,
)
await self._on_error()
break
async def _handle_internal_messages(self, msg: Any) -> None:
if isinstance(msg, pb.DisconnectRequest):
await self.send_message(pb.DisconnectResponse())
if isinstance(msg, DisconnectRequest):
await self.send_message(DisconnectResponse())
await self.stop(force=True)
elif isinstance(msg, pb.PingRequest):
await self.send_message(pb.PingResponse())
elif isinstance(msg, pb.GetTimeRequest):
resp = pb.GetTimeResponse()
elif isinstance(msg, PingRequest):
await self.send_message(PingResponse())
elif isinstance(msg, GetTimeRequest):
resp = GetTimeResponse()
resp.epoch_seconds = int(time.time())
await self.send_message(resp)
async def ping(self) -> None:
self._check_connected()
await self.send_message_await_response(pb.PingRequest(), pb.PingResponse)
await self.send_message_await_response(PingRequest(), PingResponse)
async def _disconnect(self) -> None:
self._check_connected()
try:
await self.send_message_await_response(pb.DisconnectRequest(), pb.DisconnectResponse)
await self.send_message_await_response(
DisconnectRequest(), DisconnectResponse
)
except APIConnectionError:
pass

View File

@ -1,4 +1,53 @@
import aioesphomeapi.api_pb2 as pb
from aioesphomeapi.api_pb2 import ( # type: ignore
BinarySensorStateResponse,
CameraImageRequest,
CameraImageResponse,
ClimateCommandRequest,
ClimateStateResponse,
ConnectRequest,
ConnectResponse,
CoverCommandRequest,
CoverStateResponse,
DeviceInfoRequest,
DeviceInfoResponse,
DisconnectRequest,
DisconnectResponse,
ExecuteServiceRequest,
FanCommandRequest,
FanStateResponse,
GetTimeRequest,
GetTimeResponse,
HelloRequest,
HelloResponse,
HomeassistantServiceResponse,
HomeAssistantStateResponse,
LightCommandRequest,
LightStateResponse,
ListEntitiesBinarySensorResponse,
ListEntitiesCameraResponse,
ListEntitiesClimateResponse,
ListEntitiesCoverResponse,
ListEntitiesDoneResponse,
ListEntitiesFanResponse,
ListEntitiesLightResponse,
ListEntitiesRequest,
ListEntitiesSensorResponse,
ListEntitiesServicesResponse,
ListEntitiesSwitchResponse,
ListEntitiesTextSensorResponse,
PingRequest,
PingResponse,
SensorStateResponse,
SubscribeHomeassistantServicesRequest,
SubscribeHomeAssistantStateResponse,
SubscribeHomeAssistantStatesRequest,
SubscribeLogsRequest,
SubscribeLogsResponse,
SubscribeStatesRequest,
SwitchCommandRequest,
SwitchStateResponse,
TextSensorStateResponse,
)
class APIConnectionError(Exception):
@ -6,52 +55,52 @@ class APIConnectionError(Exception):
MESSAGE_TYPE_TO_PROTO = {
1: pb.HelloRequest,
2: pb.HelloResponse,
3: pb.ConnectRequest,
4: pb.ConnectResponse,
5: pb.DisconnectRequest,
6: pb.DisconnectResponse,
7: pb.PingRequest,
8: pb.PingResponse,
9: pb.DeviceInfoRequest,
10: pb.DeviceInfoResponse,
11: pb.ListEntitiesRequest,
12: pb.ListEntitiesBinarySensorResponse,
13: pb.ListEntitiesCoverResponse,
14: pb.ListEntitiesFanResponse,
15: pb.ListEntitiesLightResponse,
16: pb.ListEntitiesSensorResponse,
17: pb.ListEntitiesSwitchResponse,
18: pb.ListEntitiesTextSensorResponse,
19: pb.ListEntitiesDoneResponse,
20: pb.SubscribeStatesRequest,
21: pb.BinarySensorStateResponse,
22: pb.CoverStateResponse,
23: pb.FanStateResponse,
24: pb.LightStateResponse,
25: pb.SensorStateResponse,
26: pb.SwitchStateResponse,
27: pb.TextSensorStateResponse,
28: pb.SubscribeLogsRequest,
29: pb.SubscribeLogsResponse,
30: pb.CoverCommandRequest,
31: pb.FanCommandRequest,
32: pb.LightCommandRequest,
33: pb.SwitchCommandRequest,
34: pb.SubscribeHomeassistantServicesRequest,
35: pb.HomeassistantServiceResponse,
36: pb.GetTimeRequest,
37: pb.GetTimeResponse,
38: pb.SubscribeHomeAssistantStatesRequest,
39: pb.SubscribeHomeAssistantStateResponse,
40: pb.HomeAssistantStateResponse,
41: pb.ListEntitiesServicesResponse,
42: pb.ExecuteServiceRequest,
43: pb.ListEntitiesCameraResponse,
44: pb.CameraImageResponse,
45: pb.CameraImageRequest,
46: pb.ListEntitiesClimateResponse,
47: pb.ClimateStateResponse,
48: pb.ClimateCommandRequest,
1: HelloRequest,
2: HelloResponse,
3: ConnectRequest,
4: ConnectResponse,
5: DisconnectRequest,
6: DisconnectResponse,
7: PingRequest,
8: PingResponse,
9: DeviceInfoRequest,
10: DeviceInfoResponse,
11: ListEntitiesRequest,
12: ListEntitiesBinarySensorResponse,
13: ListEntitiesCoverResponse,
14: ListEntitiesFanResponse,
15: ListEntitiesLightResponse,
16: ListEntitiesSensorResponse,
17: ListEntitiesSwitchResponse,
18: ListEntitiesTextSensorResponse,
19: ListEntitiesDoneResponse,
20: SubscribeStatesRequest,
21: BinarySensorStateResponse,
22: CoverStateResponse,
23: FanStateResponse,
24: LightStateResponse,
25: SensorStateResponse,
26: SwitchStateResponse,
27: TextSensorStateResponse,
28: SubscribeLogsRequest,
29: SubscribeLogsResponse,
30: CoverCommandRequest,
31: FanCommandRequest,
32: LightCommandRequest,
33: SwitchCommandRequest,
34: SubscribeHomeassistantServicesRequest,
35: HomeassistantServiceResponse,
36: GetTimeRequest,
37: GetTimeResponse,
38: SubscribeHomeAssistantStatesRequest,
39: SubscribeHomeAssistantStateResponse,
40: HomeAssistantStateResponse,
41: ListEntitiesServicesResponse,
42: ExecuteServiceRequest,
43: ListEntitiesCameraResponse,
44: CameraImageResponse,
45: CameraImageRequest,
46: ListEntitiesClimateResponse,
47: ClimateStateResponse,
48: ClimateCommandRequest,
}

View File

@ -1,15 +1,18 @@
import socket
import time
from typing import Optional
import zeroconf
class HostResolver(zeroconf.RecordUpdateListener):
def __init__(self, name):
def __init__(self, name: str):
self.name = name
self.address = None
self.address: Optional[bytes] = None
def update_record(self, zc, now, record):
def update_record(
self, zc: zeroconf.Zeroconf, now: float, record: zeroconf.DNSRecord
) -> None:
if record is None:
return
if record.type == zeroconf._TYPE_A:
@ -17,15 +20,17 @@ class HostResolver(zeroconf.RecordUpdateListener):
if record.name == self.name:
self.address = record.address
def request(self, zc, timeout):
def request(self, zc: zeroconf.Zeroconf, timeout: float) -> bool:
now = time.time()
delay = 0.2
next_ = now + delay
last = now + timeout
try:
zc.add_listener(self, zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY,
zeroconf._CLASS_IN))
zc.add_listener(
self,
zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zeroconf._CLASS_IN),
)
while self.address is None:
if last <= now:
# Timeout
@ -33,10 +38,16 @@ class HostResolver(zeroconf.RecordUpdateListener):
if next_ <= now:
out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY)
out.add_question(
zeroconf.DNSQuestion(self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN))
zeroconf.DNSQuestion(
self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
)
)
out.add_answer_at_time(
zc.cache.get_by_details(self.name, zeroconf._TYPE_A,
zeroconf._CLASS_IN), now)
zc.cache.get_by_details(
self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
),
now,
)
zc.send(out)
next_ = now + delay
delay *= 2
@ -49,27 +60,38 @@ class HostResolver(zeroconf.RecordUpdateListener):
return True
def resolve_host(host, timeout=3.0, zeroconf_instance: zeroconf.Zeroconf = None):
from aioesphomeapi import APIConnectionError
def resolve_host(
host: str,
timeout: float = 3.0,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None,
) -> str:
from aioesphomeapi.core import APIConnectionError
try:
zc = zeroconf_instance or zeroconf.Zeroconf()
except Exception:
raise APIConnectionError("Cannot start mDNS sockets, is this a docker container without "
"host network mode?")
raise APIConnectionError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
try:
info = HostResolver(host + '.')
info = HostResolver(host + ".")
assert info.address is not None
address = None
if info.request(zc, timeout):
address = socket.inet_ntoa(info.address)
except Exception as err:
raise APIConnectionError("Error resolving mDNS hostname: {}".format(err))
raise APIConnectionError(
"Error resolving mDNS hostname: {}".format(err)
) from err
finally:
if not zeroconf_instance:
zc.close()
if address is None:
raise APIConnectionError("Error resolving address with mDNS: Did not respond. "
"Maybe the device is offline.")
raise APIConnectionError(
"Error resolving address with mDNS: Did not respond. "
"Maybe the device is offline."
)
return address

View File

@ -1,19 +1,23 @@
import enum
from typing import List, Dict, TypeVar, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type, TypeVar
import attr
if TYPE_CHECKING:
from .api_pb2 import HomeassistantServiceMap # type: ignore
# All fields in here should have defaults set
# Home Assistant depends on these fields being constructible
# with args from a previous version of Home Assistant.
# The default value should *always* be the Protobuf default value
# for a field (False, 0, empty string, enum with value 0, ...)
_T = TypeVar("_T", bound="APIIntEnum")
_T = TypeVar('_T')
class APIIntEnum(enum.IntEnum):
"""Base class for int enum values in API model."""
@classmethod
def convert(cls: Type[_T], value: int) -> Optional[_T]:
try:
@ -41,20 +45,20 @@ class APIVersion:
@attr.s
class DeviceInfo:
uses_password = attr.ib(type=bool, default=False)
name = attr.ib(type=str, default='')
mac_address = attr.ib(type=str, default='')
compilation_time = attr.ib(type=str, default='')
model = attr.ib(type=str, default='')
name = attr.ib(type=str, default="")
mac_address = attr.ib(type=str, default="")
compilation_time = attr.ib(type=str, default="")
model = attr.ib(type=str, default="")
has_deep_sleep = attr.ib(type=bool, default=False)
esphome_version = attr.ib(type=str, default='')
esphome_version = attr.ib(type=str, default="")
@attr.s
class EntityInfo:
object_id = attr.ib(type=str, default='')
object_id = attr.ib(type=str, default="")
key = attr.ib(type=int, default=0)
name = attr.ib(type=str, default='')
unique_id = attr.ib(type=str, default='')
name = attr.ib(type=str, default="")
unique_id = attr.ib(type=str, default="")
@attr.s
@ -65,7 +69,7 @@ class EntityState:
# ==================== BINARY SENSOR ====================
@attr.s
class BinarySensorInfo(EntityInfo):
device_class = attr.ib(type=str, default='')
device_class = attr.ib(type=str, default="")
is_status_binary_sensor = attr.ib(type=bool, default=False)
@ -81,7 +85,7 @@ class CoverInfo(EntityInfo):
assumed_state = attr.ib(type=bool, default=False)
supports_position = attr.ib(type=bool, default=False)
supports_tilt = attr.ib(type=bool, default=False)
device_class = attr.ib(type=str, default='')
device_class = attr.ib(type=str, default="")
class LegacyCoverState(APIIntEnum):
@ -103,14 +107,20 @@ class CoverOperation(APIIntEnum):
@attr.s
class CoverState(EntityState):
legacy_state = attr.ib(type=Optional[LegacyCoverState], converter=LegacyCoverState.convert,
default=LegacyCoverState.OPEN)
legacy_state = attr.ib(
type=LegacyCoverState,
converter=LegacyCoverState.convert, # type: ignore
default=LegacyCoverState.OPEN,
)
position = attr.ib(type=float, default=0.0)
tilt = attr.ib(type=float, default=0.0)
current_operation = attr.ib(type=Optional[CoverOperation], converter=CoverOperation.convert,
default=CoverOperation.IDLE)
current_operation = attr.ib(
type=CoverOperation,
converter=CoverOperation.convert, # type: ignore
default=CoverOperation.IDLE,
)
def is_closed(self, api_version: APIVersion):
def is_closed(self, api_version: APIVersion) -> bool:
if api_version >= APIVersion(1, 1):
return self.position == 0.0
return self.legacy_state == LegacyCoverState.CLOSED
@ -140,9 +150,17 @@ class FanDirection(APIIntEnum):
class FanState(EntityState):
state = attr.ib(type=bool, default=False)
oscillating = attr.ib(type=bool, default=False)
speed = attr.ib(type=Optional[FanSpeed], converter=FanSpeed.convert, default=FanSpeed.LOW)
speed = attr.ib(
type=Optional[FanSpeed],
converter=FanSpeed.convert, # type: ignore
default=FanSpeed.LOW,
)
speed_level = attr.ib(type=int, default=0)
direction = attr.ib(type=Optional[FanDirection], converter=FanDirection.convert, default=FanDirection.FORWARD)
direction = attr.ib(
type=FanDirection,
converter=FanDirection.convert, # type: ignore
default=FanDirection.FORWARD,
)
# ==================== LIGHT ====================
@ -166,7 +184,7 @@ class LightState(EntityState):
blue = attr.ib(type=float, default=0.0)
white = attr.ib(type=float, default=0.0)
color_temperature = attr.ib(type=float, default=0.0)
effect = attr.ib(type=str, default='')
effect = attr.ib(type=str, default="")
# ==================== SENSOR ====================
@ -174,14 +192,19 @@ class SensorStateClass(APIIntEnum):
NONE = 0
MEASUREMENT = 1
@attr.s
class SensorInfo(EntityInfo):
icon = attr.ib(type=str, default='')
device_class = attr.ib(type=str, default='')
unit_of_measurement = attr.ib(type=str, default='')
icon = attr.ib(type=str, default="")
device_class = attr.ib(type=str, default="")
unit_of_measurement = attr.ib(type=str, default="")
accuracy_decimals = attr.ib(type=int, default=0)
force_update = attr.ib(type=bool, default=False)
state_class = attr.ib(type=Optional[SensorStateClass], converter=SensorStateClass.convert, default=SensorStateClass.NONE)
state_class = attr.ib(
type=SensorStateClass,
converter=SensorStateClass.convert, # type: ignore
default=SensorStateClass.NONE,
)
@attr.s
@ -193,7 +216,7 @@ class SensorState(EntityState):
# ==================== SWITCH ====================
@attr.s
class SwitchInfo(EntityInfo):
icon = attr.ib(type=str, default='')
icon = attr.ib(type=str, default="")
assumed_state = attr.ib(type=bool, default=False)
@ -205,12 +228,12 @@ class SwitchState(EntityState):
# ==================== TEXT SENSOR ====================
@attr.s
class TextSensorInfo(EntityInfo):
icon = attr.ib(type=str, default='')
icon = attr.ib(type=str, default="")
@attr.s
class TextSensorState(EntityState):
state = attr.ib(type=str, default='')
state = attr.ib(type=str, default="")
missing_state = attr.ib(type=bool, default=False)
@ -267,68 +290,90 @@ class ClimateAction(APIIntEnum):
class ClimateInfo(EntityInfo):
supports_current_temperature = attr.ib(type=bool, default=False)
supports_two_point_target_temperature = attr.ib(type=bool, default=False)
supported_modes = attr.ib(type=List[ClimateMode], converter=ClimateMode.convert_list,
factory=list)
supported_modes = attr.ib(
type=List[ClimateMode],
converter=ClimateMode.convert_list, # type: ignore
factory=list,
)
visual_min_temperature = attr.ib(type=float, default=0.0)
visual_max_temperature = attr.ib(type=float, default=0.0)
visual_temperature_step = attr.ib(type=float, default=0.0)
supports_away = attr.ib(type=bool, default=False)
supports_action = attr.ib(type=bool, default=False)
supported_fan_modes = attr.ib(
type=List[ClimateFanMode], converter=ClimateFanMode.convert_list, factory=list
type=List[ClimateFanMode],
converter=ClimateFanMode.convert_list, # type: ignore
factory=list,
)
supported_swing_modes = attr.ib(
type=List[ClimateSwingMode], converter=ClimateSwingMode.convert_list, factory=list
type=List[ClimateSwingMode],
converter=ClimateSwingMode.convert_list, # type: ignore
factory=list,
)
@attr.s
class ClimateState(EntityState):
mode = attr.ib(type=Optional[ClimateMode], converter=ClimateMode.convert,
default=ClimateMode.OFF)
action = attr.ib(type=Optional[ClimateAction], converter=ClimateAction.convert,
default=ClimateAction.OFF)
mode = attr.ib(
type=ClimateMode,
converter=ClimateMode.convert, # type: ignore
default=ClimateMode.OFF,
)
action = attr.ib(
type=ClimateAction,
converter=ClimateAction.convert, # type: ignore
default=ClimateAction.OFF,
)
current_temperature = attr.ib(type=float, default=0.0)
target_temperature = attr.ib(type=float, default=0.0)
target_temperature_low = attr.ib(type=float, default=0.0)
target_temperature_high = attr.ib(type=float, default=0.0)
away = attr.ib(type=bool, default=False)
fan_mode = attr.ib(
type=Optional[ClimateFanMode], converter=ClimateFanMode.convert, default=ClimateFanMode.ON
type=Optional[ClimateFanMode],
converter=ClimateFanMode.convert, # type: ignore
default=ClimateFanMode.ON,
)
swing_mode = attr.ib(
type=Optional[ClimateSwingMode], converter=ClimateSwingMode.convert, default=ClimateSwingMode.OFF
type=Optional[ClimateSwingMode],
converter=ClimateSwingMode.convert, # type: ignore
default=ClimateSwingMode.OFF,
)
COMPONENT_TYPE_TO_INFO = {
'binary_sensor': BinarySensorInfo,
'cover': CoverInfo,
'fan': FanInfo,
'light': LightInfo,
'sensor': SensorInfo,
'switch': SwitchInfo,
'text_sensor': TextSensorInfo,
'camera': CameraInfo,
'climate': ClimateInfo,
"binary_sensor": BinarySensorInfo,
"cover": CoverInfo,
"fan": FanInfo,
"light": LightInfo,
"sensor": SensorInfo,
"switch": SwitchInfo,
"text_sensor": TextSensorInfo,
"camera": CameraInfo,
"climate": ClimateInfo,
}
# ==================== USER-DEFINED SERVICES ====================
def _convert_homeassistant_service_map(value):
def _convert_homeassistant_service_map(
value: Iterable["HomeassistantServiceMap"],
) -> Dict[str, str]:
return {v.key: v.value for v in value}
@attr.s
class HomeassistantServiceCall:
service = attr.ib(type=str, default='')
service = attr.ib(type=str, default="")
is_event = attr.ib(type=bool, default=False)
data = attr.ib(type=Dict[str, str], converter=_convert_homeassistant_service_map,
factory=dict)
data_template = attr.ib(type=Dict[str, str], converter=_convert_homeassistant_service_map,
factory=dict)
variables = attr.ib(type=Dict[str, str], converter=_convert_homeassistant_service_map,
factory=dict)
data = attr.ib(
type=Dict[str, str], converter=_convert_homeassistant_service_map, factory=dict
)
data_template = attr.ib(
type=Dict[str, str], converter=_convert_homeassistant_service_map, factory=dict
)
variables = attr.ib(
type=Dict[str, str], converter=_convert_homeassistant_service_map, factory=dict
)
class UserServiceArgType(APIIntEnum):
@ -342,37 +387,43 @@ class UserServiceArgType(APIIntEnum):
STRING_ARRAY = 7
def _attr_obj_from_dict(cls, **kwargs):
return cls(**{key: kwargs[key] for key in attr.fields_dict(cls)})
_K = TypeVar("_K")
def _attr_obj_from_dict(cls: Type[_K], **kwargs: Any) -> _K:
return cls(**{key: kwargs[key] for key in attr.fields_dict(cls)}) # type: ignore
@attr.s
class UserServiceArg:
name = attr.ib(type=str, default='')
type_ = attr.ib(type=Optional[UserServiceArgType], converter=UserServiceArgType.convert,
default=UserServiceArgType.BOOL)
name = attr.ib(type=str, default="")
type_ = attr.ib(
type=UserServiceArgType,
converter=UserServiceArgType.convert, # type: ignore
default=UserServiceArgType.BOOL,
)
@attr.s
class UserService:
name = attr.ib(type=str, default='')
name = attr.ib(type=str, default="")
key = attr.ib(type=int, default=0)
args = attr.ib(type=List[UserServiceArg], converter=list, factory=list)
@staticmethod
def from_dict(dict_):
@classmethod
def from_dict(cls, dict_: Dict[str, Any]) -> "UserService":
args = []
for arg in dict_.get('args', []):
for arg in dict_.get("args", []):
args.append(_attr_obj_from_dict(UserServiceArg, **arg))
return UserService(
name=dict_.get('name', ''),
key=dict_.get('key', 0),
args=args
return cls(
name=dict_.get("name", ""),
key=dict_.get("key", 0),
args=args, # type: ignore
)
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
'name': self.name,
'key': self.key,
'args': [attr.asdict(arg) for arg in self.args],
"name": self.name,
"key": self.key,
"args": [attr.asdict(arg) for arg in self.args],
}

View File

@ -1,7 +1,8 @@
import asyncio
import functools
import socket
from typing import Optional, Tuple, Any
from typing import Any, Optional, Tuple
import zeroconf
# pylint: disable=cyclic-import
@ -35,8 +36,9 @@ def _bytes_to_varuint(value: bytes) -> Optional[int]:
return None
async def resolve_ip_address_getaddrinfo(eventloop: asyncio.events.AbstractEventLoop,
host: str, port: int) -> Tuple[Any, ...]:
async def resolve_ip_address_getaddrinfo(
eventloop: asyncio.events.AbstractEventLoop, host: str, port: int
) -> Tuple[Any, ...]:
try:
socket.inet_aton(host)
@ -46,8 +48,9 @@ async def resolve_ip_address_getaddrinfo(eventloop: asyncio.events.AbstractEvent
return (host, port)
try:
res = await eventloop.getaddrinfo(host, port, family=socket.AF_INET,
proto=socket.IPPROTO_TCP)
res = await eventloop.getaddrinfo(
host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP
)
except OSError as err:
raise APIConnectionError("Error resolving IP address: {}".format(err))
@ -59,21 +62,25 @@ async def resolve_ip_address_getaddrinfo(eventloop: asyncio.events.AbstractEvent
return sockaddr
async def resolve_ip_address(eventloop: asyncio.events.AbstractEventLoop,
host: str, port: int,
zeroconf_instance: zeroconf.Zeroconf = None) -> Tuple[Any, ...]:
if host.endswith('.local'):
async def resolve_ip_address(
eventloop: asyncio.events.AbstractEventLoop,
host: str,
port: int,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None,
) -> Tuple[Any, ...]:
if host.endswith(".local"):
from aioesphomeapi.host_resolver import resolve_host
try:
return await eventloop.run_in_executor(
return (
await eventloop.run_in_executor(
None,
functools.partial(
resolve_host,
host,
zeroconf_instance=zeroconf_instance
resolve_host, host, zeroconf_instance=zeroconf_instance
),
),
port,
)
), port
except APIConnectionError:
pass
return await resolve_ip_address_getaddrinfo(eventloop, host, port)

View File

@ -1,7 +0,0 @@
#!/usr/bin/env bash
# Generate protobuf compiled files
protoc --python_out=aioesphomeapi -I aioesphomeapi aioesphomeapi/*.proto
# https://github.com/protocolbuffers/protobuf/issues/1491
sed -i 's/import api_options_pb2 as api__options__pb2/from . import api_options_pb2 as api__options__pb2/' aioesphomeapi/api_pb2.py

View File

@ -14,3 +14,4 @@ disable=
unused-wildcard-import,
import-outside-toplevel,
raise-missing-from,
duplicate-code,

4
pyproject.toml Normal file
View File

@ -0,0 +1,4 @@
[tool.isort]
profile = "black"
multi_line_output = 3
extend_skip = ["api_pb2.py", "api_options_pb2.py"]

View File

@ -1 +1,6 @@
pylint==2.8.3
black==21.6b0
flake8==3.9.2
isort==5.8.0
mypy==0.902
types-protobuf==0.1.13

26
script/gen-protoc Executable file
View File

@ -0,0 +1,26 @@
#!/usr/bin/env python3
from subprocess import check_call
from pathlib import Path
import os
root_dir = Path(__file__).absolute().parent.parent
os.chdir(root_dir)
check_call([
"protoc", "--python_out=aioesphomeapi", "-I", "aioesphomeapi",
"aioesphomeapi/api.proto", "aioesphomeapi/api_options.proto"
])
# https://github.com/protocolbuffers/protobuf/issues/1491
api_file = root_dir / 'aioesphomeapi' / 'api_pb2.py'
content = api_file.read_text().replace(
"import api_options_pb2 as api__options__pb2",
"from . import api_options_pb2 as api__options__pb2"
)
api_file.write_text(content)
for fname in ['api_pb2.py', 'api_options_pb2.py']:
file = root_dir / 'aioesphomeapi' / fname
content = '# type: ignore\n' + file.read_text()
file.write_text(content)

10
script/lint Executable file
View File

@ -0,0 +1,10 @@
#!/bin/bash
cd "$(dirname "$0")/.."
set -euxo pipefail
black --safe --exclude 'api_pb2.py|api_options_pb2.py' aioesphomeapi
pylint aioesphomeapi
flake8 aioesphomeapi
isort aioesphomeapi
mypy --strict aioesphomeapi

18
setup.cfg Normal file
View File

@ -0,0 +1,18 @@
[flake8]
max-line-length = 120
# Following 4 for black compatibility
# E501: line too long
# W503: Line break occurred before a binary operator
# E203: Whitespace before ':'
# D202 No blank lines allowed after function docstring
ignore =
E501,
W503,
E203,
D202,
exclude = api_pb2.py, api_options_pb2.py
[bdist_wheel]
universal = 1