diff --git a/src/avalanche/peermanager.h b/src/avalanche/peermanager.h --- a/src/avalanche/peermanager.h +++ b/src/avalanche/peermanager.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -174,6 +175,20 @@ std::unordered_set m_unbroadcast_proofids; public: + class ConflictingProofHandler { + public: + virtual void onConflictingProof(const std::shared_ptr &proof, + bool accepted) {} + virtual ~ConflictingProofHandler() {} + }; + + std::unordered_set> + conflictingProofHandlers; + + /** Conflicting proof callback */ + std::unique_ptr + handleConflictingProof(std::shared_ptr handler); + /** * Node API. */ diff --git a/src/avalanche/peermanager.cpp b/src/avalanche/peermanager.cpp --- a/src/avalanche/peermanager.cpp +++ b/src/avalanche/peermanager.cpp @@ -14,6 +14,37 @@ namespace avalanche { +class ConflictingProofHandlerImpl : public interfaces::Handler { + PeerManager *m_peerManager; + std::shared_ptr + m_conflictingProofHandler; + +public: + explicit ConflictingProofHandlerImpl( + PeerManager *pm, std::shared_ptr + conflictingProofHandler) + : m_peerManager(pm), + m_conflictingProofHandler(std::move(conflictingProofHandler)) { + m_peerManager->conflictingProofHandlers.emplace( + m_conflictingProofHandler); + } + virtual ~ConflictingProofHandlerImpl() { disconnect(); }; + + void disconnect() override { + if (m_conflictingProofHandler) { + m_peerManager->conflictingProofHandlers.erase( + m_conflictingProofHandler); + m_conflictingProofHandler.reset(); + } + } +}; + +std::unique_ptr PeerManager::handleConflictingProof( + std::shared_ptr handler) { + return std::make_unique(this, + std::move(handler)); +} + bool PeerManager::addNode(NodeId nodeid, const ProofId &proofid) { auto &pview = peers.get(); auto it = pview.find(proofid); @@ -277,6 +308,10 @@ } } + for (auto &cb : conflictingProofHandlers) { + cb->onConflictingProof(proof, false); + } + return peers.end(); } diff --git a/src/avalanche/test/peermanager_tests.cpp b/src/avalanche/test/peermanager_tests.cpp --- a/src/avalanche/test/peermanager_tests.cpp +++ b/src/avalanche/test/peermanager_tests.cpp @@ -31,6 +31,18 @@ } }; } // namespace + +class TestConflictingProofHandler + : public PeerManager::ConflictingProofHandler { +public: + std::shared_ptr lastProof; + + void onConflictingProof(const std::shared_ptr &proof, + bool accepted) override { + lastProof = proof; + } +}; + } // namespace avalanche BOOST_FIXTURE_TEST_SUITE(peermanager_tests, TestingSetup) @@ -795,4 +807,46 @@ !pm.registerProof(std::make_shared(std::move(badProof)))); } +BOOST_AUTO_TEST_CASE(conflicting_proof_handler) { + avalanche::PeerManager pm; + + auto conflictingProofHandler = + std::make_shared(); + auto handler = pm.handleConflictingProof(conflictingProofHandler); + + const CKey key = CKey::MakeCompressedKey(); + + const COutPoint outpoint(TxId(GetRandHash()), 0); + const Amount amount(10 * COIN); + const uint32_t height = 100; + const bool is_coinbase = false; + + CScript script = GetScriptForDestination(PKHash(key.GetPubKey())); + + { + LOCK(cs_main); + CCoinsViewCache &coins = ::ChainstateActive().CoinsTip(); + coins.AddCoin(outpoint, + Coin(CTxOut(amount, script), height, is_coinbase), false); + } + + auto buildProofWithSequence = [&](uint64_t sequence) { + ProofBuilder pb(sequence, GetRandInt(std::numeric_limits::max()), + key); + BOOST_CHECK(pb.addUTXO(outpoint, amount, height, is_coinbase, key)); + return std::make_shared(pb.build()); + }; + + auto proof_base = buildProofWithSequence(0); + BOOST_CHECK_NE(pm.getPeerId(proof_base), NO_PEER); + BOOST_CHECK(!conflictingProofHandler->lastProof); + + for (size_t i = 1; i < 100; i++) { + auto proof_conflicting = buildProofWithSequence(i); + BOOST_CHECK_EQUAL(pm.getPeerId(proof_conflicting), NO_PEER); + BOOST_CHECK_EQUAL(conflictingProofHandler->lastProof->getId(), + proof_conflicting->getId()); + } +} + BOOST_AUTO_TEST_SUITE_END()