diff --git a/src/avalanche/processor.h b/src/avalanche/processor.h --- a/src/avalanche/processor.h +++ b/src/avalanche/processor.h @@ -62,7 +62,7 @@ }; template class VoteItemUpdate { - VoteItem item; + std::remove_reference_t item; VoteStatus status; public: @@ -76,6 +76,7 @@ }; using BlockUpdate = VoteItemUpdate; +using ProofUpdate = VoteItemUpdate &>; using BlockVoteMap = std::map; @@ -181,7 +182,8 @@ // TODO: Refactor the API to remove the dependency on avalanche/protocol.h void sendResponse(CNode *pfrom, Response response) const; bool registerVotes(NodeId nodeid, const Response &response, - std::vector &blockUpdates, int &banscore, + std::vector &blockUpdates, + std::vector &proofUpdates, int &banscore, std::string &error); template auto withPeerManager(Callable &&func) const { diff --git a/src/avalanche/processor.cpp b/src/avalanche/processor.cpp --- a/src/avalanche/processor.cpp +++ b/src/avalanche/processor.cpp @@ -348,6 +348,7 @@ bool Processor::registerVotes(NodeId nodeid, const Response &response, std::vector &blockUpdates, + std::vector &proofUpdates, int &banscore, std::string &error) { { // Save the time at which we can query again. @@ -395,11 +396,14 @@ } std::map responseIndex; + std::map, Vote> responseProof; - { - LOCK(cs_main); - for (const auto &v : votes) { - auto pindex = LookupBlockIndex(BlockHash(v.GetHash())); + // At this stage we are certain that invs[i] matches votes[i], so we can use + // the inv type to retrieve what is being voted on. + for (size_t i = 0; i < size; i++) { + if (invs[i].IsMsgBlk()) { + LOCK(cs_main); + auto pindex = LookupBlockIndex(BlockHash(votes[i].GetHash())); if (!pindex) { // This should not happen, but just in case... continue; @@ -410,7 +414,21 @@ continue; } - responseIndex.insert(std::make_pair(pindex, v)); + responseIndex.insert(std::make_pair(pindex, votes[i])); + } + + if (invs[i].IsMsgProof()) { + const ProofId proofid(votes[i].GetHash()); + + // TODO Use an unordered map or similar to avoid the loop + auto proofVoteRecordsReadView = proofVoteRecords.getReadView(); + for (auto it = proofVoteRecordsReadView.begin(); + it != proofVoteRecordsReadView.end(); it++) { + if (it->first->getId() == proofid) { + responseProof.insert(std::make_pair(it->first, votes[i])); + break; + } + } } } @@ -454,6 +472,8 @@ registerVoteItems(blockVoteRecords.getWriteView(), blockUpdates, responseIndex); + registerVoteItems(proofVoteRecords.getWriteView(), proofUpdates, + responseProof); return true; } @@ -591,7 +611,9 @@ // In flight request accounting. for (const auto &p : timedout_items) { const CInv &inv = p.first; - assert(inv.type == MSG_BLOCK); + if (inv.type != MSG_BLOCK) { + continue; + } CBlockIndex *pindex; diff --git a/src/avalanche/test/processor_tests.cpp b/src/avalanche/test/processor_tests.cpp --- a/src/avalanche/test/processor_tests.cpp +++ b/src/avalanche/test/processor_tests.cpp @@ -173,11 +173,12 @@ uint64_t getRound() const { return AvalancheTest::getRound(*m_processor); } bool registerVotes(NodeId nodeid, const avalanche::Response &response, - std::vector &updates) { + std::vector &blockUpdates) { int banscore; std::string error; - return m_processor->registerVotes(nodeid, response, updates, banscore, - error); + std::vector proofUpdates; + return m_processor->registerVotes(nodeid, response, blockUpdates, + proofUpdates, banscore, error); } }; @@ -205,8 +206,9 @@ bool registerVotes(NodeId nodeid, const avalanche::Response &response, std::string &error) { int banscore; - return fixture->m_processor->registerVotes(nodeid, response, updates, - banscore, error); + std::vector proofUpdates; + return fixture->m_processor->registerVotes( + nodeid, response, updates, proofUpdates, banscore, error); } bool registerVotes(NodeId nodeid, const avalanche::Response &response) { std::string error; @@ -234,12 +236,63 @@ } }; +struct ProofOnlyTestingContext { + AvalancheTestingSetup *fixture; + + std::vector updates; + uint32_t invType; + + ProofOnlyTestingContext(AvalancheTestingSetup *_fixture) + : fixture(_fixture), invType(MSG_AVA_PROOF) {} + + std::shared_ptr buildVoteItem() const { return fixture->GetProof(); } + + uint256 getVoteItemId(const std::shared_ptr &proof) const { + return proof->getId(); + } + + bool registerVotes(NodeId nodeid, const avalanche::Response &response, + std::string &error) { + int banscore; + std::vector blockUpdates; + return fixture->m_processor->registerVotes( + nodeid, response, blockUpdates, updates, banscore, error); + } + bool registerVotes(NodeId nodeid, const avalanche::Response &response) { + std::string error; + return registerVotes(nodeid, response, error); + } + + bool addToReconcile(const std::shared_ptr &proof) { + fixture->m_processor->addProofToReconcile(proof, true); + return true; + } + + std::vector + buildVotesForItems(uint32_t error, + std::vector> &&items) { + size_t numItems = items.size(); + + std::vector votes; + votes.reserve(numItems); + + // Votes are sorted by high score first + std::sort(items.begin(), items.end(), ProofSharedPointerComparator()); + for (auto &item : items) { + votes.emplace_back(error, item->getId()); + } + + return votes; + } +}; + } // namespace BOOST_FIXTURE_TEST_SUITE(processor_tests, AvalancheTestingSetup) // FIXME A std::tuple can be used instead of boost::mpl::list after boost 1.67 -typedef boost::mpl::list voteItemTestingContexts; +typedef boost::mpl::list + voteItemTestingContexts; #define REGISTER_VOTE_AND_CHECK(vr, vote, state, finalized, confidence) \ vr.registerVote(NO_NODE, vote); \ diff --git a/src/net_processing.cpp b/src/net_processing.cpp --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -4379,19 +4379,20 @@ return; } - std::vector updates; + std::vector blockUpdates; + std::vector proofUpdates; int banscore; std::string error; - if (!g_avalanche->registerVotes(pfrom.GetId(), response, updates, - banscore, error)) { + if (!g_avalanche->registerVotes(pfrom.GetId(), response, blockUpdates, + proofUpdates, banscore, error)) { Misbehaving(pfrom, banscore, error); return; } pfrom.m_avalanche_state->invsVoted(response.GetVotes().size()); - if (updates.size()) { - for (avalanche::BlockUpdate &u : updates) { + if (blockUpdates.size()) { + for (avalanche::BlockUpdate &u : blockUpdates) { CBlockIndex *pindex = u.getVoteItem(); switch (u.getStatus()) { case avalanche::VoteStatus::Invalid: