diff --git a/src/net.h b/src/net.h --- a/src/net.h +++ b/src/net.h @@ -657,6 +657,7 @@ RadixTree sharedProofs; + std::atomic compactproofs_requested{false}; }; // m_proof_relay == nullptr if we're not relaying proofs with this peer diff --git a/src/net_processing.cpp b/src/net_processing.cpp --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -1673,7 +1673,7 @@ }); } -static bool shouldSendGetAvaAddr(const CNode *pnode) { +static bool isAvalancheOutboundOrManual(const CNode *pnode) { return pnode->IsAvalancheOutboundConnection() || (pnode->IsManualConn() && (pnode->nServices & NODE_AVALANCHE)); } @@ -1686,7 +1686,7 @@ }))) { std::vector avanode_outbound_ids; m_connman.ForEachNode([&](CNode *pnode) { - if (shouldSendGetAvaAddr(pnode)) { + if (isAvalancheOutboundOrManual(pnode)) { avanode_outbound_ids.push_back(pnode->GetId()); } }); @@ -3685,12 +3685,19 @@ ->m_recently_announced_proofs.insert(localProof->getId()); } - // Send getavaaddr to our avalanche outbound connections - if (shouldSendGetAvaAddr(&pfrom)) { + // Send getavaaddr and getavaproofs to our avalanche outbound or + // manual connections + if (isAvalancheOutboundOrManual(&pfrom)) { m_connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETAVAADDR)); WITH_LOCK(peer->m_addr_token_bucket_mutex, peer->m_addr_token_bucket += GetMaxAddrToSend()); + + if (pfrom.m_proof_relay) { + m_connman.PushMessage( + &pfrom, msgMaker.Make(NetMsgType::GETAVAPROOFS)); + pfrom.m_proof_relay->compactproofs_requested = true; + } } } diff --git a/test/functional/abc_p2p_compactproofs.py b/test/functional/abc_p2p_compactproofs.py --- a/test/functional/abc_p2p_compactproofs.py +++ b/test/functional/abc_p2p_compactproofs.py @@ -12,11 +12,15 @@ get_proof_ids, wait_for_proof, ) -from test_framework.messages import msg_getavaproofs -from test_framework.p2p import p2p_lock +from test_framework.messages import ( + NODE_AVALANCHE, + NODE_NETWORK, + msg_getavaproofs, +) +from test_framework.p2p import P2PInterface, p2p_lock from test_framework.siphash import siphash256 from test_framework.test_framework import BitcoinTestFramework -from test_framework.util import assert_equal +from test_framework.util import MAX_NODES, assert_equal, p2p_port class CompactProofsTest(BitcoinTestFramework): @@ -28,6 +32,81 @@ '-whitelist=noban@127.0.0.1', ]] * self.num_nodes + def test_send_outbound_getavaproofs(self): + self.log.info( + "Check we send a getavaproofs message to our avalanche outbound peers") + node = self.nodes[0] + + non_avapeers = [] + for i in range(4): + peer = P2PInterface() + node.add_outbound_p2p_connection( + peer, + p2p_idx=i, + connection_type="outbound-full-relay", + services=NODE_NETWORK, + ) + non_avapeers.append(peer) + + inbound_avapeers = [ + node.add_p2p_connection( + AvaP2PInterface()) for _ in range(4)] + + outbound_avapeers = [] + for i in range(4): + peer = P2PInterface() + node.add_outbound_p2p_connection( + peer, + p2p_idx=16 + i, + connection_type="avalanche", + services=NODE_NETWORK | NODE_AVALANCHE, + ) + outbound_avapeers.append(peer) + + self.wait_until( + lambda: all([p.last_message.get("getavaproofs") for p in outbound_avapeers])) + assert all([p.message_count.get( + "getavaproofs", 0) == 1 for p in outbound_avapeers]) + assert all([p.message_count.get( + "getavaproofs", 0) == 0 for p in non_avapeers]) + assert all([p.message_count.get( + "getavaproofs", 0) == 0 for p in inbound_avapeers]) + + def test_send_manual_getavaproofs(self): + self.log.info( + "Check we send a getavaproofs message to our manually connected peers that support avalanche") + node = self.nodes[0] + + # Get rid of previously connected nodes + node.disconnect_p2ps() + + def added_node_connected(ip_port): + added_node_info = node.getaddednodeinfo(ip_port) + return len( + added_node_info) == 1 and added_node_info[0]['connected'] + + def connect_callback(address, port): + self.log.debug("Connecting to {}:{}".format(address, port)) + + p = AvaP2PInterface() + p2p_idx = 1 + p.peer_accept_connection( + connect_cb=connect_callback, + connect_id=p2p_idx, + net=node.chain, + timeout_factor=node.timeout_factor, + services=NODE_NETWORK | NODE_AVALANCHE, + )() + ip_port = f"127.0.01:{p2p_port(MAX_NODES - p2p_idx)}" + + node.addnode(node=ip_port, command="add") + self.wait_until(lambda: added_node_connected(ip_port)) + + assert_equal(node.getpeerinfo()[-1]['addr'], ip_port) + assert_equal(node.getpeerinfo()[-1]['connection_type'], 'manual') + + self.wait_until(lambda: p.last_message.get("getavaproofs")) + def test_respond_getavaproofs(self): self.log.info("Check the node responds to getavaproofs messages") @@ -75,6 +154,8 @@ assert_equal(len(avaproofs.prefilled_proofs), 0) def run_test(self): + self.test_send_outbound_getavaproofs() + self.test_send_manual_getavaproofs() self.test_respond_getavaproofs()