Skip to content

Commit 4a0e544

Browse files
committed
fix: improve error handling and validation in various components
1 parent e9c83b9 commit 4a0e544

6 files changed

Lines changed: 48 additions & 15 deletions

File tree

deps/rocketmq/src/ClientRemotingProcessor.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ RemotingCommand* ClientRemotingProcessor::processRequest(TcpTransportPtr channel
6363

6464
RemotingCommand* ClientRemotingProcessor::checkTransactionState(const std::string& addr, RemotingCommand* request) {
6565
auto* requestHeader = request->decodeCommandCustomHeader<CheckTransactionStateRequestHeader>();
66-
assert(requestHeader != nullptr);
66+
if (requestHeader == nullptr) {
67+
LOG_ERROR_NEW("Failed to decode CheckTransactionStateRequestHeader");
68+
return RemotingCommand::createResponseCommand(ResponseCode::SYSTEM_ERROR, "Invalid request header");
69+
}
6770

6871
auto requestBody = request->body();
6972
if (requestBody != nullptr && requestBody->size() > 0) {

deps/rocketmq/src/message/MessageClientIDSetter.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,22 @@ void MessageClientIDSetter::setStartTime(uint64_t millis) {
7070
// Although not defined, this is almost always an integral value holding the number of seconds
7171
// (not counting leap seconds) since 00:00, Jan 1 1970 UTC, corresponding to POSIX time.
7272
std::time_t tmNow = millis / 1000;
73-
std::tm* ptmNow = std::localtime(&tmNow); // may not be thread-safe
73+
std::tm tmResult;
74+
#ifdef WIN32
75+
if (localtime_s(&tmResult, &tmNow) != 0) {
76+
start_time_ = millis;
77+
next_start_time_ = millis;
78+
return;
79+
}
80+
std::tm* ptmNow = &tmResult;
81+
#else
82+
std::tm* ptmNow = localtime_r(&tmNow, &tmResult);
83+
if (ptmNow == nullptr) {
84+
start_time_ = millis;
85+
next_start_time_ = millis;
86+
return;
87+
}
88+
#endif
7489

7590
std::tm curMonthBegin = {0};
7691
curMonthBegin.tm_year = ptmNow->tm_year; // since 1900

deps/rocketmq/src/message/MessageDecoder.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ MessageExtPtr MessageDecoder::decode(ByteBuffer& byteBuffer, bool readBody, bool
137137

138138
// 12 STOREHOST
139139
int storehostIPLength = (sysFlag & MessageSysFlag::STOREHOST_V6_FLAG) == 0 ? kIPv4AddrSize : kIPv6AddrSize;
140-
ByteArray storeHost(bornHostLength);
140+
ByteArray storeHost(storehostIPLength);
141141
byteBuffer.get(storeHost, 0, storehostIPLength);
142142
int32_t storePort = byteBuffer.getInt();
143143
msgExt->set_store_host(GetSockaddrPtr(IPPortToSockaddr(storeHost, static_cast<uint16_t>(storePort))));
@@ -154,6 +154,15 @@ MessageExtPtr MessageDecoder::decode(ByteBuffer& byteBuffer, bool readBody, bool
154154
int uncompress_failed = false;
155155
int32_t bodyLen = byteBuffer.getInt();
156156
if (bodyLen > 0) {
157+
if (bodyLen > byteBuffer.remaining()) {
158+
LOG_ERROR_NEW("Invalid bodyLen: {} exceeds buffer size: {}", bodyLen, byteBuffer.remaining());
159+
return nullptr;
160+
}
161+
const int32_t MAX_BODY_SIZE = 4 * 1024 * 1024;
162+
if (bodyLen > MAX_BODY_SIZE) {
163+
LOG_ERROR_NEW("bodyLen {} exceeds maximum allowed size {}", bodyLen, MAX_BODY_SIZE);
164+
return nullptr;
165+
}
157166
if (readBody) {
158167
ByteArray body(byteBuffer.array() + byteBuffer.arrayOffset() + byteBuffer.position(), bodyLen);
159168
byteBuffer.position(byteBuffer.position() + bodyLen);

deps/rocketmq/src/transport/EventLoop.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ EventLoop::~EventLoop() {
6868
}
6969

7070
void EventLoop::start() {
71-
if (!is_running_) {
72-
is_running_ = true;
71+
bool expected = false;
72+
if (is_running_.compare_exchange_strong(expected, true)) {
7373
loop_thread_.start();
7474
}
7575
}
7676

7777
void EventLoop::stop() {
78-
if (is_running_) {
79-
is_running_ = false;
78+
bool expected = true;
79+
if (is_running_.compare_exchange_strong(expected, false)) {
8080
loop_thread_.join();
8181
}
8282
}

deps/rocketmq/src/transport/EventLoop.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <event2/bufferevent.h>
2222
#include <event2/event.h>
2323

24+
#include <atomic> // std::atomic
2425
#include <functional> // std::function
2526
#include <memory> // std::unique_ptr
2627

@@ -55,7 +56,7 @@ class EventLoop : public noncopyable {
5556
struct event_base* event_base_;
5657
thread loop_thread_;
5758

58-
bool is_running_; // aotmic is unnecessary
59+
std::atomic<bool> is_running_{false};
5960
};
6061

6162
class TcpTransport;

deps/rocketmq/src/transport/SocketUtil.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
*/
1717
#include "SocketUtil.h"
1818

19-
#include <cstdlib> // std::abort
20-
#include <cstring> // std::memcpy, std::memset
19+
#include <algorithm> // std::all_of
20+
#include <cctype> // std::isdigit
21+
#include <cstdlib> // std::abort
22+
#include <cstring> // std::memcpy, std::memset
2123

2224
#include <iostream>
2325
#include <limits>
24-
#include <memory> // std::unique_ptr
26+
#include <memory> // std::unique_ptr
2527
#include <stdexcept> // std::invalid_argument, std::runtime_error
2628
#include <string>
2729

@@ -46,7 +48,7 @@ void SafeMemcpy(void* dest, size_t dest_size, const void* src, size_t src_size)
4648
if (src_size > dest_size) {
4749
throw std::invalid_argument("source size exceeds destination buffer size");
4850
}
49-
std::memcpy(dest, src, src_size);
51+
std::memmove(dest, src, src_size);
5052
}
5153

5254
std::unique_ptr<sockaddr_storage> SockaddrToStorage(const sockaddr* src) {
@@ -134,15 +136,18 @@ std::unique_ptr<sockaddr_storage> StringToSockaddr(const std::string& addr) {
134136
uint16_t port_num = 0;
135137

136138
if (!port_str.empty()) {
139+
if (!std::all_of(port_str.begin(), port_str.end(), [](unsigned char c) { return std::isdigit(c); })) {
140+
throw std::invalid_argument("port contains non-digit characters: " + port_str);
141+
}
137142
try {
138143
uint32_t n = std::stoul(port_str);
144+
if (n == 0) {
145+
throw std::invalid_argument("port cannot be zero");
146+
}
139147
if (n > std::numeric_limits<uint16_t>::max()) {
140148
throw std::out_of_range("port is too large: " + std::to_string(n) +
141149
" (max: " + std::to_string(std::numeric_limits<uint16_t>::max()) + ")");
142150
}
143-
if (n == 0) {
144-
throw std::invalid_argument("port cannot be zero");
145-
}
146151
port_num = htons(static_cast<uint16_t>(n));
147152
} catch (const std::exception& e) {
148153
throw std::invalid_argument("invalid port: " + port_str + " (" + e.what() + ")");

0 commit comments

Comments
 (0)