diff --git a/esphome/components/api/__init__.py b/esphome/components/api/__init__.py index 248707469..f149c54ac 100644 --- a/esphome/components/api/__init__.py +++ b/esphome/components/api/__init__.py @@ -120,6 +120,9 @@ async def to_code(config): conf = config[CONF_ENCRYPTION] decoded = base64.b64decode(conf[CONF_KEY]) cg.add(var.set_noise_psk(list(decoded))) + cg.add_define("USE_API_NOISE") + else: + cg.add_define("USE_API_PLAINTEXT") cg.add_define("USE_API") cg.add_global(api_ns.using) diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index eb2372930..3c80f1c1d 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -25,7 +25,14 @@ APIConnection::APIConnection(std::unique_ptr sock, APIServer *pa list_entities_iterator_(parent, this) { this->proto_write_buffer_.reserve(64); +#ifdef USE_API_NOISE helper_ = std::unique_ptr{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())}; +#elif defined(USE_API_PLAINTEXT) + helper_ = std::unique_ptr{new APIPlaintextFrameHelper(std::move(sock))}; +#else +#error "No api frame helper enabled" +#endif + } void APIConnection::start() { this->last_traffic_ = millis(); diff --git a/esphome/components/api/api_frame_helper.cpp b/esphome/components/api/api_frame_helper.cpp index e8055a7be..3463bac40 100644 --- a/esphome/components/api/api_frame_helper.cpp +++ b/esphome/components/api/api_frame_helper.cpp @@ -2,16 +2,12 @@ #include "esphome/core/log.h" #include "esphome/core/helpers.h" +#include "proto.h" namespace esphome { namespace api { static const char *const TAG = "api.socket"; -static const char *const PROLOGUE_INIT = "NoiseAPIInit"; - -// TODO: -// - track errors internally and return if in bad state -// - send error on invalid psk /// Is the given return value (from read/write syscalls) a wouldblock error? bool is_would_block(ssize_t ret) { @@ -21,7 +17,10 @@ bool is_would_block(ssize_t ret) { return ret == 0; } -#define HELPER_LOG(msg, ...) ESP_LOGW(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__) +#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__) + +#ifdef USE_API_NOISE +static const char *const PROLOGUE_INIT = "NoiseAPIInit"; /// Convert a noise error code to a readable error std::string noise_err_to_str(int err) { @@ -93,7 +92,15 @@ APIError APINoiseFrameHelper::loop() { APIError err = state_action_(); if (err == APIError::WOULD_BLOCK) return APIError::OK; - return err; + if (err != APIError::OK) + return err; + if (!tx_buf_.empty()) { + err = try_send_tx_buf_(); + if (err != APIError::OK) { + return err; + } + } + return APIError::OK; } /** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter @@ -180,7 +187,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { } } - // ESP_LOGD(TAG, "Received frame: %s", hexencode(rx_buf_).c_str()); + // uncomment for even more debugging + // ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str()); frame->msg = std::move(rx_buf_); // consume msg rx_buf_ = {}; @@ -239,17 +247,33 @@ APIError APINoiseFrameHelper::state_action_() { // waiting for handshake msg ParsedFrame frame; aerr = try_read_frame_(&frame); + if (aerr == APIError::BAD_INDICATOR) { + send_explicit_handshake_reject_("Bad indicator byte"); + return aerr; + } + if (frame.msg.size() < 1 || frame.msg[0] != 0x00) { + aerr = APIError::BAD_HANDSHAKE_PACKET_LEN; + } + if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { + send_explicit_handshake_reject_("Bad handshake packet len"); + return aerr; + } if (aerr != APIError::OK) return aerr; NoiseBuffer mbuf; noise_buffer_init(mbuf); - noise_buffer_set_input(mbuf, frame.msg.data(), frame.msg.size()); + noise_buffer_set_input(mbuf, frame.msg.data() + 1, frame.msg.size() - 1); err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); if (err != 0) { // TODO: explicit rejection state_ = State::FAILED; HELPER_LOG("noise_handshakestate_read_message failed: %s", noise_err_to_str(err).c_str()); + if (err == NOISE_ERROR_MAC_FAILURE) { + send_explicit_handshake_reject_("Handshake MAC failure"); + } else { + send_explicit_handshake_reject_("Handshake error"); + } return APIError::HANDSHAKESTATE_READ_FAILED; } @@ -257,10 +281,10 @@ APIError APINoiseFrameHelper::state_action_() { if (aerr != APIError::OK) return aerr; } else if (action == NOISE_ACTION_WRITE_MESSAGE) { - uint8_t buffer[64]; + uint8_t buffer[65]; NoiseBuffer mbuf; noise_buffer_init(mbuf); - noise_buffer_set_output(mbuf, buffer, sizeof(buffer)); + noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1); err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr); if (err != 0) { @@ -268,7 +292,9 @@ APIError APINoiseFrameHelper::state_action_() { HELPER_LOG("noise_handshakestate_write_message failed: %s", noise_err_to_str(err).c_str()); return APIError::HANDSHAKESTATE_WRITE_FAILED; } - aerr = write_frame_(mbuf.data, mbuf.size); + buffer[0] = 0x00; // success + + aerr = write_frame_(buffer, mbuf.size + 1); if (aerr != APIError::OK) return aerr; aerr = check_handshake_finished_(); @@ -286,6 +312,15 @@ APIError APINoiseFrameHelper::state_action_() { } return APIError::OK; } +void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) { + std::vector data; + data.reserve(reason.size() + 1); + data[0] = 0x01; // failure + for (size_t i = 0; i < reason.size(); i++) { + data[i+1] = (uint8_t) reason[i]; + } + write_frame_(data.data(), data.size()); +} APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { int err; @@ -397,7 +432,7 @@ APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload } return APIError::OK; } -APIError APINoiseFrameHelper::try_send_raw_() { +APIError APINoiseFrameHelper::try_send_tx_buf_() { // try send from tx_buf while (state_ != State::CLOSED && !tx_buf_.empty()) { ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size()); @@ -428,11 +463,12 @@ APIError APINoiseFrameHelper::write_raw_(const uint8_t *data, size_t len) { int err; APIError aerr; - // ESP_LOGD(TAG, "Sending raw: %s", hexencode(data, len).c_str()); + // uncomment for even more debugging + // ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str()); if (!tx_buf_.empty()) { // try to empty tx_buf_ first - aerr = try_send_raw_(); + aerr = try_send_tx_buf_(); if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK) return aerr; } @@ -582,15 +618,290 @@ APIError APINoiseFrameHelper::shutdown(int how) { } return APIError::OK; } - -} // namespace api -} // namespace esphome - extern "C" { - // declare how noise generates random bytes (here with a good HWRNG based on the RF system) void noise_rand_bytes(void *output, size_t len) { esphome::fill_random(reinterpret_cast(output), len); } - } +#endif // USE_API_NOISE + + +#ifdef USE_API_PLAINTEXT + +/// Initialize the frame helper, returns OK if successful. +APIError APIPlaintextFrameHelper::init() { + if (state_ != State::INITIALIZE || socket_ == nullptr) { + HELPER_LOG("Bad state for init %d", (int) state_); + return APIError::BAD_STATE; + } + int err = socket_->setblocking(false); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("Setting nonblocking failed with errno %d", errno); + return APIError::TCP_NONBLOCKING_FAILED; + } + int enable = 1; + err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("Setting nodelay failed with errno %d", errno); + return APIError::TCP_NODELAY_FAILED; + } + + state_ = State::DATA; + return APIError::OK; +} +/// Not used for plaintext +APIError APIPlaintextFrameHelper::loop() { + if (state_ != State::DATA) { + return APIError::BAD_STATE; + } + // try send pending TX data + if (!tx_buf_.empty()) { + APIError err = try_send_tx_buf_(); + if (err != APIError::OK) { + return err; + } + } + return APIError::OK; +} + +/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter + * + * @param frame: The struct to hold the frame information in. + * msg: store the parsed frame in that struct + * + * @return See APIError + * + * error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. + */ +APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { + int err; + APIError aerr; + + if (frame == nullptr) { + HELPER_LOG("Bad argument for try_read_frame_"); + return APIError::BAD_ARG; + } + + // read header + while (!rx_header_parsed_) { + uint8_t data; + ssize_t received = socket_->read(&data, 1); + if (is_would_block(received)) { + return APIError::WOULD_BLOCK; + } else if (received == -1) { + state_ = State::FAILED; + HELPER_LOG("Socket read failed with errno %d", errno); + return APIError::SOCKET_READ_FAILED; + } + rx_header_buf_.push_back(data); + + // try parse header + if (rx_header_buf_[0] != 0x00) { + state_ = State::FAILED; + HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]); + return APIError::BAD_INDICATOR; + } + + size_t i = 1; + size_t consumed = 0; + auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed); + if (!msg_size_varint.has_value()) { + // not enough data there yet + continue; + } + + i += consumed; + rx_header_parsed_len_ = msg_size_varint->as_uint32(); + + auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed); + if (!msg_type_varint.has_value()) { + // not enough data there yet + continue; + } + rx_header_parsed_type_ = msg_type_varint->as_uint32(); + rx_header_parsed_ = true; + } + // header reading done + + // reserve space for body + if (rx_buf_.size() != rx_header_parsed_len_) { + rx_buf_.resize(rx_header_parsed_len_); + } + + if (rx_buf_len_ < rx_header_parsed_len_) { + // more data to read + size_t to_read = rx_header_parsed_len_ - rx_buf_len_; + ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read); + if (is_would_block(received)) { + return APIError::WOULD_BLOCK; + } else if (received == -1) { + state_ = State::FAILED; + HELPER_LOG("Socket read failed with errno %d", errno); + return APIError::SOCKET_READ_FAILED; + } + rx_buf_len_ += received; + if (received != to_read) { + // not all read + return APIError::WOULD_BLOCK; + } + } + + // uncomment for even more debugging + // ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str()); + frame->msg = std::move(rx_buf_); + // consume msg + rx_buf_ = {}; + rx_buf_len_ = 0; + rx_header_buf_.clear(); + rx_header_parsed_ = false; + return APIError::OK; +} + +APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { + int err; + APIError aerr; + + if (state_ != State::DATA) { + return APIError::WOULD_BLOCK; + } + + ParsedFrame frame; + aerr = try_read_frame_(&frame); + if (aerr != APIError::OK) + return aerr; + + buffer->container = std::move(frame.msg); + buffer->data_offset = 0; + buffer->data_len = rx_header_parsed_len_; + buffer->type = rx_header_parsed_type_; + return APIError::OK; +} +bool APIPlaintextFrameHelper::can_write_without_blocking() { + return state_ == State::DATA && tx_buf_.empty(); +} +APIError APIPlaintextFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) { + int err; + APIError aerr; + + if (state_ != State::DATA) { + return APIError::BAD_STATE; + } + + std::vector header; + header.push_back(0x00); + ProtoVarInt(payload_len).encode(header); + ProtoVarInt(type).encode(header); + + aerr = write_raw_(&header[0], header.size()); + if (aerr != APIError::OK) { + return aerr; + } + aerr = write_raw_(payload, payload_len); + if (aerr != APIError::OK) { + return aerr; + } + return APIError::OK; +} +APIError APIPlaintextFrameHelper::try_send_tx_buf_() { + // try send from tx_buf + while (state_ != State::CLOSED && !tx_buf_.empty()) { + ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size()); + if (sent == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) + break; + state_ = State::FAILED; + HELPER_LOG("Socket write failed with errno %d", errno); + return APIError::SOCKET_WRITE_FAILED; + } else if (sent == 0) { + break; + } + // TODO: inefficient if multiple packets in txbuf + // replace with deque of buffers + tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent); + } + + return APIError::OK; +} +/** Write the data to the socket, or buffer it a write would block + * + * @param data The data to write + * @param len The length of data + */ +APIError APIPlaintextFrameHelper::write_raw_(const uint8_t *data, size_t len) { + if (len == 0) + return APIError::OK; + int err; + APIError aerr; + + // uncomment for even more debugging + // ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str()); + + if (!tx_buf_.empty()) { + // try to empty tx_buf_ first + aerr = try_send_tx_buf_(); + if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK) + return aerr; + } + + if (!tx_buf_.empty()) { + // tx buf not empty, can't write now because then stream would be inconsistent + tx_buf_.insert(tx_buf_.end(), data, data + len); + return APIError::OK; + } + + ssize_t sent = socket_->write(data, len); + if (is_would_block(sent)) { + // operation would block, add buffer to tx_buf + tx_buf_.insert(tx_buf_.end(), data, data + len); + return APIError::OK; + } else if (sent == -1) { + // an error occured + state_ = State::FAILED; + HELPER_LOG("Socket write failed with errno %d", errno); + return APIError::SOCKET_WRITE_FAILED; + } else if (sent != len) { + // partially sent, add end to tx_buf + tx_buf_.insert(tx_buf_.end(), data + sent, data + len); + return APIError::OK; + } + // fully sent + return APIError::OK; +} +APIError APIPlaintextFrameHelper::write_frame_(const uint8_t *data, size_t len) { + APIError aerr; + + uint8_t header[3]; + header[0] = 0x01; // indicator + header[1] = (uint8_t) (len >> 8); + header[2] = (uint8_t) len; + + aerr = write_raw_(header, 3); + if (aerr != APIError::OK) + return aerr; + aerr = write_raw_(data, len); + return aerr; +} + +APIError APIPlaintextFrameHelper::close() { + state_ = State::CLOSED; + int err = socket_->close(); + if (err == -1) + return APIError::CLOSE_FAILED; + return APIError::OK; +} +APIError APIPlaintextFrameHelper::shutdown(int how) { + int err = socket_->shutdown(how); + if (err == -1) + return APIError::SHUTDOWN_FAILED; + if (how == SHUT_RDWR) { + state_ = State::CLOSED; + } + return APIError::OK; +} +#endif // USE_API_PLAINTEXT + +} // namespace api +} // namespace esphome diff --git a/esphome/components/api/api_frame_helper.h b/esphome/components/api/api_frame_helper.h index f36b9a89b..79ded93e8 100644 --- a/esphome/components/api/api_frame_helper.h +++ b/esphome/components/api/api_frame_helper.h @@ -3,7 +3,11 @@ #include #include +#include "esphome/core/defines.h" + +#ifdef USE_API_NOISE #include "noise/protocol.h" +#endif #include "esphome/components/socket/socket.h" #include "api_noise_context.h" @@ -63,39 +67,39 @@ class APIFrameHelper { virtual void set_log_info(std::string info) = 0; }; +#ifdef USE_API_NOISE class APINoiseFrameHelper : public APIFrameHelper { public: APINoiseFrameHelper(std::unique_ptr socket, std::shared_ptr ctx) : socket_(std::move(socket)), ctx_(ctx) {} ~APINoiseFrameHelper(); - APIError init(); - APIError loop(); - APIError read_packet(ReadPacketBuffer *buffer); - bool can_write_without_blocking(); - APIError write_packet(uint16_t type, const uint8_t *data, size_t len); - std::string getpeername() { + APIError init() override; + APIError loop() override; + APIError read_packet(ReadPacketBuffer *buffer) override; + bool can_write_without_blocking() override; + APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override; + std::string getpeername() override{ return socket_->getpeername(); } - APIError close(); - APIError shutdown(int how); + APIError close() override; + APIError shutdown(int how) override; // Give this helper a name for logging - void set_log_info(std::string info) { + void set_log_info(std::string info) override { info_ = std::move(info); } protected: - APIError reserve_rx_buf_(size_t new_capacity); - struct ParsedFrame { std::vector msg; }; APIError state_action_(); APIError try_read_frame_(ParsedFrame *frame); - APIError try_send_raw_(); + APIError try_send_tx_buf_(); APIError write_frame_(const uint8_t *data, size_t len); APIError write_raw_(const uint8_t *data, size_t len); APIError init_handshake_(); APIError check_handshake_finished_(); + void send_explicit_handshake_reject_(const std::string &reason); std::unique_ptr socket_; @@ -124,6 +128,59 @@ class APINoiseFrameHelper : public APIFrameHelper { FAILED = 7, } state_ = State::INITIALIZE; }; +#endif // USE_API_NOISE + +#ifdef USE_API_PLAINTEXT +class APIPlaintextFrameHelper : public APIFrameHelper { + public: + APIPlaintextFrameHelper(std::unique_ptr socket) : socket_(std::move(socket)) {} + ~APIPlaintextFrameHelper() = default; + APIError init() override; + APIError loop() override; + APIError read_packet(ReadPacketBuffer *buffer) override; + bool can_write_without_blocking() override; + APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override; + std::string getpeername() override { + return socket_->getpeername(); + } + APIError close() override; + APIError shutdown(int how) override; + // Give this helper a name for logging + void set_log_info(std::string info) override { + info_ = std::move(info); + } + + protected: + struct ParsedFrame { + std::vector msg; + }; + + APIError try_read_frame_(ParsedFrame *frame); + APIError try_send_tx_buf_(); + APIError write_frame_(const uint8_t *data, size_t len); + APIError write_raw_(const uint8_t *data, size_t len); + + std::unique_ptr socket_; + + std::string info_; + std::vector rx_header_buf_; + bool rx_header_parsed_ = false; + uint32_t rx_header_parsed_type_ = 0; + uint32_t rx_header_parsed_len_ = 0; + + std::vector rx_buf_; + size_t rx_buf_len_ = 0; + + std::vector tx_buf_; + + enum class State { + INITIALIZE = 1, + DATA = 2, + CLOSED = 3, + FAILED = 4, + } state_ = State::INITIALIZE; +}; +#endif } // namespace api } // namespace esphome diff --git a/esphome/components/api/api_noise_context.h b/esphome/components/api/api_noise_context.h index 055470a84..db4f885d8 100644 --- a/esphome/components/api/api_noise_context.h +++ b/esphome/components/api/api_noise_context.h @@ -1,10 +1,12 @@ #pragma once #include #include +#include "esphome/core/defines.h" namespace esphome { namespace api { +#ifdef USE_API_NOISE using psk_t = std::array; class APINoiseContext { @@ -19,6 +21,7 @@ class APINoiseContext { protected: psk_t psk_; }; +#endif // USE_API_NOISE } // namespace api } // namespace esphome diff --git a/esphome/components/api/api_server.h b/esphome/components/api/api_server.h index 61d6f873b..624674f8f 100644 --- a/esphome/components/api/api_server.h +++ b/esphome/components/api/api_server.h @@ -31,12 +31,14 @@ class APIServer : public Component, public Controller { void set_password(const std::string &password); void set_reboot_timeout(uint32_t reboot_timeout); +#ifdef USE_API_NOISE void set_noise_psk(psk_t psk) { noise_ctx_->set_psk(std::move(psk)); } std::shared_ptr get_noise_ctx() { return noise_ctx_; } +#endif // USE_API_NOISE void handle_disconnect(APIConnection *conn); #ifdef USE_BINARY_SENSOR @@ -97,7 +99,10 @@ class APIServer : public Component, public Controller { std::string password_; std::vector state_subs_; std::vector user_services_; + +#ifdef USE_API_NOISE std::shared_ptr noise_ctx_ = std::make_shared(); +#endif // USE_API_NOISE }; extern APIServer *global_api_server; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/esphome/components/socket/__init__.py b/esphome/components/socket/__init__.py index e72fe12f5..8e9502be6 100644 --- a/esphome/components/socket/__init__.py +++ b/esphome/components/socket/__init__.py @@ -1,7 +1,6 @@ import esphome.config_validation as cv import esphome.codegen as cg -# Dummy package to allow components to depend on network CODEOWNERS = ["@esphome/core"] CONF_IMPLEMENTATION = "implementation" diff --git a/esphome/components/socket/bsd_sockets_impl.cpp b/esphome/components/socket/bsd_sockets_impl.cpp index e1431a385..a11c8eaae 100644 --- a/esphome/components/socket/bsd_sockets_impl.cpp +++ b/esphome/components/socket/bsd_sockets_impl.cpp @@ -47,11 +47,6 @@ class BSDSocketImpl : public Socket { closed_ = true; return ret; } - int connect(const std::string &address) override { - // TODO - return 0; - } - int connect(const struct sockaddr *addr, socklen_t addrlen) override { return ::connect(fd_, addr, addrlen); } int shutdown(int how) override { return ::shutdown(fd_, how); } int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { return ::getpeername(fd_, addr, addrlen); } @@ -79,9 +74,7 @@ class BSDSocketImpl : public Socket { return ::setsockopt(fd_, level, optname, optval, optlen); } int listen(int backlog) override { return ::listen(fd_, backlog); } - // virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0; ssize_t read(void *buf, size_t len) override { return ::read(fd_, buf, len); } - // virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0; ssize_t write(const void *buf, size_t len) override { return ::write(fd_, buf, len); } int setblocking(bool blocking) override { int fl = ::fcntl(fd_, F_GETFL, 0); diff --git a/esphome/components/socket/lwip_raw_tcp_impl.cpp b/esphome/components/socket/lwip_raw_tcp_impl.cpp index 40bba21c9..ee65a6d1a 100644 --- a/esphome/components/socket/lwip_raw_tcp_impl.cpp +++ b/esphome/components/socket/lwip_raw_tcp_impl.cpp @@ -29,7 +29,6 @@ class LWIPRawImpl : public Socket { } void init() { - ESP_LOGD(TAG, "init()"); tcp_arg(pcb_, this); tcp_accept(pcb_, LWIPRawImpl::s_accept_fn); tcp_recv(pcb_, LWIPRawImpl::s_recv_fn); @@ -98,8 +97,7 @@ class LWIPRawImpl : public Socket { port = ntohs(addr4->sin_port); ip.addr = addr4->sin_addr.s_addr; #endif - err_t err = tcp_bind(pcb_, IP4_ADDR_ANY, port); - ESP_LOGD(TAG, "bind(ip=%u, port=%u) -> %d", ip.addr, port, err); + err_t err = tcp_bind(pcb_, &ip, port); if (err == ERR_USE) { errno = EADDRINUSE; return -1; @@ -129,14 +127,6 @@ class LWIPRawImpl : public Socket { pcb_ = nullptr; return 0; } - int connect(const std::string &address) override { - // TODO - return -1; - } - int connect(const struct sockaddr *addr, socklen_t addrlen) override { - // TODO - return -1; - } int shutdown(int how) override { if (pcb_ == nullptr) { errno = EBADF; @@ -226,7 +216,29 @@ class LWIPRawImpl : public Socket { errno = EBADF; return -1; } - // TODO + if (level == SOL_SOCKET && optname == SO_REUSEADDR) { + if (optlen < 4) { + errno = EINVAL; + return -1; + } + + // lwip doesn't seem to have this feature. Don't send an error + // to prevent warnings + *reinterpret_cast(optval) = 1; + *optlen = 4; + return 0; + } + if (level == IPPROTO_TCP && optname == TCP_NODELAY) { + if (optlen < 4) { + errno = EINVAL; + return -1; + } + *reinterpret_cast(optval) = tcp_nagle_disabled(pcb_); + *optlen = 4; + return 0; + } + + errno = EINVAL; return -1; } int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override { @@ -240,7 +252,8 @@ class LWIPRawImpl : public Socket { return -1; } - // TODO + // lwip doesn't seem to have this feature. Don't send an error + // to prevent warnings return 0; } if (level == IPPROTO_TCP && optname == TCP_NODELAY) { @@ -266,7 +279,6 @@ class LWIPRawImpl : public Socket { return -1; } struct tcp_pcb *listen_pcb = tcp_listen_with_backlog(pcb_, backlog); - ESP_LOGD(TAG, "listen(%d) -> %p", backlog, listen_pcb); if (listen_pcb == nullptr) { tcp_abort(pcb_); pcb_ = nullptr; @@ -286,7 +298,7 @@ class LWIPRawImpl : public Socket { return -1; } if (rx_closed_ && rx_buf_ == nullptr) { - errno = ECONNRESET; // TODO: is this the right errno? + errno = ECONNRESET; return -1; } if (len == 0) { @@ -333,7 +345,6 @@ class LWIPRawImpl : public Socket { return read; } - // virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0; ssize_t write(const void *buf, size_t len) { if (pcb_ == nullptr) { errno = EBADF; @@ -367,7 +378,6 @@ class LWIPRawImpl : public Socket { } return to_send; } - // virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0; int setblocking(bool blocking) { if (pcb_ == nullptr) { errno = EBADF; @@ -382,20 +392,32 @@ class LWIPRawImpl : public Socket { } err_t accept_fn(struct tcp_pcb *newpcb, err_t err) { - // TODO: check err + if (err != ERR_OK || newpcb == 0) { + // "An error code if there has been an error accepting. Only return ERR_ABRT if you have + // called tcp_abort from within the callback function!" + // https://www.nongnu.org/lwip/2_1_x/tcp_8h.html#a00517abce6856d6c82f0efebdafb734d + // nothing to do here, we just don't push it to the queue + return ERR_OK; + } accepted_sockets_.emplace(new LWIPRawImpl(newpcb)); - ESP_LOGD(TAG, "accept_fn newpcb=%p err=%d", newpcb, err); return ERR_OK; } void err_fn(err_t err) { - ESP_LOGD(TAG, "err_fn err=%d", err); + // "If a connection is aborted because of an error, the application is alerted of this event by + // the err callback." + // pcb is already freed when this callback is called + // ERR_RST: connection was reset by remote host + // ERR_ABRT: aborted through tcp_abort or TCP timer + pcb_ = nullptr; } err_t recv_fn(struct pbuf *pb, err_t err) { - // TODO: check err - ESP_LOGD(TAG, "recv_fn pb=%p err=%d", pb, err); + if (err != 0) { + // "An error code if there has been an error receiving Only return ERR_ABRT if you have + // called tcp_abort from within the callback function!" + rx_closed_ = true; + return ERR_OK; + } if (pb == nullptr) { - // remote host has closed the connection - // TODO rx_closed_ = true; return ERR_OK; } diff --git a/esphome/components/socket/socket.h b/esphome/components/socket/socket.h index 721f1225a..7a5ce7916 100644 --- a/esphome/components/socket/socket.h +++ b/esphome/components/socket/socket.h @@ -18,8 +18,9 @@ class Socket { virtual std::unique_ptr accept(struct sockaddr *addr, socklen_t *addrlen) = 0; virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0; virtual int close() = 0; - virtual int connect(const std::string &address) = 0; - virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0; + // not supported yet: + // virtual int connect(const std::string &address) = 0; + // virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0; virtual int shutdown(int how) = 0; virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0; @@ -30,9 +31,7 @@ class Socket { virtual int setsockopt(int level, int optname, const void *optval, socklen_t optlen) = 0; virtual int listen(int backlog) = 0; virtual ssize_t read(void *buf, size_t len) = 0; - // virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0; virtual ssize_t write(const void *buf, size_t len) = 0; - // virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0; virtual int setblocking(bool blocking) = 0; virtual int loop() { return 0; }; }; diff --git a/esphome/components/ssl/mbedtls_impl.cpp b/esphome/components/ssl/mbedtls_impl.cpp index f40fc82ea..1ebfa02cc 100644 --- a/esphome/components/ssl/mbedtls_impl.cpp +++ b/esphome/components/ssl/mbedtls_impl.cpp @@ -130,7 +130,6 @@ class MbedTLSWrappedSocket : public socket::Socket { int ret = mbedtls_ssl_read(&ssl_, reinterpret_cast(buf), len); return this->mbedtls_to_errno_(ret); } - // virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0; ssize_t write(const void *buf, size_t len) override { loop(); if (do_handshake_) { @@ -140,7 +139,6 @@ class MbedTLSWrappedSocket : public socket::Socket { int ret = mbedtls_ssl_write(&ssl_, reinterpret_cast(buf), len); return this->mbedtls_to_errno_(ret); } - // virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0; int setblocking(bool blocking) override { // TODO: handle blocking modes return sock_->setblocking(blocking); diff --git a/esphome/core/defines.h b/esphome/core/defines.h index 5a056f494..8e0e814a4 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -31,3 +31,5 @@ #define USE_MDNS #define USE_SOCKET_IMPL_LWIP_TCP #define USE_SOCKET_IMPL_BSD_SOCKETS +#define USE_API_NOISE +#define USE_API_PLAINTEXT