This commit is contained in:
Otto winter 2021-08-23 20:23:39 +02:00
parent 7cfc36cb70
commit 44041d2526
No known key found for this signature in database
GPG Key ID: 48ED2DDB96D7682C
5 changed files with 225 additions and 47 deletions

View File

@ -587,6 +587,10 @@ APIError APINoiseFrameHelper::shutdown(int how) {
} // namespace esphome
extern "C" {
// declare how noise generates random bytes (here with a good HWRNG based on the RF system)
void noise_rand_bytes(void *output, size_t len) { esp_fill_random(output, len); }
void noise_rand_bytes(void *output, size_t len) {
esphome::fill_random(reinterpret_cast<uint8_t *>(output), len);
}
}

View File

@ -11,6 +11,7 @@
#include <sys/types.h>
#include "lwip/inet.h"
#include <stdint.h>
#include <errno.h>
/* Address families. */
#define AF_UNSPEC 0

View File

@ -11,9 +11,13 @@
#include "lwip/netif.h"
#include "errno.h"
#include "esphome/core/log.h"
namespace esphome {
namespace socket {
static const char *const TAG = "lwip";
class LWIPRawImpl : public Socket {
public:
LWIPRawImpl(struct tcp_pcb *pcb) : pcb_(pcb) {}
@ -25,22 +29,35 @@ class LWIPRawImpl : public Socket {
}
void init() {
tcp_arg(pcb_, arg);
tcp_accept(pcb_, accept_fn);
tcp_recv(pcb_, recv_fn);
ESP_LOGD(TAG, "init()");
tcp_arg(pcb_, this);
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
tcp_recv(pcb_, LWIPRawImpl::s_recv_fn);
tcp_err(pcb_, LWIPRawImpl::s_err_fn);
}
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
if (accepted_sockets_.empty())
if (pcb_ == nullptr) {
errno = EBADF;
return nullptr;
}
if (accepted_sockets_.empty()) {
errno = EWOULDBLOCK;
return nullptr;
}
std::unique_ptr<LWIPRawImpl> sock = std::move(accepted_sockets_.front());
accepted_sockets_.pop();
if (addr != nullptr) {
sock->getpeername(addr, addrlen);
}
return std::unique_ptr<Socket>(sock);
sock->init();
return std::unique_ptr<Socket>(std::move(sock));
}
int bind(const struct sockaddr *name, socklen_t addrlen) {
int bind(const struct sockaddr *name, socklen_t addrlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (name == nullptr) {
errno = EINVAL;
return 0;
@ -81,70 +98,197 @@ class LWIPRawImpl : public Socket {
port = ntohs(addr4->sin_port);
ip.addr = addr4->sin_addr.s_addr;
#endif
err_t err = tcp_bind(pcb_, &ip, port);
err_t err = tcp_bind(pcb_, IP4_ADDR_ANY, port);
ESP_LOGD(TAG, "bind(ip=%u, port=%u) -> %d", ip.addr, port, err);
if (err == ERR_USE) {
errno = EADDRINUSE;
return -1;
}
if (err == ERR_VAL || err != ERR_OK) {
if (err == ERR_VAL) {
errno = EINVAL;
return -1;
}
if (err != ERR_OK) {
errno = EIO;
return -1;
}
return 0;
}
int close() {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
err_t err = tcp_close(pcb_);
if (err != ERR_OK) {
tcp_abort(pcb_);
pcb_ = nullptr;
errno = err == ERR_MEM ? ENOMEM : EIO;
return -1;
}
pcb_ = nullptr;
return 0;
}
int connect(const std::string &address) override {
// TODO
return -1;
}
int connect(const std::string &address) {
int connect(const struct sockaddr *addr, socklen_t addrlen) override {
// TODO
return -1;
}
int connect(const struct sockaddr *addr, socklen_t addrlen) {
// TODO
return -1;
}
int shutdown(int how) {
// TODO
return -1;
int shutdown(int how) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
bool shut_rx = false, shut_tx = false;
if (how == SHUT_RD) {
shut_rx = true;
} else if (how == SHUT_WR) {
shut_tx = true;
} else if (how == SHUT_RDWR) {
shut_rx = shut_tx = true;
} else {
errno = EINVAL;
return -1;
}
err_t err = tcp_shutdown(pcb_, shut_rx, shut_tx);
if (err != ERR_OK) {
errno = err == ERR_MEM ? ENOMEM : EIO;
return -1;
}
return 0;
}
int getpeername(struct sockaddr *addr, socklen_t *addrlen) {
int getpeername(struct sockaddr *name, socklen_t *addrlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (name == nullptr || addrlen == nullptr) {
errno = EINVAL;
return -1;
}
if (*addrlen < sizeof(struct sockaddr_in)) {
errno = EINVAL;
return -1;
}
struct sockaddr_in *addr = reinterpret_cast<struct sockaddr_in *>(name);
addr->sin_family = AF_INET;
*addrlen = addr->sin_len = sizeof(struct sockaddr_in);
addr->sin_port = pcb_->remote_port;
addr->sin_addr.s_addr = pcb_->remote_ip.addr;
return 0;
}
std::string getpeername() override {
if (pcb_ == nullptr) {
errno = EBADF;
return "";
}
char buffer[24];
uint32_t ip4 = pcb_->remote_ip.addr;
snprintf(buffer, sizeof(buffer), "%d.%d.%d.%d", (ip4 >> 24) & 0xFF, (ip4 >> 16) & 0xFF, (ip4 >> 8) & 0xFF, (ip4 >> 0) & 0xFF);
return std::string(buffer);
}
int getsockname(struct sockaddr *name, socklen_t *addrlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (name == nullptr || addrlen == nullptr) {
errno = EINVAL;
return -1;
}
if (*addrlen < sizeof(struct sockaddr_in)) {
errno = EINVAL;
return -1;
}
struct sockaddr_in *addr = reinterpret_cast<struct sockaddr_in *>(name);
addr->sin_family = AF_INET;
*addrlen = addr->sin_len = sizeof(struct sockaddr_in);
addr->sin_port = pcb_->local_port;
addr->sin_addr.s_addr = pcb_->local_ip.addr;
return 0;
}
std::string getsockname() override {
if (pcb_ == nullptr) {
errno = EBADF;
return "";
}
char buffer[24];
uint32_t ip4 = pcb_->local_ip.addr;
snprintf(buffer, sizeof(buffer), "%d.%d.%d.%d", (ip4 >> 24) & 0xFF, (ip4 >> 16) & 0xFF, (ip4 >> 8) & 0xFF, (ip4 >> 0) & 0xFF);
return std::string(buffer);
}
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
// TODO
return -1;
}
std::string getpeername() {
// TODO
return "";
}
int getsockname(struct sockaddr *addr, socklen_t *addrlen) {
// TODO
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (level == SOL_SOCKET && optname == SO_REUSEADDR) {
if (optlen != 4) {
errno = EINVAL;
return -1;
}
// TODO
return 0;
}
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
if (optlen != 4) {
errno = EINVAL;
return -1;
}
int val = *reinterpret_cast<const int *>(optval);
if (val != 0) {
tcp_nagle_disable(pcb_);
} else {
tcp_nagle_enable(pcb_);
}
return 0;
}
errno = EINVAL;
return -1;
}
std::string getsockname() {
// TODO
return "";
}
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) {
// TODO
return -1;
}
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) {
// TODO
return -1;
}
int listen(int backlog) {
int listen(int backlog) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
struct tcp_pcb *listen_pcb = tcp_listen_with_backlog(pcb_, backlog);
ESP_LOGD(TAG, "listen(%d) -> %p", backlog, listen_pcb);
if (listen_pcb == nullptr) {
tcp_abort(pcb_);
pcb_ = nullptr;
errno = EOPNOTSUPP;
return -1;
}
// tcp_listen reallocates the pcb, replace ours
pcb_ = listen_pcb;
// set callbacks on new pcb
tcp_arg(pcb_, this);
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
return 0;
}
ssize_t read(void *buf, size_t len) {
ssize_t read(void *buf, size_t len) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (rx_closed_ && rx_buf_ == nullptr) {
errno = ECONNRESET; // TODO: is this the right errno?
return -1;
}
if (len == 0) {
return 0;
}
@ -154,13 +298,14 @@ class LWIPRawImpl : public Socket {
}
size_t read = 0;
uint8_t *buf8 = reinterpret_cast<uint8_t *>(buf);
while (len) {
size_t pb_len = rx_buf_->len;
size_t pb_left = pb_len - rx_buf_offset_;
if (pb_left == 0)
break;
size_t copysize = std::min(len, pb_left);
memcpy(buf, rx_buf_->payload + rx_buf_offset_, copysize);
memcpy(buf8, reinterpret_cast<uint8_t *>(rx_buf_->payload) + rx_buf_offset_, copysize);
if (pb_left == copysize) {
// full pb copied, free it
@ -181,7 +326,7 @@ class LWIPRawImpl : public Socket {
}
tcp_recved(pcb_, copysize);
buf += copysize;
buf8 += copysize;
len -= copysize;
read += copysize;
}
@ -190,6 +335,10 @@ class LWIPRawImpl : public Socket {
}
// virtual ssize_t readv(const struct iovec *iov, int iovcnt) = 0;
ssize_t write(const void *buf, size_t len) {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (len == 0)
return 0;
if (buf == nullptr) {
@ -220,6 +369,10 @@ class LWIPRawImpl : public Socket {
}
// virtual ssize_t writev(const struct iovec *iov, int iovcnt) = 0;
int setblocking(bool blocking) {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (blocking) {
// blocking operation not supported
errno = EINVAL;
@ -231,15 +384,19 @@ class LWIPRawImpl : public Socket {
err_t accept_fn(struct tcp_pcb *newpcb, err_t err) {
// TODO: check err
accepted_sockets_.emplace(new LWIPRawImpl(newpcb));
ESP_LOGD(TAG, "accept_fn newpcb=%p err=%d", newpcb, err);
return ERR_OK;
}
void err_fn(err_t err) {
ESP_LOGD(TAG, "err_fn err=%d", err);
}
err_t recv_fn(struct pbuf *pb, err_t err) {
// TODO: check err
ESP_LOGD(TAG, "recv_fn pb=%p err=%d", pb, err);
if (pb == nullptr) {
// remote host has closed the connection
// TODO
rx_closed_ = true;
return ERR_OK;
}
if (rx_buf_ == nullptr) {
@ -254,30 +411,35 @@ class LWIPRawImpl : public Socket {
static err_t s_accept_fn(void *arg, struct tcp_pcb *newpcb, err_t err) {
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
return arg_this->accept_fn(newpcb, err);
}
static void s_err_fn(void *arg, err_t err) {
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
return arg_this->err_fn(err);
}
static err_t s_recv_fn(void *arg, struct tcp_pcb *pcb, struct pbuf *pb, err_t err) {
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
return arg_this->recv_fn(pb, err);
}
protected:
struct tcp_pcb *pcb_;
std::queue<std::unique_ptr<LWIPRawImpl>> accepted_sockets_;
bool rx_closed_ = false;
pbuf *rx_buf_ = nullptr;
size_t rx_buf_offset_ = 0;
};
std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
auto *pcb = tcp_new();
/*if (ret == nullptr)
if (pcb == nullptr)
return nullptr;
return std::unique_ptr<Socket>{new LWIPRawImpl(ret)};*/
return nullptr;
auto *sock = new LWIPRawImpl(pcb);
sock->init();
return std::unique_ptr<Socket>{sock};
}
} // namespace socket

View File

@ -55,6 +55,15 @@ double random_double() { return random_uint32() / double(UINT32_MAX); }
float random_float() { return float(random_double()); }
void fill_random(uint8_t *data, size_t len) {
#ifdef ARDUINO_ARCH_ESP32
esp_fill_random(data, len);
#else
int err = os_get_random(data, len);
assert(err == 0);
#endif
}
static uint32_t fast_random_seed = 0; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
void fast_random_set_seed(uint32_t seed) { fast_random_seed = seed; }

View File

@ -109,6 +109,8 @@ double random_double();
/// Returns a random float between 0 and 1. Essentially just casts random_double() to a float.
float random_float();
void fill_random(uint8_t *data, size_t len);
void fast_random_set_seed(uint32_t seed);
uint32_t fast_random_32();
uint16_t fast_random_16();