mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Expect a name for connections (#122)
Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
This commit is contained in:
parent
fe298c1f27
commit
9964034f18
@ -8,6 +8,7 @@ from typing import Optional
|
||||
from noise.connection import NoiseConnection # type: ignore
|
||||
|
||||
from .core import (
|
||||
BadNameAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
@ -178,18 +179,35 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
_LOGGER.debug("Received frame %s", frame.hex())
|
||||
return frame
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
async def perform_handshake(self, expected_name: Optional[str]) -> None:
|
||||
await self._write_frame(b"") # ClientHello
|
||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||
|
||||
server_hello = await self._read_frame() # ServerHello
|
||||
if not server_hello:
|
||||
raise HandshakeAPIError("ServerHello is empty")
|
||||
|
||||
# First byte of server hello is the protocol the server chose
|
||||
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
|
||||
# exists.
|
||||
chosen_proto = server_hello[0]
|
||||
if chosen_proto != 0x01:
|
||||
raise HandshakeAPIError(
|
||||
f"Unknown protocol selected by client {chosen_proto}"
|
||||
)
|
||||
|
||||
# Check name matches expected name (for noise sessions, this is done
|
||||
# during hello phase before a connection is set up)
|
||||
# Server name is encoded as a string followed by a zero byte after the chosen proto byte
|
||||
server_name_i = server_hello.find(b"\0", 1)
|
||||
if server_name_i != -1:
|
||||
# server name found, this extension was added in 2022.2
|
||||
server_name = server_hello[1:server_name_i].decode()
|
||||
if expected_name is not None and expected_name != server_name:
|
||||
raise BadNameAPIError(
|
||||
f"Server sent a different name '{server_name}'", server_name
|
||||
)
|
||||
|
||||
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||
self._proto.set_as_initiator()
|
||||
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
|
||||
|
@ -97,6 +97,9 @@ message HelloResponse {
|
||||
// and only exists for debugging/logging purposes.
|
||||
// For example "ESPHome v1.10.0 on ESP8266"
|
||||
string server_info = 3;
|
||||
|
||||
// The name of the server (App.get_name())
|
||||
string name = 4;
|
||||
}
|
||||
|
||||
// Message sent at the beginning of each connection to authenticate the client
|
||||
|
File diff suppressed because one or more lines are too long
@ -137,7 +137,23 @@ class APIClient:
|
||||
keepalive: float = 15.0,
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
noise_psk: Optional[str] = None,
|
||||
expected_name: Optional[str] = None,
|
||||
):
|
||||
"""Create a client, this object is shared across sessions.
|
||||
|
||||
:param address: The address to connect to; for example an IP address
|
||||
or .local name for mDNS lookup.
|
||||
:param port: The port to connect to
|
||||
:param password: Optional password to send to the device for authentication
|
||||
:param client_info: User Agent string to send.
|
||||
:param keepalive: The keepalive time in seconds (ping interval) for detecting stale connections.
|
||||
Every keepalive seconds a ping is sent, if no pong is received the connection is closed.
|
||||
:param zeroconf_instance: Pass a zeroconf instance to use if an mDNS lookup is necessary.
|
||||
:param noise_psk: Encryption preshared key for noise transport encrypted sessions.
|
||||
:param expected_name: Require the devices name to match the given expected name.
|
||||
Can be used to prevent accidentally connecting to a different device if
|
||||
IP passed as address but DHCP reassigned IP.
|
||||
"""
|
||||
self._params = ConnectionParams(
|
||||
address=address,
|
||||
port=port,
|
||||
@ -147,10 +163,19 @@ class APIClient:
|
||||
zeroconf_instance=zeroconf_instance,
|
||||
# treat empty psk string as missing (like password)
|
||||
noise_psk=noise_psk or None,
|
||||
expected_name=expected_name,
|
||||
)
|
||||
self._connection: Optional[APIConnection] = None
|
||||
self._cached_name: Optional[str] = None
|
||||
|
||||
@property
|
||||
def expected_name(self) -> Optional[str]:
|
||||
return self._params.expected_name
|
||||
|
||||
@expected_name.setter
|
||||
def expected_name(self, value: Optional[str]) -> None:
|
||||
self._params.expected_name = value
|
||||
|
||||
@property
|
||||
def address(self) -> str:
|
||||
return self._params.address
|
||||
|
@ -32,6 +32,7 @@ from .api_pb2 import ( # type: ignore
|
||||
from .core import (
|
||||
MESSAGE_TYPE_TO_PROTO,
|
||||
APIConnectionError,
|
||||
BadNameAPIError,
|
||||
InvalidAuthAPIError,
|
||||
PingFailedAPIError,
|
||||
ProtocolAPIError,
|
||||
@ -55,6 +56,7 @@ class ConnectionParams:
|
||||
keepalive: float
|
||||
zeroconf_instance: hr.ZeroconfInstanceType
|
||||
noise_psk: Optional[str]
|
||||
expected_name: Optional[str]
|
||||
|
||||
|
||||
class ConnectionState(enum.Enum):
|
||||
@ -174,7 +176,7 @@ class APIConnection:
|
||||
fh = self._frame_helper = APINoiseFrameHelper(
|
||||
reader, writer, self._params.noise_psk
|
||||
)
|
||||
await fh.perform_handshake()
|
||||
await fh.perform_handshake(self._params.expected_name)
|
||||
|
||||
self._connection_state = ConnectionState.SOCKET_OPENED
|
||||
|
||||
@ -206,6 +208,15 @@ class APIConnection:
|
||||
)
|
||||
raise APIConnectionError("Incompatible API version.")
|
||||
|
||||
if (
|
||||
self._params.expected_name is not None
|
||||
and resp.name != ""
|
||||
and resp.name != self._params.expected_name
|
||||
):
|
||||
raise BadNameAPIError(
|
||||
f"Server sent a different name '{resp.name}'", resp.name
|
||||
)
|
||||
|
||||
self._connection_state = ConnectionState.CONNECTED
|
||||
|
||||
async def _connect_start_ping(self) -> None:
|
||||
|
@ -96,6 +96,14 @@ class HandshakeAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class BadNameAPIError(APIConnectionError):
|
||||
"""Raised when a name received from the remote but does not much the expected name."""
|
||||
|
||||
def __init__(self, msg: str, received_name: str) -> None:
|
||||
super().__init__(msg)
|
||||
self.received_name = received_name
|
||||
|
||||
|
||||
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
|
||||
pass
|
||||
|
||||
|
@ -20,6 +20,7 @@ def connection_params() -> ConnectionParams:
|
||||
keepalive=15.0,
|
||||
zeroconf_instance=None,
|
||||
noise_psk=None,
|
||||
expected_name=None,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user