Skip to content

Commit 3ccb6e1

Browse files
committed
recursive_mutex -> mutex
1 parent ab96235 commit 3ccb6e1

File tree

2 files changed

+99
-49
lines changed

2 files changed

+99
-49
lines changed

src/herder/TransactionQueue.cpp

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ isDuplicateTx(TransactionFrameBasePtr oldTx, TransactionFrameBasePtr newTx)
270270
bool
271271
TransactionQueue::sourceAccountPending(AccountID const& accountID) const
272272
{
273-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
273+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
274274
return mAccountStates.find(accountID) != mAccountStates.end();
275275
}
276276

@@ -334,7 +334,7 @@ TransactionQueue::canAdd(
334334
std::vector<std::pair<TransactionFrameBasePtr, bool>>& txsToEvict)
335335
{
336336
ZoneScoped;
337-
if (isBanned(tx->getFullHash()))
337+
if (isBannedInternal(tx->getFullHash()))
338338
{
339339
return AddResult(
340340
TransactionQueue::AddResultCode::ADD_STATUS_TRY_AGAIN_LATER);
@@ -436,7 +436,7 @@ TransactionQueue::canAdd(
436436
mTxQueueLimiter.canAddTx(tx, currentTx, txsToEvict, ledgerVersion);
437437
if (!canAddRes.first)
438438
{
439-
ban({tx});
439+
banInternal({tx});
440440
if (canAddRes.second != 0)
441441
{
442442
AddResult result(TransactionQueue::AddResultCode::ADD_STATUS_ERROR,
@@ -454,10 +454,6 @@ TransactionQueue::canAdd(
454454
// This is done so minSeqLedgerGap is validated against the next
455455
// ledgerSeq, which is what will be used at apply time
456456
++ls.getLedgerHeader().currentToModify().ledgerSeq;
457-
// TODO: ^^ I think this is the right thing to do. Was previously the
458-
// commented out line below.
459-
// ls.getLedgerHeader().currentToModify().ledgerSeq =
460-
// mApp.getLedgerManager().getLastClosedLedgerNum() + 1;
461457
}
462458

463459
auto txResult =
@@ -645,7 +641,7 @@ TransactionQueue::AddResult
645641
TransactionQueue::tryAdd(TransactionFrameBasePtr tx, bool submittedFromSelf)
646642
{
647643
ZoneScoped;
648-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
644+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
649645

650646
auto c1 =
651647
tx->getEnvelope().type() == ENVELOPE_TYPE_TX_FEE_BUMP &&
@@ -701,8 +697,9 @@ TransactionQueue::tryAdd(TransactionFrameBasePtr tx, bool submittedFromSelf)
701697
// make space so that we can add this transaction
702698
// this will succeed as `canAdd` ensures that this is the case
703699
mTxQueueLimiter.evictTransactions(
704-
txsToEvict, *tx,
705-
[&](TransactionFrameBasePtr const& txToEvict) { ban({txToEvict}); });
700+
txsToEvict, *tx, [&](TransactionFrameBasePtr const& txToEvict) {
701+
banInternal({txToEvict});
702+
});
706703
mTxQueueLimiter.addTransaction(tx);
707704
mKnownTxHashes[tx->getFullHash()] = tx;
708705

@@ -806,7 +803,14 @@ void
806803
TransactionQueue::ban(Transactions const& banTxs)
807804
{
808805
ZoneScoped;
809-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
806+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
807+
banInternal(banTxs);
808+
}
809+
810+
void
811+
TransactionQueue::banInternal(Transactions const& banTxs)
812+
{
813+
ZoneScoped;
810814
auto& bannedFront = mBannedTransactions.front();
811815

812816
// Group the transactions by source account and ban all the transactions
@@ -852,7 +856,7 @@ TransactionQueue::AccountState
852856
TransactionQueue::getAccountTransactionQueueInfo(
853857
AccountID const& accountID) const
854858
{
855-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
859+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
856860
auto i = mAccountStates.find(accountID);
857861
if (i == std::end(mAccountStates))
858862
{
@@ -864,7 +868,7 @@ TransactionQueue::getAccountTransactionQueueInfo(
864868
size_t
865869
TransactionQueue::countBanned(int index) const
866870
{
867-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
871+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
868872
return mBannedTransactions[index].size();
869873
}
870874
#endif
@@ -939,7 +943,13 @@ TransactionQueue::shift()
939943
bool
940944
TransactionQueue::isBanned(Hash const& hash) const
941945
{
942-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
946+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
947+
return isBannedInternal(hash);
948+
}
949+
950+
bool
951+
TransactionQueue::isBannedInternal(Hash const& hash) const
952+
{
943953
return std::any_of(
944954
std::begin(mBannedTransactions), std::end(mBannedTransactions),
945955
[&](UnorderedSet<Hash> const& transactions) {
@@ -951,7 +961,14 @@ TxFrameList
951961
TransactionQueue::getTransactions(LedgerHeader const& lcl) const
952962
{
953963
ZoneScoped;
954-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
964+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
965+
return getTransactionsInternal(lcl);
966+
}
967+
968+
TxFrameList
969+
TransactionQueue::getTransactionsInternal(LedgerHeader const& lcl) const
970+
{
971+
ZoneScoped;
955972
TxFrameList txs;
956973

957974
uint32_t const nextLedgerSeq = lcl.ledgerSeq + 1;
@@ -972,7 +989,7 @@ TransactionFrameBaseConstPtr
972989
TransactionQueue::getTx(Hash const& hash) const
973990
{
974991
ZoneScoped;
975-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
992+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
976993
auto it = mKnownTxHashes.find(hash);
977994
if (it != mKnownTxHashes.end())
978995
{
@@ -1184,6 +1201,8 @@ SorobanTransactionQueue::broadcastSome()
11841201
size_t
11851202
SorobanTransactionQueue::getMaxQueueSizeOps() const
11861203
{
1204+
ZoneScoped;
1205+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
11871206
if (protocolVersionStartsFrom(
11881207
mBucketSnapshot->getLedgerHeader().ledgerVersion,
11891208
SOROBAN_PROTOCOL_VERSION))
@@ -1264,7 +1283,7 @@ ClassicTransactionQueue::broadcastSome()
12641283
std::make_shared<DexLimitingLaneConfig>(opsToFlood, dexOpsToFlood),
12651284
mBroadcastSeed);
12661285
queue.visitTopTxs(txsToBroadcast, visitor, mBroadcastOpCarryover);
1267-
ban(banningTxs);
1286+
banInternal(banningTxs);
12681287
// carry over remainder, up to MAX_OPS_PER_TX ops
12691288
// reason is that if we add 1 next round, we can flood a "worst case fee
12701289
// bump" tx
@@ -1277,15 +1296,12 @@ ClassicTransactionQueue::broadcastSome()
12771296
}
12781297

12791298
void
1280-
TransactionQueue::broadcast(bool fromCallback)
1299+
TransactionQueue::broadcast(bool fromCallback,
1300+
std::lock_guard<std::mutex> const& guard)
12811301
{
12821302
// Must be called from the main thread due to the use of `mBroadcastTimer`
12831303
releaseAssert(threadIsMain());
12841304

1285-
// NOTE: Although this is not a public function, it can be called from
1286-
// `mBroadcastTimer` and so it needs to be synchronized.
1287-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
1288-
12891305
if (mShutdown || (!fromCallback && mWaiting))
12901306
{
12911307
return;
@@ -1317,7 +1333,14 @@ TransactionQueue::broadcast(bool fromCallback)
13171333
}
13181334

13191335
void
1320-
TransactionQueue::rebroadcast()
1336+
TransactionQueue::broadcast(bool fromCallback)
1337+
{
1338+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
1339+
broadcast(fromCallback, guard);
1340+
}
1341+
1342+
void
1343+
TransactionQueue::rebroadcast(std::lock_guard<std::mutex> const& guard)
13211344
{
13221345
// For `broadcast` call
13231346
releaseAssert(threadIsMain());
@@ -1331,14 +1354,14 @@ TransactionQueue::rebroadcast()
13311354
as.mTransaction->mBroadcasted = false;
13321355
}
13331356
}
1334-
broadcast(false);
1357+
broadcast(false, guard);
13351358
}
13361359

13371360
void
13381361
TransactionQueue::shutdown()
13391362
{
13401363
releaseAssert(threadIsMain());
1341-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
1364+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
13421365
mShutdown = true;
13431366
mBroadcastTimer.cancel();
13441367
}
@@ -1351,7 +1374,7 @@ TransactionQueue::update(
13511374
{
13521375
ZoneScoped;
13531376
releaseAssert(threadIsMain());
1354-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
1377+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
13551378

13561379
mValidationSnapshot =
13571380
std::make_shared<ImmutableValidationSnapshot>(mAppConn);
@@ -1361,11 +1384,11 @@ TransactionQueue::update(
13611384
removeApplied(applied);
13621385
shift();
13631386

1364-
auto txs = getTransactions(lcl);
1387+
auto txs = getTransactionsInternal(lcl);
13651388
auto invalidTxs = filterInvalidTxs(txs);
1366-
ban(invalidTxs);
1389+
banInternal(invalidTxs);
13671390

1368-
rebroadcast();
1391+
rebroadcast(guard);
13691392
}
13701393

13711394
static bool
@@ -1409,14 +1432,14 @@ TransactionQueue::isFiltered(TransactionFrameBasePtr tx) const
14091432
size_t
14101433
TransactionQueue::getQueueSizeOps() const
14111434
{
1412-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
1435+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
14131436
return mTxQueueLimiter.size();
14141437
}
14151438

14161439
std::optional<int64_t>
14171440
TransactionQueue::getInQueueSeqNum(AccountID const& account) const
14181441
{
1419-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
1442+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
14201443
auto stateIter = mAccountStates.find(account);
14211444
if (stateIter == mAccountStates.end())
14221445
{
@@ -1433,7 +1456,8 @@ TransactionQueue::getInQueueSeqNum(AccountID const& account) const
14331456
size_t
14341457
ClassicTransactionQueue::getMaxQueueSizeOps() const
14351458
{
1436-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
1459+
ZoneScoped;
1460+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
14371461
auto res = mTxQueueLimiter.maxScaledLedgerResources(false);
14381462
releaseAssert(res.size() == NUM_CLASSIC_TX_RESOURCES);
14391463
return res.getVal(Resource::Type::OPERATIONS);

src/herder/TransactionQueue.h

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,15 @@ class TransactionQueue
210210
virtual std::pair<Resource, std::optional<Resource>>
211211
getMaxResourcesToFloodThisPeriod() const = 0;
212212
virtual bool broadcastSome() = 0;
213-
virtual int getFloodPeriod() const = 0;
214213
virtual bool allowTxBroadcast(TimestampedTx const& tx) = 0;
215214

215+
// TODO: Explain that there's an overload that takes a guard because this
216+
// function is called internally, and also scheduled on a timer. Any async
217+
// call should call the first overload (which grabs a lock), and any
218+
// internal call should call the second overload (which enforces that the
219+
// lock is already held).
216220
void broadcast(bool fromCallback);
221+
void broadcast(bool fromCallback, std::lock_guard<std::mutex> const& guard);
217222
// broadcasts a single transaction
218223
enum class BroadcastStatus
219224
{
@@ -234,6 +239,12 @@ class TransactionQueue
234239

235240
bool isFiltered(TransactionFrameBasePtr tx) const;
236241

242+
// TODO: Docs
243+
// Protected versions of public functions that contain the actual
244+
// implementation so they can be called internally when the lock is already
245+
// held.
246+
void banInternal(Transactions const& banTxs);
247+
237248
// Snapshots to use for transaction validation
238249
ImmutableValidationSnapshotPtr mValidationSnapshot;
239250
SearchableSnapshotConstPtr mBucketSnapshot;
@@ -245,7 +256,7 @@ class TransactionQueue
245256

246257
size_t mBroadcastSeed;
247258

248-
mutable std::recursive_mutex mTxQueueMutex;
259+
mutable std::mutex mTxQueueMutex;
249260

250261
private:
251262
AppConnector& mAppConn;
@@ -259,10 +270,24 @@ class TransactionQueue
259270
*/
260271
void shift();
261272

262-
void rebroadcast();
273+
// TODO: Explain that this takes a lock guard due to the `broadcast` call
274+
// that it makes.
275+
void rebroadcast(std::lock_guard<std::mutex> const& guard);
276+
277+
// TODO: Docs
278+
// Private versions of public functions that contain the actual
279+
// implementation so they can be called internally when the lock is already
280+
// held.
281+
bool isBannedInternal(Hash const& hash) const;
282+
TxFrameList getTransactionsInternal(LedgerHeader const& lcl) const;
283+
284+
virtual int getFloodPeriod() const = 0;
263285

264286
#ifdef BUILD_TESTS
265287
public:
288+
// TODO: These tests invoke protected/private functions directly that assume
289+
// things are properly locked. I need to make sure these tests operate in a
290+
// thread-safe manner or change them to not require private member access.
266291
friend class TransactionQueueTest;
267292

268293
size_t getQueueSizeOps() const;
@@ -278,19 +303,13 @@ class SorobanTransactionQueue : public TransactionQueue
278303
SearchableSnapshotConstPtr bucketSnapshot,
279304
uint32 pendingDepth, uint32 banDepth,
280305
uint32 poolLedgerMultiplier);
281-
int
282-
getFloodPeriod() const override
283-
{
284-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
285-
return mValidationSnapshot->getConfig().FLOOD_SOROBAN_TX_PERIOD_MS;
286-
}
287306

288307
size_t getMaxQueueSizeOps() const override;
289308
#ifdef BUILD_TESTS
290309
void
291310
clearBroadcastCarryover()
292311
{
293-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
312+
std::lock_guard<std::mutex> guard(mTxQueueMutex);
294313
mBroadcastOpCarryover.clear();
295314
mBroadcastOpCarryover.resize(1, Resource::makeEmptySoroban());
296315
}
@@ -307,6 +326,13 @@ class SorobanTransactionQueue : public TransactionQueue
307326
{
308327
return true;
309328
}
329+
330+
int
331+
getFloodPeriod() const override
332+
{
333+
return mValidationSnapshot->getConfig().FLOOD_SOROBAN_TX_PERIOD_MS;
334+
}
335+
310336
};
311337

312338
class ClassicTransactionQueue : public TransactionQueue
@@ -317,13 +343,6 @@ class ClassicTransactionQueue : public TransactionQueue
317343
uint32 pendingDepth, uint32 banDepth,
318344
uint32 poolLedgerMultiplier);
319345

320-
int
321-
getFloodPeriod() const override
322-
{
323-
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
324-
return mValidationSnapshot->getConfig().FLOOD_TX_PERIOD_MS;
325-
}
326-
327346
size_t getMaxQueueSizeOps() const override;
328347

329348
private:
@@ -335,6 +354,13 @@ class ClassicTransactionQueue : public TransactionQueue
335354
virtual bool broadcastSome() override;
336355
std::vector<Resource> mBroadcastOpCarryover;
337356
virtual bool allowTxBroadcast(TimestampedTx const& tx) override;
357+
358+
int
359+
getFloodPeriod() const override
360+
{
361+
return mValidationSnapshot->getConfig().FLOOD_TX_PERIOD_MS;
362+
}
363+
338364
};
339365

340366
extern std::array<const char*,

0 commit comments

Comments
 (0)