Fix failure to reconnect when the process task raises an exception during decoding a protobuf message (#339)

This commit is contained in:
J. Nick Koston 2022-12-13 10:31:12 -10:00 committed by GitHub
parent 4599e75e9d
commit b34664e44c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 21 deletions

View File

@ -1161,7 +1161,7 @@ message BluetoothLEAdvertisementResponse {
option (no_delay) = true;
uint64 address = 1;
string name = 2;
bytes name = 2;
sint32 rssi = 3;
repeated string service_uuids = 4;

File diff suppressed because one or more lines are too long

View File

@ -245,8 +245,9 @@ class APIClient:
if on_stop is not None:
await on_stop()
self._connection = APIConnection(self._params, _on_stop)
self._connection.log_name = self._log_name
self._connection = APIConnection(
self._params, _on_stop, log_name=self._log_name
)
try:
await self._connection.connect(login=login)
@ -271,7 +272,10 @@ class APIClient:
if self._connection is None:
raise APIConnectionError(f"Not connected to {self._log_name}!")
if not self._connection.is_connected:
raise APIConnectionError(f"Connection not done for {self._log_name}!")
raise APIConnectionError(
f"Connection not done for {self._log_name}; "
f"current state is {self._connection.connection_state}!"
)
def _check_authenticated(self) -> None:
self._check_connected()
@ -287,7 +291,7 @@ class APIClient:
)
info = DeviceInfo.from_pb(resp)
self._cached_name = info.name
self._connection.log_name = self._log_name
self._connection.set_log_name(self._log_name)
return info
async def list_entities_services(

View File

@ -72,8 +72,8 @@ class ConnectionState(enum.Enum):
# Internal state,
SOCKET_OPENED = 1
# The connection has been established, data can be exchanged
CONNECTED = 1
CLOSED = 2
CONNECTED = 2
CLOSED = 3
class APIConnection:
@ -84,8 +84,11 @@ class APIConnection:
"""
def __init__(
self, params: ConnectionParams, on_stop: Callable[[], Coroutine[Any, Any, None]]
):
self,
params: ConnectionParams,
on_stop: Callable[[], Coroutine[Any, Any, None]],
log_name: Optional[str] = None,
) -> None:
self._params = params
self.on_stop = on_stop
self._on_stop_called = False
@ -102,7 +105,7 @@ class APIConnection:
# Message handlers currently subscribed to incoming messages
self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs
self.log_name = params.address
self.log_name = log_name or params.address
# Handlers currently subscribed to exceptions in the read task
self._read_exception_handlers: List[Callable[[Exception], None]] = []
@ -116,6 +119,15 @@ class APIConnection:
self._connect_lock: asyncio.Lock = asyncio.Lock()
self._cleanup_task: Optional[asyncio.Task[None]] = None
@property
def connection_state(self) -> ConnectionState:
"""Return the current connection state."""
return self._connection_state
def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
self.log_name = name
async def _cleanup(self) -> None:
"""Clean up all resources that have been allocated.
@ -124,6 +136,8 @@ class APIConnection:
async def _do_cleanup() -> None:
async with self._connect_lock:
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
# Tell the process loop to stop
self._to_process.put_nowait(None)
@ -133,8 +147,16 @@ class APIConnection:
if self._process_task is not None:
self._process_task.cancel()
with suppress(asyncio.CancelledError):
try:
await self._process_task
except asyncio.CancelledError:
pass
except Exception as err: # pylint: disable=broad-except
_LOGGER.error(
"Unexpected exception in process task: %s",
err,
exc_info=err,
)
self._process_task = None
if self._socket is not None:
@ -148,7 +170,7 @@ class APIConnection:
# Note: we don't explicitly cancel the ping/read task here
# That's because if not written right the ping/read task could cancel
# themself, effectively ending execution after _cleanup which may be unexpected
# themselves, effectively ending execution after _cleanup which may be unexpected
self._ping_stop_event.set()
if not self._cleanup_task or self._cleanup_task.done():
@ -269,7 +291,7 @@ class APIConnection:
async def _connect_start_ping(self) -> None:
"""Step 5 in connect process: start the ping loop."""
async def func() -> None:
async def _keep_alive_loop() -> None:
while True:
if not self._is_socket_open:
return
@ -304,7 +326,7 @@ class APIConnection:
await self._report_fatal_error(err)
return
asyncio.create_task(func())
asyncio.create_task(_keep_alive_loop())
async def connect(self, *, login: bool) -> None:
if self._connection_state != ConnectionState.INITIALIZED:
@ -354,6 +376,7 @@ class APIConnection:
# After a timeout for connect the connection can no longer be used
# We don't know what state the device may be in after ConnectRequest
# was already sent
_LOGGER.debug("%s: Login timed out", self.log_name)
await self._report_fatal_error(err)
raise
@ -407,6 +430,7 @@ class APIConnection:
except SocketAPIError as err: # pylint: disable=broad-except
# If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
await self._report_fatal_error(err)
raise
@ -522,7 +546,7 @@ class APIConnection:
return res[0]
async def _report_fatal_error(self, err: Exception) -> None:
"""Report a fatal error that occured during an operation.
"""Report a fatal error that occurred during an operation.
This should only be called for errors that mean the connection
can no longer be used.
@ -558,6 +582,13 @@ class APIConnection:
try:
msg.ParseFromString(pkt.data)
except Exception as e:
_LOGGER.info(
"%s: Invalid protobuf message: %s: %s",
self.log_name,
pkt.data,
e,
exc_info=True,
)
await self._report_fatal_error(
ProtocolAPIError(f"Invalid protobuf message: {e}")
)
@ -590,10 +621,6 @@ class APIConnection:
to_process.put_nowait(await frame_helper.read_packet_with_lock())
except SocketClosedAPIError as err:
# don't log with info, if closed the site that closed the connection should log
if not self._is_socket_open:
# If we expected the socket to be closed, don't log
# the error.
return
_LOGGER.debug(
"%s: Socket closed, stopping read loop",
self.log_name,

View File

@ -815,13 +815,17 @@ def _convert_bluetooth_le_manufacturer_data(
return {int(v.uuid, 16): bytes(v.data if v.data else v.legacy_data) for v in value} # type: ignore
def _convert_bluetooth_le_name(value: bytes) -> str:
return value.decode("utf-8", errors="replace")
@dataclass(frozen=True)
class BluetoothLEAdvertisement(APIModelBase):
address: int = 0
name: str = ""
rssi: int = 0
address_type: int = 0
name: str = converter_field(default="", converter=_convert_bluetooth_le_name)
service_uuids: List[str] = converter_field(
default_factory=list, converter=_convert_bluetooth_le_service_uuids
)