mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-06 19:18:03 +01:00
Fix failure to reconnect when the process task raises an exception during decoding a protobuf message (#339)
This commit is contained in:
parent
4599e75e9d
commit
b34664e44c
@ -1161,7 +1161,7 @@ message BluetoothLEAdvertisementResponse {
|
|||||||
option (no_delay) = true;
|
option (no_delay) = true;
|
||||||
|
|
||||||
uint64 address = 1;
|
uint64 address = 1;
|
||||||
string name = 2;
|
bytes name = 2;
|
||||||
sint32 rssi = 3;
|
sint32 rssi = 3;
|
||||||
|
|
||||||
repeated string service_uuids = 4;
|
repeated string service_uuids = 4;
|
||||||
|
File diff suppressed because one or more lines are too long
@ -245,8 +245,9 @@ class APIClient:
|
|||||||
if on_stop is not None:
|
if on_stop is not None:
|
||||||
await on_stop()
|
await on_stop()
|
||||||
|
|
||||||
self._connection = APIConnection(self._params, _on_stop)
|
self._connection = APIConnection(
|
||||||
self._connection.log_name = self._log_name
|
self._params, _on_stop, log_name=self._log_name
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._connection.connect(login=login)
|
await self._connection.connect(login=login)
|
||||||
@ -271,7 +272,10 @@ class APIClient:
|
|||||||
if self._connection is None:
|
if self._connection is None:
|
||||||
raise APIConnectionError(f"Not connected to {self._log_name}!")
|
raise APIConnectionError(f"Not connected to {self._log_name}!")
|
||||||
if not self._connection.is_connected:
|
if not self._connection.is_connected:
|
||||||
raise APIConnectionError(f"Connection not done for {self._log_name}!")
|
raise APIConnectionError(
|
||||||
|
f"Connection not done for {self._log_name}; "
|
||||||
|
f"current state is {self._connection.connection_state}!"
|
||||||
|
)
|
||||||
|
|
||||||
def _check_authenticated(self) -> None:
|
def _check_authenticated(self) -> None:
|
||||||
self._check_connected()
|
self._check_connected()
|
||||||
@ -287,7 +291,7 @@ class APIClient:
|
|||||||
)
|
)
|
||||||
info = DeviceInfo.from_pb(resp)
|
info = DeviceInfo.from_pb(resp)
|
||||||
self._cached_name = info.name
|
self._cached_name = info.name
|
||||||
self._connection.log_name = self._log_name
|
self._connection.set_log_name(self._log_name)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
async def list_entities_services(
|
async def list_entities_services(
|
||||||
|
@ -72,8 +72,8 @@ class ConnectionState(enum.Enum):
|
|||||||
# Internal state,
|
# Internal state,
|
||||||
SOCKET_OPENED = 1
|
SOCKET_OPENED = 1
|
||||||
# The connection has been established, data can be exchanged
|
# The connection has been established, data can be exchanged
|
||||||
CONNECTED = 1
|
CONNECTED = 2
|
||||||
CLOSED = 2
|
CLOSED = 3
|
||||||
|
|
||||||
|
|
||||||
class APIConnection:
|
class APIConnection:
|
||||||
@ -84,8 +84,11 @@ class APIConnection:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, params: ConnectionParams, on_stop: Callable[[], Coroutine[Any, Any, None]]
|
self,
|
||||||
):
|
params: ConnectionParams,
|
||||||
|
on_stop: Callable[[], Coroutine[Any, Any, None]],
|
||||||
|
log_name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
self._params = params
|
self._params = params
|
||||||
self.on_stop = on_stop
|
self.on_stop = on_stop
|
||||||
self._on_stop_called = False
|
self._on_stop_called = False
|
||||||
@ -102,7 +105,7 @@ class APIConnection:
|
|||||||
# Message handlers currently subscribed to incoming messages
|
# Message handlers currently subscribed to incoming messages
|
||||||
self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
|
self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
|
||||||
# The friendly name to show for this connection in the logs
|
# The friendly name to show for this connection in the logs
|
||||||
self.log_name = params.address
|
self.log_name = log_name or params.address
|
||||||
|
|
||||||
# Handlers currently subscribed to exceptions in the read task
|
# Handlers currently subscribed to exceptions in the read task
|
||||||
self._read_exception_handlers: List[Callable[[Exception], None]] = []
|
self._read_exception_handlers: List[Callable[[Exception], None]] = []
|
||||||
@ -116,6 +119,15 @@ class APIConnection:
|
|||||||
self._connect_lock: asyncio.Lock = asyncio.Lock()
|
self._connect_lock: asyncio.Lock = asyncio.Lock()
|
||||||
self._cleanup_task: Optional[asyncio.Task[None]] = None
|
self._cleanup_task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection_state(self) -> ConnectionState:
|
||||||
|
"""Return the current connection state."""
|
||||||
|
return self._connection_state
|
||||||
|
|
||||||
|
def set_log_name(self, name: str) -> None:
|
||||||
|
"""Set the friendly log name for this connection."""
|
||||||
|
self.log_name = name
|
||||||
|
|
||||||
async def _cleanup(self) -> None:
|
async def _cleanup(self) -> None:
|
||||||
"""Clean up all resources that have been allocated.
|
"""Clean up all resources that have been allocated.
|
||||||
|
|
||||||
@ -124,6 +136,8 @@ class APIConnection:
|
|||||||
|
|
||||||
async def _do_cleanup() -> None:
|
async def _do_cleanup() -> None:
|
||||||
async with self._connect_lock:
|
async with self._connect_lock:
|
||||||
|
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
||||||
|
|
||||||
# Tell the process loop to stop
|
# Tell the process loop to stop
|
||||||
self._to_process.put_nowait(None)
|
self._to_process.put_nowait(None)
|
||||||
|
|
||||||
@ -133,8 +147,16 @@ class APIConnection:
|
|||||||
|
|
||||||
if self._process_task is not None:
|
if self._process_task is not None:
|
||||||
self._process_task.cancel()
|
self._process_task.cancel()
|
||||||
with suppress(asyncio.CancelledError):
|
try:
|
||||||
await self._process_task
|
await self._process_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
_LOGGER.error(
|
||||||
|
"Unexpected exception in process task: %s",
|
||||||
|
err,
|
||||||
|
exc_info=err,
|
||||||
|
)
|
||||||
self._process_task = None
|
self._process_task = None
|
||||||
|
|
||||||
if self._socket is not None:
|
if self._socket is not None:
|
||||||
@ -148,7 +170,7 @@ class APIConnection:
|
|||||||
|
|
||||||
# Note: we don't explicitly cancel the ping/read task here
|
# Note: we don't explicitly cancel the ping/read task here
|
||||||
# That's because if not written right the ping/read task could cancel
|
# That's because if not written right the ping/read task could cancel
|
||||||
# themself, effectively ending execution after _cleanup which may be unexpected
|
# themselves, effectively ending execution after _cleanup which may be unexpected
|
||||||
self._ping_stop_event.set()
|
self._ping_stop_event.set()
|
||||||
|
|
||||||
if not self._cleanup_task or self._cleanup_task.done():
|
if not self._cleanup_task or self._cleanup_task.done():
|
||||||
@ -269,7 +291,7 @@ class APIConnection:
|
|||||||
async def _connect_start_ping(self) -> None:
|
async def _connect_start_ping(self) -> None:
|
||||||
"""Step 5 in connect process: start the ping loop."""
|
"""Step 5 in connect process: start the ping loop."""
|
||||||
|
|
||||||
async def func() -> None:
|
async def _keep_alive_loop() -> None:
|
||||||
while True:
|
while True:
|
||||||
if not self._is_socket_open:
|
if not self._is_socket_open:
|
||||||
return
|
return
|
||||||
@ -304,7 +326,7 @@ class APIConnection:
|
|||||||
await self._report_fatal_error(err)
|
await self._report_fatal_error(err)
|
||||||
return
|
return
|
||||||
|
|
||||||
asyncio.create_task(func())
|
asyncio.create_task(_keep_alive_loop())
|
||||||
|
|
||||||
async def connect(self, *, login: bool) -> None:
|
async def connect(self, *, login: bool) -> None:
|
||||||
if self._connection_state != ConnectionState.INITIALIZED:
|
if self._connection_state != ConnectionState.INITIALIZED:
|
||||||
@ -354,6 +376,7 @@ class APIConnection:
|
|||||||
# After a timeout for connect the connection can no longer be used
|
# After a timeout for connect the connection can no longer be used
|
||||||
# We don't know what state the device may be in after ConnectRequest
|
# We don't know what state the device may be in after ConnectRequest
|
||||||
# was already sent
|
# was already sent
|
||||||
|
_LOGGER.debug("%s: Login timed out", self.log_name)
|
||||||
await self._report_fatal_error(err)
|
await self._report_fatal_error(err)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -407,6 +430,7 @@ class APIConnection:
|
|||||||
except SocketAPIError as err: # pylint: disable=broad-except
|
except SocketAPIError as err: # pylint: disable=broad-except
|
||||||
# If writing packet fails, we don't know what state the frames
|
# If writing packet fails, we don't know what state the frames
|
||||||
# are in anymore and we have to close the connection
|
# are in anymore and we have to close the connection
|
||||||
|
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
|
||||||
await self._report_fatal_error(err)
|
await self._report_fatal_error(err)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -522,7 +546,7 @@ class APIConnection:
|
|||||||
return res[0]
|
return res[0]
|
||||||
|
|
||||||
async def _report_fatal_error(self, err: Exception) -> None:
|
async def _report_fatal_error(self, err: Exception) -> None:
|
||||||
"""Report a fatal error that occured during an operation.
|
"""Report a fatal error that occurred during an operation.
|
||||||
|
|
||||||
This should only be called for errors that mean the connection
|
This should only be called for errors that mean the connection
|
||||||
can no longer be used.
|
can no longer be used.
|
||||||
@ -558,6 +582,13 @@ class APIConnection:
|
|||||||
try:
|
try:
|
||||||
msg.ParseFromString(pkt.data)
|
msg.ParseFromString(pkt.data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
_LOGGER.info(
|
||||||
|
"%s: Invalid protobuf message: %s: %s",
|
||||||
|
self.log_name,
|
||||||
|
pkt.data,
|
||||||
|
e,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
await self._report_fatal_error(
|
await self._report_fatal_error(
|
||||||
ProtocolAPIError(f"Invalid protobuf message: {e}")
|
ProtocolAPIError(f"Invalid protobuf message: {e}")
|
||||||
)
|
)
|
||||||
@ -590,10 +621,6 @@ class APIConnection:
|
|||||||
to_process.put_nowait(await frame_helper.read_packet_with_lock())
|
to_process.put_nowait(await frame_helper.read_packet_with_lock())
|
||||||
except SocketClosedAPIError as err:
|
except SocketClosedAPIError as err:
|
||||||
# don't log with info, if closed the site that closed the connection should log
|
# don't log with info, if closed the site that closed the connection should log
|
||||||
if not self._is_socket_open:
|
|
||||||
# If we expected the socket to be closed, don't log
|
|
||||||
# the error.
|
|
||||||
return
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Socket closed, stopping read loop",
|
"%s: Socket closed, stopping read loop",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
|
@ -815,13 +815,17 @@ def _convert_bluetooth_le_manufacturer_data(
|
|||||||
return {int(v.uuid, 16): bytes(v.data if v.data else v.legacy_data) for v in value} # type: ignore
|
return {int(v.uuid, 16): bytes(v.data if v.data else v.legacy_data) for v in value} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_bluetooth_le_name(value: bytes) -> str:
|
||||||
|
return value.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BluetoothLEAdvertisement(APIModelBase):
|
class BluetoothLEAdvertisement(APIModelBase):
|
||||||
address: int = 0
|
address: int = 0
|
||||||
name: str = ""
|
|
||||||
rssi: int = 0
|
rssi: int = 0
|
||||||
address_type: int = 0
|
address_type: int = 0
|
||||||
|
|
||||||
|
name: str = converter_field(default="", converter=_convert_bluetooth_le_name)
|
||||||
service_uuids: List[str] = converter_field(
|
service_uuids: List[str] = converter_field(
|
||||||
default_factory=list, converter=_convert_bluetooth_le_service_uuids
|
default_factory=list, converter=_convert_bluetooth_le_service_uuids
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user