Expect a name for connections (#122)

Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
This commit is contained in:
Otto Winter 2022-01-20 12:03:36 +01:00 committed by GitHub
parent fe298c1f27
commit 9964034f18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 241 additions and 168 deletions

View File

@ -8,6 +8,7 @@ from typing import Optional
from noise.connection import NoiseConnection # type: ignore
from .core import (
BadNameAPIError,
HandshakeAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
@ -178,18 +179,35 @@ class APINoiseFrameHelper(APIFrameHelper):
_LOGGER.debug("Received frame %s", frame.hex())
return frame
async def perform_handshake(self) -> None:
async def perform_handshake(self, expected_name: Optional[str]) -> None:
await self._write_frame(b"") # ClientHello
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame() # ServerHello
if not server_hello:
raise HandshakeAPIError("ServerHello is empty")
# First byte of server hello is the protocol the server chose
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
# exists.
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
raise HandshakeAPIError(
f"Unknown protocol selected by client {chosen_proto}"
)
# Check name matches expected name (for noise sessions, this is done
# during hello phase before a connection is set up)
# Server name is encoded as a string followed by a zero byte after the chosen proto byte
server_name_i = server_hello.find(b"\0", 1)
if server_name_i != -1:
# server name found, this extension was added in 2022.2
server_name = server_hello[1:server_name_i].decode()
if expected_name is not None and expected_name != server_name:
raise BadNameAPIError(
f"Server sent a different name '{server_name}'", server_name
)
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator()
self._proto.set_psks(_decode_noise_psk(self._noise_psk))

View File

@ -97,6 +97,9 @@ message HelloResponse {
// and only exists for debugging/logging purposes.
// For example "ESPHome v1.10.0 on ESP8266"
string server_info = 3;
// The name of the server (App.get_name())
string name = 4;
}
// Message sent at the beginning of each connection to authenticate the client

File diff suppressed because one or more lines are too long

View File

@ -137,7 +137,23 @@ class APIClient:
keepalive: float = 15.0,
zeroconf_instance: ZeroconfInstanceType = None,
noise_psk: Optional[str] = None,
expected_name: Optional[str] = None,
):
"""Create a client, this object is shared across sessions.
:param address: The address to connect to; for example an IP address
or .local name for mDNS lookup.
:param port: The port to connect to
:param password: Optional password to send to the device for authentication
:param client_info: User Agent string to send.
:param keepalive: The keepalive time in seconds (ping interval) for detecting stale connections.
Every keepalive seconds a ping is sent, if no pong is received the connection is closed.
:param zeroconf_instance: Pass a zeroconf instance to use if an mDNS lookup is necessary.
:param noise_psk: Encryption preshared key for noise transport encrypted sessions.
:param expected_name: Require the devices name to match the given expected name.
Can be used to prevent accidentally connecting to a different device if
IP passed as address but DHCP reassigned IP.
"""
self._params = ConnectionParams(
address=address,
port=port,
@ -147,10 +163,19 @@ class APIClient:
zeroconf_instance=zeroconf_instance,
# treat empty psk string as missing (like password)
noise_psk=noise_psk or None,
expected_name=expected_name,
)
self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None
@property
def expected_name(self) -> Optional[str]:
return self._params.expected_name
@expected_name.setter
def expected_name(self, value: Optional[str]) -> None:
self._params.expected_name = value
@property
def address(self) -> str:
return self._params.address

View File

@ -32,6 +32,7 @@ from .api_pb2 import ( # type: ignore
from .core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
BadNameAPIError,
InvalidAuthAPIError,
PingFailedAPIError,
ProtocolAPIError,
@ -55,6 +56,7 @@ class ConnectionParams:
keepalive: float
zeroconf_instance: hr.ZeroconfInstanceType
noise_psk: Optional[str]
expected_name: Optional[str]
class ConnectionState(enum.Enum):
@ -174,7 +176,7 @@ class APIConnection:
fh = self._frame_helper = APINoiseFrameHelper(
reader, writer, self._params.noise_psk
)
await fh.perform_handshake()
await fh.perform_handshake(self._params.expected_name)
self._connection_state = ConnectionState.SOCKET_OPENED
@ -206,6 +208,15 @@ class APIConnection:
)
raise APIConnectionError("Incompatible API version.")
if (
self._params.expected_name is not None
and resp.name != ""
and resp.name != self._params.expected_name
):
raise BadNameAPIError(
f"Server sent a different name '{resp.name}'", resp.name
)
self._connection_state = ConnectionState.CONNECTED
async def _connect_start_ping(self) -> None:

View File

@ -96,6 +96,14 @@ class HandshakeAPIError(APIConnectionError):
pass
class BadNameAPIError(APIConnectionError):
"""Raised when a name received from the remote but does not much the expected name."""
def __init__(self, msg: str, received_name: str) -> None:
super().__init__(msg)
self.received_name = received_name
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
pass

View File

@ -20,6 +20,7 @@ def connection_params() -> ConnectionParams:
keepalive=15.0,
zeroconf_instance=None,
noise_psk=None,
expected_name=None,
)