diff --git a/src/avalanche/processor.h b/src/avalanche/processor.h --- a/src/avalanche/processor.h +++ b/src/avalanche/processor.h @@ -96,7 +96,13 @@ const AnyVoteItem &getVoteItem() const { return item; } }; -struct VoteMapComparator { +class VoteMapComparator { + const CTxMemPool *mempool{nullptr}; + +public: + VoteMapComparator() {} + VoteMapComparator(const CTxMemPool *mempoolIn) : mempool(mempoolIn) {} + bool operator()(const AnyVoteItem &lhs, const AnyVoteItem &rhs) const { // If the variants are of different types, sort them by variant index if (lhs.index() != rhs.index()) { @@ -112,8 +118,36 @@ // Reverse ordering so we get the highest work first return CBlockIndexWorkComparator()(rhs, lhs); }, - [](const CTransactionRef &lhs, const CTransactionRef &rhs) { - return lhs->GetId() < rhs->GetId(); + [this](const CTransactionRef &lhs, const CTransactionRef &rhs) { + const TxId &lhsTxId = lhs->GetId(); + const TxId &rhsTxId = rhs->GetId(); + + // If there is no mempool, sort by TxId. Note that polling + // for txs is currently not supported if there is no mempool + // so this is only a safety net. + if (!mempool) { + return lhsTxId < rhsTxId; + } + + LOCK(mempool->cs); + + auto lhsOptIter = mempool->GetIter(lhsTxId); + auto rhsOptIter = mempool->GetIter(rhsTxId); + + // If the transactions are not in the mempool, tie by TxId + if (!lhsOptIter && !rhsOptIter) { + return lhsTxId < rhsTxId; + } + + // If only one is in the mempool, pick that one + if (lhsOptIter.has_value() != rhsOptIter.has_value()) { + return !!lhsOptIter; + } + + // Both are in the mempool, select the highest fee rate + // including the fee deltas + return CompareTxMemPoolEntryByModifiedFeeRate{}( + **lhsOptIter, **rhsOptIter); }, [](const auto &lhs, const auto &rhs) { // This serves 2 purposes: diff --git a/src/avalanche/processor.cpp b/src/avalanche/processor.cpp --- a/src/avalanche/processor.cpp +++ b/src/avalanche/processor.cpp @@ -169,9 +169,10 @@ uint32_t staleVoteThresholdIn, uint32_t staleVoteFactorIn, Amount stakeUtxoDustThreshold) : avaconfig(std::move(avaconfigIn)), connman(connmanIn), - chainman(chainmanIn), mempool(mempoolIn), round(0), - peerManager( - std::make_unique(stakeUtxoDustThreshold, chainman)), + chainman(chainmanIn), mempool(mempoolIn), + voteRecords(RWCollection(VoteMap(VoteMapComparator(mempool)))), + round(0), peerManager(std::make_unique( + stakeUtxoDustThreshold, chainman)), peerData(std::move(peerDataIn)), sessionKey(std::move(sessionKeyIn)), minQuorumScore(minQuorumTotalScoreIn), minQuorumConnectedScoreRatio(minQuorumConnectedScoreRatioIn), @@ -515,7 +516,8 @@ } } - std::map responseItems; + std::map responseItems( + (VoteMapComparator(mempool))); // At this stage we are certain that invs[i] matches votes[i], so we can use // the inv type to retrieve what is being voted on. diff --git a/src/avalanche/test/processor_tests.cpp b/src/avalanche/test/processor_tests.cpp --- a/src/avalanche/test/processor_tests.cpp +++ b/src/avalanche/test/processor_tests.cpp @@ -365,15 +365,17 @@ : fixture(_fixture), invType(MSG_TX) {} CTransactionRef buildVoteItem() const { + auto rng = FastRandomContext(); CMutableTransaction mtx; mtx.nVersion = 2; - mtx.vin.emplace_back(COutPoint{TxId(FastRandomContext().rand256()), 0}); - mtx.vout.emplace_back(1 * COIN, CScript() << OP_TRUE); + mtx.vin.emplace_back(COutPoint{TxId(rng.rand256()), 0}); + mtx.vout.emplace_back(10 * COIN, CScript() << OP_TRUE); CTransactionRef tx = MakeTransactionRef(std::move(mtx)); TestMemPoolEntryHelper mempoolEntryHelper; - auto entry = mempoolEntryHelper.FromTx(tx); + auto entry = mempoolEntryHelper.Fee(int64_t(rng.randrange(10)) * COIN) + .FromTx(tx); CTxMemPool *mempool = Assert(fixture->m_node.mempool.get()); { @@ -396,11 +398,28 @@ std::vector votes; votes.reserve(numItems); - // Transactions are sorted by TxId - std::sort(items.begin(), items.end(), - [](const CTransactionRef &lhs, const CTransactionRef &rhs) { - return lhs->GetId() < rhs->GetId(); - }); + CTxMemPool *mempool = Assert(fixture->m_node.mempool.get()); + + { + LOCK(mempool->cs); + + // Transactions are sorted by modified fee rate as long as they are + // in the mempool. Let's keep it simple here and assume it's the + // case. + std::sort(items.begin(), items.end(), + [mempool](const CTransactionRef &lhs, + const CTransactionRef &rhs) + EXCLUSIVE_LOCKS_REQUIRED(mempool->cs) { + auto lhsIter = mempool->GetIter(lhs->GetId()); + auto rhsIter = mempool->GetIter(rhs->GetId()); + BOOST_CHECK(lhsIter); + BOOST_CHECK(rhsIter); + + return CompareTxMemPoolEntryByModifiedFeeRate{}( + **lhsIter, **rhsIter); + }); + } + for (auto &item : items) { votes.emplace_back(error, item->GetId()); } @@ -1965,69 +1984,197 @@ } Shuffle(indexes.begin(), indexes.end(), rng); - auto allItems = std::make_tuple(std::move(proofs), std::move(indexes)); - static const size_t numTypes = std::tuple_size::value; + CTxMemPool *mempool = Assert(m_node.mempool.get()); + TestMemPoolEntryHelper mempoolEntryHelper; + std::vector txs; + for (size_t i = 1; i <= numberElementsEachType; i++) { + CMutableTransaction mtx; + mtx.nVersion = 2; + mtx.vin.emplace_back(COutPoint{TxId(rng.rand256()), 0}); + mtx.vout.emplace_back(1000 * COIN, CScript() << OP_TRUE); + + CTransactionRef tx = MakeTransactionRef(std::move(mtx)); - RWCollection voteMap; + auto entry = mempoolEntryHelper.Fee(int64_t(i) * COIN).FromTx(tx); + { + LOCK2(cs_main, mempool->cs); + mempool->addUnchecked(entry); + BOOST_CHECK(mempool->exists(tx->GetId())); + } + + txs.emplace_back(std::move(tx)); + } + + auto allItems = + std::make_tuple(std::move(proofs), std::move(indexes), std::move(txs)); + static const size_t numTypes = std::tuple_size::value; { - auto writeView = voteMap.getWriteView(); - for (size_t i = 0; i < numberElementsEachType; i++) { - // Randomize the insert order at each loop increment - const size_t firstType = rng.randrange(numTypes); - - for (size_t j = 0; j < numTypes; j++) { - switch ((firstType + j) % numTypes) { - // ProofRef - case 0: - writeView->insert(std::make_pair( - std::get<0>(allItems)[i], VoteRecord(true))); - break; - // CBlockIndex * - case 1: - writeView->insert(std::make_pair( - &std::get<1>(allItems)[i], VoteRecord(true))); - break; - default: - break; + RWCollection voteMap(VoteMap(m_node.mempool.get())); + + { + auto writeView = voteMap.getWriteView(); + for (size_t i = 0; i < numberElementsEachType; i++) { + // Randomize the insert order at each loop increment + const size_t firstType = rng.randrange(numTypes); + + for (size_t j = 0; j < numTypes; j++) { + switch ((firstType + j) % numTypes) { + // ProofRef + case 0: + writeView->insert(std::make_pair( + std::get<0>(allItems)[i], VoteRecord(true))); + break; + // CBlockIndex * + case 1: + writeView->insert(std::make_pair( + &std::get<1>(allItems)[i], VoteRecord(true))); + break; + // CTransactionRef + case 2: + writeView->insert(std::make_pair( + std::get<2>(allItems)[i], VoteRecord(true))); + break; + default: + break; + } } } } + + { + // Check ordering + auto readView = voteMap.getReadView(); + auto it = readView.begin(); + + // The first batch of items is the proofs ordered by score + // (descending) + uint32_t lastScore = std::numeric_limits::max(); + for (size_t i = 0; i < numberElementsEachType; i++) { + BOOST_CHECK(std::holds_alternative(it->first)); + + uint32_t currentScore = + std::get(it->first)->getScore(); + BOOST_CHECK_LT(currentScore, lastScore); + lastScore = currentScore; + + it++; + } + + // The next batch of items is the block indexes ordered by work + // (descending) + arith_uint256 lastWork = -1; + for (size_t i = 0; i < numberElementsEachType; i++) { + BOOST_CHECK( + std::holds_alternative(it->first)); + + arith_uint256 currentWork = + std::get(it->first)->nChainWork; + BOOST_CHECK(currentWork < lastWork); + lastWork = currentWork; + + it++; + } + + // The last batch of items is the txs ordered by modified fee rate + CFeeRate lastFeeRate{MAX_MONEY}; + { + LOCK(mempool->cs); + + for (size_t i = 0; i < numberElementsEachType; i++) { + BOOST_CHECK(std::holds_alternative( + it->first)); + + auto iter = mempool->GetIter( + std::get(it->first)->GetId()); + BOOST_CHECK(iter.has_value()); + + CFeeRate currentFeeRate = (*iter)->GetModifiedFeeRate(); + + BOOST_CHECK(currentFeeRate < lastFeeRate); + lastFeeRate = currentFeeRate; + + it++; + } + } + + BOOST_CHECK(it == readView.end()); + } + } +} + +BOOST_AUTO_TEST_CASE(vote_map_tx_comparator) { + CTxMemPool *mempool = Assert(m_node.mempool.get()); + TestMemPoolEntryHelper mempoolEntryHelper; + TxProvider provider(this); + + std::vector txs; + for (size_t i = 0; i < 5; i++) { + txs.emplace_back(provider.buildVoteItem()); } { - // Check ordering + // When there is no mempool, the txs are sorted by txid + RWCollection voteMap(VoteMap(nullptr)); + { + auto writeView = voteMap.getWriteView(); + for (const auto &tx : txs) { + writeView->insert(std::make_pair(tx, VoteRecord(true))); + } + } + auto readView = voteMap.getReadView(); - auto it = readView.begin(); + TxId lastTxId{uint256::ZERO}; + for (const auto &[item, vote] : readView) { + auto tx = std::get(item); + BOOST_CHECK_GT(tx->GetId(), lastTxId); + lastTxId = tx->GetId(); + } + } - // The first batch of items is the proofs ordered by score (descending) - uint32_t lastScore = std::numeric_limits::max(); - for (size_t i = 0; i < numberElementsEachType; i++) { - BOOST_CHECK(std::holds_alternative(it->first)); + // Remove the 5 first txs from the mempool, and add 5 more + mempool->clear(); + for (size_t i = 0; i < 5; i++) { + txs.emplace_back(provider.buildVoteItem()); + } - uint32_t currentScore = - std::get(it->first)->getScore(); - BOOST_CHECK_LT(currentScore, lastScore); - lastScore = currentScore; + RWCollection voteMap((VoteMap(mempool))); - it++; + { + auto writeView = voteMap.getWriteView(); + for (const auto &tx : txs) { + writeView->insert(std::make_pair(tx, VoteRecord(true))); } + } - // The next batch of items is the block indexes ordered by work - // (descending) - arith_uint256 lastWork = -1; - for (size_t i = 0; i < numberElementsEachType; i++) { - BOOST_CHECK(std::holds_alternative(it->first)); + auto readView = voteMap.getReadView(); + auto it = readView.begin(); - arith_uint256 currentWork = - std::get(it->first)->nChainWork; - BOOST_CHECK(currentWork < lastWork); - lastWork = currentWork; + LOCK(mempool->cs); - it++; - } + // The first 5 txs are sorted by fee + CFeeRate lastFeeRate{MAX_MONEY}; + for (size_t i = 0; i < 5; i++) { + auto tx = std::get(it->first); + + auto iter = mempool->GetIter(tx->GetId()); + BOOST_CHECK(iter.has_value()); + + BOOST_CHECK((*iter)->GetModifiedFeeRate() <= lastFeeRate); + lastFeeRate = (*iter)->GetModifiedFeeRate(); + it++; + } + + // The last 5 txs are sorted by txid + TxId lastTxId{uint256::ZERO}; + for (size_t i = 0; i < 5; i++) { + auto tx = std::get(it->first); + + BOOST_CHECK(!mempool->exists(tx->GetId())); - BOOST_CHECK(it == readView.end()); + BOOST_CHECK_GT(tx->GetId(), lastTxId); + lastTxId = tx->GetId(); + it++; } } diff --git a/src/rwcollection.h b/src/rwcollection.h --- a/src/rwcollection.h +++ b/src/rwcollection.h @@ -70,6 +70,8 @@ public: RWCollection() : collection() {} + explicit RWCollection(T &&collection_) + : collection(std::move(collection_)) {} using ReadView = RWCollectionView>;