diff --git a/src/txmempool.h b/src/txmempool.h --- a/src/txmempool.h +++ b/src/txmempool.h @@ -535,6 +535,8 @@ EXCLUSIVE_LOCKS_REQUIRED(cs); const setEntries &GetMemPoolChildren(txiter entry) const EXCLUSIVE_LOCKS_REQUIRED(cs); + uint64_t CalculateDescendantMaximum(txiter entry) const + EXCLUSIVE_LOCKS_REQUIRED(cs); private: typedef std::map cacheMap; @@ -702,8 +704,8 @@ * Returns false if the transaction is in the mempool and not within the * chain limit specified. */ - bool TransactionWithinChainLimit(const uint256 &txid, - size_t chainLimit) const; + bool TransactionWithinChainLimit(const uint256 &txid, size_t ancestor_limit, + size_t descendant_limit) const; unsigned long size() { LOCK(cs); diff --git a/src/txmempool.cpp b/src/txmempool.cpp --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -1217,12 +1217,38 @@ } } +uint64_t CTxMemPool::CalculateDescendantMaximum(txiter entry) const { + // find parent with highest descendant count + std::vector candidates; + setEntries counted; + candidates.push_back(entry); + uint64_t maximum = 0; + while (candidates.size()) { + txiter candidate = candidates.back(); + candidates.pop_back(); + if (!counted.insert(candidate).second) { + continue; + } + const setEntries &parents = GetMemPoolParents(candidate); + if (parents.size() == 0) { + maximum = std::max(maximum, candidate->GetCountWithDescendants()); + } else { + for (txiter i : parents) { + candidates.push_back(i); + } + } + } + return maximum; +} + bool CTxMemPool::TransactionWithinChainLimit(const uint256 &txid, - size_t chainLimit) const { + size_t ancestor_limit, + size_t descendant_limit) const { LOCK(cs); auto it = mapTx.find(txid); - return it == mapTx.end() || (it->GetCountWithAncestors() < chainLimit && - it->GetCountWithDescendants() < chainLimit); + return it == mapTx.end() || + (it->GetCountWithAncestors() < ancestor_limit && + CalculateDescendantMaximum(it) < descendant_limit); } SaltedTxidHasher::SaltedTxidHasher() diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -638,11 +638,16 @@ const int conf_mine; const int conf_theirs; const uint64_t max_ancestors; + const uint64_t max_descendants; CoinEligibilityFilter(int conf_mine_, int conf_theirs_, uint64_t max_ancestors_) : conf_mine(conf_mine_), conf_theirs(conf_theirs_), - max_ancestors(max_ancestors_) {} + max_ancestors(max_ancestors_), max_descendants(max_ancestors_) {} + CoinEligibilityFilter(int conf_mine_, int conf_theirs_, + uint64_t max_ancestors_, uint64_t max_descendants_) + : conf_mine(conf_mine_), conf_theirs(conf_theirs_), + max_ancestors(max_ancestors_), max_descendants(max_descendants_) {} }; // forward declarations for ScanForWalletTransactions/RescanFromTime diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -2595,7 +2595,8 @@ } if (!g_mempool.TransactionWithinChainLimit( - output.tx->GetId(), eligibility_filter.max_ancestors)) { + output.tx->GetId(), eligibility_filter.max_ancestors, + eligibility_filter.max_descendants)) { return false; }