Fix failure to reconnect when the process task raises an exception during decoding a protobuf message (#339)

This commit is contained in:
J. Nick Koston 2022-12-13 10:31:12 -10:00 committed by GitHub
parent 4599e75e9d
commit b34664e44c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 21 deletions

View File

@ -1161,7 +1161,7 @@ message BluetoothLEAdvertisementResponse {
option (no_delay) = true; option (no_delay) = true;
uint64 address = 1; uint64 address = 1;
string name = 2; bytes name = 2;
sint32 rssi = 3; sint32 rssi = 3;
repeated string service_uuids = 4; repeated string service_uuids = 4;

File diff suppressed because one or more lines are too long

View File

@ -245,8 +245,9 @@ class APIClient:
if on_stop is not None: if on_stop is not None:
await on_stop() await on_stop()
self._connection = APIConnection(self._params, _on_stop) self._connection = APIConnection(
self._connection.log_name = self._log_name self._params, _on_stop, log_name=self._log_name
)
try: try:
await self._connection.connect(login=login) await self._connection.connect(login=login)
@ -271,7 +272,10 @@ class APIClient:
if self._connection is None: if self._connection is None:
raise APIConnectionError(f"Not connected to {self._log_name}!") raise APIConnectionError(f"Not connected to {self._log_name}!")
if not self._connection.is_connected: if not self._connection.is_connected:
raise APIConnectionError(f"Connection not done for {self._log_name}!") raise APIConnectionError(
f"Connection not done for {self._log_name}; "
f"current state is {self._connection.connection_state}!"
)
def _check_authenticated(self) -> None: def _check_authenticated(self) -> None:
self._check_connected() self._check_connected()
@ -287,7 +291,7 @@ class APIClient:
) )
info = DeviceInfo.from_pb(resp) info = DeviceInfo.from_pb(resp)
self._cached_name = info.name self._cached_name = info.name
self._connection.log_name = self._log_name self._connection.set_log_name(self._log_name)
return info return info
async def list_entities_services( async def list_entities_services(

View File

@ -72,8 +72,8 @@ class ConnectionState(enum.Enum):
# Internal state, # Internal state,
SOCKET_OPENED = 1 SOCKET_OPENED = 1
# The connection has been established, data can be exchanged # The connection has been established, data can be exchanged
CONNECTED = 1 CONNECTED = 2
CLOSED = 2 CLOSED = 3
class APIConnection: class APIConnection:
@ -84,8 +84,11 @@ class APIConnection:
""" """
def __init__( def __init__(
self, params: ConnectionParams, on_stop: Callable[[], Coroutine[Any, Any, None]] self,
): params: ConnectionParams,
on_stop: Callable[[], Coroutine[Any, Any, None]],
log_name: Optional[str] = None,
) -> None:
self._params = params self._params = params
self.on_stop = on_stop self.on_stop = on_stop
self._on_stop_called = False self._on_stop_called = False
@ -102,7 +105,7 @@ class APIConnection:
# Message handlers currently subscribed to incoming messages # Message handlers currently subscribed to incoming messages
self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {} self._message_handlers: Dict[Any, List[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs # The friendly name to show for this connection in the logs
self.log_name = params.address self.log_name = log_name or params.address
# Handlers currently subscribed to exceptions in the read task # Handlers currently subscribed to exceptions in the read task
self._read_exception_handlers: List[Callable[[Exception], None]] = [] self._read_exception_handlers: List[Callable[[Exception], None]] = []
@ -116,6 +119,15 @@ class APIConnection:
self._connect_lock: asyncio.Lock = asyncio.Lock() self._connect_lock: asyncio.Lock = asyncio.Lock()
self._cleanup_task: Optional[asyncio.Task[None]] = None self._cleanup_task: Optional[asyncio.Task[None]] = None
@property
def connection_state(self) -> ConnectionState:
"""Return the current connection state."""
return self._connection_state
def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
self.log_name = name
async def _cleanup(self) -> None: async def _cleanup(self) -> None:
"""Clean up all resources that have been allocated. """Clean up all resources that have been allocated.
@ -124,6 +136,8 @@ class APIConnection:
async def _do_cleanup() -> None: async def _do_cleanup() -> None:
async with self._connect_lock: async with self._connect_lock:
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
# Tell the process loop to stop # Tell the process loop to stop
self._to_process.put_nowait(None) self._to_process.put_nowait(None)
@ -133,8 +147,16 @@ class APIConnection:
if self._process_task is not None: if self._process_task is not None:
self._process_task.cancel() self._process_task.cancel()
with suppress(asyncio.CancelledError): try:
await self._process_task await self._process_task
except asyncio.CancelledError:
pass
except Exception as err: # pylint: disable=broad-except
_LOGGER.error(
"Unexpected exception in process task: %s",
err,
exc_info=err,
)
self._process_task = None self._process_task = None
if self._socket is not None: if self._socket is not None:
@ -148,7 +170,7 @@ class APIConnection:
# Note: we don't explicitly cancel the ping/read task here # Note: we don't explicitly cancel the ping/read task here
# That's because if not written right the ping/read task could cancel # That's because if not written right the ping/read task could cancel
# themself, effectively ending execution after _cleanup which may be unexpected # themselves, effectively ending execution after _cleanup which may be unexpected
self._ping_stop_event.set() self._ping_stop_event.set()
if not self._cleanup_task or self._cleanup_task.done(): if not self._cleanup_task or self._cleanup_task.done():
@ -269,7 +291,7 @@ class APIConnection:
async def _connect_start_ping(self) -> None: async def _connect_start_ping(self) -> None:
"""Step 5 in connect process: start the ping loop.""" """Step 5 in connect process: start the ping loop."""
async def func() -> None: async def _keep_alive_loop() -> None:
while True: while True:
if not self._is_socket_open: if not self._is_socket_open:
return return
@ -304,7 +326,7 @@ class APIConnection:
await self._report_fatal_error(err) await self._report_fatal_error(err)
return return
asyncio.create_task(func()) asyncio.create_task(_keep_alive_loop())
async def connect(self, *, login: bool) -> None: async def connect(self, *, login: bool) -> None:
if self._connection_state != ConnectionState.INITIALIZED: if self._connection_state != ConnectionState.INITIALIZED:
@ -354,6 +376,7 @@ class APIConnection:
# After a timeout for connect the connection can no longer be used # After a timeout for connect the connection can no longer be used
# We don't know what state the device may be in after ConnectRequest # We don't know what state the device may be in after ConnectRequest
# was already sent # was already sent
_LOGGER.debug("%s: Login timed out", self.log_name)
await self._report_fatal_error(err) await self._report_fatal_error(err)
raise raise
@ -407,6 +430,7 @@ class APIConnection:
except SocketAPIError as err: # pylint: disable=broad-except except SocketAPIError as err: # pylint: disable=broad-except
# If writing packet fails, we don't know what state the frames # If writing packet fails, we don't know what state the frames
# are in anymore and we have to close the connection # are in anymore and we have to close the connection
_LOGGER.info("%s: Error writing packet: %s", self.log_name, err)
await self._report_fatal_error(err) await self._report_fatal_error(err)
raise raise
@ -522,7 +546,7 @@ class APIConnection:
return res[0] return res[0]
async def _report_fatal_error(self, err: Exception) -> None: async def _report_fatal_error(self, err: Exception) -> None:
"""Report a fatal error that occured during an operation. """Report a fatal error that occurred during an operation.
This should only be called for errors that mean the connection This should only be called for errors that mean the connection
can no longer be used. can no longer be used.
@ -558,6 +582,13 @@ class APIConnection:
try: try:
msg.ParseFromString(pkt.data) msg.ParseFromString(pkt.data)
except Exception as e: except Exception as e:
_LOGGER.info(
"%s: Invalid protobuf message: %s: %s",
self.log_name,
pkt.data,
e,
exc_info=True,
)
await self._report_fatal_error( await self._report_fatal_error(
ProtocolAPIError(f"Invalid protobuf message: {e}") ProtocolAPIError(f"Invalid protobuf message: {e}")
) )
@ -590,10 +621,6 @@ class APIConnection:
to_process.put_nowait(await frame_helper.read_packet_with_lock()) to_process.put_nowait(await frame_helper.read_packet_with_lock())
except SocketClosedAPIError as err: except SocketClosedAPIError as err:
# don't log with info, if closed the site that closed the connection should log # don't log with info, if closed the site that closed the connection should log
if not self._is_socket_open:
# If we expected the socket to be closed, don't log
# the error.
return
_LOGGER.debug( _LOGGER.debug(
"%s: Socket closed, stopping read loop", "%s: Socket closed, stopping read loop",
self.log_name, self.log_name,

View File

@ -815,13 +815,17 @@ def _convert_bluetooth_le_manufacturer_data(
return {int(v.uuid, 16): bytes(v.data if v.data else v.legacy_data) for v in value} # type: ignore return {int(v.uuid, 16): bytes(v.data if v.data else v.legacy_data) for v in value} # type: ignore
def _convert_bluetooth_le_name(value: bytes) -> str:
return value.decode("utf-8", errors="replace")
@dataclass(frozen=True) @dataclass(frozen=True)
class BluetoothLEAdvertisement(APIModelBase): class BluetoothLEAdvertisement(APIModelBase):
address: int = 0 address: int = 0
name: str = ""
rssi: int = 0 rssi: int = 0
address_type: int = 0 address_type: int = 0
name: str = converter_field(default="", converter=_convert_bluetooth_le_name)
service_uuids: List[str] = converter_field( service_uuids: List[str] = converter_field(
default_factory=list, converter=_convert_bluetooth_le_service_uuids default_factory=list, converter=_convert_bluetooth_le_service_uuids
) )