diff --git a/src/common/network/NetRPC.cpp b/src/common/network/NetRPC.cpp index 6005eccc..3368d3ea 100644 --- a/src/common/network/NetRPC.cpp +++ b/src/common/network/NetRPC.cpp @@ -25,7 +25,7 @@ using namespace network::frame; #include // --------------------------------------------------------------------------- -// Public Class Members +// Constants // --------------------------------------------------------------------------- #define REPLY_WAIT 200 // 200ms diff --git a/src/common/network/PacketBuffer.cpp b/src/common/network/PacketBuffer.cpp index 48a4bb78..95f5e35b 100644 --- a/src/common/network/PacketBuffer.cpp +++ b/src/common/network/PacketBuffer.cpp @@ -17,6 +17,13 @@ using namespace compress; #include +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +#define MAX_FRAGMENT_SIZE 8192 * 1024 // 8MB max + + // --------------------------------------------------------------------------- // Public Class Members // --------------------------------------------------------------------------- @@ -59,6 +66,14 @@ bool PacketBuffer::decode(const uint8_t* data, uint8_t** message, uint32_t* outL uint32_t size = GET_UINT32(data, 0U); uint32_t compressedSize = GET_UINT32(data, 4U); + // make sure we can't exceed max fragment size -- prevent potential DOS attack by sending + // enormous fragment sizes + if (size > MAX_FRAGMENT_SIZE || compressedSize > MAX_FRAGMENT_SIZE) { + LogError(LOG_NET, "%s, fragment size exceeds maximum. BUGBUG.", m_name); + delete frag; + return false; + } + frag->size = size; frag->compressedSize = compressedSize; } diff --git a/src/common/network/udp/Socket.cpp b/src/common/network/udp/Socket.cpp index 748f1226..f607de18 100644 --- a/src/common/network/udp/Socket.cpp +++ b/src/common/network/udp/Socket.cpp @@ -325,6 +325,12 @@ ssize_t Socket::read(uint8_t* buffer, uint32_t length, sockaddr_storage& address // does the network packet contain the appropriate magic leader? uint16_t magic = GET_UINT16(buffer, 0U); if (magic == AES_WRAPPED_PCKT_MAGIC) { + // prevent malicious packets that are too short + if (len < 2U + crypto::AES::BLOCK_BYTES_LEN) { + LogError(LOG_NET, "Encrypted packet too short"); + return -1; + } + uint32_t cryptedLen = (len - 2U) * sizeof(uint8_t); uint8_t* cryptoBuffer = buffer + 2U;