diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 08ecf954..c3d26ac8 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -44,6 +44,8 @@ file(GLOB common_SRC "src/common/network/*.cpp" "src/common/network/rest/*.cpp" "src/common/network/rest/http/*.cpp" + "src/common/network/tcp/*.cpp" + "src/common/network/udp/*.cpp" "src/common/yaml/*.cpp" "src/common/*.cpp" ) @@ -85,6 +87,7 @@ file(GLOB common_INCLUDE "src/common/network/rest/*.h" "src/common/network/rest/http/*.h" "src/common/network/tcp/*.h" + "src/common/network/udp/*.h" "src/common/yaml/*.h" "src/common/*.h" ) diff --git a/src/common/network/BaseNetwork.cpp b/src/common/network/BaseNetwork.cpp index 0e3d3224..64f03c4f 100644 --- a/src/common/network/BaseNetwork.cpp +++ b/src/common/network/BaseNetwork.cpp @@ -65,7 +65,7 @@ BaseNetwork::BaseNetwork(uint32_t peerId, bool duplex, bool debug, bool slot1, b { assert(peerId < 999999999U); - m_socket = new UDPSocket(localPort); + m_socket = new udp::Socket(localPort); m_frameQueue = new FrameQueue(m_socket, peerId, debug); std::random_device rd; diff --git a/src/common/network/BaseNetwork.h b/src/common/network/BaseNetwork.h index 1f4f3500..96cb47de 100644 --- a/src/common/network/BaseNetwork.h +++ b/src/common/network/BaseNetwork.h @@ -23,7 +23,7 @@ #include "common/p25/Audio.h" #include "common/nxdn/lc/RTCH.h" #include "common/network/FrameQueue.h" -#include "common/network/UDPSocket.h" +#include "common/network/udp/Socket.h" #include "common/RingBuffer.h" #include "common/Utils.h" @@ -253,7 +253,7 @@ namespace network bool m_debug; - UDPSocket* m_socket; + udp::Socket* m_socket; FrameQueue* m_frameQueue; RingBuffer m_rxDMRData; diff --git a/src/common/network/FrameQueue.cpp b/src/common/network/FrameQueue.cpp index a0c8558a..9fc78782 100644 --- a/src/common/network/FrameQueue.cpp +++ b/src/common/network/FrameQueue.cpp @@ -35,7 +35,7 @@ using namespace network::frame; /// /// Local port used to listen for incoming data. /// Unique ID of this modem on the network. -FrameQueue::FrameQueue(UDPSocket* socket, uint32_t peerId, bool debug) : RawFrameQueue(socket, debug), +FrameQueue::FrameQueue(udp::Socket* socket, uint32_t peerId, bool debug) : RawFrameQueue(socket, debug), m_peerId(peerId), m_streamTimestamps() { @@ -227,7 +227,7 @@ void FrameQueue::enqueueMessage(const uint8_t* message, uint32_t length, uint32_ if (m_debug) Utils::dump(1U, "FrameQueue::enqueueMessage() Buffered Message", buffer, bufferLen); - UDPDatagram *dgram = new UDPDatagram; + udp::UDPDatagram *dgram = new udp::UDPDatagram; dgram->buffer = buffer; dgram->length = bufferLen; dgram->address = addr; diff --git a/src/common/network/FrameQueue.h b/src/common/network/FrameQueue.h index e2f5ceb8..ea7fa34c 100644 --- a/src/common/network/FrameQueue.h +++ b/src/common/network/FrameQueue.h @@ -36,8 +36,12 @@ namespace network class HOST_SW_API FrameQueue : public RawFrameQueue { public: typedef std::pair OpcodePair; public: + auto operator=(FrameQueue&) -> FrameQueue& = delete; + auto operator=(FrameQueue&&) -> FrameQueue& = delete; + FrameQueue(FrameQueue&) = delete; + /// Initializes a new instance of the FrameQueue class. - FrameQueue(UDPSocket* socket, uint32_t peerId, bool debug); + FrameQueue(udp::Socket* socket, uint32_t peerId, bool debug); /// Read message from the received UDP packet. UInt8Array read(int& messageLength, sockaddr_storage& address, uint32_t& addrLen, diff --git a/src/common/network/RawFrameQueue.cpp b/src/common/network/RawFrameQueue.cpp index 91c376ab..5757d25b 100644 --- a/src/common/network/RawFrameQueue.cpp +++ b/src/common/network/RawFrameQueue.cpp @@ -12,7 +12,7 @@ */ #include "Defines.h" #include "network/RawFrameQueue.h" -#include "network/UDPSocket.h" +#include "network/udp/Socket.h" #include "Log.h" #include "Utils.h" @@ -30,7 +30,7 @@ using namespace network; /// /// Local port used to listen for incoming data. /// -RawFrameQueue::RawFrameQueue(UDPSocket* socket, bool debug) : +RawFrameQueue::RawFrameQueue(udp::Socket* socket, bool debug) : m_socket(socket), m_buffers(), m_debug(debug) @@ -101,7 +101,7 @@ void RawFrameQueue::enqueueMessage(const uint8_t* message, uint32_t length, sock if (m_debug) Utils::dump(1U, "RawFrameQueue::enqueueMessage() Buffered Message", buffer, length); - UDPDatagram* dgram = new UDPDatagram; + udp::UDPDatagram* dgram = new udp::UDPDatagram; dgram->buffer = buffer; dgram->length = length; dgram->address = addr; diff --git a/src/common/network/RawFrameQueue.h b/src/common/network/RawFrameQueue.h index 654b9946..093d74eb 100644 --- a/src/common/network/RawFrameQueue.h +++ b/src/common/network/RawFrameQueue.h @@ -14,7 +14,7 @@ #define __RAW_FRAME_QUEUE_H__ #include "common/Defines.h" -#include "common/network/UDPSocket.h" +#include "common/network/udp/Socket.h" #include "common/Utils.h" namespace network @@ -32,8 +32,12 @@ namespace network class HOST_SW_API RawFrameQueue { public: + auto operator=(RawFrameQueue&) -> RawFrameQueue& = delete; + auto operator=(RawFrameQueue&&) -> RawFrameQueue& = delete; + RawFrameQueue(RawFrameQueue&) = delete; + /// Initializes a new instance of the RawFrameQueue class. - RawFrameQueue(UDPSocket* socket, bool debug); + RawFrameQueue(udp::Socket* socket, bool debug); /// Finalizes a instance of the RawFrameQueue class. virtual ~RawFrameQueue(); @@ -49,9 +53,9 @@ namespace network protected: sockaddr_storage m_addr; uint32_t m_addrLen; - UDPSocket* m_socket; + udp::Socket* m_socket; - BufferVector m_buffers; + udp::BufferVector m_buffers; bool m_debug; diff --git a/src/common/network/UDPSocket.h b/src/common/network/UDPSocket.h deleted file mode 100644 index f0227689..00000000 --- a/src/common/network/UDPSocket.h +++ /dev/null @@ -1,167 +0,0 @@ -// SPDX-License-Identifier: GPL-2.0-only -/** -* Digital Voice Modem - Common Library -* GPLv2 Open Source. Use is subject to license terms. -* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. -* -* @package DVM / Common Library -* @derivedfrom MMDVMHost (https://github.com/g4klx/MMDVMHost) -* @license GPLv2 License (https://opensource.org/licenses/GPL-2.0) -* -* Copyright (C) 2006-2016,2020 Jonathan Naylor, G4KLX -* Copyright (C) 2017-2024 Bryan Biedenkapp, N2PLL -* -*/ -#if !defined(__UDP_SOCKET_H__) -#define __UDP_SOCKET_H__ - -#include "common/Defines.h" -#include "common/AESCrypto.h" - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if !defined(UDP_SOCKET_MAX) -#define UDP_SOCKET_MAX 1 -#endif - -#define AES_WRAPPED_PCKT_MAGIC 0xC0FEU -#define AES_WRAPPED_PCKT_KEY_LEN 32 - -enum IPMATCHTYPE { - IMT_ADDRESS_AND_PORT, - IMT_ADDRESS_ONLY -}; - -namespace network -{ -#if defined(HAVE_SENDMSG) && !defined(HAVE_SENDMMSG) - /* For `sendmmsg'. */ - struct mmsghdr { - struct msghdr msg_hdr; /* Actual message header. */ - unsigned int msg_len; /* Number of received or sent bytes for the entry. */ - }; - - /* Send a VLEN messages as described by VMESSAGES to socket FD. - Returns the number of datagrams successfully written or -1 for errors. - - This function is a cancellation point and therefore not marked with - __THROW. */ - static inline int sendmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags) - { - ssize_t n = 0; - for (unsigned int i = 0; i < vlen; i++) { - ssize_t ret = sendmsg(sockfd, &msgvec[i].msg_hdr, flags); - if (ret < 0) - break; - n += ret; - } - - if (n == 0) - return -1; - - return int(n); - } -#endif - - // --------------------------------------------------------------------------- - // Structure Declaration - // This structure represents a container for a network buffer. - // --------------------------------------------------------------------------- - - struct UDPDatagram { - uint8_t* buffer; - size_t length; - - sockaddr_storage address; - uint32_t addrLen; - }; - - /* Vector of buffers that contain a full frames */ - typedef std::vector BufferVector; - - // --------------------------------------------------------------------------- - // Class Declaration - // This class implements low-level routines to communicate over a UDP - // network socket. - // --------------------------------------------------------------------------- - - class HOST_SW_API UDPSocket { - public: - /// Initializes a new instance of the UDPSocket class. - UDPSocket(const std::string& address, uint16_t port = 0U); - /// Initializes a new instance of the UDPSocket class. - UDPSocket(uint16_t port = 0U); - /// Finalizes a instance of the UDPSocket class. - ~UDPSocket(); - - /// Opens UDP socket connection. - bool open(uint32_t af = AF_UNSPEC); - /// Opens UDP socket connection. - bool open(const sockaddr_storage& address); - /// Opens UDP socket connection. - bool open(const uint32_t index, const uint32_t af, const std::string& address, const uint16_t port); - - /// Read data from the UDP socket. - int read(uint8_t* buffer, uint32_t length, sockaddr_storage& address, uint32_t& addrLen); - /// Write data to the UDP socket. - bool write(const uint8_t* buffer, uint32_t length, const sockaddr_storage& address, uint32_t addrLen, int* lenWritten = nullptr); - /// Write data to the UDP socket. - bool write(BufferVector& buffers, int* lenWritten = nullptr); - - /// Closes the UDP socket connection. - void close(); - /// Closes the UDP socket connection. - void close(const uint32_t index); - - /// Sets the preshared encryption key. - void setPresharedKey(const uint8_t* presharedKey); - - /// Flag indicating the UDP socket(s) are open. - bool isOpen() const { return m_isOpen; } - - /// Helper to lookup a hostname and resolve it to an IP address. - static int lookup(const std::string& hostName, uint16_t port, sockaddr_storage& address, uint32_t& addrLen); - /// Helper to lookup a hostname and resolve it to an IP address. - static int lookup(const std::string& hostName, uint16_t port, sockaddr_storage& address, uint32_t& addrLen, struct addrinfo& hints); - - /// - static bool match(const sockaddr_storage& addr1, const sockaddr_storage& addr2, IPMATCHTYPE type = IMT_ADDRESS_AND_PORT); - /// - static std::string address(const sockaddr_storage& addr); - /// - static uint16_t port(const sockaddr_storage& addr); - - /// - static bool isNone(const sockaddr_storage& addr); - - private: - std::string m_address_save; - uint16_t m_port_save; - std::string m_address[UDP_SOCKET_MAX]; - uint16_t m_port[UDP_SOCKET_MAX]; - - bool m_isOpen; - - uint32_t m_af[UDP_SOCKET_MAX]; - int m_fd[UDP_SOCKET_MAX]; - - crypto::AES* m_aes; - bool m_isCryptoWrapped; - uint8_t* m_presharedKey; - - uint32_t m_counter; - }; -} // namespace network - -#endif // __UDP_SOCKET_H__ diff --git a/src/common/network/tcp/Socket.cpp b/src/common/network/tcp/Socket.cpp new file mode 100644 index 00000000..afef1c9f --- /dev/null +++ b/src/common/network/tcp/Socket.cpp @@ -0,0 +1,402 @@ +// SPDX-License-Identifier: GPL-2.0-only +/** +* Digital Voice Modem - Common Library +* GPLv2 Open Source. Use is subject to license terms. +* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. +* +* @package DVM / Common Library +* @license GPLv2 License (https://opensource.org/licenses/GPL-2.0) +* +* Copyright (C) 2024 Bryan Biedenkapp, N2PLL +* +*/ +#include "Defines.h" +#include "network/tcp/Socket.h" +#include "Log.h" +#include "Utils.h" + +using namespace network; +using namespace network::tcp; + +#include +#include +#include + +// --------------------------------------------------------------------------- +// Public Class Members +// --------------------------------------------------------------------------- + +/// +/// Initializes a new instance of the Socket class. +/// +Socket::Socket() : + m_localAddress(), + m_localPort(0U), + m_fd(-1), + m_counter(0U) +{ + /* stub */ +} + +/// +/// Initializes a new instance of the Socket class. +/// +/// +Socket::Socket(const int fd) noexcept : + m_localAddress(), + m_localPort(0U), + m_fd(fd), + m_counter(0U) +{ + /* stub */ +} + +/// +/// Initializes a new instance of the Socket class. +/// +/// +/// +/// +Socket::Socket(const int domain, const int type, const int protocol) : Socket() +{ + initSocket(domain, type, protocol); +} + +/// +/// Finalizes a instance of the Socket class. +/// +Socket::~Socket() +{ + static_cast(::shutdown(m_fd, SHUT_RDWR)); + static_cast(::close(m_fd)); +} + +/// +/// Accepts a pending connection request. +/// +/// +/// +/// +int Socket::accept(sockaddr* address, socklen_t* addrlen) noexcept +{ + // check that the accept() won't block + int i, n; + struct pollfd pfd[TCP_SOCKET_MAX]; + for (i = n = 0; i < TCP_SOCKET_MAX; i++) { + if (m_fd >= 0) { + pfd[n].fd = m_fd; + pfd[n].events = POLLIN; + n++; + } + } + + // no socket descriptor to receive + if (n == 0) + return 0; + + // Return immediately + int ret = ::poll(pfd, n, 0); + if (ret < 0) { + LogError(LOG_NET, "Error returned from TCP poll, err: %d", errno); + return -1; + } + + int index; + for (i = 0; i < n; i++) { + // round robin + index = (i + m_counter) % n; + if (pfd[index].revents & POLLIN) + break; + } + if (i == n) + return -1; + + return ::accept(pfd[index].fd, address, addrlen); +} + +/// +/// Connects the client to a remote TCP host using the specified host name and port number. +/// +/// +/// +/// +bool Socket::connect(const std::string& ipAddr, const uint16_t port) +{ + sockaddr_in addr = {}; + initAddr(ipAddr, port, addr); + + m_localAddress = ipAddr; + m_localPort = port; + + socklen_t length = sizeof(addr); + bool retval = true; + if (::connect(m_fd, reinterpret_cast(&addr), length) < 0) + retval = false; + + return retval; +} + +/// +/// Starts listening for incoming connection requests with a maximum number of pending connection. +/// +/// +/// +/// +/// +ssize_t Socket::listen(const std::string& ipAddr, const uint16_t port, int backlog) noexcept +{ + m_localAddress = ipAddr; + m_localPort = port; + + if (!bind(m_localAddress, m_localPort)) { + return -1; + } + + LogInfoEx(LOG_NET, "Listening TCP port on %u", m_localPort); + return ::listen(m_fd, backlog); +} + +/// +/// Read data from the socket. +/// +/// Buffer to read data into. +/// Length of data to read. +/// +[[nodiscard]] ssize_t Socket::read(uint8_t* buffer, size_t length) noexcept +{ + assert(buffer != nullptr); + assert(length > 0U); + + if (m_fd < 0) + return -1; + + // check that the read() won't block + int i, n; + struct pollfd pfd[TCP_SOCKET_MAX]; + for (i = n = 0; i < TCP_SOCKET_MAX; i++) { + if (m_fd >= 0) { + pfd[n].fd = m_fd; + pfd[n].events = POLLIN; + n++; + } + } + + // no socket descriptor to receive + if (n == 0) + return 0; + + // Return immediately + int ret = ::poll(pfd, n, 0); + if (ret < 0) { + LogError(LOG_NET, "Error returned from TCP poll, err: %d", errno); + return -1; + } + + int index; + for (i = 0; i < n; i++) { + // round robin + index = (i + m_counter) % n; + if (pfd[index].revents & POLLIN) + break; + } + if (i == n) + return 0; + + m_counter++; + return ::read(pfd[index].fd, (char*)buffer, length); +} + +/// +/// Write data to the socket. +/// +/// Buffer containing data to write to socket. +/// Length of data to write. +/// +ssize_t Socket::write(const uint8_t* buffer, size_t length) noexcept +{ + assert(buffer != nullptr); + assert(length > 0U); + + if (m_fd < 0) + return -1; + + return ::send(m_fd, buffer, length, 0); +} + +/// +/// +/// +/// +/// +uint32_t Socket::addr(const sockaddr_storage& addr) +{ + switch (addr.ss_family) { + case AF_INET: + { + struct sockaddr_in* in; + in = (struct sockaddr_in*)& addr; + return in->sin_addr.s_addr; + } + break; + case AF_INET6: + default: + return -1; + break; + } +} + +/// +/// +/// +/// +/// +std::string Socket::address(const sockaddr_storage& addr) +{ + std::string address = std::string(); + char str[INET_ADDRSTRLEN]; + + switch (addr.ss_family) { + case AF_INET: + { + struct sockaddr_in* in; + in = (struct sockaddr_in*)& addr; + inet_ntop(AF_INET, &(in->sin_addr), str, INET_ADDRSTRLEN); + address = std::string(str); + } + break; + case AF_INET6: + { + struct sockaddr_in6* in6; + in6 = (struct sockaddr_in6*)& addr; + inet_ntop(AF_INET6, &(in6->sin6_addr), str, INET_ADDRSTRLEN); + address = std::string(str); + } + break; + default: + break; + } + + return address; +} + +/// +/// +/// +/// +/// +uint16_t Socket::port(const sockaddr_storage& addr) +{ + uint16_t port = 0U; + + switch (addr.ss_family) { + case AF_INET: + { + struct sockaddr_in* in; + in = (struct sockaddr_in*)& addr; + port = ntohs(in->sin_port); + } + break; + case AF_INET6: + { + struct sockaddr_in6* in6; + in6 = (struct sockaddr_in6*)& addr; + port = ntohs(in6->sin6_port); + } + break; + default: + break; + } + + return port; +} + +/// +/// +/// +/// +/// +bool Socket::isNone(const sockaddr_storage& addr) +{ + struct sockaddr_in* in = (struct sockaddr_in*)& addr; + + return ((addr.ss_family == AF_INET) && (in->sin_addr.s_addr == htonl(INADDR_NONE))); +} + +// --------------------------------------------------------------------------- +// Protected Class Members +// --------------------------------------------------------------------------- + +/// +/// +/// +/// +/// +/// +bool Socket::initSocket(const int domain, const int type, const int protocol) +{ + m_fd = ::socket(domain, type, protocol); + if (m_fd < 0) { + LogError(LOG_NET, "Cannot create the TCP socket, err: %d", errno); + return false; + } + + return true; +} + +/// +/// +/// +/// +/// +/// +bool Socket::bind(const std::string& ipAddr, const uint16_t port) +{ + m_localAddress = std::string(ipAddr); + m_localPort = port; + + sockaddr_in addr = {}; + initAddr(ipAddr, port, addr); + + socklen_t length = sizeof(addr); + bool retval = true; + if (::bind(m_fd, reinterpret_cast(&addr), length) < 0) { + LogError(LOG_NET, "Cannot bind the TCP address, err: %d", errno); + retval = false; + } + + return retval; +} + +/// +/// Helper to lookup a hostname and resolve it to an IP address. +/// +/// String containing hostname to resolve. +/// +[[nodiscard]] std::string Socket::getIpAddress(const in_addr inaddr) +{ + char* receivedAddr = ::inet_ntoa(inaddr); + if (receivedAddr == reinterpret_cast(INADDR_NONE)) + throw std::runtime_error("Invalid IP address received on readfrom."); + + return { receivedAddr }; +} + +/// +/// Initialize the sockaddr_in structure with the provided IP and port +/// +/// IP address. +/// IP address. +/// +void Socket::initAddr(const std::string& ipAddr, const int port, sockaddr_in& addr) noexcept(false) +{ + addr.sin_family = AF_INET; + if (ipAddr.empty() || ipAddr == "0.0.0.0") + addr.sin_addr.s_addr = INADDR_ANY; + else + { + if (::inet_pton(AF_INET, ipAddr.c_str(), &addr.sin_addr) <= 0) + throw std::runtime_error("Failed to parse IP address"); + } + + addr.sin_port = ::htons(port); +} \ No newline at end of file diff --git a/src/common/network/tcp/Socket.h b/src/common/network/tcp/Socket.h index f09e258a..d57c29a9 100644 --- a/src/common/network/tcp/Socket.h +++ b/src/common/network/tcp/Socket.h @@ -10,8 +10,8 @@ * Copyright (C) 2024 Bryan Biedenkapp, N2PLL * */ -#if !defined(__SOCKET_H__) -#define __SOCKET_H__ +#if !defined(__TCP_SOCKET_H__) +#define __TCP_SOCKET_H__ #include "Defines.h" #include "common/Log.h" @@ -47,321 +47,56 @@ namespace network Socket(Socket&) = delete; /// Initializes a new instance of the Socket class. - Socket() noexcept(false) : - m_fd(-1), - m_address(), - m_port(0U), - m_counter(0U) - { - /* stub */ - } + Socket(); /// Initializes a new instance of the Socket class. - /// - Socket(const int fd) noexcept : - m_fd(fd), - m_address(), - m_port(0U), - m_counter(0U) - { - /* stub */ - } + Socket(const int fd) noexcept; /// Initializes a new instance of the Socket class. - /// - /// - /// - Socket(const int domain, const int type, const int protocol) noexcept(false) : Socket() - { - initSocket(domain, type, protocol); - } + Socket(const int domain, const int type, const int protocol); /// Finalizes a instance of the Socket class. - virtual ~Socket() - { - static_cast(::shutdown(m_fd, SHUT_RDWR)); - static_cast(::close(m_fd)); - } + virtual ~Socket(); - /// - /// - /// - /// - /// - /// - void initSocket(const int domain, const int type, const int protocol) noexcept(false) - { - m_fd = ::socket(domain, type, protocol); - if (m_fd < 0) - throw std::runtime_error("Failed to create Socket"); - } + /// Accepts a pending connection request. + int accept(sockaddr* address, socklen_t* addrlen) noexcept; + /// Connects the client to a remote TCP host using the specified host name and port number. + virtual bool connect(const std::string& ipAddr, const uint16_t port); + /// Starts listening for incoming connection requests with a maximum number of pending connection. + ssize_t listen(const std::string& ipAddr, const uint16_t port, int backlog) noexcept; - /// - /// - /// - /// - /// - /// - int accept(sockaddr* address, socklen_t* addrlen) const noexcept - { - // check that the accept() won't block - int i, n; - struct pollfd pfd[TCP_SOCKET_MAX]; - for (i = n = 0; i < TCP_SOCKET_MAX; i++) { - if (m_fd >= 0) { - pfd[n].fd = m_fd; - pfd[n].events = POLLIN; - n++; - } - } - - // no socket descriptor to receive - if (n == 0) - return 0; - - // Return immediately - int ret = ::poll(pfd, n, 0); - if (ret < 0) { - LogError(LOG_NET, "Error returned from TCP poll, err: %d", errno); - return -1; - } - - int index; - for (i = 0; i < n; i++) { - // round robin - index = (i + m_counter) % n; - if (pfd[index].revents & POLLIN) - break; - } - if (i == n) - return -1; - - return ::accept(pfd[index].fd, address, addrlen); - } - - /// - /// - /// - /// - /// - /// - bool bind(const std::string& ipAddr, const uint16_t port) noexcept(false) - { - m_address = std::string(ipAddr); - m_port = port; - - sockaddr_in addr = {}; - initAddr(ipAddr, port, addr); - socklen_t length = sizeof(addr); - bool retval = true; - if (::bind(m_fd, reinterpret_cast(&addr), length) < 0) - retval = false; - - return retval; - } - - /// - /// - /// - /// - /// - /// - virtual bool connect(const std::string& ipAddr, const uint16_t port) noexcept(false) - { - sockaddr_in addr = {}; - initAddr(ipAddr, port, addr); - socklen_t length = sizeof(addr); - bool retval = true; - if (::connect(m_fd, reinterpret_cast(&addr), length) < 0) - retval = false; - return retval; - } - - /// - /// - /// - /// - /// - ssize_t listen(int backlog) const noexcept - { - LogInfoEx(LOG_NET, "Listening TCP port on %u", m_port); - return ::listen(m_fd, backlog); - } - - /// - /// Read data from the socket. - /// - /// Buffer to read data into. - /// Length of data to read. - /// - [[nodiscard]] ssize_t read(uint8_t* buffer, size_t length) const noexcept - { - // check that the read() won't block - int i, n; - struct pollfd pfd[TCP_SOCKET_MAX]; - for (i = n = 0; i < TCP_SOCKET_MAX; i++) { - if (m_fd >= 0) { - pfd[n].fd = m_fd; - pfd[n].events = POLLIN; - n++; - } - } - - // no socket descriptor to receive - if (n == 0) - return 0; - - // Return immediately - int ret = ::poll(pfd, n, 0); - if (ret < 0) { - LogError(LOG_NET, "Error returned from TCP poll, err: %d", errno); - return -1; - } - - int index; - for (i = 0; i < n; i++) { - // round robin - index = (i + m_counter) % n; - if (pfd[index].revents & POLLIN) - break; - } - if (i == n) - return 0; - - return ::read(pfd[index].fd, (char*)buffer, length); - } - - /// - /// Write data to the socket. - /// - /// Buffer containing data to write to socket. - /// Length of data to write. - /// - ssize_t write(const uint8_t* buffer, size_t length) const noexcept - { - return ::send(m_fd, buffer, length, 0); - } + /// Read data from the socket. + [[nodiscard]] virtual ssize_t read(uint8_t* buffer, size_t length) noexcept; + /// Write data to the socket. + virtual ssize_t write(const uint8_t* buffer, size_t length) noexcept; /// - static uint32_t addr(const sockaddr_storage& addr) - { - switch (addr.ss_family) { - case AF_INET: - { - struct sockaddr_in* in; - in = (struct sockaddr_in*)& addr; - return in->sin_addr.s_addr; - } - break; - case AF_INET6: - default: - return -1; - break; - } - } + static uint32_t addr(const sockaddr_storage& addr); /// - static std::string address(const sockaddr_storage& addr) - { - std::string address = std::string(); - char str[INET_ADDRSTRLEN]; - - switch (addr.ss_family) { - case AF_INET: - { - struct sockaddr_in* in; - in = (struct sockaddr_in*)& addr; - inet_ntop(AF_INET, &(in->sin_addr), str, INET_ADDRSTRLEN); - address = std::string(str); - } - break; - case AF_INET6: - { - struct sockaddr_in6* in6; - in6 = (struct sockaddr_in6*)& addr; - inet_ntop(AF_INET6, &(in6->sin6_addr), str, INET_ADDRSTRLEN); - address = std::string(str); - } - break; - default: - break; - } - - return address; - } + static std::string address(const sockaddr_storage& addr); /// - static uint16_t port(const sockaddr_storage& addr) - { - uint16_t port = 0U; - - switch (addr.ss_family) { - case AF_INET: - { - struct sockaddr_in* in; - in = (struct sockaddr_in*)& addr; - port = ntohs(in->sin_port); - } - break; - case AF_INET6: - { - struct sockaddr_in6* in6; - in6 = (struct sockaddr_in6*)& addr; - port = ntohs(in6->sin6_port); - } - break; - default: - break; - } - - return port; - } + static uint16_t port(const sockaddr_storage& addr); /// - static bool isNone(const sockaddr_storage& addr) - { - struct sockaddr_in* in = (struct sockaddr_in*)& addr; - - return ((addr.ss_family == AF_INET) && (in->sin_addr.s_addr == htonl(INADDR_NONE))); - } + static bool isNone(const sockaddr_storage& addr); protected: + std::string m_localAddress; + uint16_t m_localPort; + int m_fd; - std::string m_address; - uint16_t m_port; uint32_t m_counter; - /// - /// Helper to lookup a hostname and resolve it to an IP address. - /// - /// String containing hostname to resolve. - /// - [[nodiscard]] static std::string getIpAddress(const in_addr inaddr) noexcept(false) - { - char* receivedAddr = ::inet_ntoa(inaddr); - if (receivedAddr == reinterpret_cast(INADDR_NONE)) - throw std::runtime_error("Invalid IP address received on readfrom."); - - return { receivedAddr }; - } + /// + bool initSocket(const int domain, const int type, const int protocol); + /// + bool bind(const std::string& ipAddr, const uint16_t port); - /// - /// Initialize the sockaddr_in structure with the provided IP and port - /// - /// IP address. - /// IP address. - /// - static void initAddr(const std::string& ipAddr, const int port, sockaddr_in& addr) noexcept(false) - { - addr.sin_family = AF_INET; - if (ipAddr.empty() || ipAddr == "0.0.0.0") - addr.sin_addr.s_addr = INADDR_ANY; - else - { - if (::inet_pton(AF_INET, ipAddr.c_str(), &addr.sin_addr) <= 0) - throw std::runtime_error("Failed to parse IP address"); - } + /// Helper to lookup a hostname and resolve it to an IP address. + [[nodiscard]] static std::string getIpAddress(const in_addr inaddr); - addr.sin_port = ::htons(port); - } + /// Initialize the sockaddr_in structure with the provided IP and port. + static void initAddr(const std::string& ipAddr, const int port, sockaddr_in& addr); }; } // namespace tcp } // namespace network -#endif // __SOCKET_H__ +#endif // __TCP_SOCKET_H__ diff --git a/src/common/network/UDPSocket.cpp b/src/common/network/udp/Socket.cpp similarity index 75% rename from src/common/network/UDPSocket.cpp rename to src/common/network/udp/Socket.cpp index e485430f..c64f6c42 100644 --- a/src/common/network/UDPSocket.cpp +++ b/src/common/network/udp/Socket.cpp @@ -13,11 +13,12 @@ * */ #include "Defines.h" -#include "network/UDPSocket.h" +#include "network/udp/Socket.h" #include "Log.h" #include "Utils.h" using namespace network; +using namespace network::udp; #include #include @@ -32,15 +33,17 @@ using namespace network; // --------------------------------------------------------------------------- // Public Class Members // --------------------------------------------------------------------------- + /// -/// Initializes a new instance of the UDPSocket class. +/// Initializes a new instance of the Socket class. /// /// Hostname/IP address to connect to. /// Port number. -UDPSocket::UDPSocket(const std::string& address, uint16_t port) : - m_address_save(address), - m_port_save(port), - m_isOpen(false), +Socket::Socket(const std::string& address, uint16_t port) : + m_localAddress(address), + m_localPort(port), + m_af(AF_UNSPEC), + m_fd(-1), m_aes(nullptr), m_isCryptoWrapped(false), m_presharedKey(nullptr), @@ -48,22 +51,17 @@ UDPSocket::UDPSocket(const std::string& address, uint16_t port) : { m_aes = new crypto::AES(crypto::AESKeyLength::AES_256); m_presharedKey = new uint8_t[AES_WRAPPED_PCKT_KEY_LEN]; - for (int i = 0; i < UDP_SOCKET_MAX; i++) { - m_address[i] = ""; - m_port[i] = 0U; - m_af[i] = 0U; - m_fd[i] = -1; - } } /// -/// Initializes a new instance of the UDPSocket class. +/// Initializes a new instance of the Socket class. /// /// Port number. -UDPSocket::UDPSocket(uint16_t port) : - m_address_save(), - m_port_save(port), - m_isOpen(false), +Socket::Socket(uint16_t port) : + m_localAddress(), + m_localPort(port), + m_af(AF_UNSPEC), + m_fd(-1), m_aes(nullptr), m_isCryptoWrapped(false), m_presharedKey(nullptr), @@ -71,18 +69,12 @@ UDPSocket::UDPSocket(uint16_t port) : { m_aes = new crypto::AES(crypto::AESKeyLength::AES_256); m_presharedKey = new uint8_t[AES_WRAPPED_PCKT_KEY_LEN]; - for (int i = 0; i < UDP_SOCKET_MAX; i++) { - m_address[i] = ""; - m_port[i] = 0U; - m_af[i] = 0U; - m_fd[i] = -1; - } } /// -/// Finalizes a instance of the UDPSocket class. +/// Finalizes a instance of the Socket class. /// -UDPSocket::~UDPSocket() +Socket::~Socket() { if (m_aes != nullptr) delete m_aes; @@ -95,7 +87,7 @@ UDPSocket::~UDPSocket() /// /// /// True, if UDP socket is opened, otherwise false. -bool UDPSocket::open(const sockaddr_storage& address) +bool Socket::open(const sockaddr_storage& address) noexcept { return open(address.ss_family); } @@ -105,20 +97,19 @@ bool UDPSocket::open(const sockaddr_storage& address) /// /// /// True, if UDP socket is opened, otherwise false. -bool UDPSocket::open(uint32_t af) +bool Socket::open(uint32_t af) noexcept { - return open(0, af, m_address_save, m_port_save); + return open(af, m_localAddress, m_localPort); } /// /// Opens UDP socket connection. /// -/// /// /// /// /// True, if UDP socket is opened, otherwise false. -bool UDPSocket::open(const uint32_t index, const uint32_t af, const std::string& address, const uint16_t port) +bool Socket::open(const uint32_t af, const std::string& address, const uint16_t port) noexcept { sockaddr_storage addr; uint32_t addrlen; @@ -132,45 +123,42 @@ bool UDPSocket::open(const uint32_t index, const uint32_t af, const std::string& int err = lookup(address, port, addr, addrlen, hints); if (err != 0) { LogError(LOG_NET, "The local address is invalid - %s", address.c_str()); - m_isOpen = false; return false; } - close(index); + close(); - int fd = ::socket(addr.ss_family, SOCK_DGRAM, 0); - if (fd < 0) { - LogError(LOG_NET, "Cannot create the UDP socket, err: %d", errno); - m_isOpen = false; + if (!initSocket(addr.ss_family, SOCK_DGRAM, 0)) return false; - } - - m_address[index] = address; - m_port[index] = port; - m_af[index] = addr.ss_family; - m_fd[index] = fd; if (port > 0U) { int reuse = 1; - if (::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (char*)& reuse, sizeof(reuse)) == -1) { + if (::setsockopt(m_fd, SOL_SOCKET, SO_REUSEADDR, (char*)& reuse, sizeof(reuse)) == -1) { LogError(LOG_NET, "Cannot set the UDP socket option, err: %d", errno); - m_isOpen = false; return false; } - if (::bind(fd, (sockaddr*)& addr, addrlen) == -1) { - LogError(LOG_NET, "Cannot bind the UDP address, err: %d", errno); - m_isOpen = false; + if (!bind(address, port)) { return false; } LogInfoEx(LOG_NET, "Opening UDP port on %u", port); } - m_isOpen = true; return true; } +/// +/// Closes the UDP socket connection. +/// +void Socket::close() +{ + if (m_fd >= 0) { + ::close(m_fd); + m_fd = -1; + } +} + /// /// Read data from the UDP socket. /// @@ -179,17 +167,20 @@ bool UDPSocket::open(const uint32_t index, const uint32_t af, const std::string& /// IP address data read from. /// /// Actual length of data read from remote UDP socket. -int UDPSocket::read(uint8_t* buffer, uint32_t length, sockaddr_storage& address, uint32_t& addrLen) +ssize_t Socket::read(uint8_t* buffer, uint32_t length, sockaddr_storage& address, uint32_t& addrLen) noexcept { assert(buffer != nullptr); assert(length > 0U); + if (m_fd < 0) + return -1; + // Check that the readfrom() won't block int i, n; struct pollfd pfd[UDP_SOCKET_MAX]; for (i = n = 0; i < UDP_SOCKET_MAX; i++) { - if (m_fd[i] >= 0) { - pfd[n].fd = m_fd[i]; + if (m_fd >= 0) { + pfd[n].fd = m_fd; pfd[n].events = POLLIN; n++; } @@ -222,7 +213,7 @@ int UDPSocket::read(uint8_t* buffer, uint32_t length, sockaddr_storage& address, LogError(LOG_NET, "Error returned from recvfrom, err: %d", errno); if (len == -1 && errno == ENOTSOCK) { - LogMessage(LOG_NET, "Re-opening UDP port on %u", m_port[index]); + LogMessage(LOG_NET, "Re-opening UDP port on %u", m_localPort); close(); open(); } @@ -241,12 +232,12 @@ int UDPSocket::read(uint8_t* buffer, uint32_t length, sockaddr_storage& address, uint16_t magic = __GET_UINT16B(buffer, 0U); if (magic == AES_WRAPPED_PCKT_MAGIC) { uint32_t cryptedLen = (len - 2U) * sizeof(uint8_t); - // Utils::dump(1U, "UDPSocket::read() crypted", buffer + 2U, cryptedLen); + // Utils::dump(1U, "Socket::read() crypted", buffer + 2U, cryptedLen); // decrypt uint8_t* decrypted = m_aes->decryptECB(buffer + 2U, cryptedLen, m_presharedKey); - // Utils::dump(1U, "UDPSocket::read() decrypted", decrypted, cryptedLen); + // Utils::dump(1U, "Socket::read() decrypted", decrypted, cryptedLen); // finalize, cleanup buffers and replace with new if (decrypted != nullptr) { @@ -279,14 +270,22 @@ int UDPSocket::read(uint8_t* buffer, uint32_t length, sockaddr_storage& address, /// /// Total number of bytes written. /// True if message was sent, otherwise false. -bool UDPSocket::write(const uint8_t* buffer, uint32_t length, const sockaddr_storage& address, uint32_t addrLen, int* lenWritten) +bool Socket::write(const uint8_t* buffer, uint32_t length, const sockaddr_storage& address, uint32_t addrLen, ssize_t* lenWritten) noexcept { assert(buffer != nullptr); assert(length > 0U); - bool result = false; + if (m_fd < 0) { + if (lenWritten != nullptr) { + *lenWritten = -1; + } + + return false; + } + bool result = false; UInt8Array out = nullptr; + // are we crypto wrapped? if (m_isCryptoWrapped) { if (m_presharedKey == nullptr) { @@ -312,7 +311,7 @@ bool UDPSocket::write(const uint8_t* buffer, uint32_t length, const sockaddr_sto // encrypt uint8_t* crypted = m_aes->encryptECB(cryptoBuffer, cryptedLen, m_presharedKey); - // Utils::dump(1U, "UDPSocket::write() crypted", crypted, cryptedLen); + // Utils::dump(1U, "Socket::write() crypted", crypted, cryptedLen); // finalize, cleanup buffers and replace with new out = std::unique_ptr(new uint8_t[cryptedLen + 2U]); @@ -334,25 +333,20 @@ bool UDPSocket::write(const uint8_t* buffer, uint32_t length, const sockaddr_sto ::memcpy(out.get(), buffer, length); } - for (int i = 0; i < UDP_SOCKET_MAX; i++) { - if (m_fd[i] < 0 || m_af[i] != address.ss_family) - continue; + ssize_t sent = ::sendto(m_fd, (char*)out.get(), length, 0, (sockaddr*)& address, addrLen); + if (sent < 0) { + LogError(LOG_NET, "Error returned from sendto, err: %d", errno); - ssize_t sent = ::sendto(m_fd[i], (char*)out.get(), length, 0, (sockaddr*)& address, addrLen); - if (sent < 0) { - LogError(LOG_NET, "Error returned from sendto, err: %d", errno); - - if (lenWritten != nullptr) { - *lenWritten = -1; - } + if (lenWritten != nullptr) { + *lenWritten = -1; } - else { - if (sent == ssize_t(length)) - result = true; + } + else { + if (sent == ssize_t(length)) + result = true; - if (lenWritten != nullptr) { - *lenWritten = sent; - } + if (lenWritten != nullptr) { + *lenWritten = sent; } } @@ -367,16 +361,32 @@ bool UDPSocket::write(const uint8_t* buffer, uint32_t length, const sockaddr_sto /// /// Total number of bytes written. /// True if messages were sent, otherwise false. -bool UDPSocket::write(BufferVector& buffers, int* lenWritten) +bool Socket::write(BufferVector& buffers, ssize_t* lenWritten) noexcept { bool result = false; + if (m_fd < 0) { + if (lenWritten != nullptr) { + *lenWritten = -1; + } + + return false; + } + if (buffers.empty()) { + if (lenWritten != nullptr) { + *lenWritten = -1; + } + return false; } // bryanb: this is the same as above -- but for some assinine reason prevents // weirdness if (buffers.size() == 0U) { + if (lenWritten != nullptr) { + *lenWritten = -1; + } + return false; } @@ -384,15 +394,25 @@ bool UDPSocket::write(BufferVector& buffers, int* lenWritten) if (buffers.size() > UINT16_MAX) { LogError(LOG_NET, "Trying to send too many buffers?"); + + if (lenWritten != nullptr) { + *lenWritten = -1; + } + return false; } - // LogDebug(LOG_NET, "Sending message(s) (to %s:%u) addrLen %u", UDPSocket::address(address).c_str(), UDPSocket::port(address), addrLen); + // LogDebug(LOG_NET, "Sending message(s) (to %s:%u) addrLen %u", Socket::address(address).c_str(), Socket::port(address), addrLen); // are we crypto wrapped? if (m_isCryptoWrapped) { if (m_presharedKey == nullptr) { LogError(LOG_NET, "tried to write datagram encrypted with no key? this shouldn't happen BUGBUG"); + + if (lenWritten != nullptr) { + *lenWritten = -1; + } + return false; } } @@ -441,7 +461,7 @@ bool UDPSocket::write(BufferVector& buffers, int* lenWritten) continue; } - // Utils::dump(1U, "UDPSocket::write() crypted", crypted, cryptedLen); + // Utils::dump(1U, "Socket::write() crypted", crypted, cryptedLen); // finalize uint8_t out[cryptedLen + 2U]; @@ -475,71 +495,50 @@ bool UDPSocket::write(BufferVector& buffers, int* lenWritten) headers[i].msg_hdr.msg_controllen = 0; } - for (int i = 0; i < UDP_SOCKET_MAX; i++) { - if (m_fd[i] < 0) - continue; + bool skip = false; + for (auto& buffer : buffers) { + if (m_af != buffer->address.ss_family) { + skip = true; + break; + } + } - bool skip = false; - for (auto& buffer : buffers) { - if (m_af[i] != buffer->address.ss_family) { - skip = true; - break; - } + if (skip) { + if (lenWritten != nullptr) { + *lenWritten = -1; } - if (skip) - continue; - if (sendmmsg(m_fd[i], headers, size, 0) < 0) { - LogError(LOG_NET, "Error returned from sendmmsg, err: %d", errno); - if (lenWritten != nullptr) { - *lenWritten = -1; - } + return false; + } + + if (sendmmsg(m_fd, headers, size, 0) < 0) { + LogError(LOG_NET, "Error returned from sendmmsg, err: %d", errno); + if (lenWritten != nullptr) { + *lenWritten = -1; } + } - if (sent < 0) { - LogError(LOG_NET, "Error returned from sendmmsg, err: %d", errno); - if (lenWritten != nullptr) { - *lenWritten = -1; - } + if (sent < 0) { + LogError(LOG_NET, "Error returned from sendmmsg, err: %d", errno); + if (lenWritten != nullptr) { + *lenWritten = -1; } - else { - result = true; - if (lenWritten != nullptr) { - *lenWritten = sent; - } + } + else { + result = true; + if (lenWritten != nullptr) { + *lenWritten = sent; } } return result; } -/// -/// Closes the UDP socket connection. -/// -void UDPSocket::close() -{ - for (int i = 0; i < UDP_SOCKET_MAX; i++) - close(i); - m_isOpen = false; -} - -/// -/// Closes the UDP socket connection. -/// -/// -void UDPSocket::close(const uint32_t index) -{ - if ((index < UDP_SOCKET_MAX) && (m_fd[index] >= 0)) { - ::close(m_fd[index]); - m_fd[index] = -1; - } -} - /// /// Sets the preshared encryption key. /// /// -void UDPSocket::setPresharedKey(const uint8_t* presharedKey) +void Socket::setPresharedKey(const uint8_t* presharedKey) { if (presharedKey != nullptr) { ::memset(m_presharedKey, 0x00U, AES_WRAPPED_PCKT_KEY_LEN); @@ -559,7 +558,7 @@ void UDPSocket::setPresharedKey(const uint8_t* presharedKey) /// Socket address structure. /// /// Zero if no error during lookup, otherwise error. -int UDPSocket::lookup(const std::string& hostname, uint16_t port, sockaddr_storage& addr, uint32_t& addrLen) +int Socket::lookup(const std::string& hostname, uint16_t port, sockaddr_storage& addr, uint32_t& addrLen) { struct addrinfo hints; ::memset(&hints, 0, sizeof(hints)); @@ -576,7 +575,7 @@ int UDPSocket::lookup(const std::string& hostname, uint16_t port, sockaddr_stora /// /// /// Zero if no error during lookup, otherwise error. -int UDPSocket::lookup(const std::string& hostname, uint16_t port, sockaddr_storage& addr, uint32_t& addrLen, struct addrinfo& hints) +int Socket::lookup(const std::string& hostname, uint16_t port, sockaddr_storage& addr, uint32_t& addrLen, struct addrinfo& hints) { std::string portstr = std::to_string(port); struct addrinfo* res; @@ -609,7 +608,7 @@ int UDPSocket::lookup(const std::string& hostname, uint16_t port, sockaddr_stora /// /// /// -bool UDPSocket::match(const sockaddr_storage& addr1, const sockaddr_storage& addr2, IPMATCHTYPE type) +bool Socket::match(const sockaddr_storage& addr1, const sockaddr_storage& addr2, IPMATCHTYPE type) { if (addr1.ss_family != addr2.ss_family) return false; @@ -656,7 +655,7 @@ bool UDPSocket::match(const sockaddr_storage& addr1, const sockaddr_storage& add /// /// /// -std::string UDPSocket::address(const sockaddr_storage& addr) +std::string Socket::address(const sockaddr_storage& addr) { std::string address = std::string(); char str[INET_ADDRSTRLEN]; @@ -690,7 +689,7 @@ std::string UDPSocket::address(const sockaddr_storage& addr) /// /// /// -uint16_t UDPSocket::port(const sockaddr_storage& addr) +uint16_t Socket::port(const sockaddr_storage& addr) { uint16_t port = 0U; @@ -721,9 +720,75 @@ uint16_t UDPSocket::port(const sockaddr_storage& addr) /// /// /// -bool UDPSocket::isNone(const sockaddr_storage& addr) +bool Socket::isNone(const sockaddr_storage& addr) { struct sockaddr_in* in = (struct sockaddr_in*)& addr; return ((addr.ss_family == AF_INET) && (in->sin_addr.s_addr == htonl(INADDR_NONE))); } + +// --------------------------------------------------------------------------- +// Protected Class Members +// --------------------------------------------------------------------------- + +/// +/// +/// +/// +/// +/// +bool Socket::initSocket(const int domain, const int type, const int protocol) noexcept(false) +{ + m_fd = ::socket(domain, type, protocol); + if (m_fd < 0) { + LogError(LOG_NET, "Cannot create the UDP socket, err: %d", errno); + return false; + } + + m_af = domain; + return true; +} + +/// +/// +/// +/// +/// +/// +bool Socket::bind(const std::string& ipAddr, const uint16_t port) noexcept(false) +{ + m_localAddress = std::string(ipAddr); + m_localPort = port; + + sockaddr_in addr = {}; + initAddr(ipAddr, port, addr); + + socklen_t length = sizeof(addr); + bool retval = true; + if (::bind(m_fd, reinterpret_cast(&addr), length) < 0) { + LogError(LOG_NET, "Cannot bind the UDP address, err: %d", errno); + retval = false; + } + + return retval; +} + +/// +/// Initialize the sockaddr_in structure with the provided IP and port +/// +/// IP address. +/// IP address. +/// +void Socket::initAddr(const std::string& ipAddr, const int port, sockaddr_in& addr) noexcept(false) +{ + addr.sin_family = AF_INET; + if (ipAddr.empty() || ipAddr == "0.0.0.0") + addr.sin_addr.s_addr = INADDR_ANY; + else + { + if (::inet_pton(AF_INET, ipAddr.c_str(), &addr.sin_addr) <= 0) + throw std::runtime_error("Failed to parse IP address"); + } + + addr.sin_port = ::htons(port); +} \ No newline at end of file diff --git a/src/common/network/udp/Socket.h b/src/common/network/udp/Socket.h new file mode 100644 index 00000000..51286732 --- /dev/null +++ b/src/common/network/udp/Socket.h @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: GPL-2.0-only +/** +* Digital Voice Modem - Common Library +* GPLv2 Open Source. Use is subject to license terms. +* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. +* +* @package DVM / Common Library +* @derivedfrom MMDVMHost (https://github.com/g4klx/MMDVMHost) +* @license GPLv2 License (https://opensource.org/licenses/GPL-2.0) +* +* Copyright (C) 2006-2016,2020 Jonathan Naylor, G4KLX +* Copyright (C) 2017-2024 Bryan Biedenkapp, N2PLL +* +*/ +#if !defined(__UDP_SOCKET_H__) +#define __UDP_SOCKET_H__ + +#include "common/Defines.h" +#include "common/AESCrypto.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if !defined(UDP_SOCKET_MAX) +#define UDP_SOCKET_MAX 1 +#endif + +#define AES_WRAPPED_PCKT_MAGIC 0xC0FEU +#define AES_WRAPPED_PCKT_KEY_LEN 32 + +enum IPMATCHTYPE { + IMT_ADDRESS_AND_PORT, + IMT_ADDRESS_ONLY +}; + +namespace network +{ + namespace udp + { +#if defined(HAVE_SENDMSG) && !defined(HAVE_SENDMMSG) + /* For `sendmmsg'. */ + struct mmsghdr { + struct msghdr msg_hdr; /* Actual message header. */ + unsigned int msg_len; /* Number of received or sent bytes for the entry. */ + }; + + /* Send a VLEN messages as described by VMESSAGES to socket FD. + Returns the number of datagrams successfully written or -1 for errors. + + This function is a cancellation point and therefore not marked with + __THROW. */ + static inline int sendmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags) + { + ssize_t n = 0; + for (unsigned int i = 0; i < vlen; i++) { + ssize_t ret = sendmsg(sockfd, &msgvec[i].msg_hdr, flags); + if (ret < 0) + break; + n += ret; + } + + if (n == 0) + return -1; + + return int(n); + } +#endif + + // --------------------------------------------------------------------------- + // Structure Declaration + // This structure represents a container for a network buffer. + // --------------------------------------------------------------------------- + + struct UDPDatagram { + uint8_t* buffer; + size_t length; + + sockaddr_storage address; + uint32_t addrLen; + }; + + /* Vector of buffers that contain a full frames */ + typedef std::vector BufferVector; + + // --------------------------------------------------------------------------- + // Class Declaration + // This class implements low-level routines to communicate over a UDP + // network socket. + // --------------------------------------------------------------------------- + + class HOST_SW_API Socket { + public: + auto operator=(Socket&) -> Socket& = delete; + auto operator=(Socket&&) -> Socket& = delete; + Socket(Socket&) = delete; + + /// Initializes a new instance of the Socket class. + Socket(const std::string& address, uint16_t port = 0U); + /// Initializes a new instance of the Socket class. + Socket(uint16_t port = 0U); + /// Finalizes a instance of the Socket class. + virtual ~Socket(); + + /// Opens UDP socket connection. + bool open(const sockaddr_storage& address) noexcept; + /// Opens UDP socket connection. + bool open(uint32_t af = AF_UNSPEC) noexcept; + /// Opens UDP socket connection. + bool open(const uint32_t af, const std::string& address, const uint16_t port) noexcept; + + /// Closes the UDP socket connection. + void close(); + + /// Read data from the UDP socket. + virtual ssize_t read(uint8_t* buffer, uint32_t length, sockaddr_storage& address, uint32_t& addrLen) noexcept; + /// Write data to the UDP socket. + virtual bool write(const uint8_t* buffer, uint32_t length, const sockaddr_storage& address, uint32_t addrLen, ssize_t* lenWritten = nullptr) noexcept; + /// Write data to the UDP socket. + virtual bool write(BufferVector& buffers, ssize_t* lenWritten = nullptr) noexcept; + + /// Sets the preshared encryption key. + void setPresharedKey(const uint8_t* presharedKey); + + /// Helper to lookup a hostname and resolve it to an IP address. + static int lookup(const std::string& hostName, uint16_t port, sockaddr_storage& address, uint32_t& addrLen); + /// Helper to lookup a hostname and resolve it to an IP address. + static int lookup(const std::string& hostName, uint16_t port, sockaddr_storage& address, uint32_t& addrLen, struct addrinfo& hints); + + /// + static bool match(const sockaddr_storage& addr1, const sockaddr_storage& addr2, IPMATCHTYPE type = IMT_ADDRESS_AND_PORT); + + /// + static uint32_t addr(const sockaddr_storage& addr); + /// + static std::string address(const sockaddr_storage& addr); + /// + static uint16_t port(const sockaddr_storage& addr); + + /// + static bool isNone(const sockaddr_storage& addr); + + private: + std::string m_localAddress; + uint16_t m_localPort; + + uint32_t m_af; + int m_fd; + + crypto::AES* m_aes; + bool m_isCryptoWrapped; + uint8_t* m_presharedKey; + + uint32_t m_counter; + + /// + bool initSocket(const int domain, const int type, const int protocol); + /// + bool bind(const std::string& ipAddr, const uint16_t port); + + /// Initialize the sockaddr_in structure with the provided IP and port. + static void initAddr(const std::string& ipAddr, const int port, sockaddr_in& addr); + }; + } // namespace udp +} // namespace network + +#endif // __UDP_SOCKET_H__ diff --git a/src/fne/HostFNE.cpp b/src/fne/HostFNE.cpp index 80d2be4d..8669b291 100644 --- a/src/fne/HostFNE.cpp +++ b/src/fne/HostFNE.cpp @@ -11,7 +11,7 @@ * */ #include "Defines.h" -#include "common/network/UDPSocket.h" +#include "common/network/udp/Socket.h" #include "common/Log.h" #include "common/StopWatch.h" #include "common/Thread.h" diff --git a/src/fne/network/FNENetwork.cpp b/src/fne/network/FNENetwork.cpp index b7843c02..355094e4 100644 --- a/src/fne/network/FNENetwork.cpp +++ b/src/fne/network/FNENetwork.cpp @@ -237,7 +237,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); // validate peer (simple validation really) if (connection->connected() && connection->address() == ip) { @@ -254,7 +254,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); // validate peer (simple validation really) if (connection->connected() && connection->address() == ip) { @@ -271,7 +271,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); // validate peer (simple validation really) if (connection->connected() && connection->address() == ip) { @@ -477,7 +477,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); // validate peer (simple validation really) if (connection->connected() && connection->address() == ip) { @@ -496,7 +496,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); // validate peer (simple validation really) if (connection->connected() && connection->address() == ip) { @@ -527,7 +527,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); // validate peer (simple validation really) if (connection->connected() && connection->address() == ip) { @@ -548,7 +548,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); // validate peer (simple validation really) if (connection->connected() && connection->address() == ip) { @@ -571,7 +571,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); // validate peer (simple validation really) if (connection->connected() && connection->address() == ip) { @@ -604,7 +604,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); lookups::AffiliationLookup* aff = m_peerAffiliations[peerId]; // validate peer (simple validation really) @@ -624,7 +624,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); lookups::AffiliationLookup* aff = m_peerAffiliations[peerId]; // validate peer (simple validation really) @@ -642,7 +642,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); lookups::AffiliationLookup* aff = m_peerAffiliations[peerId]; // validate peer (simple validation really) @@ -660,7 +660,7 @@ void FNENetwork::clock(uint32_t ms) if (peerId > 0 && (m_peers.find(peerId) != m_peers.end())) { FNEPeerConnection* connection = m_peers[peerId]; if (connection != nullptr) { - std::string ip = UDPSocket::address(address); + std::string ip = udp::Socket::address(address); // validate peer (simple validation really) if (connection->connected() && connection->address() == ip) { @@ -728,7 +728,7 @@ bool FNENetwork::open() m_status = NET_STAT_MST_RUNNING; m_maintainenceTimer.start(); - m_socket = new UDPSocket(m_address, m_port); + m_socket = new udp::Socket(m_address, m_port); // reinitialize the frame queue if (m_frameQueue != nullptr) { diff --git a/src/fne/network/FNENetwork.h b/src/fne/network/FNENetwork.h index 93dd312b..85f8eb86 100644 --- a/src/fne/network/FNENetwork.h +++ b/src/fne/network/FNENetwork.h @@ -76,8 +76,8 @@ namespace network m_currStreamId(0U), m_socketStorage(socketStorage), m_sockStorageLen(sockStorageLen), - m_address(UDPSocket::address(socketStorage)), - m_port(UDPSocket::port(socketStorage)), + m_address(udp::Socket::address(socketStorage)), + m_port(udp::Socket::port(socketStorage)), m_salt(0U), m_connected(false), m_connectionState(NET_STAT_INVALID), diff --git a/src/fne/network/RESTAPI.h b/src/fne/network/RESTAPI.h index b09cd821..8ba496a4 100644 --- a/src/fne/network/RESTAPI.h +++ b/src/fne/network/RESTAPI.h @@ -14,7 +14,6 @@ #define __REST_API_H__ #include "fne/Defines.h" -#include "common/network/UDPSocket.h" #include "common/network/rest/RequestDispatcher.h" #include "common/network/rest/http/HTTPServer.h" #include "common/lookups/RadioIdLookup.h" diff --git a/src/host/Host.Config.cpp b/src/host/Host.Config.cpp index 0a188a22..4571ca79 100644 --- a/src/host/Host.Config.cpp +++ b/src/host/Host.Config.cpp @@ -11,7 +11,7 @@ * */ #include "Defines.h" -#include "common/network/UDPSocket.h" +#include "common/network/udp/Socket.h" #include "modem/port/ModemNullPort.h" #include "modem/port/UARTPort.h" #include "modem/port/PseudoPTYPort.h" diff --git a/src/host/modem/port/UDPPort.cpp b/src/host/modem/port/UDPPort.cpp index 818108d1..2710614f 100644 --- a/src/host/modem/port/UDPPort.cpp +++ b/src/host/modem/port/UDPPort.cpp @@ -47,11 +47,11 @@ UDPPort::UDPPort(const std::string& address, uint16_t modemPort) : assert(!address.empty()); assert(modemPort > 0U); - if (UDPSocket::lookup(address, modemPort, m_addr, m_addrLen) != 0) + if (udp::Socket::lookup(address, modemPort, m_addr, m_addrLen) != 0) m_addrLen = 0U; if (m_addrLen > 0U) { - std::string addrStr = UDPSocket::address(m_addr); + std::string addrStr = udp::Socket::address(m_addr); LogWarning(LOG_HOST, "SECURITY: Remote modem expects IP address; %s for remote modem control", addrStr.c_str()); } } @@ -99,11 +99,11 @@ int UDPPort::read(uint8_t* buffer, uint32_t length) // Add new data to the ring buffer if (ret > 0) { - if (UDPSocket::match(addr, m_addr)) { + if (udp::Socket::match(addr, m_addr)) { m_buffer.addData(data, ret); } else { - std::string addrStr = UDPSocket::address(addr); + std::string addrStr = udp::Socket::address(addr); LogWarning(LOG_HOST, "SECURITY: Remote modem mode encountered invalid IP address; %s", addrStr.c_str()); } } diff --git a/src/host/modem/port/UDPPort.h b/src/host/modem/port/UDPPort.h index e8bf4c8f..4fffee8d 100644 --- a/src/host/modem/port/UDPPort.h +++ b/src/host/modem/port/UDPPort.h @@ -16,7 +16,7 @@ #define __UDP_PORT_H__ #include "Defines.h" -#include "common/network/UDPSocket.h" +#include "common/network/udp/Socket.h" #include "common/RingBuffer.h" #include "modem/port/IModemPort.h" @@ -50,7 +50,7 @@ namespace modem void close() override; protected: - network::UDPSocket m_socket; + network::udp::Socket m_socket; sockaddr_storage m_addr; uint32_t m_addrLen; diff --git a/src/host/network/Network.cpp b/src/host/network/Network.cpp index 8060af4e..fa97e011 100644 --- a/src/host/network/Network.cpp +++ b/src/host/network/Network.cpp @@ -255,7 +255,7 @@ void Network::clock(uint32_t ms) // read message UInt8Array buffer = m_frameQueue->read(length, address, addrLen, &rtpHeader, &fneHeader); if (length > 0) { - if (!UDPSocket::match(m_addr, address)) { + if (!udp::Socket::match(m_addr, address)) { LogError(LOG_NET, "Packet received from an invalid source"); return; } @@ -580,7 +580,7 @@ bool Network::open() if (m_debug) LogMessage(LOG_NET, "Opening Network"); - if (UDPSocket::lookup(m_address, m_port, m_addr, m_addrLen) != 0) { + if (udp::Socket::lookup(m_address, m_port, m_addr, m_addrLen) != 0) { LogMessage(LOG_NET, "Could not lookup the address of the master"); return false; } diff --git a/src/host/network/RESTAPI.h b/src/host/network/RESTAPI.h index 120b62e4..a4cf31d2 100644 --- a/src/host/network/RESTAPI.h +++ b/src/host/network/RESTAPI.h @@ -14,7 +14,6 @@ #define __REST_API_H__ #include "Defines.h" -#include "common/network/UDPSocket.h" #include "common/network/rest/RequestDispatcher.h" #include "common/network/rest/http/HTTPServer.h" #include "common/lookups/RadioIdLookup.h"