Improve WebSocket mask handling (#875)

This commit is contained in:
Martin Chang 2021-05-29 15:11:41 +08:00 committed by GitHub
parent 1bddbb117a
commit ffc410a66e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 105 additions and 28 deletions

View File

@ -64,7 +64,7 @@ if (BUILD_DROGON_SHARED)
SOVERSION ${DROGON_MAJOR_VERSION})
target_link_libraries(${PROJECT_NAME} PUBLIC Threads::Threads)
if (WIN32)
target_link_libraries(${PROJECT_NAME} PUBLIC Rpcrt4 ws2_32)
target_link_libraries(${PROJECT_NAME} PUBLIC Rpcrt4 ws2_32 crypt32 Advapi32)
if (CMAKE_CXX_COMPILER_ID MATCHES MSVC)
# Ignore MSVC C4251 and C4275 warning of exporting std objects with no dll export
# We export class to facilitate maintenance, thus if you compile

View File

@ -163,5 +163,18 @@ DROGON_EXPORT void replaceAll(std::string &s,
const std::string &from,
const std::string &to);
/**
* @brief Generates cryptographically secure random bytes.
*
* @param ptr the pointer which the random bytes are stored to
* @param size number of bytes to generate
*
* @return true if generation is successfull. False otherwise
*
* @note DO NOT abuse this function. Especially if Drogon is built without
* OpenSSL. Entropy running low is a real issue.
*/
DROGON_EXPORT bool secureRandomBytes(void *ptr, size_t size);
} // namespace utils
} // namespace drogon

View File

@ -17,6 +17,7 @@
#include <drogon/config.h>
#ifdef OpenSSL_FOUND
#include <openssl/md5.h>
#include <openssl/rand.h>
#else
#include "ssl_funcs/Md5.h"
#endif
@ -28,8 +29,10 @@
#include <Rpc.h>
#include <direct.h>
#include <io.h>
#include <ntsecapi.h>
#else
#include <uuid.h>
#include <unistd.h>
#endif
#include <zlib.h>
#include <iomanip>
@ -44,9 +47,6 @@
#include <cstdlib>
#include <stdio.h>
#include <string.h>
#ifndef _WIN32
#include <unistd.h>
#endif
#include <sys/stat.h>
#include <fcntl.h>
#include <stdarg.h>
@ -1199,5 +1199,48 @@ void replaceAll(std::string &s, const std::string &from, const std::string &to)
}
}
/**
* @brief Generates `size` random bytes from the systems random source and
* stores them into `ptr`.
*/
static bool systemRandomBytes(void *ptr, size_t size)
{
#if defined(__BSD__) || defined(__APPLE__)
arc4random_buf(ptr, size);
return true;
#elif defined(__linux__) && \
((defined(__GLIBC__) && \
(__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 25))))
return getentropy(ptr, size) != -1;
#elif defined(_WIN32) // Windows
return RtlGenRandom(ptr, size);
#elif defined(__unix__) // fallback to /dev/urandom for other/old UNIX
static std::unique_ptr<FILE, std::function<void(FILE *)> > fptr(
fopen("/dev/urandom", "rb"), [](FILE *ptr) {
if (ptr != nullptr)
fclose(ptr);
});
if (fptr == nullptr)
{
LOG_FATAL << "Failed to open /dev/urandom for randomness";
abort();
}
if (fread(ptr, 1, size, fptr.get()) != 0)
return true;
#endif
return false;
}
bool secureRandomBytes(void *ptr, size_t size)
{
#ifdef OpenSSL_FOUND
if (RAND_bytes((unsigned char *)ptr, size) == 0)
return true;
#endif
if (systemRandomBytes(ptr, size))
return true;
return false;
}
} // namespace utils
} // namespace drogon

View File

@ -23,7 +23,8 @@ WebSocketConnectionImpl::WebSocketConnectionImpl(
: tcpConnectionPtr_(conn),
localAddr_(conn->localAddr()),
peerAddr_(conn->peerAddr()),
isServer_(isServer)
isServer_(isServer),
usingMask_(false)
{
}
WebSocketConnectionImpl::~WebSocketConnectionImpl()
@ -105,16 +106,42 @@ void WebSocketConnectionImpl::sendWsData(const char *msg,
}
if (!isServer_)
{
// Add masking key;
static std::once_flag once;
std::call_once(once, []() {
std::srand(static_cast<unsigned int>(time(nullptr)));
});
int random = std::rand();
int random;
// Use the cached randomness if no one else is also using it. Otherwise
// generate one from scratch.
if (!usingMask_.exchange(true))
{
if (masks_.empty())
{
masks_.resize(16);
bool status =
utils::secureRandomBytes(masks_.data(),
masks_.size() * sizeof(uint32_t));
if (status == false)
{
LOG_ERROR << "Failed to generate random numbers for "
"WebSocket mask";
abort();
}
}
random = masks_.back();
masks_.pop_back();
usingMask_ = false;
}
else
{
bool status = utils::secureRandomBytes(&random, sizeof(random));
if (status == false)
{
LOG_ERROR
<< "Failed to generate random numbers for WebSocket mask";
abort();
}
}
bytesFormatted[1] = (bytesFormatted[1] | 0x80);
bytesFormatted.resize(indexStartRawData + 4 + len);
*((int *)&bytesFormatted[indexStartRawData]) = random;
memcpy(&bytesFormatted[indexStartRawData], &random, sizeof(random));
for (size_t i = 0; i < len; ++i)
{
bytesFormatted[indexStartRawData + 4 + i] =

View File

@ -108,6 +108,8 @@ class WebSocketConnectionImpl final
bool isServer_{true};
WebSocketMessageParser parser_;
trantor::TimerId pingTimerId_{trantor::InvalidTimerId};
std::vector<uint32_t> masks_;
std::atomic<bool> usingMask_;
std::function<void(std::string &&,
const WebSocketConnectionImplPtr &,

View File

@ -308,8 +308,14 @@ void Sqlite3Connection::disconnect()
std::promise<int> pro;
auto f = pro.get_future();
auto thisPtr = shared_from_this();
loopThread_.getLoop()->runInLoop([thisPtr, &pro]() {
thisPtr->connectionPtr_.reset();
std::weak_ptr<Sqlite3Connection> weakPtr = thisPtr;
loopThread_.getLoop()->runInLoop([weakPtr, &pro]() {
{
auto thisPtr = weakPtr.lock();
if (!thisPtr)
return;
thisPtr->connectionPtr_.reset();
}
pro.set_value(1);
});
f.get();

View File

@ -1928,20 +1928,6 @@ int main(int argc, char **argv)
#if USE_SQLITE3
sqlite3Client = DbClient::newSqlite3Client("filename=:memory:", 1);
#endif
int testStatus = test::run(argc, argv);
std::this_thread::sleep_for(0.008s);
// Destruct the clients before event loop shutdown
#if USE_MYSQL
mysqlClient.reset();
#endif
#if USE_POSTGRESQL
postgreClient.reset();
#endif
#if USE_SQLITE3
sqlite3Client.reset();
#endif
return testStatus;
}