From fadcdb150145022f876595085e7762e3edc66be1 Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Fri, 4 Jan 2019 18:35:38 +0100 Subject: [PATCH] Restructure API for prevent race conditions --- aioesphomeapi/client.py | 461 +++++++++++++++++++++++----------------- 1 file changed, 264 insertions(+), 197 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 04c2205..63e3c78 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -255,168 +255,186 @@ class ServiceCall: variables = attr.ib(type=Dict[str, str], converter=dict) -class APIClient: - def __init__(self, eventloop, address: str, port: int, password: str): - self._eventloop = eventloop # type: asyncio.events.AbstractEventLoop - self._address = address # type: str - self._port = port # type: int - self._password = password # type: Optional[str] - self._socket = None # type: Optional[socket.socket] - self._connected = False # type: bool - self._authenticated = False # type: bool - self._message_handlers = [] # type: List[Callable[[message], None]] - self._keepalive = 60 # type: Union[float, int] - self._ping_timer = None # type: Optional[asyncio.Future] - self.on_disconnect = None - self.on_login = None - self.running_event = asyncio.Event() - self._stop_event = asyncio.Event() - self._socket_open_event = asyncio.Event() - self._sock_reader = None # type: Optional[asyncio.StreamReader] - self._sock_writer = None # type: Optional[asyncio.StreamWriter] +@attr.s +class State: + running = attr.ib(type=bool) + stopped = attr.ib(type=bool) + socket = attr.ib(type=Optional[socket.socket]) + socket_reader = attr.ib(type=Optional[asyncio.StreamReader]) + socket_writer = attr.ib(type=Optional[asyncio.StreamWriter]) + socket_open = attr.ib(type=bool) + connected = attr.ib(type=bool) + authenticated = attr.ib(type=bool) - self._refresh_ping() + +@attr.s +class ConnectionParams: + eventloop = attr.ib(type=asyncio.events.AbstractEventLoop) + address = attr.ib(type=str) + port = attr.ib(type=int) + password = attr.ib(type=Optional[str]) + client_info = attr.ib(type=str) + keepalive = attr.ib(type=float) + + +class APIConnection: + def __init__(self, params: ConnectionParams, on_stop): + self._params = params + self.on_stop = on_stop + self._stopped = False + self._socket = None # type: Optional[socket.socket] + self._socket_reader = None # type: Optional[asyncio.StreamReader] + self._socket_writer = None # type: Optional[asyncio.StreamWriter] + self._write_lock = asyncio.Lock() + self._connected = False + self._authenticated = False + self._socket_connected = False + self._state_lock = asyncio.Lock() + + self._ping_timer = None # type: Optional[asyncio.Task] + self._message_handlers = [] # type: List[Callable[[message], None]] + + self._running_task = None # type: Optional[asyncio.Task] def _refresh_ping(self) -> None: - if self._ping_timer is not None: - self._ping_timer.cancel() - self._ping_timer = None + self._cancel_ping() async def func() -> None: - await asyncio.sleep(self._keepalive) - self._ping_timer = None + await asyncio.sleep(self._params.keepalive) if self._connected: try: await self.ping() except APIConnectionError: await self._on_error() + else: + self._refresh_ping() - self._refresh_ping() - - self._ping_timer = asyncio.ensure_future(func(), loop=self._eventloop) - - async def _close_socket(self) -> None: - if self._socket is not None: - self._socket.close() - self._socket = None - if self._sock_writer is not None: - self._sock_writer.close() - if hasattr(self._sock_writer, 'wait_closed'): - await self._sock_writer.wait_closed() - self._sock_writer = None - self._sock_reader = None - self._socket_open_event.clear() - self._connected = False - self._authenticated = False + self._ping_timer = self._params.eventloop.create_task(func()) def _cancel_ping(self) -> None: if self._ping_timer is not None: self._ping_timer.cancel() self._ping_timer = None - async def start(self) -> None: - self._eventloop.create_task(self.run_forever()) - await self.running_event.wait() + async def _close_socket(self) -> None: + if not self._socket_connected: + return + async with self._write_lock: + self._socket_writer.close() + if hasattr(self._socket_writer, 'wait_closed'): + await self._socket_writer.wait_closed() + self._socket_writer = None + self._socket_reader = None + if self._socket is not None: + self._socket.close() + self._socket_connected = False + self._connected = False + self._authenticated = False + _LOGGER.debug("Closed socket") async def stop(self, force: bool = False) -> None: - if not self.running_event.is_set(): + if self._stopped: raise ValueError - if self._connected and not force: try: - await self.disconnect() + await self._disconnect() except APIConnectionError: pass - await self._close_socket() - - self._stop_event.set() + self._stopped = True + if self._running_task is not None: + self._running_task.cancel() self._cancel_ping() + await self._close_socket() + await self.on_stop() + + async def _on_error(self) -> None: + await self.stop(force=True) async def connect(self) -> None: - if not self.running_event.is_set(): - raise APIConnectionError("You need to call start() first!") - + if self._stopped: + raise APIConnectionError("Connection is closed!") if self._connected: raise APIConnectionError("Already connected!") - self._message_handlers = [] - try: - coro = resolve_ip_address(self._eventloop, self._address, self._port) + coro = resolve_ip_address(self._params.eventloop, self._params.address, + self._params.port) sockaddr = await asyncio.wait_for(coro, 15.0) except APIConnectionError as err: + await self._on_error() raise err except asyncio.TimeoutError: + await self._on_error() raise APIConnectionError("Timeout while resolving IP address") self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.setblocking(False) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - _LOGGER.debug("Connecting to %s:%s (%s)", self._address, self._port, sockaddr) + _LOGGER.debug("Connecting to %s:%s (%s)", self._params.address, self._params.port, sockaddr) try: - coro = self._eventloop.sock_connect(self._socket, sockaddr) + coro = self._params.eventloop.sock_connect(self._socket, sockaddr) await asyncio.wait_for(coro, 15.0) except OSError as err: await self._on_error() raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err)) except asyncio.TimeoutError: + await self._on_error() raise APIConnectionError("Timeout while connecting to {}".format(sockaddr)) - self._sock_reader, self._sock_writer = await asyncio.open_connection(sock=self._socket) - - self._socket_open_event.set() + _LOGGER.debug("Opened socket") + self._socket_reader, self._socket_writer = await asyncio.open_connection(sock=self._socket) + self._params.eventloop.create_task(self.run_forever()) hello = pb.HelloRequest() - hello.client_info = 'Home Assistant' + hello.client_info = self._params.client_info try: - resp = await self._send_message_await_response(hello, pb.HelloResponse) + resp = await self.send_message_await_response(hello, pb.HelloResponse) except APIConnectionError as err: await self._on_error() raise err - _LOGGER.debug("Successfully connected to %s ('%s' API=%s.%s)", self._address, + _LOGGER.debug("Successfully connected to %s ('%s' API=%s.%s)", self._params.address, resp.server_info, resp.api_version_major, resp.api_version_minor) self._connected = True - def _check_connected(self) -> None: - if not self._connected: - raise APIConnectionError("Must be connected!") - async def login(self) -> None: self._check_connected() if self._authenticated: raise APIConnectionError("Already logged in!") connect = pb.ConnectRequest() - if self._password is not None: - connect.password = self._password - resp = await self._send_message_await_response(connect, pb.ConnectResponse) + if self._params.password is not None: + connect.password = self._params.password + resp = await self.send_message_await_response(connect, pb.ConnectResponse) if resp.invalid_password: raise APIConnectionError("Invalid password!") self._authenticated = True - if self.on_login is not None: - await self.on_login() - async def _on_error(self) -> None: - was_connected = self._connected + def _check_connected(self) -> None: + if not self._connected: + raise APIConnectionError("Must be connected!") - await self._close_socket() + @property + def is_connected(self) -> bool: + return self._connected - if was_connected and self.on_disconnect is not None: - await self.on_disconnect() + @property + def is_authenticated(self) -> bool: + return self._authenticated async def _write(self, data: bytes) -> None: _LOGGER.debug("Write: %s", ' '.join('{:02X}'.format(x) for x in data)) try: - self._sock_writer.write(data) - await self._sock_writer.drain() + async with self._write_lock: + self._socket_writer.write(data) + await self._socket_writer.drain() except OSError as err: await self._on_error() raise APIConnectionError("Error while writing data: {}".format(err)) - async def _send_message(self, msg: message.Message) -> None: + async def send_message(self, msg: message.Message) -> None: for message_type, klass in MESSAGE_TYPE_TO_PROTO.items(): if isinstance(msg, klass): break @@ -432,25 +450,34 @@ class APIClient: await self._write(req) self._refresh_ping() - async def _send_message_await_response_complex(self, send_msg: message.Message, - do_append: Callable[[Any], bool], - do_stop: Callable[[Any], bool], - timeout: float = 1.0) -> List[Any]: - fut = self._eventloop.create_future() + async def send_message_callback_response(self, send_msg: message.Message, + on_message: Callable[[Any], None]) -> None: + self._message_handlers.append(on_message) + await self.send_message(send_msg) + + async def send_message_await_response_complex(self, send_msg: message.Message, + do_append: Callable[[Any], bool], + do_stop: Callable[[Any], bool], + timeout: float = 1.0) -> List[Any]: + fut = self._params.eventloop.create_future() responses = [] def on_message(resp): + if fut.done(): + return if do_append(resp): responses.append(resp) if do_stop(resp): fut.set_result(responses) self._message_handlers.append(on_message) - await self._send_message(send_msg) + await self.send_message(send_msg) try: await asyncio.wait_for(fut, timeout) except asyncio.TimeoutError: + if self._stopped: + raise APIConnectionError("Disconnected while waiting for API response!") raise APIConnectionError("Timeout while waiting for API response!") try: @@ -460,22 +487,154 @@ class APIClient: return responses - async def _send_message_await_response(self, - send_msg: message.Message, - response_type: Any, timeout: float = 1.0) -> Any: + async def send_message_await_response(self, + send_msg: message.Message, + response_type: Any, timeout: float = 1.0) -> Any: def is_response(msg): return isinstance(msg, response_type) - res = await self._send_message_await_response_complex( + res = await self.send_message_await_response_complex( send_msg, is_response, is_response, timeout=timeout) if len(res) != 1: raise APIConnectionError("Expected one result, got {}".format(len(res))) return res[0] + async def _recv(self, amount: int) -> bytes: + if amount == 0: + return bytes() + + try: + ret = await self._socket_reader.readexactly(amount) + except (asyncio.IncompleteReadError, OSError) as err: + raise APIConnectionError("Error while receiving data: {}".format(err)) + + return ret + + async def _recv_varint(self) -> int: + raw = bytes() + while not raw or raw[-1] & 0x80: + raw += await self._recv(1) + return cast(int, _bytes_to_varuint(raw)) + + async def _run_once(self) -> None: + preamble = await self._recv(1) + if preamble[0] != 0x00: + raise APIConnectionError("Invalid preamble") + + length = await self._recv_varint() + msg_type = await self._recv_varint() + + raw_msg = await self._recv(length) + if msg_type not in MESSAGE_TYPE_TO_PROTO: + _LOGGER.debug("Skipping message type %s", msg_type) + return + + msg = MESSAGE_TYPE_TO_PROTO[msg_type]() + msg.ParseFromString(raw_msg) + _LOGGER.debug("Got message of type %s: %s", type(msg), msg) + for msg_handler in self._message_handlers[:]: + msg_handler(msg) + await self._handle_internal_messages(msg) + self._refresh_ping() + + async def run_forever(self) -> None: + while True: + try: + await self._run_once() + except APIConnectionError as err: + _LOGGER.info("Error while reading incoming messages for %s: %s", self._params.address, + err) + await self._on_error() + + async def _handle_internal_messages(self, msg: Any) -> None: + if isinstance(msg, pb.DisconnectRequest): + await self.send_message(pb.DisconnectResponse()) + await self.stop(force=True) + elif isinstance(msg, pb.PingRequest): + await self.send_message(pb.PingResponse()) + elif isinstance(msg, pb.GetTimeRequest): + resp = pb.GetTimeResponse() + resp.epoch_seconds = int(time.time()) + await self.send_message(resp) + + async def ping(self) -> None: + self._check_connected() + await self.send_message_await_response(pb.PingRequest(), pb.PingResponse) + + async def _disconnect(self) -> None: + self._check_connected() + + try: + await self.send_message_await_response(pb.DisconnectRequest(), pb.DisconnectResponse) + except APIConnectionError: + pass + + def _check_authenticated(self) -> None: + if not self._authenticated: + raise APIConnectionError("Must login first!") + + +class APIClient: + def __init__(self, eventloop, address: str, port: int, password: str, *, + client_info: str = 'aioesphomeapi', keepalive: float = 15.0): + self._params = ConnectionParams( + eventloop=eventloop, + address=address, + port=port, + password=password, + client_info=client_info, + keepalive=keepalive, + ) + self._connection = None # type: Optional[APIConnection] + + async def connect(self, on_stop=None, login=False): + if self._connection is not None: + raise APIConnectionError("Already connected!") + + connected = False + + async def _on_stop(): + if self._connection is None: + return + self._connection = None + if connected and on_stop is not None: + await on_stop() + + self._connection = APIConnection(self._params, _on_stop) + + try: + await self._connection.connect() + if login: + await self._connection.login() + except APIConnectionError: + await _on_stop() + raise + except Exception as e: + await _on_stop() + raise APIConnectionError("Unexpected error while connecting: {}".format(e)) + + connected = True + + async def disconnect(self, force=False): + if self._connection is None: + return + await self._connection.stop(force=force) + + def _check_connected(self): + if self._connection is None: + raise APIConnectionError("Not connected!") + if not self._connection.is_connected: + raise APIConnectionError("Connection not done!") + + def _check_authenticated(self): + self._check_connected() + if not self._connection.is_authenticated: + raise APIConnectionError("Not authenticated!") + async def device_info(self) -> DeviceInfo: self._check_connected() - resp = await self._send_message_await_response( + resp = await self._connection.send_message_await_response( pb.DeviceInfoRequest(), pb.DeviceInfoResponse) return DeviceInfo( uses_password=resp.uses_password, @@ -487,27 +646,6 @@ class APIClient: has_deep_sleep=resp.has_deep_sleep, ) - async def ping(self) -> None: - self._check_connected() - await self._send_message_await_response(pb.PingRequest(), pb.PingResponse) - return - - async def disconnect(self) -> None: - self._check_connected() - - try: - await self._send_message_await_response(pb.DisconnectRequest(), pb.DisconnectResponse) - except APIConnectionError: - pass - await self._close_socket() - - if self.on_disconnect is not None: - await self.on_disconnect() - - def _check_authenticated(self) -> None: - if not self._authenticated: - raise APIConnectionError("Must login first!") - async def list_entities(self) -> List[Any]: self._check_authenticated() response_types = { @@ -526,7 +664,7 @@ class APIClient: def do_stop(msg): return isinstance(msg, pb.ListEntitiesDoneResponse) - resp = await self._send_message_await_response_complex( + resp = await self._connection.send_message_await_response_complex( pb.ListEntitiesRequest(), do_append, do_stop, timeout=5) entities = [] for msg in resp: @@ -565,8 +703,7 @@ class APIClient: kwargs[key] = getattr(msg, key) on_state(cls(**kwargs)) - self._message_handlers.append(on_msg) - await self._send_message(pb.SubscribeStatesRequest()) + await self._connection.send_message_callback_response(pb.SubscribeStatesRequest(), on_msg) async def subscribe_logs(self, on_log: Callable[[pb.SubscribeLogsResponse], None], log_level=None) -> None: @@ -576,11 +713,10 @@ class APIClient: if isinstance(msg, pb.SubscribeLogsResponse): on_log(msg) - self._message_handlers.append(on_msg) req = pb.SubscribeLogsRequest() if log_level is not None: req.level = log_level - await self._send_message(req) + await self._connection.send_message_callback_response(req, on_msg) async def subscribe_service_calls(self, on_service_call: Callable[[ServiceCall], None]) -> None: self._check_authenticated() @@ -592,8 +728,8 @@ class APIClient: kwargs[key] = getattr(msg, key) on_service_call(ServiceCall(**kwargs)) - self._message_handlers.append(on_msg) - await self._send_message(pb.SubscribeServiceCallsRequest()) + await self._connection.send_message_callback_response(pb.SubscribeServiceCallsRequest(), + on_msg) async def subscribe_home_assistant_states(self, on_state_sub: Callable[[str], None]) -> None: self._check_authenticated() @@ -602,13 +738,13 @@ class APIClient: if isinstance(msg, pb.SubscribeHomeAssistantStateResponse): on_state_sub(msg.entity_id) - self._message_handlers.append(on_msg) - await self._send_message(pb.SubscribeHomeAssistantStatesRequest()) + await self._connection.send_message_callback_response( + pb.SubscribeHomeAssistantStatesRequest(), on_msg) async def send_home_assistant_state(self, entity_id: str, state: str) -> None: self._check_authenticated() - await self._send_message(pb.HomeAssistantStateResponse( + await self._connection.send_message(pb.HomeAssistantStateResponse( entity_id=entity_id, state=state, )) @@ -625,7 +761,7 @@ class APIClient: if command not in COVER_COMMANDS: raise ValueError req.command = command - await self._send_message(req) + await self._connection.send_message(req) async def fan_command(self, key: int, @@ -648,7 +784,7 @@ class APIClient: if oscillating is not None: req.has_oscillating = True req.oscillating = oscillating - await self._send_message(req) + await self._connection.send_message(req) async def light_command(self, key: int, @@ -691,7 +827,7 @@ class APIClient: if effect is not None: req.has_effect = True req.effect = effect - await self._send_message(req) + await self._connection.send_message(req) async def switch_command(self, key: int, @@ -702,73 +838,4 @@ class APIClient: req = pb.SwitchCommandRequest() req.key = key req.state = state - await self._send_message(req) - - async def _recv(self, amount: int) -> bytes: - if amount == 0: - return bytes() - - try: - ret = await self._sock_reader.readexactly(amount) - except (asyncio.IncompleteReadError, OSError) as err: - raise APIConnectionError("Error while receiving data: {}".format(err)) - - return ret - - async def _recv_varint(self) -> int: - raw = bytes() - while not raw or raw[-1] & 0x80: - raw += await self._recv(1) - return cast(int, _bytes_to_varuint(raw)) - - async def _run_once(self) -> None: - await self._socket_open_event.wait() - - preamble = await self._recv(1) - if preamble[0] != 0x00: - raise APIConnectionError("Invalid preamble") - - length = await self._recv_varint() - msg_type = await self._recv_varint() - - raw_msg = await self._recv(length) - if msg_type not in MESSAGE_TYPE_TO_PROTO: - _LOGGER.debug("Skipping message type %s", msg_type) - return - - msg = MESSAGE_TYPE_TO_PROTO[msg_type]() - msg.ParseFromString(raw_msg) - _LOGGER.debug("Got message of type %s: %s", type(msg), msg) - for msg_handler in self._message_handlers[:]: - msg_handler(msg) - await self._handle_internal_messages(msg) - self._refresh_ping() - - async def run_forever(self) -> None: - if self.running_event.is_set(): - raise ValueError - self.running_event.set() - try: - while True: - try: - await self._run_once() - except APIConnectionError as err: - if self._connected: - _LOGGER.debug("Error while reading incoming messages: %s", err) - await self._on_error() - except asyncio.CancelledError: - self.running_event.clear() - raise - - async def _handle_internal_messages(self, msg: Any) -> None: - if isinstance(msg, pb.DisconnectRequest): - await self._send_message(pb.DisconnectResponse()) - await self._close_socket() - if self.on_disconnect is not None: - await self.on_disconnect() - elif isinstance(msg, pb.PingRequest): - await self._send_message(pb.PingResponse()) - elif isinstance(msg, pb.GetTimeRequest): - resp = pb.GetTimeResponse() - resp.epoch_seconds = int(time.time()) - await self._send_message(resp) + await self._connection.send_message(req)