Add flake8, black, isort and mypy linting (#39)

This commit is contained in:
Otto Winter 2021-06-18 17:57:02 +02:00 committed by GitHub
parent 41d0d335e9
commit 52cf01e11a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1677 additions and 985 deletions

View File

@ -18,5 +18,68 @@ jobs:
run: | run: |
pip3 install -e . pip3 install -e .
pip3 install -r requirements_test.txt pip3 install -r requirements_test.txt
- name: Register problem matcher
run: |
echo "::add-matcher::.github/workflows/matchers/pylint.json"
- run: pylint aioesphomeapi - run: pylint aioesphomeapi
flake8:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- name: Register problem matcher
run: |
echo "::add-matcher::.github/workflows/matchers/flake8.json"
- run: flake8 aioesphomeapi
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- run: black --safe --exclude 'api_pb2.py|api_options_pb2.py' aioesphomeapi
isort:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- name: Register problem matcher
run: |
echo "::add-matcher::.github/workflows/matchers/isort.json"
- run: isort --check aioesphomeapi
mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- name: Register problem matcher
run: |
echo "::add-matcher::.github/workflows/matchers/mypy.json"
- run: mypy --strict aioesphomeapi

30
.github/workflows/matchers/flake8.json vendored Normal file
View File

@ -0,0 +1,30 @@
{
"problemMatcher": [
{
"owner": "flake8-error",
"severity": "error",
"pattern": [
{
"regexp": "^(.*):(\\d+):(\\d+):\\s([EF]\\d{3}\\s.*)$",
"file": 1,
"line": 2,
"column": 3,
"message": 4
}
]
},
{
"owner": "flake8-warning",
"severity": "warning",
"pattern": [
{
"regexp": "^(.*):(\\d+):(\\d+):\\s([CDNW]\\d{3}\\s.*)$",
"file": 1,
"line": 2,
"column": 3,
"message": 4
}
]
}
]
}

14
.github/workflows/matchers/isort.json vendored Normal file
View File

@ -0,0 +1,14 @@
{
"problemMatcher": [
{
"owner": "isort",
"pattern": [
{
"regexp": "^ERROR:\\s+(.+)\\s+(.+)$",
"file": 1,
"message": 2
}
]
}
]
}

16
.github/workflows/matchers/mypy.json vendored Normal file
View File

@ -0,0 +1,16 @@
{
"problemMatcher": [
{
"owner": "mypy",
"pattern": [
{
"regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$",
"file": 1,
"line": 2,
"severity": 3,
"message": 4
}
]
}
]
}

32
.github/workflows/matchers/pylint.json vendored Normal file
View File

@ -0,0 +1,32 @@
{
"problemMatcher": [
{
"owner": "pylint-error",
"severity": "error",
"pattern": [
{
"regexp": "^(.+):(\\d+):(\\d+):\\s(([EF]\\d{4}):\\s.+)$",
"file": 1,
"line": 2,
"column": 3,
"message": 4,
"code": 5
}
]
},
{
"owner": "pylint-warning",
"severity": "warning",
"pattern": [
{
"regexp": "^(.+):(\\d+):(\\d+):\\s(([CRW]\\d{4}):\\s.+)$",
"file": 1,
"line": 2,
"column": 3,
"message": 4,
"code": 5
}
]
}
]
}

37
.github/workflows/protoc-update.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Update protobuf generated files
on:
push:
branches: [master]
jobs:
protoc-update:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
- name: Install protoc
run: |
sudo apt-get install protobuf-compiler
- name: Set up Python environment
run: |
pip3 install -e .
pip3 install -r requirements_test.txt
- name: Generate protoc
run: |
script/gen-protoc
# github actions email from here: https://github.community/t/github-actions-bot-email-address/17204
- name: Commit changes
run: |
if git diff-index --quiet HEAD --; then
echo "No changes detected, protobuf files are up to date!"
else
git config --global user.name "github-actions[bot]"
git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
git commit -am "Update protobuf files"
git push
fi

View File

@ -1,5 +1,6 @@
# flake8: noqa
from .client import APIClient from .client import APIClient
from .connection import ConnectionParams, APIConnection from .connection import APIConnection, ConnectionParams
from .core import APIConnectionError, MESSAGE_TYPE_TO_PROTO from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .model import * from .model import *
from .util import resolve_ip_address_getaddrinfo, resolve_ip_address from .util import resolve_ip_address, resolve_ip_address_getaddrinfo

View File

@ -1,8 +1,8 @@
# type: ignore
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: api_options.proto # source: api_options.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
@ -21,7 +21,8 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='', package='',
syntax='proto2', syntax='proto2',
serialized_options=None, serialized_options=None,
serialized_pb=_b('\n\x11\x61pi_options.proto\x1a google/protobuf/descriptor.proto\"\x06\n\x04void*F\n\rAPISourceType\x12\x0f\n\x0bSOURCE_BOTH\x10\x00\x12\x11\n\rSOURCE_SERVER\x10\x01\x12\x11\n\rSOURCE_CLIENT\x10\x02:E\n\x16needs_setup_connection\x12\x1e.google.protobuf.MethodOptions\x18\x8e\x08 \x01(\x08:\x04true:C\n\x14needs_authentication\x12\x1e.google.protobuf.MethodOptions\x18\x8f\x08 \x01(\x08:\x04true:/\n\x02id\x12\x1f.google.protobuf.MessageOptions\x18\x8c\x08 \x01(\r:\x01\x30:M\n\x06source\x12\x1f.google.protobuf.MessageOptions\x18\x8d\x08 \x01(\x0e\x32\x0e.APISourceType:\x0bSOURCE_BOTH:/\n\x05ifdef\x12\x1f.google.protobuf.MessageOptions\x18\x8e\x08 \x01(\t:3\n\x03log\x12\x1f.google.protobuf.MessageOptions\x18\x8f\x08 \x01(\x08:\x04true:9\n\x08no_delay\x12\x1f.google.protobuf.MessageOptions\x18\x90\x08 \x01(\x08:\x05\x66\x61lse') create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x11\x61pi_options.proto\x1a google/protobuf/descriptor.proto\"\x06\n\x04void*F\n\rAPISourceType\x12\x0f\n\x0bSOURCE_BOTH\x10\x00\x12\x11\n\rSOURCE_SERVER\x10\x01\x12\x11\n\rSOURCE_CLIENT\x10\x02:E\n\x16needs_setup_connection\x12\x1e.google.protobuf.MethodOptions\x18\x8e\x08 \x01(\x08:\x04true:C\n\x14needs_authentication\x12\x1e.google.protobuf.MethodOptions\x18\x8f\x08 \x01(\x08:\x04true:/\n\x02id\x12\x1f.google.protobuf.MessageOptions\x18\x8c\x08 \x01(\r:\x01\x30:M\n\x06source\x12\x1f.google.protobuf.MessageOptions\x18\x8d\x08 \x01(\x0e\x32\x0e.APISourceType:\x0bSOURCE_BOTH:/\n\x05ifdef\x12\x1f.google.protobuf.MessageOptions\x18\x8e\x08 \x01(\t:3\n\x03log\x12\x1f.google.protobuf.MessageOptions\x18\x8f\x08 \x01(\x08:\x04true:9\n\x08no_delay\x12\x1f.google.protobuf.MessageOptions\x18\x90\x08 \x01(\x08:\x05\x66\x61lse'
, ,
dependencies=[google_dot_protobuf_dot_descriptor__pb2.DESCRIPTOR,]) dependencies=[google_dot_protobuf_dot_descriptor__pb2.DESCRIPTOR,])
@ -30,19 +31,23 @@ _APISOURCETYPE = _descriptor.EnumDescriptor(
full_name='APISourceType', full_name='APISourceType',
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
values=[ values=[
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='SOURCE_BOTH', index=0, number=0, name='SOURCE_BOTH', index=0, number=0,
serialized_options=None, serialized_options=None,
type=None), type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='SOURCE_SERVER', index=1, number=1, name='SOURCE_SERVER', index=1, number=1,
serialized_options=None, serialized_options=None,
type=None), type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='SOURCE_CLIENT', index=2, number=2, name='SOURCE_CLIENT', index=2, number=2,
serialized_options=None, serialized_options=None,
type=None), type=None,
create_key=_descriptor._internal_create_key),
], ],
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
@ -63,7 +68,7 @@ needs_setup_connection = _descriptor.FieldDescriptor(
has_default_value=True, default_value=True, has_default_value=True, default_value=True,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None, is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR) serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
NEEDS_AUTHENTICATION_FIELD_NUMBER = 1039 NEEDS_AUTHENTICATION_FIELD_NUMBER = 1039
needs_authentication = _descriptor.FieldDescriptor( needs_authentication = _descriptor.FieldDescriptor(
name='needs_authentication', full_name='needs_authentication', index=1, name='needs_authentication', full_name='needs_authentication', index=1,
@ -71,7 +76,7 @@ needs_authentication = _descriptor.FieldDescriptor(
has_default_value=True, default_value=True, has_default_value=True, default_value=True,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None, is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR) serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
ID_FIELD_NUMBER = 1036 ID_FIELD_NUMBER = 1036
id = _descriptor.FieldDescriptor( id = _descriptor.FieldDescriptor(
name='id', full_name='id', index=2, name='id', full_name='id', index=2,
@ -79,7 +84,7 @@ id = _descriptor.FieldDescriptor(
has_default_value=True, default_value=0, has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None, is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR) serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
SOURCE_FIELD_NUMBER = 1037 SOURCE_FIELD_NUMBER = 1037
source = _descriptor.FieldDescriptor( source = _descriptor.FieldDescriptor(
name='source', full_name='source', index=3, name='source', full_name='source', index=3,
@ -87,15 +92,15 @@ source = _descriptor.FieldDescriptor(
has_default_value=True, default_value=0, has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None, is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR) serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
IFDEF_FIELD_NUMBER = 1038 IFDEF_FIELD_NUMBER = 1038
ifdef = _descriptor.FieldDescriptor( ifdef = _descriptor.FieldDescriptor(
name='ifdef', full_name='ifdef', index=4, name='ifdef', full_name='ifdef', index=4,
number=1038, type=9, cpp_type=9, label=1, number=1038, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None, is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR) serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
LOG_FIELD_NUMBER = 1039 LOG_FIELD_NUMBER = 1039
log = _descriptor.FieldDescriptor( log = _descriptor.FieldDescriptor(
name='log', full_name='log', index=5, name='log', full_name='log', index=5,
@ -103,7 +108,7 @@ log = _descriptor.FieldDescriptor(
has_default_value=True, default_value=True, has_default_value=True, default_value=True,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None, is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR) serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
NO_DELAY_FIELD_NUMBER = 1040 NO_DELAY_FIELD_NUMBER = 1040
no_delay = _descriptor.FieldDescriptor( no_delay = _descriptor.FieldDescriptor(
name='no_delay', full_name='no_delay', index=6, name='no_delay', full_name='no_delay', index=6,
@ -111,7 +116,7 @@ no_delay = _descriptor.FieldDescriptor(
has_default_value=True, default_value=False, has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=True, extension_scope=None, is_extension=True, extension_scope=None,
serialized_options=None, file=DESCRIPTOR) serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key)
_VOID = _descriptor.Descriptor( _VOID = _descriptor.Descriptor(
@ -120,6 +125,7 @@ _VOID = _descriptor.Descriptor(
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[ fields=[
], ],
extensions=[ extensions=[
@ -148,11 +154,11 @@ DESCRIPTOR.extensions_by_name['log'] = log
DESCRIPTOR.extensions_by_name['no_delay'] = no_delay DESCRIPTOR.extensions_by_name['no_delay'] = no_delay
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
void = _reflection.GeneratedProtocolMessageType('void', (_message.Message,), dict( void = _reflection.GeneratedProtocolMessageType('void', (_message.Message,), {
DESCRIPTOR = _VOID, 'DESCRIPTOR' : _VOID,
__module__ = 'api_options_pb2' '__module__' : 'api_options_pb2'
# @@protoc_insertion_point(class_scope:void) # @@protoc_insertion_point(class_scope:void)
)) })
_sym_db.RegisterMessage(void) _sym_db.RegisterMessage(void)
google_dot_protobuf_dot_descriptor__pb2.MethodOptions.RegisterExtension(needs_setup_connection) google_dot_protobuf_dot_descriptor__pb2.MethodOptions.RegisterExtension(needs_setup_connection)

File diff suppressed because one or more lines are too long

View File

@ -1,19 +1,108 @@
import asyncio
import logging import logging
from typing import Any, Callable, Optional, Tuple from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union, cast
import zeroconf
import aioesphomeapi.api_pb2 as pb import attr
import zeroconf
from google.protobuf import message
from aioesphomeapi.api_pb2 import ( # type: ignore
BinarySensorStateResponse,
CameraImageRequest,
CameraImageResponse,
ClimateCommandRequest,
ClimateStateResponse,
CoverCommandRequest,
CoverStateResponse,
DeviceInfoRequest,
DeviceInfoResponse,
ExecuteServiceArgument,
ExecuteServiceRequest,
FanCommandRequest,
FanStateResponse,
HomeassistantServiceResponse,
HomeAssistantStateResponse,
LightCommandRequest,
LightStateResponse,
ListEntitiesBinarySensorResponse,
ListEntitiesCameraResponse,
ListEntitiesClimateResponse,
ListEntitiesCoverResponse,
ListEntitiesDoneResponse,
ListEntitiesFanResponse,
ListEntitiesLightResponse,
ListEntitiesRequest,
ListEntitiesSensorResponse,
ListEntitiesServicesResponse,
ListEntitiesSwitchResponse,
ListEntitiesTextSensorResponse,
LogLevel,
SensorStateResponse,
SubscribeHomeassistantServicesRequest,
SubscribeHomeAssistantStateResponse,
SubscribeHomeAssistantStatesRequest,
SubscribeLogsRequest,
SubscribeLogsResponse,
SubscribeStatesRequest,
SwitchCommandRequest,
SwitchStateResponse,
TextSensorStateResponse,
)
from aioesphomeapi.connection import APIConnection, ConnectionParams from aioesphomeapi.connection import APIConnection, ConnectionParams
from aioesphomeapi.core import APIConnectionError from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.model import * from aioesphomeapi.model import (
APIVersion,
BinarySensorInfo,
BinarySensorState,
CameraInfo,
CameraState,
ClimateFanMode,
ClimateInfo,
ClimateMode,
ClimateState,
ClimateSwingMode,
CoverInfo,
CoverState,
DeviceInfo,
EntityInfo,
FanDirection,
FanInfo,
FanSpeed,
FanState,
HomeassistantServiceCall,
LegacyCoverCommand,
LightInfo,
LightState,
SensorInfo,
SensorState,
SwitchInfo,
SwitchState,
TextSensorInfo,
TextSensorState,
UserService,
UserServiceArg,
UserServiceArgType,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ExecuteServiceDataType = Dict[
str, Union[bool, int, float, str, List[bool], List[int], List[float], List[str]]
]
class APIClient: class APIClient:
def __init__(self, eventloop, address: str, port: int, password: str, *, def __init__(
client_info: str = 'aioesphomeapi', keepalive: float = 15.0, self,
zeroconf_instance: zeroconf.Zeroconf = None): eventloop: asyncio.AbstractEventLoop,
address: str,
port: int,
password: str,
*,
client_info: str = "aioesphomeapi",
keepalive: float = 15.0,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None
):
self._params = ConnectionParams( self._params = ConnectionParams(
eventloop=eventloop, eventloop=eventloop,
address=address, address=address,
@ -21,18 +110,22 @@ class APIClient:
password=password, password=password,
client_info=client_info, client_info=client_info,
keepalive=keepalive, keepalive=keepalive,
zeroconf_instance=zeroconf_instance zeroconf_instance=zeroconf_instance,
) )
self._connection = None # type: Optional[APIConnection] self._connection = None # type: Optional[APIConnection]
async def connect(self, on_stop=None, login=False): async def connect(
self,
on_stop: Optional[Callable[[], Awaitable[None]]] = None,
login: bool = False,
) -> None:
if self._connection is not None: if self._connection is not None:
raise APIConnectionError("Already connected!") raise APIConnectionError("Already connected!")
connected = False connected = False
stopped = False stopped = False
async def _on_stop(): async def _on_stop() -> None:
nonlocal stopped nonlocal stopped
if stopped: if stopped:
@ -53,31 +146,33 @@ class APIClient:
raise raise
except Exception as e: except Exception as e:
await _on_stop() await _on_stop()
raise APIConnectionError( raise APIConnectionError("Unexpected error while connecting: {}".format(e))
"Unexpected error while connecting: {}".format(e))
connected = True connected = True
async def disconnect(self, force=False): async def disconnect(self, force: bool = False) -> None:
if self._connection is None: if self._connection is None:
return return
await self._connection.stop(force=force) await self._connection.stop(force=force)
def _check_connected(self): def _check_connected(self) -> None:
if self._connection is None: if self._connection is None:
raise APIConnectionError("Not connected!") raise APIConnectionError("Not connected!")
if not self._connection.is_connected: if not self._connection.is_connected:
raise APIConnectionError("Connection not done!") raise APIConnectionError("Connection not done!")
def _check_authenticated(self): def _check_authenticated(self) -> None:
self._check_connected() self._check_connected()
assert self._connection is not None
if not self._connection.is_authenticated: if not self._connection.is_authenticated:
raise APIConnectionError("Not authenticated!") raise APIConnectionError("Not authenticated!")
async def device_info(self) -> DeviceInfo: async def device_info(self) -> DeviceInfo:
self._check_connected() self._check_connected()
assert self._connection is not None
resp = await self._connection.send_message_await_response( resp = await self._connection.send_message_await_response(
pb.DeviceInfoRequest(), pb.DeviceInfoResponse) DeviceInfoRequest(), DeviceInfoResponse
)
return DeviceInfo( return DeviceInfo(
uses_password=resp.uses_password, uses_password=resp.uses_password,
name=resp.name, name=resp.name,
@ -88,49 +183,60 @@ class APIClient:
has_deep_sleep=resp.has_deep_sleep, has_deep_sleep=resp.has_deep_sleep,
) )
async def list_entities_services(self) -> Tuple[List[Any], List[UserService]]: async def list_entities_services(
self,
) -> Tuple[List[EntityInfo], List[UserService]]:
self._check_authenticated() self._check_authenticated()
response_types = { response_types = {
pb.ListEntitiesBinarySensorResponse: BinarySensorInfo, ListEntitiesBinarySensorResponse: BinarySensorInfo,
pb.ListEntitiesCoverResponse: CoverInfo, ListEntitiesCoverResponse: CoverInfo,
pb.ListEntitiesFanResponse: FanInfo, ListEntitiesFanResponse: FanInfo,
pb.ListEntitiesLightResponse: LightInfo, ListEntitiesLightResponse: LightInfo,
pb.ListEntitiesSensorResponse: SensorInfo, ListEntitiesSensorResponse: SensorInfo,
pb.ListEntitiesSwitchResponse: SwitchInfo, ListEntitiesSwitchResponse: SwitchInfo,
pb.ListEntitiesTextSensorResponse: TextSensorInfo, ListEntitiesTextSensorResponse: TextSensorInfo,
pb.ListEntitiesServicesResponse: None, ListEntitiesServicesResponse: None,
pb.ListEntitiesCameraResponse: CameraInfo, ListEntitiesCameraResponse: CameraInfo,
pb.ListEntitiesClimateResponse: ClimateInfo, ListEntitiesClimateResponse: ClimateInfo,
} }
def do_append(msg): def do_append(msg: message.Message) -> bool:
return isinstance(msg, tuple(response_types.keys())) return isinstance(msg, tuple(response_types.keys()))
def do_stop(msg): def do_stop(msg: message.Message) -> bool:
return isinstance(msg, pb.ListEntitiesDoneResponse) return isinstance(msg, ListEntitiesDoneResponse)
assert self._connection is not None
resp = await self._connection.send_message_await_response_complex( resp = await self._connection.send_message_await_response_complex(
pb.ListEntitiesRequest(), do_append, do_stop, timeout=5) ListEntitiesRequest(), do_append, do_stop, timeout=5
entities = [] )
services = [] entities: List[EntityInfo] = []
services: List[UserService] = []
for msg in resp: for msg in resp:
if isinstance(msg, pb.ListEntitiesServicesResponse): if isinstance(msg, ListEntitiesServicesResponse):
args = [] args = []
for arg in msg.args: for arg in msg.args:
args.append(UserServiceArg( args.append(
name=arg.name, UserServiceArg(
type_=arg.type, name=arg.name,
)) type_=arg.type,
services.append(UserService( )
name=msg.name, )
key=msg.key, services.append(
args=args, UserService(
)) name=msg.name,
key=msg.key,
args=args, # type: ignore
)
)
continue continue
cls = None cls = None
for resp_type, cls in response_types.items(): for resp_type, cls in response_types.items():
if isinstance(msg, resp_type): if isinstance(msg, resp_type):
break break
else:
continue
cls = cast(type, cls)
kwargs = {} kwargs = {}
for key, _ in attr.fields_dict(cls).items(): for key, _ in attr.fields_dict(cls).items():
kwargs[key] = getattr(msg, key) kwargs[key] = getattr(msg, key)
@ -141,20 +247,20 @@ class APIClient:
self._check_authenticated() self._check_authenticated()
response_types = { response_types = {
pb.BinarySensorStateResponse: BinarySensorState, BinarySensorStateResponse: BinarySensorState,
pb.CoverStateResponse: CoverState, CoverStateResponse: CoverState,
pb.FanStateResponse: FanState, FanStateResponse: FanState,
pb.LightStateResponse: LightState, LightStateResponse: LightState,
pb.SensorStateResponse: SensorState, SensorStateResponse: SensorState,
pb.SwitchStateResponse: SwitchState, SwitchStateResponse: SwitchState,
pb.TextSensorStateResponse: TextSensorState, TextSensorStateResponse: TextSensorState,
pb.ClimateStateResponse: ClimateState, ClimateStateResponse: ClimateState,
} }
image_stream = {} image_stream: Dict[int, bytes] = {}
def on_msg(msg): def on_msg(msg: message.Message) -> None:
if isinstance(msg, pb.CameraImageResponse): if isinstance(msg, CameraImageResponse):
data = image_stream.pop(msg.key, bytes()) + msg.data data = image_stream.pop(msg.key, bytes()) + msg.data
if msg.done: if msg.done:
on_state(CameraState(key=msg.key, image=data)) on_state(CameraState(key=msg.key, image=data))
@ -174,33 +280,43 @@ class APIClient:
kwargs[key] = getattr(msg, key) kwargs[key] = getattr(msg, key)
on_state(cls(**kwargs)) on_state(cls(**kwargs))
await self._connection.send_message_callback_response(pb.SubscribeStatesRequest(), on_msg) assert self._connection is not None
await self._connection.send_message_callback_response(
SubscribeStatesRequest(), on_msg
)
async def subscribe_logs(self, on_log: Callable[[pb.SubscribeLogsResponse], None], async def subscribe_logs(
log_level=None) -> None: self,
on_log: Callable[[SubscribeLogsResponse], None],
log_level: Optional[LogLevel] = None,
) -> None:
self._check_authenticated() self._check_authenticated()
def on_msg(msg): def on_msg(msg: message.Message) -> None:
if isinstance(msg, pb.SubscribeLogsResponse): if isinstance(msg, SubscribeLogsResponse):
on_log(msg) on_log(msg)
req = pb.SubscribeLogsRequest() req = SubscribeLogsRequest()
if log_level is not None: if log_level is not None:
req.level = log_level req.level = log_level
assert self._connection is not None
await self._connection.send_message_callback_response(req, on_msg) await self._connection.send_message_callback_response(req, on_msg)
async def subscribe_service_calls(self, on_service_call: Callable[[HomeassistantServiceCall], None]) -> None: async def subscribe_service_calls(
self, on_service_call: Callable[[HomeassistantServiceCall], None]
) -> None:
self._check_authenticated() self._check_authenticated()
def on_msg(msg): def on_msg(msg: message.Message) -> None:
if isinstance(msg, pb.HomeassistantServiceResponse): if isinstance(msg, HomeassistantServiceResponse):
kwargs = {} kwargs = {}
for key, _ in attr.fields_dict(HomeassistantServiceCall).items(): for key, _ in attr.fields_dict(HomeassistantServiceCall).items():
kwargs[key] = getattr(msg, key) kwargs[key] = getattr(msg, key)
on_service_call(HomeassistantServiceCall(**kwargs)) on_service_call(HomeassistantServiceCall(**kwargs))
assert self._connection is not None
await self._connection.send_message_callback_response( await self._connection.send_message_callback_response(
pb.SubscribeHomeassistantServicesRequest(), on_msg SubscribeHomeassistantServicesRequest(), on_msg
) )
async def subscribe_home_assistant_states( async def subscribe_home_assistant_states(
@ -208,12 +324,13 @@ class APIClient:
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
def on_msg(msg): def on_msg(msg: message.Message) -> None:
if isinstance(msg, pb.SubscribeHomeAssistantStateResponse): if isinstance(msg, SubscribeHomeAssistantStateResponse):
on_state_sub(msg.entity_id, msg.attribute) on_state_sub(msg.entity_id, msg.attribute)
assert self._connection is not None
await self._connection.send_message_callback_response( await self._connection.send_message_callback_response(
pb.SubscribeHomeAssistantStatesRequest(), on_msg SubscribeHomeAssistantStatesRequest(), on_msg
) )
async def send_home_assistant_state( async def send_home_assistant_state(
@ -221,25 +338,28 @@ class APIClient:
) -> None: ) -> None:
self._check_authenticated() self._check_authenticated()
assert self._connection is not None
await self._connection.send_message( await self._connection.send_message(
pb.HomeAssistantStateResponse( HomeAssistantStateResponse(
entity_id=entity_id, entity_id=entity_id,
state=state, state=state,
attribute=attribute, attribute=attribute,
) )
) )
async def cover_command(self, async def cover_command(
key: int, self,
position: Optional[float] = None, key: int,
tilt: Optional[float] = None, position: Optional[float] = None,
stop: bool = False, tilt: Optional[float] = None,
) -> None: stop: bool = False,
) -> None:
self._check_authenticated() self._check_authenticated()
req = pb.CoverCommandRequest() req = CoverCommandRequest()
req.key = key req.key = key
if self.api_version >= APIVersion(1, 1): apiv = cast(APIVersion, self.api_version)
if apiv >= APIVersion(1, 1):
if position is not None: if position is not None:
req.has_position = True req.has_position = True
req.position = position req.position = position
@ -256,19 +376,21 @@ class APIClient:
req.legacy_command = LegacyCoverCommand.OPEN req.legacy_command = LegacyCoverCommand.OPEN
else: else:
req.legacy_command = LegacyCoverCommand.CLOSE req.legacy_command = LegacyCoverCommand.CLOSE
assert self._connection is not None
await self._connection.send_message(req) await self._connection.send_message(req)
async def fan_command(self, async def fan_command(
key: int, self,
state: Optional[bool] = None, key: int,
speed: Optional[FanSpeed] = None, state: Optional[bool] = None,
speed_level: Optional[int] = None, speed: Optional[FanSpeed] = None,
oscillating: Optional[bool] = None, speed_level: Optional[int] = None,
direction: Optional[FanDirection] = None oscillating: Optional[bool] = None,
) -> None: direction: Optional[FanDirection] = None,
) -> None:
self._check_authenticated() self._check_authenticated()
req = pb.FanCommandRequest() req = FanCommandRequest()
req.key = key req.key = key
if state is not None: if state is not None:
req.has_state = True req.has_state = True
@ -285,22 +407,24 @@ class APIClient:
if direction is not None: if direction is not None:
req.has_direction = True req.has_direction = True
req.direction = direction req.direction = direction
assert self._connection is not None
await self._connection.send_message(req) await self._connection.send_message(req)
async def light_command(self, async def light_command(
key: int, self,
state: Optional[bool] = None, key: int,
brightness: Optional[float] = None, state: Optional[bool] = None,
rgb: Optional[Tuple[float, float, float]] = None, brightness: Optional[float] = None,
white: Optional[float] = None, rgb: Optional[Tuple[float, float, float]] = None,
color_temperature: Optional[float] = None, white: Optional[float] = None,
transition_length: Optional[float] = None, color_temperature: Optional[float] = None,
flash_length: Optional[float] = None, transition_length: Optional[float] = None,
effect: Optional[str] = None, flash_length: Optional[float] = None,
): effect: Optional[str] = None,
) -> None:
self._check_authenticated() self._check_authenticated()
req = pb.LightCommandRequest() req = LightCommandRequest()
req.key = key req.key = key
if state is not None: if state is not None:
req.has_state = True req.has_state = True
@ -328,32 +452,32 @@ class APIClient:
if effect is not None: if effect is not None:
req.has_effect = True req.has_effect = True
req.effect = effect req.effect = effect
assert self._connection is not None
await self._connection.send_message(req) await self._connection.send_message(req)
async def switch_command(self, async def switch_command(self, key: int, state: bool) -> None:
key: int,
state: bool
) -> None:
self._check_authenticated() self._check_authenticated()
req = pb.SwitchCommandRequest() req = SwitchCommandRequest()
req.key = key req.key = key
req.state = state req.state = state
assert self._connection is not None
await self._connection.send_message(req) await self._connection.send_message(req)
async def climate_command(self, async def climate_command(
key: int, self,
mode: Optional[ClimateMode] = None, key: int,
target_temperature: Optional[float] = None, mode: Optional[ClimateMode] = None,
target_temperature_low: Optional[float] = None, target_temperature: Optional[float] = None,
target_temperature_high: Optional[float] = None, target_temperature_low: Optional[float] = None,
away: Optional[bool] = None, target_temperature_high: Optional[float] = None,
fan_mode: Optional[ClimateFanMode] = None, away: Optional[bool] = None,
swing_mode: Optional[ClimateSwingMode] = None, fan_mode: Optional[ClimateFanMode] = None,
) -> None: swing_mode: Optional[ClimateSwingMode] = None,
) -> None:
self._check_authenticated() self._check_authenticated()
req = pb.ClimateCommandRequest() req = ClimateCommandRequest()
req.key = key req.key = key
if mode is not None: if mode is not None:
req.has_mode = True req.has_mode = True
@ -376,30 +500,33 @@ class APIClient:
if swing_mode is not None: if swing_mode is not None:
req.has_swing_mode = True req.has_swing_mode = True
req.swing_mode = swing_mode req.swing_mode = swing_mode
assert self._connection is not None
await self._connection.send_message(req) await self._connection.send_message(req)
async def execute_service(self, service: UserService, data: dict): async def execute_service(
self, service: UserService, data: ExecuteServiceDataType
) -> None:
self._check_authenticated() self._check_authenticated()
req = pb.ExecuteServiceRequest() req = ExecuteServiceRequest()
req.key = service.key req.key = service.key
args = [] args = []
for arg_desc in service.args: for arg_desc in service.args:
arg = pb.ExecuteServiceArgument() arg = ExecuteServiceArgument()
val = data[arg_desc.name] val = data[arg_desc.name]
int_type = 'int_' if self.api_version >= APIVersion( apiv = cast(APIVersion, self.api_version)
1, 3) else 'legacy_int' int_type = "int_" if apiv >= APIVersion(1, 3) else "legacy_int"
map_single = { map_single = {
UserServiceArgType.BOOL: 'bool_', UserServiceArgType.BOOL: "bool_",
UserServiceArgType.INT: int_type, UserServiceArgType.INT: int_type,
UserServiceArgType.FLOAT: 'float_', UserServiceArgType.FLOAT: "float_",
UserServiceArgType.STRING: 'string_', UserServiceArgType.STRING: "string_",
} }
map_array = { map_array = {
UserServiceArgType.BOOL_ARRAY: 'bool_array', UserServiceArgType.BOOL_ARRAY: "bool_array",
UserServiceArgType.INT_ARRAY: 'int_array', UserServiceArgType.INT_ARRAY: "int_array",
UserServiceArgType.FLOAT_ARRAY: 'float_array', UserServiceArgType.FLOAT_ARRAY: "float_array",
UserServiceArgType.STRING_ARRAY: 'string_array', UserServiceArgType.STRING_ARRAY: "string_array",
} }
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
if arg_desc.type_ in map_array: if arg_desc.type_ in map_array:
@ -411,18 +538,22 @@ class APIClient:
args.append(arg) args.append(arg)
# pylint: disable=no-member # pylint: disable=no-member
req.args.extend(args) req.args.extend(args)
assert self._connection is not None
await self._connection.send_message(req) await self._connection.send_message(req)
async def _request_image(self, *, single=False, stream=False): async def _request_image(
req = pb.CameraImageRequest() self, *, single: bool = False, stream: bool = False
) -> None:
req = CameraImageRequest()
req.single = single req.single = single
req.stream = stream req.stream = stream
assert self._connection is not None
await self._connection.send_message(req) await self._connection.send_message(req)
async def request_single_image(self): async def request_single_image(self) -> None:
await self._request_image(single=True) await self._request_image(single=True)
async def request_image_stream(self): async def request_image_stream(self) -> None:
await self._request_image(stream=True) await self._request_image(stream=True)
@property @property

View File

@ -2,14 +2,25 @@ import asyncio
import logging import logging
import socket import socket
import time import time
from typing import Any, Callable, List, Optional, cast from typing import Any, Awaitable, Callable, List, Optional, cast
import attr import attr
import zeroconf import zeroconf
from google.protobuf import message from google.protobuf import message
import aioesphomeapi.api_pb2 as pb from aioesphomeapi.api_pb2 import ( # type: ignore
from aioesphomeapi.core import APIConnectionError, MESSAGE_TYPE_TO_PROTO ConnectRequest,
ConnectResponse,
DisconnectRequest,
DisconnectResponse,
GetTimeRequest,
GetTimeResponse,
HelloRequest,
HelloResponse,
PingRequest,
PingResponse,
)
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from aioesphomeapi.model import APIVersion from aioesphomeapi.model import APIVersion
from aioesphomeapi.util import _bytes_to_varuint, _varuint_to_bytes, resolve_ip_address from aioesphomeapi.util import _bytes_to_varuint, _varuint_to_bytes, resolve_ip_address
@ -24,26 +35,27 @@ class ConnectionParams:
password = attr.ib(type=Optional[str]) password = attr.ib(type=Optional[str])
client_info = attr.ib(type=str) client_info = attr.ib(type=str)
keepalive = attr.ib(type=float) keepalive = attr.ib(type=float)
zeroconf_instance = attr.ib(type=zeroconf.Zeroconf) zeroconf_instance = attr.ib(type=Optional[zeroconf.Zeroconf])
class APIConnection: class APIConnection:
def __init__(self, params: ConnectionParams, on_stop): def __init__(
self, params: ConnectionParams, on_stop: Callable[[], Awaitable[None]]
):
self._params = params self._params = params
self.on_stop = on_stop self.on_stop = on_stop
self._stopped = False self._stopped = False
self._socket = None # type: Optional[socket.socket] self._socket: Optional[socket.socket] = None
self._socket_reader = None # type: Optional[asyncio.StreamReader] self._socket_reader: Optional[asyncio.StreamReader] = None
self._socket_writer = None # type: Optional[asyncio.StreamWriter] self._socket_writer: Optional[asyncio.StreamWriter] = None
self._write_lock = asyncio.Lock() self._write_lock = asyncio.Lock()
self._connected = False self._connected = False
self._authenticated = False self._authenticated = False
self._socket_connected = False self._socket_connected = False
self._state_lock = asyncio.Lock() self._state_lock = asyncio.Lock()
self._api_version = None # type: Optional[APIVersion] self._api_version: Optional[APIVersion] = None
self._message_handlers = [] # type: List[Callable[[message], None]] self._message_handlers: List[Callable[[message.Message], None]] = []
self._running_task = None # type: Optional[asyncio.Task]
def _start_ping(self) -> None: def _start_ping(self) -> None:
async def func() -> None: async def func() -> None:
@ -66,7 +78,8 @@ class APIConnection:
if not self._socket_connected: if not self._socket_connected:
return return
async with self._write_lock: async with self._write_lock:
self._socket_writer.close() if self._socket_writer is not None:
self._socket_writer.close()
self._socket_writer = None self._socket_writer = None
self._socket_reader = None self._socket_reader = None
if self._socket is not None: if self._socket is not None:
@ -85,8 +98,6 @@ class APIConnection:
except APIConnectionError: except APIConnectionError:
pass pass
self._stopped = True self._stopped = True
if self._running_task is not None:
self._running_task.cancel()
await self._close_socket() await self._close_socket()
await self.on_stop() await self.on_stop()
@ -100,8 +111,12 @@ class APIConnection:
raise APIConnectionError("Already connected!") raise APIConnectionError("Already connected!")
try: try:
coro = resolve_ip_address(self._params.eventloop, self._params.address, coro = resolve_ip_address(
self._params.port, self._params.zeroconf_instance) self._params.eventloop,
self._params.address,
self._params.port,
self._params.zeroconf_instance,
)
sockaddr = await asyncio.wait_for(coro, 30.0) sockaddr = await asyncio.wait_for(coro, 30.0)
except APIConnectionError as err: except APIConnectionError as err:
await self._on_error() await self._on_error()
@ -114,40 +129,51 @@ class APIConnection:
self._socket.setblocking(False) self._socket.setblocking(False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
_LOGGER.debug("%s: Connecting to %s:%s (%s)", self._params.address, _LOGGER.debug(
self._params.address, self._params.port, sockaddr) "%s: Connecting to %s:%s (%s)",
self._params.address,
self._params.address,
self._params.port,
sockaddr,
)
try: try:
coro = self._params.eventloop.sock_connect(self._socket, sockaddr) coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
await asyncio.wait_for(coro, 30.0) await asyncio.wait_for(coro2, 30.0)
except OSError as err: except OSError as err:
await self._on_error() await self._on_error()
raise APIConnectionError( raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err))
"Error connecting to {}: {}".format(sockaddr, err))
except asyncio.TimeoutError: except asyncio.TimeoutError:
await self._on_error() await self._on_error()
raise APIConnectionError( raise APIConnectionError("Timeout while connecting to {}".format(sockaddr))
"Timeout while connecting to {}".format(sockaddr))
_LOGGER.debug("%s: Opened socket for", self._params.address) _LOGGER.debug("%s: Opened socket for", self._params.address)
self._socket_reader, self._socket_writer = await asyncio.open_connection(sock=self._socket) self._socket_reader, self._socket_writer = await asyncio.open_connection(
sock=self._socket
)
self._socket_connected = True self._socket_connected = True
self._params.eventloop.create_task(self.run_forever()) self._params.eventloop.create_task(self.run_forever())
hello = pb.HelloRequest() hello = HelloRequest()
hello.client_info = self._params.client_info hello.client_info = self._params.client_info
try: try:
resp = await self.send_message_await_response(hello, pb.HelloResponse) resp = await self.send_message_await_response(hello, HelloResponse)
except APIConnectionError as err: except APIConnectionError as err:
await self._on_error() await self._on_error()
raise err raise err
_LOGGER.debug("%s: Successfully connected ('%s' API=%s.%s)", _LOGGER.debug(
self._params.address, resp.server_info, resp.api_version_major, "%s: Successfully connected ('%s' API=%s.%s)",
resp.api_version_minor) self._params.address,
self._api_version = APIVersion( resp.server_info,
resp.api_version_major, resp.api_version_minor) resp.api_version_major,
resp.api_version_minor,
)
self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
if self._api_version.major > 2: if self._api_version.major > 2:
_LOGGER.error("%s: Incompatible version %s! Closing connection", _LOGGER.error(
self._params.address, self._api_version.major) "%s: Incompatible version %s! Closing connection",
self._params.address,
self._api_version.major,
)
await self._on_error() await self._on_error()
raise APIConnectionError("Incompatible API version.") raise APIConnectionError("Incompatible API version.")
self._connected = True self._connected = True
@ -159,10 +185,10 @@ class APIConnection:
if self._authenticated: if self._authenticated:
raise APIConnectionError("Already logged in!") raise APIConnectionError("Already logged in!")
connect = pb.ConnectRequest() connect = ConnectRequest()
if self._params.password is not None: if self._params.password is not None:
connect.password = self._params.password connect.password = self._params.password
resp = await self.send_message_await_response(connect, pb.ConnectResponse) resp = await self.send_message_await_response(connect, ConnectResponse)
if resp.invalid_password: if resp.invalid_password:
raise APIConnectionError("Invalid password!") raise APIConnectionError("Invalid password!")
@ -187,12 +213,12 @@ class APIConnection:
raise APIConnectionError("Socket is not connected") raise APIConnectionError("Socket is not connected")
try: try:
async with self._write_lock: async with self._write_lock:
self._socket_writer.write(data) if self._socket_writer is not None:
await self._socket_writer.drain() self._socket_writer.write(data)
await self._socket_writer.drain()
except OSError as err: except OSError as err:
await self._on_error() await self._on_error()
raise APIConnectionError( raise APIConnectionError("Error while writing data: {}".format(err))
"Error while writing data: {}".format(err))
async def send_message(self, msg: message.Message) -> None: async def send_message(self, msg: message.Message) -> None:
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items(): for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
@ -202,8 +228,7 @@ class APIConnection:
raise ValueError raise ValueError
encoded = msg.SerializeToString() encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
self._params.address, type(msg), str(msg))
req = bytes([0]) req = bytes([0])
req += _varuint_to_bytes(len(encoded)) req += _varuint_to_bytes(len(encoded))
# pylint: disable=undefined-loop-variable # pylint: disable=undefined-loop-variable
@ -211,19 +236,23 @@ class APIConnection:
req += encoded req += encoded
await self._write(req) await self._write(req)
async def send_message_callback_response(self, send_msg: message.Message, async def send_message_callback_response(
on_message: Callable[[Any], None]) -> None: self, send_msg: message.Message, on_message: Callable[[Any], None]
) -> None:
self._message_handlers.append(on_message) self._message_handlers.append(on_message)
await self.send_message(send_msg) await self.send_message(send_msg)
async def send_message_await_response_complex(self, send_msg: message.Message, async def send_message_await_response_complex(
do_append: Callable[[Any], bool], self,
do_stop: Callable[[Any], bool], send_msg: message.Message,
timeout: float = 5.0) -> List[Any]: do_append: Callable[[Any], bool],
do_stop: Callable[[Any], bool],
timeout: float = 5.0,
) -> List[Any]:
fut = self._params.eventloop.create_future() fut = self._params.eventloop.create_future()
responses = [] responses = []
def on_message(resp): def on_message(resp: message.Message) -> None:
if fut.done(): if fut.done():
return return
if do_append(resp): if do_append(resp):
@ -238,8 +267,7 @@ class APIConnection:
await asyncio.wait_for(fut, timeout) await asyncio.wait_for(fut, timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if self._stopped: if self._stopped:
raise APIConnectionError( raise APIConnectionError("Disconnected while waiting for API response!")
"Disconnected while waiting for API response!")
raise APIConnectionError("Timeout while waiting for API response!") raise APIConnectionError("Timeout while waiting for API response!")
try: try:
@ -249,17 +277,17 @@ class APIConnection:
return responses return responses
async def send_message_await_response(self, async def send_message_await_response(
send_msg: message.Message, self, send_msg: message.Message, response_type: Any, timeout: float = 5.0
response_type: Any, timeout: float = 5.0) -> Any: ) -> Any:
def is_response(msg): def is_response(msg: message.Message) -> bool:
return isinstance(msg, response_type) return isinstance(msg, response_type)
res = await self.send_message_await_response_complex( res = await self.send_message_await_response_complex(
send_msg, is_response, is_response, timeout=timeout) send_msg, is_response, is_response, timeout=timeout
)
if len(res) != 1: if len(res) != 1:
raise APIConnectionError( raise APIConnectionError("Expected one result, got {}".format(len(res)))
"Expected one result, got {}".format(len(res)))
return res[0] return res[0]
@ -268,10 +296,10 @@ class APIConnection:
return bytes() return bytes()
try: try:
assert self._socket_reader is not None
ret = await self._socket_reader.readexactly(amount) ret = await self._socket_reader.readexactly(amount)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
raise APIConnectionError( raise APIConnectionError("Error while receiving data: {}".format(err))
"Error while receiving data: {}".format(err))
return ret return ret
@ -291,8 +319,9 @@ class APIConnection:
raw_msg = await self._recv(length) raw_msg = await self._recv(length)
if msg_type not in MESSAGE_TYPE_TO_PROTO: if msg_type not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug("%s: Skipping message type %s", _LOGGER.debug(
self._params.address, msg_type) "%s: Skipping message type %s", self._params.address, msg_type
)
return return
msg = MESSAGE_TYPE_TO_PROTO[msg_type]() msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
@ -300,8 +329,9 @@ class APIConnection:
msg.ParseFromString(raw_msg) msg.ParseFromString(raw_msg)
except Exception as e: except Exception as e:
raise APIConnectionError("Invalid protobuf message: {}".format(e)) raise APIConnectionError("Invalid protobuf message: {}".format(e))
_LOGGER.debug("%s: Got message of type %s: %s", _LOGGER.debug(
self._params.address, type(msg), msg) "%s: Got message of type %s: %s", self._params.address, type(msg), msg
)
for msg_handler in self._message_handlers[:]: for msg_handler in self._message_handlers[:]:
msg_handler(msg) msg_handler(msg)
await self._handle_internal_messages(msg) await self._handle_internal_messages(msg)
@ -311,36 +341,44 @@ class APIConnection:
try: try:
await self._run_once() await self._run_once()
except APIConnectionError as err: except APIConnectionError as err:
_LOGGER.info("%s: Error while reading incoming messages: %s", _LOGGER.info(
self._params.address, err) "%s: Error while reading incoming messages: %s",
self._params.address,
err,
)
await self._on_error() await self._on_error()
break break
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
_LOGGER.info("%s: Unexpected error while reading incoming messages: %s", _LOGGER.info(
self._params.address, err) "%s: Unexpected error while reading incoming messages: %s",
self._params.address,
err,
)
await self._on_error() await self._on_error()
break break
async def _handle_internal_messages(self, msg: Any) -> None: async def _handle_internal_messages(self, msg: Any) -> None:
if isinstance(msg, pb.DisconnectRequest): if isinstance(msg, DisconnectRequest):
await self.send_message(pb.DisconnectResponse()) await self.send_message(DisconnectResponse())
await self.stop(force=True) await self.stop(force=True)
elif isinstance(msg, pb.PingRequest): elif isinstance(msg, PingRequest):
await self.send_message(pb.PingResponse()) await self.send_message(PingResponse())
elif isinstance(msg, pb.GetTimeRequest): elif isinstance(msg, GetTimeRequest):
resp = pb.GetTimeResponse() resp = GetTimeResponse()
resp.epoch_seconds = int(time.time()) resp.epoch_seconds = int(time.time())
await self.send_message(resp) await self.send_message(resp)
async def ping(self) -> None: async def ping(self) -> None:
self._check_connected() self._check_connected()
await self.send_message_await_response(pb.PingRequest(), pb.PingResponse) await self.send_message_await_response(PingRequest(), PingResponse)
async def _disconnect(self) -> None: async def _disconnect(self) -> None:
self._check_connected() self._check_connected()
try: try:
await self.send_message_await_response(pb.DisconnectRequest(), pb.DisconnectResponse) await self.send_message_await_response(
DisconnectRequest(), DisconnectResponse
)
except APIConnectionError: except APIConnectionError:
pass pass

View File

@ -1,4 +1,53 @@
import aioesphomeapi.api_pb2 as pb from aioesphomeapi.api_pb2 import ( # type: ignore
BinarySensorStateResponse,
CameraImageRequest,
CameraImageResponse,
ClimateCommandRequest,
ClimateStateResponse,
ConnectRequest,
ConnectResponse,
CoverCommandRequest,
CoverStateResponse,
DeviceInfoRequest,
DeviceInfoResponse,
DisconnectRequest,
DisconnectResponse,
ExecuteServiceRequest,
FanCommandRequest,
FanStateResponse,
GetTimeRequest,
GetTimeResponse,
HelloRequest,
HelloResponse,
HomeassistantServiceResponse,
HomeAssistantStateResponse,
LightCommandRequest,
LightStateResponse,
ListEntitiesBinarySensorResponse,
ListEntitiesCameraResponse,
ListEntitiesClimateResponse,
ListEntitiesCoverResponse,
ListEntitiesDoneResponse,
ListEntitiesFanResponse,
ListEntitiesLightResponse,
ListEntitiesRequest,
ListEntitiesSensorResponse,
ListEntitiesServicesResponse,
ListEntitiesSwitchResponse,
ListEntitiesTextSensorResponse,
PingRequest,
PingResponse,
SensorStateResponse,
SubscribeHomeassistantServicesRequest,
SubscribeHomeAssistantStateResponse,
SubscribeHomeAssistantStatesRequest,
SubscribeLogsRequest,
SubscribeLogsResponse,
SubscribeStatesRequest,
SwitchCommandRequest,
SwitchStateResponse,
TextSensorStateResponse,
)
class APIConnectionError(Exception): class APIConnectionError(Exception):
@ -6,52 +55,52 @@ class APIConnectionError(Exception):
MESSAGE_TYPE_TO_PROTO = { MESSAGE_TYPE_TO_PROTO = {
1: pb.HelloRequest, 1: HelloRequest,
2: pb.HelloResponse, 2: HelloResponse,
3: pb.ConnectRequest, 3: ConnectRequest,
4: pb.ConnectResponse, 4: ConnectResponse,
5: pb.DisconnectRequest, 5: DisconnectRequest,
6: pb.DisconnectResponse, 6: DisconnectResponse,
7: pb.PingRequest, 7: PingRequest,
8: pb.PingResponse, 8: PingResponse,
9: pb.DeviceInfoRequest, 9: DeviceInfoRequest,
10: pb.DeviceInfoResponse, 10: DeviceInfoResponse,
11: pb.ListEntitiesRequest, 11: ListEntitiesRequest,
12: pb.ListEntitiesBinarySensorResponse, 12: ListEntitiesBinarySensorResponse,
13: pb.ListEntitiesCoverResponse, 13: ListEntitiesCoverResponse,
14: pb.ListEntitiesFanResponse, 14: ListEntitiesFanResponse,
15: pb.ListEntitiesLightResponse, 15: ListEntitiesLightResponse,
16: pb.ListEntitiesSensorResponse, 16: ListEntitiesSensorResponse,
17: pb.ListEntitiesSwitchResponse, 17: ListEntitiesSwitchResponse,
18: pb.ListEntitiesTextSensorResponse, 18: ListEntitiesTextSensorResponse,
19: pb.ListEntitiesDoneResponse, 19: ListEntitiesDoneResponse,
20: pb.SubscribeStatesRequest, 20: SubscribeStatesRequest,
21: pb.BinarySensorStateResponse, 21: BinarySensorStateResponse,
22: pb.CoverStateResponse, 22: CoverStateResponse,
23: pb.FanStateResponse, 23: FanStateResponse,
24: pb.LightStateResponse, 24: LightStateResponse,
25: pb.SensorStateResponse, 25: SensorStateResponse,
26: pb.SwitchStateResponse, 26: SwitchStateResponse,
27: pb.TextSensorStateResponse, 27: TextSensorStateResponse,
28: pb.SubscribeLogsRequest, 28: SubscribeLogsRequest,
29: pb.SubscribeLogsResponse, 29: SubscribeLogsResponse,
30: pb.CoverCommandRequest, 30: CoverCommandRequest,
31: pb.FanCommandRequest, 31: FanCommandRequest,
32: pb.LightCommandRequest, 32: LightCommandRequest,
33: pb.SwitchCommandRequest, 33: SwitchCommandRequest,
34: pb.SubscribeHomeassistantServicesRequest, 34: SubscribeHomeassistantServicesRequest,
35: pb.HomeassistantServiceResponse, 35: HomeassistantServiceResponse,
36: pb.GetTimeRequest, 36: GetTimeRequest,
37: pb.GetTimeResponse, 37: GetTimeResponse,
38: pb.SubscribeHomeAssistantStatesRequest, 38: SubscribeHomeAssistantStatesRequest,
39: pb.SubscribeHomeAssistantStateResponse, 39: SubscribeHomeAssistantStateResponse,
40: pb.HomeAssistantStateResponse, 40: HomeAssistantStateResponse,
41: pb.ListEntitiesServicesResponse, 41: ListEntitiesServicesResponse,
42: pb.ExecuteServiceRequest, 42: ExecuteServiceRequest,
43: pb.ListEntitiesCameraResponse, 43: ListEntitiesCameraResponse,
44: pb.CameraImageResponse, 44: CameraImageResponse,
45: pb.CameraImageRequest, 45: CameraImageRequest,
46: pb.ListEntitiesClimateResponse, 46: ListEntitiesClimateResponse,
47: pb.ClimateStateResponse, 47: ClimateStateResponse,
48: pb.ClimateCommandRequest, 48: ClimateCommandRequest,
} }

View File

@ -1,15 +1,18 @@
import socket import socket
import time import time
from typing import Optional
import zeroconf import zeroconf
class HostResolver(zeroconf.RecordUpdateListener): class HostResolver(zeroconf.RecordUpdateListener):
def __init__(self, name): def __init__(self, name: str):
self.name = name self.name = name
self.address = None self.address: Optional[bytes] = None
def update_record(self, zc, now, record): def update_record(
self, zc: zeroconf.Zeroconf, now: float, record: zeroconf.DNSRecord
) -> None:
if record is None: if record is None:
return return
if record.type == zeroconf._TYPE_A: if record.type == zeroconf._TYPE_A:
@ -17,15 +20,17 @@ class HostResolver(zeroconf.RecordUpdateListener):
if record.name == self.name: if record.name == self.name:
self.address = record.address self.address = record.address
def request(self, zc, timeout): def request(self, zc: zeroconf.Zeroconf, timeout: float) -> bool:
now = time.time() now = time.time()
delay = 0.2 delay = 0.2
next_ = now + delay next_ = now + delay
last = now + timeout last = now + timeout
try: try:
zc.add_listener(self, zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zc.add_listener(
zeroconf._CLASS_IN)) self,
zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zeroconf._CLASS_IN),
)
while self.address is None: while self.address is None:
if last <= now: if last <= now:
# Timeout # Timeout
@ -33,10 +38,16 @@ class HostResolver(zeroconf.RecordUpdateListener):
if next_ <= now: if next_ <= now:
out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY) out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY)
out.add_question( out.add_question(
zeroconf.DNSQuestion(self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN)) zeroconf.DNSQuestion(
self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
)
)
out.add_answer_at_time( out.add_answer_at_time(
zc.cache.get_by_details(self.name, zeroconf._TYPE_A, zc.cache.get_by_details(
zeroconf._CLASS_IN), now) self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
),
now,
)
zc.send(out) zc.send(out)
next_ = now + delay next_ = now + delay
delay *= 2 delay *= 2
@ -49,27 +60,38 @@ class HostResolver(zeroconf.RecordUpdateListener):
return True return True
def resolve_host(host, timeout=3.0, zeroconf_instance: zeroconf.Zeroconf = None): def resolve_host(
from aioesphomeapi import APIConnectionError host: str,
timeout: float = 3.0,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None,
) -> str:
from aioesphomeapi.core import APIConnectionError
try: try:
zc = zeroconf_instance or zeroconf.Zeroconf() zc = zeroconf_instance or zeroconf.Zeroconf()
except Exception: except Exception:
raise APIConnectionError("Cannot start mDNS sockets, is this a docker container without " raise APIConnectionError(
"host network mode?") "Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
try: try:
info = HostResolver(host + '.') info = HostResolver(host + ".")
assert info.address is not None
address = None address = None
if info.request(zc, timeout): if info.request(zc, timeout):
address = socket.inet_ntoa(info.address) address = socket.inet_ntoa(info.address)
except Exception as err: except Exception as err:
raise APIConnectionError("Error resolving mDNS hostname: {}".format(err)) raise APIConnectionError(
"Error resolving mDNS hostname: {}".format(err)
) from err
finally: finally:
if not zeroconf_instance: if not zeroconf_instance:
zc.close() zc.close()
if address is None: if address is None:
raise APIConnectionError("Error resolving address with mDNS: Did not respond. " raise APIConnectionError(
"Maybe the device is offline.") "Error resolving address with mDNS: Did not respond. "
"Maybe the device is offline."
)
return address return address

View File

@ -1,19 +1,23 @@
import enum import enum
from typing import List, Dict, TypeVar, Optional, Type from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type, TypeVar
import attr import attr
if TYPE_CHECKING:
from .api_pb2 import HomeassistantServiceMap # type: ignore
# All fields in here should have defaults set # All fields in here should have defaults set
# Home Assistant depends on these fields being constructible # Home Assistant depends on these fields being constructible
# with args from a previous version of Home Assistant. # with args from a previous version of Home Assistant.
# The default value should *always* be the Protobuf default value # The default value should *always* be the Protobuf default value
# for a field (False, 0, empty string, enum with value 0, ...) # for a field (False, 0, empty string, enum with value 0, ...)
_T = TypeVar("_T", bound="APIIntEnum")
_T = TypeVar('_T')
class APIIntEnum(enum.IntEnum): class APIIntEnum(enum.IntEnum):
"""Base class for int enum values in API model.""" """Base class for int enum values in API model."""
@classmethod @classmethod
def convert(cls: Type[_T], value: int) -> Optional[_T]: def convert(cls: Type[_T], value: int) -> Optional[_T]:
try: try:
@ -41,20 +45,20 @@ class APIVersion:
@attr.s @attr.s
class DeviceInfo: class DeviceInfo:
uses_password = attr.ib(type=bool, default=False) uses_password = attr.ib(type=bool, default=False)
name = attr.ib(type=str, default='') name = attr.ib(type=str, default="")
mac_address = attr.ib(type=str, default='') mac_address = attr.ib(type=str, default="")
compilation_time = attr.ib(type=str, default='') compilation_time = attr.ib(type=str, default="")
model = attr.ib(type=str, default='') model = attr.ib(type=str, default="")
has_deep_sleep = attr.ib(type=bool, default=False) has_deep_sleep = attr.ib(type=bool, default=False)
esphome_version = attr.ib(type=str, default='') esphome_version = attr.ib(type=str, default="")
@attr.s @attr.s
class EntityInfo: class EntityInfo:
object_id = attr.ib(type=str, default='') object_id = attr.ib(type=str, default="")
key = attr.ib(type=int, default=0) key = attr.ib(type=int, default=0)
name = attr.ib(type=str, default='') name = attr.ib(type=str, default="")
unique_id = attr.ib(type=str, default='') unique_id = attr.ib(type=str, default="")
@attr.s @attr.s
@ -65,7 +69,7 @@ class EntityState:
# ==================== BINARY SENSOR ==================== # ==================== BINARY SENSOR ====================
@attr.s @attr.s
class BinarySensorInfo(EntityInfo): class BinarySensorInfo(EntityInfo):
device_class = attr.ib(type=str, default='') device_class = attr.ib(type=str, default="")
is_status_binary_sensor = attr.ib(type=bool, default=False) is_status_binary_sensor = attr.ib(type=bool, default=False)
@ -81,7 +85,7 @@ class CoverInfo(EntityInfo):
assumed_state = attr.ib(type=bool, default=False) assumed_state = attr.ib(type=bool, default=False)
supports_position = attr.ib(type=bool, default=False) supports_position = attr.ib(type=bool, default=False)
supports_tilt = attr.ib(type=bool, default=False) supports_tilt = attr.ib(type=bool, default=False)
device_class = attr.ib(type=str, default='') device_class = attr.ib(type=str, default="")
class LegacyCoverState(APIIntEnum): class LegacyCoverState(APIIntEnum):
@ -103,14 +107,20 @@ class CoverOperation(APIIntEnum):
@attr.s @attr.s
class CoverState(EntityState): class CoverState(EntityState):
legacy_state = attr.ib(type=Optional[LegacyCoverState], converter=LegacyCoverState.convert, legacy_state = attr.ib(
default=LegacyCoverState.OPEN) type=LegacyCoverState,
converter=LegacyCoverState.convert, # type: ignore
default=LegacyCoverState.OPEN,
)
position = attr.ib(type=float, default=0.0) position = attr.ib(type=float, default=0.0)
tilt = attr.ib(type=float, default=0.0) tilt = attr.ib(type=float, default=0.0)
current_operation = attr.ib(type=Optional[CoverOperation], converter=CoverOperation.convert, current_operation = attr.ib(
default=CoverOperation.IDLE) type=CoverOperation,
converter=CoverOperation.convert, # type: ignore
default=CoverOperation.IDLE,
)
def is_closed(self, api_version: APIVersion): def is_closed(self, api_version: APIVersion) -> bool:
if api_version >= APIVersion(1, 1): if api_version >= APIVersion(1, 1):
return self.position == 0.0 return self.position == 0.0
return self.legacy_state == LegacyCoverState.CLOSED return self.legacy_state == LegacyCoverState.CLOSED
@ -140,9 +150,17 @@ class FanDirection(APIIntEnum):
class FanState(EntityState): class FanState(EntityState):
state = attr.ib(type=bool, default=False) state = attr.ib(type=bool, default=False)
oscillating = attr.ib(type=bool, default=False) oscillating = attr.ib(type=bool, default=False)
speed = attr.ib(type=Optional[FanSpeed], converter=FanSpeed.convert, default=FanSpeed.LOW) speed = attr.ib(
type=Optional[FanSpeed],
converter=FanSpeed.convert, # type: ignore
default=FanSpeed.LOW,
)
speed_level = attr.ib(type=int, default=0) speed_level = attr.ib(type=int, default=0)
direction = attr.ib(type=Optional[FanDirection], converter=FanDirection.convert, default=FanDirection.FORWARD) direction = attr.ib(
type=FanDirection,
converter=FanDirection.convert, # type: ignore
default=FanDirection.FORWARD,
)
# ==================== LIGHT ==================== # ==================== LIGHT ====================
@ -166,7 +184,7 @@ class LightState(EntityState):
blue = attr.ib(type=float, default=0.0) blue = attr.ib(type=float, default=0.0)
white = attr.ib(type=float, default=0.0) white = attr.ib(type=float, default=0.0)
color_temperature = attr.ib(type=float, default=0.0) color_temperature = attr.ib(type=float, default=0.0)
effect = attr.ib(type=str, default='') effect = attr.ib(type=str, default="")
# ==================== SENSOR ==================== # ==================== SENSOR ====================
@ -174,14 +192,19 @@ class SensorStateClass(APIIntEnum):
NONE = 0 NONE = 0
MEASUREMENT = 1 MEASUREMENT = 1
@attr.s @attr.s
class SensorInfo(EntityInfo): class SensorInfo(EntityInfo):
icon = attr.ib(type=str, default='') icon = attr.ib(type=str, default="")
device_class = attr.ib(type=str, default='') device_class = attr.ib(type=str, default="")
unit_of_measurement = attr.ib(type=str, default='') unit_of_measurement = attr.ib(type=str, default="")
accuracy_decimals = attr.ib(type=int, default=0) accuracy_decimals = attr.ib(type=int, default=0)
force_update = attr.ib(type=bool, default=False) force_update = attr.ib(type=bool, default=False)
state_class = attr.ib(type=Optional[SensorStateClass], converter=SensorStateClass.convert, default=SensorStateClass.NONE) state_class = attr.ib(
type=SensorStateClass,
converter=SensorStateClass.convert, # type: ignore
default=SensorStateClass.NONE,
)
@attr.s @attr.s
@ -193,7 +216,7 @@ class SensorState(EntityState):
# ==================== SWITCH ==================== # ==================== SWITCH ====================
@attr.s @attr.s
class SwitchInfo(EntityInfo): class SwitchInfo(EntityInfo):
icon = attr.ib(type=str, default='') icon = attr.ib(type=str, default="")
assumed_state = attr.ib(type=bool, default=False) assumed_state = attr.ib(type=bool, default=False)
@ -205,12 +228,12 @@ class SwitchState(EntityState):
# ==================== TEXT SENSOR ==================== # ==================== TEXT SENSOR ====================
@attr.s @attr.s
class TextSensorInfo(EntityInfo): class TextSensorInfo(EntityInfo):
icon = attr.ib(type=str, default='') icon = attr.ib(type=str, default="")
@attr.s @attr.s
class TextSensorState(EntityState): class TextSensorState(EntityState):
state = attr.ib(type=str, default='') state = attr.ib(type=str, default="")
missing_state = attr.ib(type=bool, default=False) missing_state = attr.ib(type=bool, default=False)
@ -267,68 +290,90 @@ class ClimateAction(APIIntEnum):
class ClimateInfo(EntityInfo): class ClimateInfo(EntityInfo):
supports_current_temperature = attr.ib(type=bool, default=False) supports_current_temperature = attr.ib(type=bool, default=False)
supports_two_point_target_temperature = attr.ib(type=bool, default=False) supports_two_point_target_temperature = attr.ib(type=bool, default=False)
supported_modes = attr.ib(type=List[ClimateMode], converter=ClimateMode.convert_list, supported_modes = attr.ib(
factory=list) type=List[ClimateMode],
converter=ClimateMode.convert_list, # type: ignore
factory=list,
)
visual_min_temperature = attr.ib(type=float, default=0.0) visual_min_temperature = attr.ib(type=float, default=0.0)
visual_max_temperature = attr.ib(type=float, default=0.0) visual_max_temperature = attr.ib(type=float, default=0.0)
visual_temperature_step = attr.ib(type=float, default=0.0) visual_temperature_step = attr.ib(type=float, default=0.0)
supports_away = attr.ib(type=bool, default=False) supports_away = attr.ib(type=bool, default=False)
supports_action = attr.ib(type=bool, default=False) supports_action = attr.ib(type=bool, default=False)
supported_fan_modes = attr.ib( supported_fan_modes = attr.ib(
type=List[ClimateFanMode], converter=ClimateFanMode.convert_list, factory=list type=List[ClimateFanMode],
converter=ClimateFanMode.convert_list, # type: ignore
factory=list,
) )
supported_swing_modes = attr.ib( supported_swing_modes = attr.ib(
type=List[ClimateSwingMode], converter=ClimateSwingMode.convert_list, factory=list type=List[ClimateSwingMode],
converter=ClimateSwingMode.convert_list, # type: ignore
factory=list,
) )
@attr.s @attr.s
class ClimateState(EntityState): class ClimateState(EntityState):
mode = attr.ib(type=Optional[ClimateMode], converter=ClimateMode.convert, mode = attr.ib(
default=ClimateMode.OFF) type=ClimateMode,
action = attr.ib(type=Optional[ClimateAction], converter=ClimateAction.convert, converter=ClimateMode.convert, # type: ignore
default=ClimateAction.OFF) default=ClimateMode.OFF,
)
action = attr.ib(
type=ClimateAction,
converter=ClimateAction.convert, # type: ignore
default=ClimateAction.OFF,
)
current_temperature = attr.ib(type=float, default=0.0) current_temperature = attr.ib(type=float, default=0.0)
target_temperature = attr.ib(type=float, default=0.0) target_temperature = attr.ib(type=float, default=0.0)
target_temperature_low = attr.ib(type=float, default=0.0) target_temperature_low = attr.ib(type=float, default=0.0)
target_temperature_high = attr.ib(type=float, default=0.0) target_temperature_high = attr.ib(type=float, default=0.0)
away = attr.ib(type=bool, default=False) away = attr.ib(type=bool, default=False)
fan_mode = attr.ib( fan_mode = attr.ib(
type=Optional[ClimateFanMode], converter=ClimateFanMode.convert, default=ClimateFanMode.ON type=Optional[ClimateFanMode],
converter=ClimateFanMode.convert, # type: ignore
default=ClimateFanMode.ON,
) )
swing_mode = attr.ib( swing_mode = attr.ib(
type=Optional[ClimateSwingMode], converter=ClimateSwingMode.convert, default=ClimateSwingMode.OFF type=Optional[ClimateSwingMode],
converter=ClimateSwingMode.convert, # type: ignore
default=ClimateSwingMode.OFF,
) )
COMPONENT_TYPE_TO_INFO = { COMPONENT_TYPE_TO_INFO = {
'binary_sensor': BinarySensorInfo, "binary_sensor": BinarySensorInfo,
'cover': CoverInfo, "cover": CoverInfo,
'fan': FanInfo, "fan": FanInfo,
'light': LightInfo, "light": LightInfo,
'sensor': SensorInfo, "sensor": SensorInfo,
'switch': SwitchInfo, "switch": SwitchInfo,
'text_sensor': TextSensorInfo, "text_sensor": TextSensorInfo,
'camera': CameraInfo, "camera": CameraInfo,
'climate': ClimateInfo, "climate": ClimateInfo,
} }
# ==================== USER-DEFINED SERVICES ==================== # ==================== USER-DEFINED SERVICES ====================
def _convert_homeassistant_service_map(value): def _convert_homeassistant_service_map(
value: Iterable["HomeassistantServiceMap"],
) -> Dict[str, str]:
return {v.key: v.value for v in value} return {v.key: v.value for v in value}
@attr.s @attr.s
class HomeassistantServiceCall: class HomeassistantServiceCall:
service = attr.ib(type=str, default='') service = attr.ib(type=str, default="")
is_event = attr.ib(type=bool, default=False) is_event = attr.ib(type=bool, default=False)
data = attr.ib(type=Dict[str, str], converter=_convert_homeassistant_service_map, data = attr.ib(
factory=dict) type=Dict[str, str], converter=_convert_homeassistant_service_map, factory=dict
data_template = attr.ib(type=Dict[str, str], converter=_convert_homeassistant_service_map, )
factory=dict) data_template = attr.ib(
variables = attr.ib(type=Dict[str, str], converter=_convert_homeassistant_service_map, type=Dict[str, str], converter=_convert_homeassistant_service_map, factory=dict
factory=dict) )
variables = attr.ib(
type=Dict[str, str], converter=_convert_homeassistant_service_map, factory=dict
)
class UserServiceArgType(APIIntEnum): class UserServiceArgType(APIIntEnum):
@ -342,37 +387,43 @@ class UserServiceArgType(APIIntEnum):
STRING_ARRAY = 7 STRING_ARRAY = 7
def _attr_obj_from_dict(cls, **kwargs): _K = TypeVar("_K")
return cls(**{key: kwargs[key] for key in attr.fields_dict(cls)})
def _attr_obj_from_dict(cls: Type[_K], **kwargs: Any) -> _K:
return cls(**{key: kwargs[key] for key in attr.fields_dict(cls)}) # type: ignore
@attr.s @attr.s
class UserServiceArg: class UserServiceArg:
name = attr.ib(type=str, default='') name = attr.ib(type=str, default="")
type_ = attr.ib(type=Optional[UserServiceArgType], converter=UserServiceArgType.convert, type_ = attr.ib(
default=UserServiceArgType.BOOL) type=UserServiceArgType,
converter=UserServiceArgType.convert, # type: ignore
default=UserServiceArgType.BOOL,
)
@attr.s @attr.s
class UserService: class UserService:
name = attr.ib(type=str, default='') name = attr.ib(type=str, default="")
key = attr.ib(type=int, default=0) key = attr.ib(type=int, default=0)
args = attr.ib(type=List[UserServiceArg], converter=list, factory=list) args = attr.ib(type=List[UserServiceArg], converter=list, factory=list)
@staticmethod @classmethod
def from_dict(dict_): def from_dict(cls, dict_: Dict[str, Any]) -> "UserService":
args = [] args = []
for arg in dict_.get('args', []): for arg in dict_.get("args", []):
args.append(_attr_obj_from_dict(UserServiceArg, **arg)) args.append(_attr_obj_from_dict(UserServiceArg, **arg))
return UserService( return cls(
name=dict_.get('name', ''), name=dict_.get("name", ""),
key=dict_.get('key', 0), key=dict_.get("key", 0),
args=args args=args, # type: ignore
) )
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
'name': self.name, "name": self.name,
'key': self.key, "key": self.key,
'args': [attr.asdict(arg) for arg in self.args], "args": [attr.asdict(arg) for arg in self.args],
} }

View File

@ -1,7 +1,8 @@
import asyncio import asyncio
import functools import functools
import socket import socket
from typing import Optional, Tuple, Any from typing import Any, Optional, Tuple
import zeroconf import zeroconf
# pylint: disable=cyclic-import # pylint: disable=cyclic-import
@ -35,8 +36,9 @@ def _bytes_to_varuint(value: bytes) -> Optional[int]:
return None return None
async def resolve_ip_address_getaddrinfo(eventloop: asyncio.events.AbstractEventLoop, async def resolve_ip_address_getaddrinfo(
host: str, port: int) -> Tuple[Any, ...]: eventloop: asyncio.events.AbstractEventLoop, host: str, port: int
) -> Tuple[Any, ...]:
try: try:
socket.inet_aton(host) socket.inet_aton(host)
@ -46,8 +48,9 @@ async def resolve_ip_address_getaddrinfo(eventloop: asyncio.events.AbstractEvent
return (host, port) return (host, port)
try: try:
res = await eventloop.getaddrinfo(host, port, family=socket.AF_INET, res = await eventloop.getaddrinfo(
proto=socket.IPPROTO_TCP) host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP
)
except OSError as err: except OSError as err:
raise APIConnectionError("Error resolving IP address: {}".format(err)) raise APIConnectionError("Error resolving IP address: {}".format(err))
@ -59,21 +62,25 @@ async def resolve_ip_address_getaddrinfo(eventloop: asyncio.events.AbstractEvent
return sockaddr return sockaddr
async def resolve_ip_address(eventloop: asyncio.events.AbstractEventLoop, async def resolve_ip_address(
host: str, port: int, eventloop: asyncio.events.AbstractEventLoop,
zeroconf_instance: zeroconf.Zeroconf = None) -> Tuple[Any, ...]: host: str,
if host.endswith('.local'): port: int,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None,
) -> Tuple[Any, ...]:
if host.endswith(".local"):
from aioesphomeapi.host_resolver import resolve_host from aioesphomeapi.host_resolver import resolve_host
try: try:
return await eventloop.run_in_executor( return (
None, await eventloop.run_in_executor(
functools.partial( None,
resolve_host, functools.partial(
host, resolve_host, host, zeroconf_instance=zeroconf_instance
zeroconf_instance=zeroconf_instance ),
) ),
), port port,
)
except APIConnectionError: except APIConnectionError:
pass pass
return await resolve_ip_address_getaddrinfo(eventloop, host, port) return await resolve_ip_address_getaddrinfo(eventloop, host, port)

View File

@ -1,7 +0,0 @@
#!/usr/bin/env bash
# Generate protobuf compiled files
protoc --python_out=aioesphomeapi -I aioesphomeapi aioesphomeapi/*.proto
# https://github.com/protocolbuffers/protobuf/issues/1491
sed -i 's/import api_options_pb2 as api__options__pb2/from . import api_options_pb2 as api__options__pb2/' aioesphomeapi/api_pb2.py

View File

@ -14,3 +14,4 @@ disable=
unused-wildcard-import, unused-wildcard-import,
import-outside-toplevel, import-outside-toplevel,
raise-missing-from, raise-missing-from,
duplicate-code,

4
pyproject.toml Normal file
View File

@ -0,0 +1,4 @@
[tool.isort]
profile = "black"
multi_line_output = 3
extend_skip = ["api_pb2.py", "api_options_pb2.py"]

View File

@ -1 +1,6 @@
pylint==2.8.3 pylint==2.8.3
black==21.6b0
flake8==3.9.2
isort==5.8.0
mypy==0.902
types-protobuf==0.1.13

26
script/gen-protoc Executable file
View File

@ -0,0 +1,26 @@
#!/usr/bin/env python3
from subprocess import check_call
from pathlib import Path
import os
root_dir = Path(__file__).absolute().parent.parent
os.chdir(root_dir)
check_call([
"protoc", "--python_out=aioesphomeapi", "-I", "aioesphomeapi",
"aioesphomeapi/api.proto", "aioesphomeapi/api_options.proto"
])
# https://github.com/protocolbuffers/protobuf/issues/1491
api_file = root_dir / 'aioesphomeapi' / 'api_pb2.py'
content = api_file.read_text().replace(
"import api_options_pb2 as api__options__pb2",
"from . import api_options_pb2 as api__options__pb2"
)
api_file.write_text(content)
for fname in ['api_pb2.py', 'api_options_pb2.py']:
file = root_dir / 'aioesphomeapi' / fname
content = '# type: ignore\n' + file.read_text()
file.write_text(content)

10
script/lint Executable file
View File

@ -0,0 +1,10 @@
#!/bin/bash
cd "$(dirname "$0")/.."
set -euxo pipefail
black --safe --exclude 'api_pb2.py|api_options_pb2.py' aioesphomeapi
pylint aioesphomeapi
flake8 aioesphomeapi
isort aioesphomeapi
mypy --strict aioesphomeapi

18
setup.cfg Normal file
View File

@ -0,0 +1,18 @@
[flake8]
max-line-length = 120
# Following 4 for black compatibility
# E501: line too long
# W503: Line break occurred before a binary operator
# E203: Whitespace before ':'
# D202 No blank lines allowed after function docstring
ignore =
E501,
W503,
E203,
D202,
exclude = api_pb2.py, api_options_pb2.py
[bdist_wheel]
universal = 1