import asyncio import logging import socket from typing import Any, Callable, List, Optional, Tuple, Union, cast import attr from google.protobuf import message import aioesphomeapi.api_pb2 as pb _LOGGER = logging.getLogger(__name__) class APIConnectionError(Exception): pass MESSAGE_TYPE_TO_PROTO = { 1: pb.HelloRequest, 2: pb.HelloResponse, 3: pb.ConnectRequest, 4: pb.ConnectResponse, 5: pb.DisconnectRequest, 6: pb.DisconnectResponse, 7: pb.PingRequest, 8: pb.PingResponse, 9: pb.DeviceInfoRequest, 10: pb.DeviceInfoResponse, 11: pb.ListEntitiesRequest, 12: pb.ListEntitiesBinarySensorResponse, 13: pb.ListEntitiesCoverResponse, 14: pb.ListEntitiesFanResponse, 15: pb.ListEntitiesLightResponse, 16: pb.ListEntitiesSensorResponse, 17: pb.ListEntitiesSwitchResponse, 18: pb.ListEntitiesTextSensorResponse, 19: pb.ListEntitiesDoneResponse, 20: pb.SubscribeStatesRequest, 21: pb.BinarySensorStateResponse, 22: pb.CoverStateResponse, 23: pb.FanStateResponse, 24: pb.LightStateResponse, 25: pb.SensorStateResponse, 26: pb.SwitchStateResponse, 27: pb.TextSensorStateResponse, 28: pb.SubscribeLogsRequest, 29: pb.SubscribeLogsResponse, 30: pb.CoverCommandRequest, 31: pb.FanCommandRequest, 32: pb.LightCommandRequest, 33: pb.SwitchCommandRequest, } def _varuint_to_bytes(value: int) -> bytes: if value <= 0x7F: return bytes([value]) ret = bytes() while value: temp = value & 0x7F value >>= 7 if value: ret += bytes([temp | 0x80]) else: ret += bytes([temp]) return ret def _bytes_to_varuint(value: bytes) -> Optional[int]: result = 0 bitpos = 0 for val in value: result |= (val & 0x7F) << bitpos bitpos += 7 if (val & 0x80) == 0: return result return None async def resolve_ip_address(eventloop: asyncio.events.AbstractEventLoop, host: str, port: int) -> Tuple[Any, ...]: try: res = await eventloop.getaddrinfo(host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP) except OSError as err: raise APIConnectionError("Error resolving IP address: {}".format(err)) if not res: raise APIConnectionError("Error resolving IP address: No matches!") _, _, _, _, sockaddr = res[0] return sockaddr # Wrap some types in attr classes to make them serializable @attr.s class DeviceInfo: uses_password = attr.ib(type=bool) name = attr.ib(type=str) mac_address = attr.ib(type=str) esphomelib_version = attr.ib(type=str) compilation_time = attr.ib(type=str) model = attr.ib(type=str) has_deep_sleep = attr.ib(type=bool) @attr.s class EntityInfo: object_id = attr.ib(type=str) key = attr.ib(type=int) name = attr.ib(type=str) unique_id = attr.ib(type=str) @attr.s class EntityState: key = attr.ib(type=int) @attr.s class BinarySensorInfo(EntityInfo): device_class = attr.ib(type=str) is_status_binary_sensor = attr.ib(type=bool) @attr.s class BinarySensorState(EntityState): state = attr.ib(type=bool) @attr.s class CoverInfo(EntityInfo): is_optimistic = attr.ib(type=bool) COVER_STATE_OPEN = 0 COVER_SATE_CLOSED = 1 COVER_STATES = [COVER_STATE_OPEN, COVER_SATE_CLOSED] COVER_COMMAND_OPEN = 0 COVER_COMMAND_CLOSE = 1 COVER_COMMAND_STOP = 2 COVER_COMMANDS = [COVER_COMMAND_OPEN, COVER_COMMAND_CLOSE, COVER_COMMAND_STOP] @attr.s class CoverState(EntityState): state = attr.ib(type=int, converter=int, validator=attr.validators.in_(COVER_STATES)) @attr.s class FanInfo(EntityInfo): supports_oscillation = attr.ib(type=bool) supports_speed = attr.ib(type=bool) FAN_SPEED_LOW = 0 FAN_SPEED_MEDIUM = 1 FAN_SPEED_HIGH = 2 FAN_SPEEDS = [FAN_SPEED_LOW, FAN_SPEED_MEDIUM, FAN_SPEED_HIGH] @attr.s class FanState(EntityState): state = attr.ib(type=bool) oscillating = attr.ib(type=bool) speed = attr.ib(type=int, converter=int, validator=attr.validators.in_(FAN_SPEEDS)) @attr.s class LightInfo(EntityInfo): supports_brightness = attr.ib(type=bool) supports_rgb = attr.ib(type=bool) supports_white_value = attr.ib(type=bool) supports_color_temperature = attr.ib(type=bool) min_mireds = attr.ib(type=float) max_mireds = attr.ib(type=float) effects = attr.ib(type=List[str], converter=list) @attr.s class LightState(EntityState): state = attr.ib(type=bool) brightness = attr.ib(type=float) red = attr.ib(type=float) green = attr.ib(type=float) blue = attr.ib(type=float) white = attr.ib(type=float) color_temperature = attr.ib(type=float) effect = attr.ib(type=str) @attr.s class SensorInfo(EntityInfo): icon = attr.ib(type=str) unit_of_measurement = attr.ib(type=str) accuracy_decimals = attr.ib(type=int) @attr.s class SensorState(EntityState): state = attr.ib(type=float) @attr.s class SwitchInfo(EntityInfo): icon = attr.ib(type=str) @attr.s class SwitchState(EntityState): state = attr.ib(type=bool) @attr.s class TextSensorInfo(EntityInfo): icon = attr.ib(type=str) @attr.s class TextSensorState(EntityState): state = attr.ib(type=str) COMPONENT_TYPE_TO_INFO = { 'binary_sensor': BinarySensorInfo, 'cover': CoverInfo, 'fan': FanInfo, 'light': LightInfo, 'sensor': SensorInfo, 'switch': SwitchInfo, 'text_sensor': TextSensorInfo, } class APIClient: def __init__(self, eventloop, address: str, port: int, password: str): self._eventloop = eventloop # type: asyncio.events.AbstractEventLoop self._address = address # type: str self._port = port # type: int self._password = password # type: Optional[str] self._socket = None # type: Optional[socket.socket] self._connected = False # type: bool self._authenticated = False # type: bool self._message_handlers = [] # type: List[Callable[[message], None]] self._keepalive = 60 # type: Union[float, int] self._ping_timer = None # type: Optional[asyncio.Future] self.on_disconnect = None self.on_login = None self.running_event = asyncio.Event() self._stop_event = asyncio.Event() self._socket_open_event = asyncio.Event() self._sock_reader = None # type: Optional[asyncio.StreamReader] self._sock_writer = None # type: Optional[asyncio.StreamWriter] self._refresh_ping() def _refresh_ping(self) -> None: if self._ping_timer is not None: self._ping_timer.cancel() self._ping_timer = None async def func() -> None: await asyncio.sleep(self._keepalive) self._ping_timer = None if self._connected: try: await self.ping() except APIConnectionError: await self._on_error() self._refresh_ping() self._ping_timer = asyncio.ensure_future(func(), loop=self._eventloop) async def _close_socket(self) -> None: if self._socket is not None: self._socket.close() self._socket = None if self._sock_writer is not None: self._sock_writer.close() if hasattr(self._sock_writer, 'wait_closed'): await self._sock_writer.wait_closed() self._sock_writer = None self._sock_reader = None self._socket_open_event.clear() self._connected = False self._authenticated = False def _cancel_ping(self) -> None: if self._ping_timer is not None: self._ping_timer.cancel() self._ping_timer = None async def start(self): self._eventloop.create_task(self.run_forever()) await self.running_event.wait() async def stop(self, force: bool = False) -> None: if not self.running_event.is_set(): raise ValueError if self._connected and not force: try: await self.disconnect() except APIConnectionError: pass await self._close_socket() self._stop_event.set() self._cancel_ping() async def connect(self) -> None: if not self.running_event.is_set(): raise APIConnectionError("You need to call start() first!") if self._connected: raise APIConnectionError("Already connected!") self._message_handlers = [] try: coro = resolve_ip_address(self._eventloop, self._address, self._port) sockaddr = await asyncio.wait_for(coro, 15.0) except APIConnectionError as err: raise err except asyncio.TimeoutError: raise APIConnectionError("Timeout while resolving IP address") self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.setblocking(False) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) _LOGGER.debug("Connecting to %s:%s (%s)", self._address, self._port, sockaddr) try: coro = self._eventloop.sock_connect(self._socket, sockaddr) await asyncio.wait_for(coro, 15.0) except OSError as err: await self._on_error() raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err)) except asyncio.TimeoutError: raise APIConnectionError("Timeout while connecting to {}".format(sockaddr)) self._sock_reader, self._sock_writer = await asyncio.open_connection(sock=self._socket) self._socket_open_event.set() hello = pb.HelloRequest() hello.client_info = 'Home Assistant' try: resp = await self._send_message_await_response(hello, pb.HelloResponse) except APIConnectionError as err: await self._on_error() raise err _LOGGER.debug("Successfully connected to %s ('%s' API=%s.%s)", self._address, resp.server_info, resp.api_version_major, resp.api_version_minor) self._connected = True def _check_connected(self) -> None: if not self._connected: raise APIConnectionError("Must be connected!") async def login(self) -> None: self._check_connected() if self._authenticated: raise APIConnectionError("Already logged in!") connect = pb.ConnectRequest() if self._password is not None: connect.password = self._password resp = await self._send_message_await_response(connect, pb.ConnectResponse) if resp.invalid_password: raise APIConnectionError("Invalid password!") self._authenticated = True if self.on_login is not None: await self.on_login() async def _on_error(self) -> None: was_connected = self._connected await self._close_socket() if was_connected and self.on_disconnect is not None: await self.on_disconnect() async def _write(self, data: bytes) -> None: _LOGGER.debug("Write: %s", ' '.join('{:02X}'.format(x) for x in data)) try: self._sock_writer.write(data) await self._sock_writer.drain() except OSError as err: await self._on_error() raise APIConnectionError("Error while writing data: {}".format(err)) async def _send_message(self, msg: message.Message) -> None: for message_type, klass in MESSAGE_TYPE_TO_PROTO.items(): if isinstance(msg, klass): break else: raise ValueError encoded = msg.SerializeToString() _LOGGER.debug("Sending %s: %s", type(msg), str(msg)) req = bytes([0]) req += _varuint_to_bytes(len(encoded)) req += _varuint_to_bytes(message_type) req += encoded await self._write(req) self._refresh_ping() async def _send_message_await_response_complex(self, send_msg: message.Message, do_append: Callable[[Any], bool], do_stop: Callable[[Any], bool], timeout: float = 1.0) -> List[Any]: fut = self._eventloop.create_future() responses = [] def on_message(resp): if do_append(resp): responses.append(resp) if do_stop(resp): fut.set_result(responses) self._message_handlers.append(on_message) await self._send_message(send_msg) try: await asyncio.wait_for(fut, timeout) except asyncio.TimeoutError: raise APIConnectionError("Timeout while waiting for API response!") try: self._message_handlers.remove(on_message) except ValueError: pass return responses async def _send_message_await_response(self, send_msg: message.Message, response_type: Any, timeout: float = 1.0) -> Any: def is_response(msg): return isinstance(msg, response_type) res = await self._send_message_await_response_complex( send_msg, is_response, is_response, timeout=timeout) if len(res) != 1: raise APIConnectionError("Expected one result, got {}".format(len(res))) return res[0] async def device_info(self) -> DeviceInfo: self._check_connected() resp = await self._send_message_await_response( pb.DeviceInfoRequest(), pb.DeviceInfoResponse) return DeviceInfo( uses_password=resp.uses_password, name=resp.name, mac_address=resp.mac_address, esphomelib_version=resp.esphomelib_version, compilation_time=resp.compilation_time, model=resp.model, has_deep_sleep=resp.has_deep_sleep, ) async def ping(self) -> None: self._check_connected() await self._send_message_await_response(pb.PingRequest(), pb.PingResponse) return async def disconnect(self) -> None: self._check_connected() try: await self._send_message_await_response(pb.DisconnectRequest(), pb.DisconnectResponse) except APIConnectionError: pass await self._close_socket() if self.on_disconnect is not None: await self.on_disconnect() def _check_authenticated(self) -> None: if not self._authenticated: raise APIConnectionError("Must login first!") async def list_entities(self) -> List[Any]: self._check_authenticated() response_types = { pb.ListEntitiesBinarySensorResponse: BinarySensorInfo, pb.ListEntitiesCoverResponse: CoverInfo, pb.ListEntitiesFanResponse: FanInfo, pb.ListEntitiesLightResponse: LightInfo, pb.ListEntitiesSensorResponse: SensorInfo, pb.ListEntitiesSwitchResponse: SwitchInfo, pb.ListEntitiesTextSensorResponse: TextSensorInfo, } def do_append(msg): return isinstance(msg, tuple(response_types.keys())) def do_stop(msg): return isinstance(msg, pb.ListEntitiesDoneResponse) resp = await self._send_message_await_response_complex( pb.ListEntitiesRequest(), do_append, do_stop, timeout=5) entities = [] for msg in resp: cls = None for resp_type, cls in response_types.items(): if isinstance(msg, resp_type): break kwargs = {} for key, _ in attr.fields_dict(cls).items(): kwargs[key] = getattr(msg, key) entities.append(cls(**kwargs)) return entities async def subscribe_states(self, on_state: Callable[[Any], None]) -> None: self._check_authenticated() response_types = { pb.BinarySensorStateResponse: BinarySensorState, pb.CoverStateResponse: CoverState, pb.FanStateResponse: FanState, pb.LightStateResponse: LightState, pb.SensorStateResponse: SensorState, pb.SwitchStateResponse: SwitchState, pb.TextSensorStateResponse: TextSensorState, } def on_msg(msg): for resp_type, cls in response_types.items(): if isinstance(msg, resp_type): break else: return kwargs = {} for key, _ in attr.fields_dict(cls).items(): kwargs[key] = getattr(msg, key) on_state(cls(**kwargs)) self._message_handlers.append(on_msg) await self._send_message(pb.SubscribeStatesRequest()) async def subscribe_logs(self, on_log: Callable[[pb.SubscribeLogsResponse], None], log_level=None) -> None: self._check_authenticated() def on_msg(msg): if isinstance(msg, pb.SubscribeLogsResponse): on_log(msg) self._message_handlers.append(on_msg) req = pb.SubscribeLogsRequest() if log_level is not None: req.level = log_level await self._send_message(req) async def cover_command(self, key: int, command: int ) -> None: self._check_authenticated() req = pb.CoverCommandRequest() req.key = key req.has_state = True if command not in COVER_COMMANDS: raise ValueError req.command = command await self._send_message(req) async def fan_command(self, key: int, state: Optional[bool] = None, speed: Optional[int] = None, oscillating: Optional[bool] = None ) -> None: self._check_authenticated() req = pb.FanCommandRequest() req.key = key if state is not None: req.has_state = True req.state = state if speed is not None: req.has_speed = True if speed not in FAN_SPEEDS: raise ValueError req.speed = speed if oscillating is not None: req.has_oscillating = True req.oscillating = oscillating await self._send_message(req) async def light_command(self, key: int, state: Optional[bool] = None, brightness: Optional[float] = None, rgb: Optional[Tuple[float, float, float]] = None, white: Optional[float] = None, color_temperature: Optional[float] = None, transition_length: Optional[float] = None, flash_length: Optional[float] = None, effect: Optional[str] = None, ): self._check_authenticated() req = pb.LightCommandRequest() req.key = key if state is not None: req.has_state = True req.state = state if brightness is not None: req.has_brightness = True req.brightness = brightness if rgb is not None: req.has_rgb = True req.red = rgb[0] req.green = rgb[1] req.blue = rgb[2] if white is not None: req.has_white = True req.white = white if color_temperature is not None: req.has_color_temperature = True req.color_temperature = color_temperature if transition_length is not None: req.has_transition_length = True req.transition_length = int(round(transition_length / 1000)) if flash_length is not None: req.has_flash_length = True req.flash_length = int(round(flash_length / 1000)) if effect is not None: req.has_effect = True req.effect = effect await self._send_message(req) async def switch_command(self, key: int, state: bool ) -> None: self._check_authenticated() req = pb.SwitchCommandRequest() req.key = key req.state = state await self._send_message(req) async def _recv(self, amount: int) -> bytes: if amount == 0: return bytes() try: ret = await self._sock_reader.readexactly(amount) except (asyncio.IncompleteReadError, OSError) as err: raise APIConnectionError("Error while receiving data: {}".format(err)) return ret async def _recv_varint(self) -> int: raw = bytes() while not raw or raw[-1] & 0x80: raw += await self._recv(1) return cast(int, _bytes_to_varuint(raw)) async def _run_once(self) -> None: await self._socket_open_event.wait() preamble = await self._recv(1) if preamble[0] != 0x00: raise APIConnectionError("Invalid preamble") length = await self._recv_varint() msg_type = await self._recv_varint() raw_msg = await self._recv(length) if msg_type not in MESSAGE_TYPE_TO_PROTO: _LOGGER.debug("Skipping message type %s", msg_type) return msg = MESSAGE_TYPE_TO_PROTO[msg_type]() msg.ParseFromString(raw_msg) _LOGGER.debug("Got message of type %s: %s", type(msg), msg) for msg_handler in self._message_handlers[:]: msg_handler(msg) await self._handle_internal_messages(msg) self._refresh_ping() async def run_forever(self) -> None: if self.running_event.is_set(): raise ValueError self.running_event.set() try: while True: try: await self._run_once() except APIConnectionError as err: if self._connected: _LOGGER.debug("Error while reading incoming messages: %s", err) await self._on_error() except asyncio.CancelledError: self.running_event.clear() raise async def _handle_internal_messages(self, msg: Any) -> None: if isinstance(msg, pb.DisconnectRequest): await self._send_message(pb.DisconnectResponse()) await self._close_socket() if self.on_disconnect is not None: await self.on_disconnect() elif isinstance(msg, pb.PingRequest): await self._send_message(pb.PingResponse())