mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Restructure API for prevent race conditions
This commit is contained in:
parent
e0c02ffc5c
commit
fadcdb1501
@ -255,168 +255,186 @@ class ServiceCall:
|
|||||||
variables = attr.ib(type=Dict[str, str], converter=dict)
|
variables = attr.ib(type=Dict[str, str], converter=dict)
|
||||||
|
|
||||||
|
|
||||||
class APIClient:
|
@attr.s
|
||||||
def __init__(self, eventloop, address: str, port: int, password: str):
|
class State:
|
||||||
self._eventloop = eventloop # type: asyncio.events.AbstractEventLoop
|
running = attr.ib(type=bool)
|
||||||
self._address = address # type: str
|
stopped = attr.ib(type=bool)
|
||||||
self._port = port # type: int
|
socket = attr.ib(type=Optional[socket.socket])
|
||||||
self._password = password # type: Optional[str]
|
socket_reader = attr.ib(type=Optional[asyncio.StreamReader])
|
||||||
self._socket = None # type: Optional[socket.socket]
|
socket_writer = attr.ib(type=Optional[asyncio.StreamWriter])
|
||||||
self._connected = False # type: bool
|
socket_open = attr.ib(type=bool)
|
||||||
self._authenticated = False # type: bool
|
connected = attr.ib(type=bool)
|
||||||
self._message_handlers = [] # type: List[Callable[[message], None]]
|
authenticated = attr.ib(type=bool)
|
||||||
self._keepalive = 60 # type: Union[float, int]
|
|
||||||
self._ping_timer = None # type: Optional[asyncio.Future]
|
|
||||||
self.on_disconnect = None
|
|
||||||
self.on_login = None
|
|
||||||
self.running_event = asyncio.Event()
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._socket_open_event = asyncio.Event()
|
|
||||||
self._sock_reader = None # type: Optional[asyncio.StreamReader]
|
|
||||||
self._sock_writer = None # type: Optional[asyncio.StreamWriter]
|
|
||||||
|
|
||||||
self._refresh_ping()
|
|
||||||
|
@attr.s
|
||||||
|
class ConnectionParams:
|
||||||
|
eventloop = attr.ib(type=asyncio.events.AbstractEventLoop)
|
||||||
|
address = attr.ib(type=str)
|
||||||
|
port = attr.ib(type=int)
|
||||||
|
password = attr.ib(type=Optional[str])
|
||||||
|
client_info = attr.ib(type=str)
|
||||||
|
keepalive = attr.ib(type=float)
|
||||||
|
|
||||||
|
|
||||||
|
class APIConnection:
|
||||||
|
def __init__(self, params: ConnectionParams, on_stop):
|
||||||
|
self._params = params
|
||||||
|
self.on_stop = on_stop
|
||||||
|
self._stopped = False
|
||||||
|
self._socket = None # type: Optional[socket.socket]
|
||||||
|
self._socket_reader = None # type: Optional[asyncio.StreamReader]
|
||||||
|
self._socket_writer = None # type: Optional[asyncio.StreamWriter]
|
||||||
|
self._write_lock = asyncio.Lock()
|
||||||
|
self._connected = False
|
||||||
|
self._authenticated = False
|
||||||
|
self._socket_connected = False
|
||||||
|
self._state_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
self._ping_timer = None # type: Optional[asyncio.Task]
|
||||||
|
self._message_handlers = [] # type: List[Callable[[message], None]]
|
||||||
|
|
||||||
|
self._running_task = None # type: Optional[asyncio.Task]
|
||||||
|
|
||||||
def _refresh_ping(self) -> None:
|
def _refresh_ping(self) -> None:
|
||||||
if self._ping_timer is not None:
|
self._cancel_ping()
|
||||||
self._ping_timer.cancel()
|
|
||||||
self._ping_timer = None
|
|
||||||
|
|
||||||
async def func() -> None:
|
async def func() -> None:
|
||||||
await asyncio.sleep(self._keepalive)
|
await asyncio.sleep(self._params.keepalive)
|
||||||
self._ping_timer = None
|
|
||||||
|
|
||||||
if self._connected:
|
if self._connected:
|
||||||
try:
|
try:
|
||||||
await self.ping()
|
await self.ping()
|
||||||
except APIConnectionError:
|
except APIConnectionError:
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
|
else:
|
||||||
|
self._refresh_ping()
|
||||||
|
|
||||||
self._refresh_ping()
|
self._ping_timer = self._params.eventloop.create_task(func())
|
||||||
|
|
||||||
self._ping_timer = asyncio.ensure_future(func(), loop=self._eventloop)
|
|
||||||
|
|
||||||
async def _close_socket(self) -> None:
|
|
||||||
if self._socket is not None:
|
|
||||||
self._socket.close()
|
|
||||||
self._socket = None
|
|
||||||
if self._sock_writer is not None:
|
|
||||||
self._sock_writer.close()
|
|
||||||
if hasattr(self._sock_writer, 'wait_closed'):
|
|
||||||
await self._sock_writer.wait_closed()
|
|
||||||
self._sock_writer = None
|
|
||||||
self._sock_reader = None
|
|
||||||
self._socket_open_event.clear()
|
|
||||||
self._connected = False
|
|
||||||
self._authenticated = False
|
|
||||||
|
|
||||||
def _cancel_ping(self) -> None:
|
def _cancel_ping(self) -> None:
|
||||||
if self._ping_timer is not None:
|
if self._ping_timer is not None:
|
||||||
self._ping_timer.cancel()
|
self._ping_timer.cancel()
|
||||||
self._ping_timer = None
|
self._ping_timer = None
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def _close_socket(self) -> None:
|
||||||
self._eventloop.create_task(self.run_forever())
|
if not self._socket_connected:
|
||||||
await self.running_event.wait()
|
return
|
||||||
|
async with self._write_lock:
|
||||||
|
self._socket_writer.close()
|
||||||
|
if hasattr(self._socket_writer, 'wait_closed'):
|
||||||
|
await self._socket_writer.wait_closed()
|
||||||
|
self._socket_writer = None
|
||||||
|
self._socket_reader = None
|
||||||
|
if self._socket is not None:
|
||||||
|
self._socket.close()
|
||||||
|
self._socket_connected = False
|
||||||
|
self._connected = False
|
||||||
|
self._authenticated = False
|
||||||
|
_LOGGER.debug("Closed socket")
|
||||||
|
|
||||||
async def stop(self, force: bool = False) -> None:
|
async def stop(self, force: bool = False) -> None:
|
||||||
if not self.running_event.is_set():
|
if self._stopped:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
if self._connected and not force:
|
if self._connected and not force:
|
||||||
try:
|
try:
|
||||||
await self.disconnect()
|
await self._disconnect()
|
||||||
except APIConnectionError:
|
except APIConnectionError:
|
||||||
pass
|
pass
|
||||||
await self._close_socket()
|
self._stopped = True
|
||||||
|
if self._running_task is not None:
|
||||||
self._stop_event.set()
|
self._running_task.cancel()
|
||||||
self._cancel_ping()
|
self._cancel_ping()
|
||||||
|
await self._close_socket()
|
||||||
|
await self.on_stop()
|
||||||
|
|
||||||
|
async def _on_error(self) -> None:
|
||||||
|
await self.stop(force=True)
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
if not self.running_event.is_set():
|
if self._stopped:
|
||||||
raise APIConnectionError("You need to call start() first!")
|
raise APIConnectionError("Connection is closed!")
|
||||||
|
|
||||||
if self._connected:
|
if self._connected:
|
||||||
raise APIConnectionError("Already connected!")
|
raise APIConnectionError("Already connected!")
|
||||||
|
|
||||||
self._message_handlers = []
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
coro = resolve_ip_address(self._eventloop, self._address, self._port)
|
coro = resolve_ip_address(self._params.eventloop, self._params.address,
|
||||||
|
self._params.port)
|
||||||
sockaddr = await asyncio.wait_for(coro, 15.0)
|
sockaddr = await asyncio.wait_for(coro, 15.0)
|
||||||
except APIConnectionError as err:
|
except APIConnectionError as err:
|
||||||
|
await self._on_error()
|
||||||
raise err
|
raise err
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
await self._on_error()
|
||||||
raise APIConnectionError("Timeout while resolving IP address")
|
raise APIConnectionError("Timeout while resolving IP address")
|
||||||
|
|
||||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self._socket.setblocking(False)
|
self._socket.setblocking(False)
|
||||||
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||||
|
|
||||||
_LOGGER.debug("Connecting to %s:%s (%s)", self._address, self._port, sockaddr)
|
_LOGGER.debug("Connecting to %s:%s (%s)", self._params.address, self._params.port, sockaddr)
|
||||||
try:
|
try:
|
||||||
coro = self._eventloop.sock_connect(self._socket, sockaddr)
|
coro = self._params.eventloop.sock_connect(self._socket, sockaddr)
|
||||||
await asyncio.wait_for(coro, 15.0)
|
await asyncio.wait_for(coro, 15.0)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err))
|
raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err))
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
await self._on_error()
|
||||||
raise APIConnectionError("Timeout while connecting to {}".format(sockaddr))
|
raise APIConnectionError("Timeout while connecting to {}".format(sockaddr))
|
||||||
|
|
||||||
self._sock_reader, self._sock_writer = await asyncio.open_connection(sock=self._socket)
|
_LOGGER.debug("Opened socket")
|
||||||
|
self._socket_reader, self._socket_writer = await asyncio.open_connection(sock=self._socket)
|
||||||
self._socket_open_event.set()
|
self._params.eventloop.create_task(self.run_forever())
|
||||||
|
|
||||||
hello = pb.HelloRequest()
|
hello = pb.HelloRequest()
|
||||||
hello.client_info = 'Home Assistant'
|
hello.client_info = self._params.client_info
|
||||||
try:
|
try:
|
||||||
resp = await self._send_message_await_response(hello, pb.HelloResponse)
|
resp = await self.send_message_await_response(hello, pb.HelloResponse)
|
||||||
except APIConnectionError as err:
|
except APIConnectionError as err:
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
raise err
|
raise err
|
||||||
_LOGGER.debug("Successfully connected to %s ('%s' API=%s.%s)", self._address,
|
_LOGGER.debug("Successfully connected to %s ('%s' API=%s.%s)", self._params.address,
|
||||||
resp.server_info, resp.api_version_major, resp.api_version_minor)
|
resp.server_info, resp.api_version_major, resp.api_version_minor)
|
||||||
self._connected = True
|
self._connected = True
|
||||||
|
|
||||||
def _check_connected(self) -> None:
|
|
||||||
if not self._connected:
|
|
||||||
raise APIConnectionError("Must be connected!")
|
|
||||||
|
|
||||||
async def login(self) -> None:
|
async def login(self) -> None:
|
||||||
self._check_connected()
|
self._check_connected()
|
||||||
if self._authenticated:
|
if self._authenticated:
|
||||||
raise APIConnectionError("Already logged in!")
|
raise APIConnectionError("Already logged in!")
|
||||||
|
|
||||||
connect = pb.ConnectRequest()
|
connect = pb.ConnectRequest()
|
||||||
if self._password is not None:
|
if self._params.password is not None:
|
||||||
connect.password = self._password
|
connect.password = self._params.password
|
||||||
resp = await self._send_message_await_response(connect, pb.ConnectResponse)
|
resp = await self.send_message_await_response(connect, pb.ConnectResponse)
|
||||||
if resp.invalid_password:
|
if resp.invalid_password:
|
||||||
raise APIConnectionError("Invalid password!")
|
raise APIConnectionError("Invalid password!")
|
||||||
|
|
||||||
self._authenticated = True
|
self._authenticated = True
|
||||||
if self.on_login is not None:
|
|
||||||
await self.on_login()
|
|
||||||
|
|
||||||
async def _on_error(self) -> None:
|
def _check_connected(self) -> None:
|
||||||
was_connected = self._connected
|
if not self._connected:
|
||||||
|
raise APIConnectionError("Must be connected!")
|
||||||
|
|
||||||
await self._close_socket()
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self._connected
|
||||||
|
|
||||||
if was_connected and self.on_disconnect is not None:
|
@property
|
||||||
await self.on_disconnect()
|
def is_authenticated(self) -> bool:
|
||||||
|
return self._authenticated
|
||||||
|
|
||||||
async def _write(self, data: bytes) -> None:
|
async def _write(self, data: bytes) -> None:
|
||||||
_LOGGER.debug("Write: %s", ' '.join('{:02X}'.format(x) for x in data))
|
_LOGGER.debug("Write: %s", ' '.join('{:02X}'.format(x) for x in data))
|
||||||
try:
|
try:
|
||||||
self._sock_writer.write(data)
|
async with self._write_lock:
|
||||||
await self._sock_writer.drain()
|
self._socket_writer.write(data)
|
||||||
|
await self._socket_writer.drain()
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
raise APIConnectionError("Error while writing data: {}".format(err))
|
raise APIConnectionError("Error while writing data: {}".format(err))
|
||||||
|
|
||||||
async def _send_message(self, msg: message.Message) -> None:
|
async def send_message(self, msg: message.Message) -> None:
|
||||||
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
||||||
if isinstance(msg, klass):
|
if isinstance(msg, klass):
|
||||||
break
|
break
|
||||||
@ -432,25 +450,34 @@ class APIClient:
|
|||||||
await self._write(req)
|
await self._write(req)
|
||||||
self._refresh_ping()
|
self._refresh_ping()
|
||||||
|
|
||||||
async def _send_message_await_response_complex(self, send_msg: message.Message,
|
async def send_message_callback_response(self, send_msg: message.Message,
|
||||||
do_append: Callable[[Any], bool],
|
on_message: Callable[[Any], None]) -> None:
|
||||||
do_stop: Callable[[Any], bool],
|
self._message_handlers.append(on_message)
|
||||||
timeout: float = 1.0) -> List[Any]:
|
await self.send_message(send_msg)
|
||||||
fut = self._eventloop.create_future()
|
|
||||||
|
async def send_message_await_response_complex(self, send_msg: message.Message,
|
||||||
|
do_append: Callable[[Any], bool],
|
||||||
|
do_stop: Callable[[Any], bool],
|
||||||
|
timeout: float = 1.0) -> List[Any]:
|
||||||
|
fut = self._params.eventloop.create_future()
|
||||||
responses = []
|
responses = []
|
||||||
|
|
||||||
def on_message(resp):
|
def on_message(resp):
|
||||||
|
if fut.done():
|
||||||
|
return
|
||||||
if do_append(resp):
|
if do_append(resp):
|
||||||
responses.append(resp)
|
responses.append(resp)
|
||||||
if do_stop(resp):
|
if do_stop(resp):
|
||||||
fut.set_result(responses)
|
fut.set_result(responses)
|
||||||
|
|
||||||
self._message_handlers.append(on_message)
|
self._message_handlers.append(on_message)
|
||||||
await self._send_message(send_msg)
|
await self.send_message(send_msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(fut, timeout)
|
await asyncio.wait_for(fut, timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
if self._stopped:
|
||||||
|
raise APIConnectionError("Disconnected while waiting for API response!")
|
||||||
raise APIConnectionError("Timeout while waiting for API response!")
|
raise APIConnectionError("Timeout while waiting for API response!")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -460,22 +487,154 @@ class APIClient:
|
|||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
async def _send_message_await_response(self,
|
async def send_message_await_response(self,
|
||||||
send_msg: message.Message,
|
send_msg: message.Message,
|
||||||
response_type: Any, timeout: float = 1.0) -> Any:
|
response_type: Any, timeout: float = 1.0) -> Any:
|
||||||
def is_response(msg):
|
def is_response(msg):
|
||||||
return isinstance(msg, response_type)
|
return isinstance(msg, response_type)
|
||||||
|
|
||||||
res = await self._send_message_await_response_complex(
|
res = await self.send_message_await_response_complex(
|
||||||
send_msg, is_response, is_response, timeout=timeout)
|
send_msg, is_response, is_response, timeout=timeout)
|
||||||
if len(res) != 1:
|
if len(res) != 1:
|
||||||
raise APIConnectionError("Expected one result, got {}".format(len(res)))
|
raise APIConnectionError("Expected one result, got {}".format(len(res)))
|
||||||
|
|
||||||
return res[0]
|
return res[0]
|
||||||
|
|
||||||
|
async def _recv(self, amount: int) -> bytes:
|
||||||
|
if amount == 0:
|
||||||
|
return bytes()
|
||||||
|
|
||||||
|
try:
|
||||||
|
ret = await self._socket_reader.readexactly(amount)
|
||||||
|
except (asyncio.IncompleteReadError, OSError) as err:
|
||||||
|
raise APIConnectionError("Error while receiving data: {}".format(err))
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def _recv_varint(self) -> int:
|
||||||
|
raw = bytes()
|
||||||
|
while not raw or raw[-1] & 0x80:
|
||||||
|
raw += await self._recv(1)
|
||||||
|
return cast(int, _bytes_to_varuint(raw))
|
||||||
|
|
||||||
|
async def _run_once(self) -> None:
|
||||||
|
preamble = await self._recv(1)
|
||||||
|
if preamble[0] != 0x00:
|
||||||
|
raise APIConnectionError("Invalid preamble")
|
||||||
|
|
||||||
|
length = await self._recv_varint()
|
||||||
|
msg_type = await self._recv_varint()
|
||||||
|
|
||||||
|
raw_msg = await self._recv(length)
|
||||||
|
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
||||||
|
_LOGGER.debug("Skipping message type %s", msg_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
||||||
|
msg.ParseFromString(raw_msg)
|
||||||
|
_LOGGER.debug("Got message of type %s: %s", type(msg), msg)
|
||||||
|
for msg_handler in self._message_handlers[:]:
|
||||||
|
msg_handler(msg)
|
||||||
|
await self._handle_internal_messages(msg)
|
||||||
|
self._refresh_ping()
|
||||||
|
|
||||||
|
async def run_forever(self) -> None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await self._run_once()
|
||||||
|
except APIConnectionError as err:
|
||||||
|
_LOGGER.info("Error while reading incoming messages for %s: %s", self._params.address,
|
||||||
|
err)
|
||||||
|
await self._on_error()
|
||||||
|
|
||||||
|
async def _handle_internal_messages(self, msg: Any) -> None:
|
||||||
|
if isinstance(msg, pb.DisconnectRequest):
|
||||||
|
await self.send_message(pb.DisconnectResponse())
|
||||||
|
await self.stop(force=True)
|
||||||
|
elif isinstance(msg, pb.PingRequest):
|
||||||
|
await self.send_message(pb.PingResponse())
|
||||||
|
elif isinstance(msg, pb.GetTimeRequest):
|
||||||
|
resp = pb.GetTimeResponse()
|
||||||
|
resp.epoch_seconds = int(time.time())
|
||||||
|
await self.send_message(resp)
|
||||||
|
|
||||||
|
async def ping(self) -> None:
|
||||||
|
self._check_connected()
|
||||||
|
await self.send_message_await_response(pb.PingRequest(), pb.PingResponse)
|
||||||
|
|
||||||
|
async def _disconnect(self) -> None:
|
||||||
|
self._check_connected()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.send_message_await_response(pb.DisconnectRequest(), pb.DisconnectResponse)
|
||||||
|
except APIConnectionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _check_authenticated(self) -> None:
|
||||||
|
if not self._authenticated:
|
||||||
|
raise APIConnectionError("Must login first!")
|
||||||
|
|
||||||
|
|
||||||
|
class APIClient:
|
||||||
|
def __init__(self, eventloop, address: str, port: int, password: str, *,
|
||||||
|
client_info: str = 'aioesphomeapi', keepalive: float = 15.0):
|
||||||
|
self._params = ConnectionParams(
|
||||||
|
eventloop=eventloop,
|
||||||
|
address=address,
|
||||||
|
port=port,
|
||||||
|
password=password,
|
||||||
|
client_info=client_info,
|
||||||
|
keepalive=keepalive,
|
||||||
|
)
|
||||||
|
self._connection = None # type: Optional[APIConnection]
|
||||||
|
|
||||||
|
async def connect(self, on_stop=None, login=False):
|
||||||
|
if self._connection is not None:
|
||||||
|
raise APIConnectionError("Already connected!")
|
||||||
|
|
||||||
|
connected = False
|
||||||
|
|
||||||
|
async def _on_stop():
|
||||||
|
if self._connection is None:
|
||||||
|
return
|
||||||
|
self._connection = None
|
||||||
|
if connected and on_stop is not None:
|
||||||
|
await on_stop()
|
||||||
|
|
||||||
|
self._connection = APIConnection(self._params, _on_stop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._connection.connect()
|
||||||
|
if login:
|
||||||
|
await self._connection.login()
|
||||||
|
except APIConnectionError:
|
||||||
|
await _on_stop()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await _on_stop()
|
||||||
|
raise APIConnectionError("Unexpected error while connecting: {}".format(e))
|
||||||
|
|
||||||
|
connected = True
|
||||||
|
|
||||||
|
async def disconnect(self, force=False):
|
||||||
|
if self._connection is None:
|
||||||
|
return
|
||||||
|
await self._connection.stop(force=force)
|
||||||
|
|
||||||
|
def _check_connected(self):
|
||||||
|
if self._connection is None:
|
||||||
|
raise APIConnectionError("Not connected!")
|
||||||
|
if not self._connection.is_connected:
|
||||||
|
raise APIConnectionError("Connection not done!")
|
||||||
|
|
||||||
|
def _check_authenticated(self):
|
||||||
|
self._check_connected()
|
||||||
|
if not self._connection.is_authenticated:
|
||||||
|
raise APIConnectionError("Not authenticated!")
|
||||||
|
|
||||||
async def device_info(self) -> DeviceInfo:
|
async def device_info(self) -> DeviceInfo:
|
||||||
self._check_connected()
|
self._check_connected()
|
||||||
resp = await self._send_message_await_response(
|
resp = await self._connection.send_message_await_response(
|
||||||
pb.DeviceInfoRequest(), pb.DeviceInfoResponse)
|
pb.DeviceInfoRequest(), pb.DeviceInfoResponse)
|
||||||
return DeviceInfo(
|
return DeviceInfo(
|
||||||
uses_password=resp.uses_password,
|
uses_password=resp.uses_password,
|
||||||
@ -487,27 +646,6 @@ class APIClient:
|
|||||||
has_deep_sleep=resp.has_deep_sleep,
|
has_deep_sleep=resp.has_deep_sleep,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ping(self) -> None:
|
|
||||||
self._check_connected()
|
|
||||||
await self._send_message_await_response(pb.PingRequest(), pb.PingResponse)
|
|
||||||
return
|
|
||||||
|
|
||||||
async def disconnect(self) -> None:
|
|
||||||
self._check_connected()
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self._send_message_await_response(pb.DisconnectRequest(), pb.DisconnectResponse)
|
|
||||||
except APIConnectionError:
|
|
||||||
pass
|
|
||||||
await self._close_socket()
|
|
||||||
|
|
||||||
if self.on_disconnect is not None:
|
|
||||||
await self.on_disconnect()
|
|
||||||
|
|
||||||
def _check_authenticated(self) -> None:
|
|
||||||
if not self._authenticated:
|
|
||||||
raise APIConnectionError("Must login first!")
|
|
||||||
|
|
||||||
async def list_entities(self) -> List[Any]:
|
async def list_entities(self) -> List[Any]:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
response_types = {
|
response_types = {
|
||||||
@ -526,7 +664,7 @@ class APIClient:
|
|||||||
def do_stop(msg):
|
def do_stop(msg):
|
||||||
return isinstance(msg, pb.ListEntitiesDoneResponse)
|
return isinstance(msg, pb.ListEntitiesDoneResponse)
|
||||||
|
|
||||||
resp = await self._send_message_await_response_complex(
|
resp = await self._connection.send_message_await_response_complex(
|
||||||
pb.ListEntitiesRequest(), do_append, do_stop, timeout=5)
|
pb.ListEntitiesRequest(), do_append, do_stop, timeout=5)
|
||||||
entities = []
|
entities = []
|
||||||
for msg in resp:
|
for msg in resp:
|
||||||
@ -565,8 +703,7 @@ class APIClient:
|
|||||||
kwargs[key] = getattr(msg, key)
|
kwargs[key] = getattr(msg, key)
|
||||||
on_state(cls(**kwargs))
|
on_state(cls(**kwargs))
|
||||||
|
|
||||||
self._message_handlers.append(on_msg)
|
await self._connection.send_message_callback_response(pb.SubscribeStatesRequest(), on_msg)
|
||||||
await self._send_message(pb.SubscribeStatesRequest())
|
|
||||||
|
|
||||||
async def subscribe_logs(self, on_log: Callable[[pb.SubscribeLogsResponse], None],
|
async def subscribe_logs(self, on_log: Callable[[pb.SubscribeLogsResponse], None],
|
||||||
log_level=None) -> None:
|
log_level=None) -> None:
|
||||||
@ -576,11 +713,10 @@ class APIClient:
|
|||||||
if isinstance(msg, pb.SubscribeLogsResponse):
|
if isinstance(msg, pb.SubscribeLogsResponse):
|
||||||
on_log(msg)
|
on_log(msg)
|
||||||
|
|
||||||
self._message_handlers.append(on_msg)
|
|
||||||
req = pb.SubscribeLogsRequest()
|
req = pb.SubscribeLogsRequest()
|
||||||
if log_level is not None:
|
if log_level is not None:
|
||||||
req.level = log_level
|
req.level = log_level
|
||||||
await self._send_message(req)
|
await self._connection.send_message_callback_response(req, on_msg)
|
||||||
|
|
||||||
async def subscribe_service_calls(self, on_service_call: Callable[[ServiceCall], None]) -> None:
|
async def subscribe_service_calls(self, on_service_call: Callable[[ServiceCall], None]) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
@ -592,8 +728,8 @@ class APIClient:
|
|||||||
kwargs[key] = getattr(msg, key)
|
kwargs[key] = getattr(msg, key)
|
||||||
on_service_call(ServiceCall(**kwargs))
|
on_service_call(ServiceCall(**kwargs))
|
||||||
|
|
||||||
self._message_handlers.append(on_msg)
|
await self._connection.send_message_callback_response(pb.SubscribeServiceCallsRequest(),
|
||||||
await self._send_message(pb.SubscribeServiceCallsRequest())
|
on_msg)
|
||||||
|
|
||||||
async def subscribe_home_assistant_states(self, on_state_sub: Callable[[str], None]) -> None:
|
async def subscribe_home_assistant_states(self, on_state_sub: Callable[[str], None]) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
@ -602,13 +738,13 @@ class APIClient:
|
|||||||
if isinstance(msg, pb.SubscribeHomeAssistantStateResponse):
|
if isinstance(msg, pb.SubscribeHomeAssistantStateResponse):
|
||||||
on_state_sub(msg.entity_id)
|
on_state_sub(msg.entity_id)
|
||||||
|
|
||||||
self._message_handlers.append(on_msg)
|
await self._connection.send_message_callback_response(
|
||||||
await self._send_message(pb.SubscribeHomeAssistantStatesRequest())
|
pb.SubscribeHomeAssistantStatesRequest(), on_msg)
|
||||||
|
|
||||||
async def send_home_assistant_state(self, entity_id: str, state: str) -> None:
|
async def send_home_assistant_state(self, entity_id: str, state: str) -> None:
|
||||||
self._check_authenticated()
|
self._check_authenticated()
|
||||||
|
|
||||||
await self._send_message(pb.HomeAssistantStateResponse(
|
await self._connection.send_message(pb.HomeAssistantStateResponse(
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
state=state,
|
state=state,
|
||||||
))
|
))
|
||||||
@ -625,7 +761,7 @@ class APIClient:
|
|||||||
if command not in COVER_COMMANDS:
|
if command not in COVER_COMMANDS:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
req.command = command
|
req.command = command
|
||||||
await self._send_message(req)
|
await self._connection.send_message(req)
|
||||||
|
|
||||||
async def fan_command(self,
|
async def fan_command(self,
|
||||||
key: int,
|
key: int,
|
||||||
@ -648,7 +784,7 @@ class APIClient:
|
|||||||
if oscillating is not None:
|
if oscillating is not None:
|
||||||
req.has_oscillating = True
|
req.has_oscillating = True
|
||||||
req.oscillating = oscillating
|
req.oscillating = oscillating
|
||||||
await self._send_message(req)
|
await self._connection.send_message(req)
|
||||||
|
|
||||||
async def light_command(self,
|
async def light_command(self,
|
||||||
key: int,
|
key: int,
|
||||||
@ -691,7 +827,7 @@ class APIClient:
|
|||||||
if effect is not None:
|
if effect is not None:
|
||||||
req.has_effect = True
|
req.has_effect = True
|
||||||
req.effect = effect
|
req.effect = effect
|
||||||
await self._send_message(req)
|
await self._connection.send_message(req)
|
||||||
|
|
||||||
async def switch_command(self,
|
async def switch_command(self,
|
||||||
key: int,
|
key: int,
|
||||||
@ -702,73 +838,4 @@ class APIClient:
|
|||||||
req = pb.SwitchCommandRequest()
|
req = pb.SwitchCommandRequest()
|
||||||
req.key = key
|
req.key = key
|
||||||
req.state = state
|
req.state = state
|
||||||
await self._send_message(req)
|
await self._connection.send_message(req)
|
||||||
|
|
||||||
async def _recv(self, amount: int) -> bytes:
|
|
||||||
if amount == 0:
|
|
||||||
return bytes()
|
|
||||||
|
|
||||||
try:
|
|
||||||
ret = await self._sock_reader.readexactly(amount)
|
|
||||||
except (asyncio.IncompleteReadError, OSError) as err:
|
|
||||||
raise APIConnectionError("Error while receiving data: {}".format(err))
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
async def _recv_varint(self) -> int:
|
|
||||||
raw = bytes()
|
|
||||||
while not raw or raw[-1] & 0x80:
|
|
||||||
raw += await self._recv(1)
|
|
||||||
return cast(int, _bytes_to_varuint(raw))
|
|
||||||
|
|
||||||
async def _run_once(self) -> None:
|
|
||||||
await self._socket_open_event.wait()
|
|
||||||
|
|
||||||
preamble = await self._recv(1)
|
|
||||||
if preamble[0] != 0x00:
|
|
||||||
raise APIConnectionError("Invalid preamble")
|
|
||||||
|
|
||||||
length = await self._recv_varint()
|
|
||||||
msg_type = await self._recv_varint()
|
|
||||||
|
|
||||||
raw_msg = await self._recv(length)
|
|
||||||
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
|
||||||
_LOGGER.debug("Skipping message type %s", msg_type)
|
|
||||||
return
|
|
||||||
|
|
||||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
|
||||||
msg.ParseFromString(raw_msg)
|
|
||||||
_LOGGER.debug("Got message of type %s: %s", type(msg), msg)
|
|
||||||
for msg_handler in self._message_handlers[:]:
|
|
||||||
msg_handler(msg)
|
|
||||||
await self._handle_internal_messages(msg)
|
|
||||||
self._refresh_ping()
|
|
||||||
|
|
||||||
async def run_forever(self) -> None:
|
|
||||||
if self.running_event.is_set():
|
|
||||||
raise ValueError
|
|
||||||
self.running_event.set()
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await self._run_once()
|
|
||||||
except APIConnectionError as err:
|
|
||||||
if self._connected:
|
|
||||||
_LOGGER.debug("Error while reading incoming messages: %s", err)
|
|
||||||
await self._on_error()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
self.running_event.clear()
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _handle_internal_messages(self, msg: Any) -> None:
|
|
||||||
if isinstance(msg, pb.DisconnectRequest):
|
|
||||||
await self._send_message(pb.DisconnectResponse())
|
|
||||||
await self._close_socket()
|
|
||||||
if self.on_disconnect is not None:
|
|
||||||
await self.on_disconnect()
|
|
||||||
elif isinstance(msg, pb.PingRequest):
|
|
||||||
await self._send_message(pb.PingResponse())
|
|
||||||
elif isinstance(msg, pb.GetTimeRequest):
|
|
||||||
resp = pb.GetTimeResponse()
|
|
||||||
resp.epoch_seconds = int(time.time())
|
|
||||||
await self._send_message(resp)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user