diff --git a/groups/bal/balber/balber_berutil.cpp b/groups/bal/balber/balber_berutil.cpp index 7b5e5a121f..7c3421025a 100644 --- a/groups/bal/balber/balber_berutil.cpp +++ b/groups/bal/balber/balber_berutil.cpp @@ -84,6 +84,45 @@ const u::Uint64 BerUtil_64BitFloatingPointMasks:: const u::Uint64 BerUtil_64BitFloatingPointMasks::k_DOUBLE_SIGN_MASK = 0x8000000000000000ULL; + // ===================== + // class ReadRestFunctor + // ===================== + +class ReadRestFunctor { + // A functor for 'string::resize_and_overwrite'. Appends read bytes to the + // buffer. + + // DATA + bsl::streambuf *d_streamBuf; + int d_oldSize; + public: + // CREATORS + ReadRestFunctor(bsl::streambuf *streamBuf, int oldSize); + + // MODIFIERS + size_t operator()(char *buf, size_t newSize); +}; + + // --------------------- + // class ReadRestFunctor + // --------------------- + +// CREATORS +ReadRestFunctor::ReadRestFunctor(bsl::streambuf *streamBuf, int oldSize) +: d_streamBuf(streamBuf) +, d_oldSize(oldSize) +{ +} + +// MODIFIERS +size_t ReadRestFunctor::operator()(char *buf, size_t newSize) +{ + bsl::streamsize nRead = d_streamBuf->sgetn( + buf + d_oldSize, + static_cast(newSize) - d_oldSize); + return static_cast(d_oldSize + nRead); +} + // FREE FUNCTIONS void warnOnce() // The first time this is called, issue an explanatory warning. @@ -1084,14 +1123,31 @@ int BerUtil_StringImpUtil::getStringValue(bsl::string *value, return -1; // RETURN } - value->resize_and_overwrite( - length, - bdlf::MemFnUtil::memFn(&bsl::streambuf::sgetn, streamBuf)); + static const int maxInitialAllocation = 16 * 1024 * 1024; // 16 MB - if (static_cast(length) != value->size()) { + // 'length' could be corrupt or invalid, so we limit the initial buffer. + // On success the remaining bytes are read via a second pass. + int initialLength = length < maxInitialAllocation ? length + : maxInitialAllocation; + + // Read no more than 'maxInitialAllocation' + value->resize_and_overwrite(initialLength, + bdlf::MemFnUtil::memFn(&bsl::streambuf::sgetn, + streamBuf)); + if (static_cast(initialLength) != value->size()) { return -1; // RETURN } + if (length > initialLength) { + // 'length' > 'maxInitialAllocation'. Read the rest. + value->resize_and_overwrite(length, + u::ReadRestFunctor(streamBuf, + initialLength)); + if (static_cast(length) != value->size()) { + return -1; // RETURN + } + } + return 0; } diff --git a/groups/bal/balber/balber_berutil.t.cpp b/groups/bal/balber/balber_berutil.t.cpp index a0824a4ec9..4814abb359 100644 --- a/groups/bal/balber/balber_berutil.t.cpp +++ b/groups/bal/balber/balber_berutil.t.cpp @@ -16072,6 +16072,54 @@ int main(int argc, char *argv[]) LOOP3_ASSERT(LINE, LEN, numBytesConsumed, LEN == numBytesConsumed); } + + if (verbose) { cout << "\nDecode truncated string" << endl; } + { + const int length = 1024; // the claimed length + bdlsb::MemOutStreamBuf osb; + ASSERT(SUCCESS == Util::putLength(&osb, length)); + // Write 'length - 10' bytes + ASSERT(length > 10); + for (int n = length - 10; n > 0; --n) { + ASSERT(osb.sputc('X') == 'X'); + } + + int numBytesConsumed = 0; + bdlsb::FixedMemInStreamBuf isb(osb.data(), osb.length()); + bsl::string val; + ASSERT(SUCCESS != Util::getValue(&isb, &val, &numBytesConsumed)); + } + + static const int notHugeStringMax = 16 * 1024 * 1024; // 16 MB + + if (verbose) { cout << "\nDecode huge string" << endl; } + { + const bsl::string hugeStr(notHugeStringMax + 256, 'X'); + bdlsb::MemOutStreamBuf osb; + ASSERT(0 == Util::putValue(&osb, hugeStr)); + + int numBytesConsumed = 0; + bdlsb::FixedMemInStreamBuf isb(osb.data(), osb.length()); + bsl::string val; + ASSERT(SUCCESS == Util::getValue(&isb, &val, &numBytesConsumed)); + ASSERT(val.length() == hugeStr.length()); + ASSERT(val == hugeStr); + } + + if (verbose) { cout << "\nDecode truncated huge string" << endl; } + { + bdlsb::MemOutStreamBuf osb; + ASSERT(SUCCESS == Util::putLength(&osb, notHugeStringMax + 20)); + // Write 10 bytes less than claimed + for (int n = notHugeStringMax + 10; n > 0; --n) { + ASSERT(osb.sputc('X') == 'X'); + } + + int numBytesConsumed = 0; + bdlsb::FixedMemInStreamBuf isb(osb.data(), osb.length()); + bsl::string val; + ASSERT(SUCCESS != Util::getValue(&isb, &val, &numBytesConsumed)); + } } break; case 12: { // --------------------------------------------------------------------