diff --git a/src/avalanche/processor.h b/src/avalanche/processor.h --- a/src/avalanche/processor.h +++ b/src/avalanche/processor.h @@ -96,7 +96,13 @@ const AnyVoteItem &getVoteItem() const { return item; } }; -struct VoteMapComparator { +class VoteMapComparator { + const CTxMemPool *mempool{nullptr}; + +public: + VoteMapComparator() {} + VoteMapComparator(const CTxMemPool *mempoolIn) : mempool(mempoolIn) {} + bool operator()(const AnyVoteItem &lhs, const AnyVoteItem &rhs) const { // If the variants are of different types, sort them by variant index if (lhs.index() != rhs.index()) { @@ -112,8 +118,36 @@ // Reverse ordering so we get the highest work first return CBlockIndexWorkComparator()(rhs, lhs); }, - [](const CTransactionRef &lhs, const CTransactionRef &rhs) { - return lhs->GetId() < rhs->GetId(); + [this](const CTransactionRef &lhs, const CTransactionRef &rhs) { + const TxId &lhsTxId = lhs->GetId(); + const TxId &rhsTxId = rhs->GetId(); + + // If there is no mempool, sort by TxId. Note that polling + // for txs is currently not supported if there is no mempool + // so this is only a safety net. + if (!mempool) { + return lhsTxId < rhsTxId; + } + + LOCK(mempool->cs); + + auto lhsOptIter = mempool->GetIter(lhsTxId); + auto rhsOptIter = mempool->GetIter(rhsTxId); + + // If the transactions are not in the mempool, tie by TxId + if (!lhsOptIter && !rhsOptIter) { + return lhsTxId < rhsTxId; + } + + // If only one is in the mempool, pick that one + if (lhsOptIter.has_value() != rhsOptIter.has_value()) { + return !!lhsOptIter; + } + + // Both are in the mempool, select the highest fee rate + // including the fee deltas + return CompareTxMemPoolEntryByModifiedFeeRate{}( + **lhsOptIter, **rhsOptIter); }, [](const auto &lhs, const auto &rhs) { // This serves 2 purposes: diff --git a/src/avalanche/processor.cpp b/src/avalanche/processor.cpp --- a/src/avalanche/processor.cpp +++ b/src/avalanche/processor.cpp @@ -169,9 +169,10 @@ uint32_t staleVoteThresholdIn, uint32_t staleVoteFactorIn, Amount stakeUtxoDustThreshold) : avaconfig(std::move(avaconfigIn)), connman(connmanIn), - chainman(chainmanIn), mempool(mempoolIn), round(0), - peerManager( - std::make_unique(stakeUtxoDustThreshold, chainman)), + chainman(chainmanIn), mempool(mempoolIn), + voteRecords(RWCollection(VoteMap(VoteMapComparator(mempool)))), + round(0), peerManager(std::make_unique( + stakeUtxoDustThreshold, chainman)), peerData(std::move(peerDataIn)), sessionKey(std::move(sessionKeyIn)), minQuorumScore(minQuorumTotalScoreIn), minQuorumConnectedScoreRatio(minQuorumConnectedScoreRatioIn), @@ -515,7 +516,8 @@ } } - std::map responseItems; + std::map responseItems( + (VoteMapComparator(mempool))); // At this stage we are certain that invs[i] matches votes[i], so we can use // the inv type to retrieve what is being voted on. diff --git a/src/avalanche/test/processor_tests.cpp b/src/avalanche/test/processor_tests.cpp --- a/src/avalanche/test/processor_tests.cpp +++ b/src/avalanche/test/processor_tests.cpp @@ -365,15 +365,17 @@ : fixture(_fixture), invType(MSG_TX) {} CTransactionRef buildVoteItem() const { + auto rng = FastRandomContext(); CMutableTransaction mtx; mtx.nVersion = 2; - mtx.vin.emplace_back(COutPoint{TxId(FastRandomContext().rand256()), 0}); - mtx.vout.emplace_back(1 * COIN, CScript() << OP_TRUE); + mtx.vin.emplace_back(COutPoint{TxId(rng.rand256()), 0}); + mtx.vout.emplace_back(10 * COIN, CScript() << OP_TRUE); CTransactionRef tx = MakeTransactionRef(std::move(mtx)); TestMemPoolEntryHelper mempoolEntryHelper; - auto entry = mempoolEntryHelper.FromTx(tx); + auto entry = mempoolEntryHelper.Fee(int64_t(rng.randrange(10)) * COIN) + .FromTx(tx); CTxMemPool *mempool = Assert(fixture->m_node.mempool.get()); { @@ -396,11 +398,28 @@ std::vector votes; votes.reserve(numItems); - // Transactions are sorted by TxId - std::sort(items.begin(), items.end(), - [](const CTransactionRef &lhs, const CTransactionRef &rhs) { - return lhs->GetId() < rhs->GetId(); - }); + CTxMemPool *mempool = Assert(fixture->m_node.mempool.get()); + + { + LOCK(mempool->cs); + + // Transactions are sorted by modified fee rate as long as they are + // in the mempool. Let's keep it simple here and assume it's the + // case. + std::sort(items.begin(), items.end(), + [mempool](const CTransactionRef &lhs, + const CTransactionRef &rhs) + EXCLUSIVE_LOCKS_REQUIRED(mempool->cs) { + auto lhsIter = mempool->GetIter(lhs->GetId()); + auto rhsIter = mempool->GetIter(rhs->GetId()); + BOOST_CHECK(lhsIter); + BOOST_CHECK(rhsIter); + + return CompareTxMemPoolEntryByModifiedFeeRate{}( + **lhsIter, **rhsIter); + }); + } + for (auto &item : items) { votes.emplace_back(error, item->GetId()); } @@ -1965,10 +1984,32 @@ } Shuffle(indexes.begin(), indexes.end(), rng); - auto allItems = std::make_tuple(std::move(proofs), std::move(indexes)); + CTxMemPool *mempool = Assert(m_node.mempool.get()); + TestMemPoolEntryHelper mempoolEntryHelper; + std::vector txs; + for (size_t i = 1; i <= numberElementsEachType; i++) { + CMutableTransaction mtx; + mtx.nVersion = 2; + mtx.vin.emplace_back(COutPoint{TxId(rng.rand256()), 0}); + mtx.vout.emplace_back(1000 * COIN, CScript() << OP_TRUE); + + CTransactionRef tx = MakeTransactionRef(std::move(mtx)); + + auto entry = mempoolEntryHelper.Fee(int64_t(i) * COIN).FromTx(tx); + { + LOCK2(cs_main, mempool->cs); + mempool->addUnchecked(entry); + BOOST_CHECK(mempool->exists(tx->GetId())); + } + + txs.emplace_back(std::move(tx)); + } + + auto allItems = + std::make_tuple(std::move(proofs), std::move(indexes), std::move(txs)); static const size_t numTypes = std::tuple_size::value; - RWCollection voteMap; + RWCollection voteMap(VoteMap(m_node.mempool.get())); { auto writeView = voteMap.getWriteView(); @@ -1988,6 +2029,11 @@ writeView->insert(std::make_pair( &std::get<1>(allItems)[i], VoteRecord(true))); break; + // CTransactionRef + case 2: + writeView->insert(std::make_pair( + std::get<2>(allItems)[i], VoteRecord(true))); + break; default: break; } @@ -2000,7 +2046,8 @@ auto readView = voteMap.getReadView(); auto it = readView.begin(); - // The first batch of items is the proofs ordered by score (descending) + // The first batch of items is the proofs ordered by score + // (descending) uint32_t lastScore = std::numeric_limits::max(); for (size_t i = 0; i < numberElementsEachType; i++) { BOOST_CHECK(std::holds_alternative(it->first)); @@ -2027,10 +2074,109 @@ it++; } + // The last batch of items is the txs ordered by modified fee rate + CFeeRate lastFeeRate{MAX_MONEY}; + { + LOCK(mempool->cs); + + for (size_t i = 0; i < numberElementsEachType; i++) { + BOOST_CHECK( + std::holds_alternative(it->first)); + + auto iter = mempool->GetIter( + std::get(it->first)->GetId()); + BOOST_CHECK(iter.has_value()); + + CFeeRate currentFeeRate = (*iter)->GetModifiedFeeRate(); + + BOOST_CHECK(currentFeeRate < lastFeeRate); + lastFeeRate = currentFeeRate; + + it++; + } + } + BOOST_CHECK(it == readView.end()); } } +BOOST_AUTO_TEST_CASE(vote_map_tx_comparator) { + CTxMemPool *mempool = Assert(m_node.mempool.get()); + TestMemPoolEntryHelper mempoolEntryHelper; + TxProvider provider(this); + + std::vector txs; + for (size_t i = 0; i < 5; i++) { + txs.emplace_back(provider.buildVoteItem()); + } + + { + // When there is no mempool, the txs are sorted by txid + RWCollection voteMap(VoteMap(nullptr)); + { + auto writeView = voteMap.getWriteView(); + for (const auto &tx : txs) { + writeView->insert(std::make_pair(tx, VoteRecord(true))); + } + } + + auto readView = voteMap.getReadView(); + TxId lastTxId{uint256::ZERO}; + for (const auto &[item, vote] : readView) { + auto tx = std::get(item); + BOOST_CHECK_GT(tx->GetId(), lastTxId); + lastTxId = tx->GetId(); + } + } + + // Remove the 5 first txs from the mempool, and add 5 more + mempool->clear(); + for (size_t i = 0; i < 5; i++) { + txs.emplace_back(provider.buildVoteItem()); + } + + { + RWCollection voteMap((VoteMap(mempool))); + + { + auto writeView = voteMap.getWriteView(); + for (const auto &tx : txs) { + writeView->insert(std::make_pair(tx, VoteRecord(true))); + } + } + + auto readView = voteMap.getReadView(); + auto it = readView.begin(); + + LOCK(mempool->cs); + + // The first 5 txs are sorted by fee + CFeeRate lastFeeRate{MAX_MONEY}; + for (size_t i = 0; i < 5; i++) { + auto tx = std::get(it->first); + + auto iter = mempool->GetIter(tx->GetId()); + BOOST_CHECK(iter.has_value()); + + BOOST_CHECK((*iter)->GetModifiedFeeRate() <= lastFeeRate); + lastFeeRate = (*iter)->GetModifiedFeeRate(); + it++; + } + + // The last 5 txs are sorted by txid + TxId lastTxId{uint256::ZERO}; + for (size_t i = 0; i < 5; i++) { + auto tx = std::get(it->first); + + BOOST_CHECK(!mempool->exists(tx->GetId())); + + BOOST_CHECK_GT(tx->GetId(), lastTxId); + lastTxId = tx->GetId(); + it++; + } + } +} + BOOST_AUTO_TEST_CASE(block_reconcile_initial_vote) { const auto &config = GetConfig(); auto &chainman = Assert(m_node.chainman); diff --git a/src/rwcollection.h b/src/rwcollection.h --- a/src/rwcollection.h +++ b/src/rwcollection.h @@ -70,6 +70,8 @@ public: RWCollection() : collection() {} + explicit RWCollection(T &&collection_) + : collection(std::move(collection_)) {} using ReadView = RWCollectionView>;