diff --git a/src/avalanche/peermanager.h b/src/avalanche/peermanager.h --- a/src/avalanche/peermanager.h +++ b/src/avalanche/peermanager.h @@ -196,6 +196,8 @@ * Proof and Peer related API. */ bool registerProof(const ProofRef &proof); + bool rejectProof(const ProofId &proofid); + bool exists(const ProofId &proofid) const { return getProof(proofid) != nullptr; } diff --git a/src/avalanche/peermanager.cpp b/src/avalanche/peermanager.cpp --- a/src/avalanche/peermanager.cpp +++ b/src/avalanche/peermanager.cpp @@ -213,6 +213,42 @@ return true; } +bool PeerManager::rejectProof(const ProofId &proofid) { + if (!exists(proofid)) { + return false; + } + + if (orphanProofPool.removeProof(proofid)) { + return true; + } + + if (conflictingProofPool.removeProof(proofid)) { + return true; + } + + auto &pview = peers.get(); + auto it = pview.find(proofid); + assert(it != pview.end()); + + if (!removePeer(it->peerid)) { + return false; + } + + // If there was conflicting proofs, attempt to pull them back + for (const SignedStake &ss : it->proof->getStakes()) { + const ProofRef conflictingProof = + conflictingProofPool.getProof(ss.getStake().getUTXO()); + if (!conflictingProof) { + continue; + } + + conflictingProofPool.removeProof(conflictingProof->getId()); + registerProof(conflictingProof); + } + + return true; +} + NodeId PeerManager::selectNode() { for (int retry = 0; retry < SELECT_NODE_MAX_RETRY; retry++) { const PeerId p = selectPeer(); diff --git a/src/avalanche/test/peermanager_tests.cpp b/src/avalanche/test/peermanager_tests.cpp --- a/src/avalanche/test/peermanager_tests.cpp +++ b/src/avalanche/test/peermanager_tests.cpp @@ -1062,4 +1062,82 @@ BOOST_CHECK(!pm.exists(proofSeq10->getId())); } +BOOST_AUTO_TEST_CASE(reject_proof) { + avalanche::PeerManager pm; + + const CKey key = CKey::MakeCompressedKey(); + + const Amount amount(10 * COIN); + const uint32_t height = 100; + const bool is_coinbase = false; + CScript script = GetScriptForDestination(PKHash(key.GetPubKey())); + + const COutPoint conflictingOutpoint(TxId(GetRandHash()), 0); + { + LOCK(cs_main); + CCoinsViewCache &coins = ::ChainstateActive().CoinsTip(); + coins.AddCoin(conflictingOutpoint, + Coin(CTxOut(amount, script), height, is_coinbase), false); + } + + auto buildProofWithSequenceAndOutpoints = + [&](uint64_t sequence, const std::vector &outpoints) { + ProofBuilder pb(sequence, 0, key); + for (const COutPoint &outpoint : outpoints) { + BOOST_CHECK( + pb.addUTXO(outpoint, amount, height, is_coinbase, key)); + } + return pb.build(); + }; + + // The good, the bad and the ugly + auto proofSeq10 = + buildProofWithSequenceAndOutpoints(10, {conflictingOutpoint}); + auto proofSeq20 = + buildProofWithSequenceAndOutpoints(20, {conflictingOutpoint}); + auto orphan30 = buildProofWithSequenceAndOutpoints( + 20, {conflictingOutpoint, {TxId(GetRandHash()), 0}}); + + BOOST_CHECK(pm.registerProof(proofSeq20)); + BOOST_CHECK(!pm.registerProof(proofSeq10)); + BOOST_CHECK(!pm.registerProof(orphan30)); + + BOOST_CHECK(pm.isBoundToPeer(proofSeq20->getId())); + BOOST_CHECK(pm.isInConflictingPool(proofSeq10->getId())); + BOOST_CHECK(pm.isOrphan(orphan30->getId())); + + // Rejecting a proof that doesn't exist should fail + for (size_t i = 0; i < 10; i++) { + BOOST_CHECK(!pm.rejectProof(avalanche::ProofId(GetRandHash()))); + } + + auto checkReject = [&](const ProofId &proofid) { + BOOST_CHECK(pm.exists(proofid)); + BOOST_CHECK(pm.rejectProof(proofid)); + BOOST_CHECK(!pm.exists(proofid)); + + // Rejecting a few more times has no effect and fails + for (size_t i = 0; i < 10; i++) { + BOOST_CHECK(!pm.rejectProof(proofid)); + } + }; + + // Reject from the orphan pool + checkReject(orphan30->getId()); + + // Reject from the conflicting pool + checkReject(proofSeq10->getId()); + + // Add again a proof to the conflicting pool + BOOST_CHECK(!pm.registerProof(proofSeq10)); + BOOST_CHECK(pm.isInConflictingPool(proofSeq10->getId())); + + // Reject from the valid pool + checkReject(proofSeq20->getId()); + + // The conflicting proof should be promoted to a peer + BOOST_CHECK(!pm.isInConflictingPool(proofSeq10->getId())); + BOOST_CHECK(pm.isBoundToPeer(proofSeq10->getId())); +} + BOOST_AUTO_TEST_SUITE_END()