diff --git a/src/avalanche/processor.h b/src/avalanche/processor.h --- a/src/avalanche/processor.h +++ b/src/avalanche/processor.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -88,7 +89,9 @@ struct AvalancheTest; } -class Processor { +// FIXME Implement a proper notification handler for node disconnection instead +// of implementing the whole NetEventsInterface for a single interesting event. +class Processor final : public NetEventsInterface { CConnman *connman; std::chrono::milliseconds queryTimeoutDuration; @@ -213,6 +216,20 @@ bool isQuorumEstablished(); + // Implement NetEventInterface. Only FinalizeNode is of interest. + void InitializeNode(const Config &config, CNode *pnode) override {} + bool ProcessMessages(const Config &config, CNode *pnode, + std::atomic &interrupt) override { + return false; + } + bool SendMessages(const Config &config, CNode *pnode) override { + return false; + } + + /** Handle removal of a node */ + void FinalizeNode(const Config &config, const CNode &node, + bool &update_connection_time) override; + private: void runEventLoop(); void clearTimedoutRequests(); diff --git a/src/avalanche/processor.cpp b/src/avalanche/processor.cpp --- a/src/avalanche/processor.cpp +++ b/src/avalanche/processor.cpp @@ -805,4 +805,9 @@ return true; } +void Processor::FinalizeNode(const Config &config, const CNode &node, + bool &update_connection_time) { + WITH_LOCK(cs_peerManager, peerManager->removeNode(node.GetId())); +} + } // namespace avalanche diff --git a/src/init.cpp b/src/init.cpp --- a/src/init.cpp +++ b/src/init.cpp @@ -3066,7 +3066,10 @@ connOptions.nMaxFeeler = MAX_FEELER_CONNECTIONS; connOptions.uiInterface = &uiInterface; connOptions.m_banman = node.banman.get(); - connOptions.m_msgproc = node.peerman.get(); + connOptions.m_msgproc.push_back(node.peerman.get()); + if (g_avalanche) { + connOptions.m_msgproc.push_back(g_avalanche.get()); + } connOptions.nSendBufferMaxSize = 1000 * args.GetArg("-maxsendbuffer", DEFAULT_MAXSENDBUFFER); connOptions.nReceiveFloodSize = diff --git a/src/net.h b/src/net.h --- a/src/net.h +++ b/src/net.h @@ -950,7 +950,7 @@ int nMaxAddnode = 0; int nMaxFeeler = 0; CClientUIInterface *uiInterface = nullptr; - NetEventsInterface *m_msgproc = nullptr; + std::vector m_msgproc; BanMan *m_banman = nullptr; unsigned int nSendBufferMaxSize = 0; unsigned int nReceiveFloodSize = 0; @@ -1370,7 +1370,8 @@ int m_max_outbound; bool m_use_addrman_outgoing; CClientUIInterface *clientInterface; - NetEventsInterface *m_msgproc; + // FIXME m_msgproc is a terrible name + std::vector m_msgproc; /** * Pointer to this node's banman. May be nullptr - check existence before * dereferencing. diff --git a/src/net.cpp b/src/net.cpp --- a/src/net.cpp +++ b/src/net.cpp @@ -1407,7 +1407,9 @@ // it as whitelisted (backward compatibility) pnode->m_legacyWhitelisted = legacyWhitelisted; pnode->m_prefer_evict = discouraged; - m_msgproc->InitializeNode(*config, pnode); + for (auto interface : m_msgproc) { + interface->InitializeNode(*config, pnode); + } LogPrint(BCLog::NET, "connection from %s accepted\n", addr.ToString()); @@ -2706,7 +2708,10 @@ grantOutbound->MoveTo(pnode->grantOutbound); } - m_msgproc->InitializeNode(*config, pnode); + for (auto interface : m_msgproc) { + interface->InitializeNode(*config, pnode); + } + { LOCK(cs_vNodes); vNodes.push_back(pnode); @@ -2731,9 +2736,12 @@ continue; } + bool fMoreNodeWork = false; // Receive messages - bool fMoreNodeWork = m_msgproc->ProcessMessages( - *config, pnode, flagInterruptMsgProc); + for (auto interface : m_msgproc) { + fMoreNodeWork |= interface->ProcessMessages( + *config, pnode, flagInterruptMsgProc); + } fMoreWork |= (fMoreNodeWork && !pnode->fPauseSend); if (flagInterruptMsgProc) { return; @@ -2742,7 +2750,9 @@ // Send messages { LOCK(pnode->cs_sendProcessing); - m_msgproc->SendMessages(*config, pnode); + for (auto interface : m_msgproc) { + interface->SendMessages(*config, pnode); + } } if (flagInterruptMsgProc) { @@ -3084,7 +3094,7 @@ // // Start threads // - assert(m_msgproc); + assert(m_msgproc.size() > 0); InterruptSocks5(false); interruptNet.reset(); flagInterruptMsgProc = false; @@ -3263,7 +3273,9 @@ void CConnman::DeleteNode(CNode *pnode) { assert(pnode); bool fUpdateConnectionTime = false; - m_msgproc->FinalizeNode(*config, *pnode, fUpdateConnectionTime); + for (auto interface : m_msgproc) { + interface->FinalizeNode(*config, *pnode, fUpdateConnectionTime); + } if (fUpdateConnectionTime) { addrman.Connected(pnode->addr); } diff --git a/src/test/util/net.h b/src/test/util/net.h --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -35,7 +35,9 @@ } void ProcessMessagesOnce(CNode &node) { - m_msgproc->ProcessMessages(*config, &node, flagInterruptMsgProc); + for (auto interface : m_msgproc) { + interface->ProcessMessages(*config, &node, flagInterruptMsgProc); + } } void NodeReceiveMsgBytes(CNode &node, Span msg_bytes, diff --git a/src/test/util/setup_common.cpp b/src/test/util/setup_common.cpp --- a/src/test/util/setup_common.cpp +++ b/src/test/util/setup_common.cpp @@ -238,7 +238,7 @@ *m_node.chainman, *m_node.mempool, false); { CConnman::Options options; - options.m_msgproc = m_node.peerman.get(); + options.m_msgproc.push_back(m_node.peerman.get()); m_node.connman->Init(options); } } diff --git a/test/functional/abc_rpc_getavalancheinfo.py b/test/functional/abc_rpc_getavalancheinfo.py --- a/test/functional/abc_rpc_getavalancheinfo.py +++ b/test/functional/abc_rpc_getavalancheinfo.py @@ -216,6 +216,32 @@ } }) + self.log.info("Disconnect all the nodes") + + for n in node.p2ps: + n.peer_disconnect() + n.wait_for_disconnect() + + assert_avalancheinfo({ + "active": True, + "local": { + "live": True, + "proofid": f"{proof.proofid:0{64}x}", + "limited_proofid": f"{proof.limited_proofid:0{64}x}", + "master": privkey.get_pubkey().get_bytes().hex(), + "stake_amount": coinbase_amount, + }, + "network": { + "proof_count": N, + "connected_proof_count": 0, + "total_stake_amount": coinbase_amount * N, + "connected_stake_amount": 0, + "node_count": 0, + "connected_node_count": 0, + "pending_node_count": 0, + } + }) + if __name__ == '__main__': GetAvalancheInfoTest().main()