diff --git a/src/net_processing.cpp b/src/net_processing.cpp --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -3262,7 +3262,8 @@ msg_type == NetMsgType::AVAPROOF || msg_type == NetMsgType::GETAVAADDR || msg_type == NetMsgType::GETAVAPROOFS || - msg_type == NetMsgType::AVAPROOFS; + msg_type == NetMsgType::AVAPROOFS || + msg_type == NetMsgType::AVAPROOFSREQ; } uint32_t PeerManagerImpl::GetAvalancheVoteForBlock(const BlockHash &hash) { @@ -5270,6 +5271,35 @@ return; } + if (msg_type == NetMsgType::AVAPROOFSREQ) { + if (pfrom.m_proof_relay == nullptr) { + return; + } + + avalanche::ProofsRequest proofreq; + vRecv >> proofreq; + + auto requestedIndiceIt = proofreq.indices.begin(); + uint32_t treeIndice = 0; + pfrom.m_proof_relay->sharedProofs.forEachLeaf([&](const auto &proof) { + if (requestedIndiceIt == proofreq.indices.end()) { + // No more indice to process + return false; + } + + if (treeIndice++ == *requestedIndiceIt) { + m_connman.PushMessage( + &pfrom, msgMaker.Make(NetMsgType::AVAPROOF, *proof)); + requestedIndiceIt++; + } + + return true; + }); + + pfrom.m_proof_relay->sharedProofs = {}; + return; + } + if (msg_type == NetMsgType::GETADDR) { // This asymmetric behavior for inbound and outbound connections was // introduced to prevent a fingerprinting attack: an attacker can send diff --git a/test/functional/abc_p2p_compactproofs.py b/test/functional/abc_p2p_compactproofs.py --- a/test/functional/abc_p2p_compactproofs.py +++ b/test/functional/abc_p2p_compactproofs.py @@ -21,6 +21,7 @@ AvalanchePrefilledProof, calculate_shortid, msg_avaproofs, + msg_avaproofsreq, msg_getavaproofs, ) from test_framework.p2p import P2PInterface, p2p_lock @@ -28,14 +29,36 @@ from test_framework.util import MAX_NODES, assert_equal, p2p_port +class ProofStoreP2PInterface(AvaP2PInterface): + def __init__(self): + self.proofs = [] + super().__init__() + + def on_avaproof(self, message): + self.proofs.append(message.proof) + + def get_proofs(self): + with p2p_lock: + return self.proofs + + class CompactProofsTest(BitcoinTestFramework): def set_test_params(self): - self.num_nodes = 1 + self.num_nodes = 2 self.extra_args = [[ '-enableavalanche=1', '-avacooldown=0', ]] * self.num_nodes + def setup_network(self): + # Don't connect the nodes + self.setup_nodes() + + @staticmethod + def received_avaproofs(peer): + with p2p_lock: + return peer.last_message.get("avaproofs") + def test_send_outbound_getavaproofs(self): self.log.info( "Check we send a getavaproofs message to our avalanche outbound peers") @@ -116,15 +139,11 @@ node = self.nodes[0] - def received_avaproofs(peer): - with p2p_lock: - return peer.last_message.get("avaproofs") - def send_getavaproof_check_shortid_len(peer, expected_len): peer.send_message(msg_getavaproofs()) - self.wait_until(lambda: received_avaproofs(peer)) + self.wait_until(lambda: self.received_avaproofs(peer)) - avaproofs = received_avaproofs(peer) + avaproofs = self.received_avaproofs(peer) assert_equal(len(avaproofs.shortids), expected_len) # Initially the node has 0 peer @@ -146,7 +165,7 @@ receiving_peer = node.add_p2p_connection(AvaP2PInterface()) send_getavaproof_check_shortid_len(receiving_peer, len(proofids)) - avaproofs = received_avaproofs(receiving_peer) + avaproofs = self.received_avaproofs(receiving_peer) expected_shortids = [ calculate_shortid( avaproofs.key0, @@ -319,11 +338,117 @@ bad_peer.send_message(msg) bad_peer.wait_for_disconnect() + def test_send_missing_proofs(self): + self.log.info("Check the node respond to missing proofs requests") + + node = self.nodes[0] + + self.restart_node(0) + + numof_proof = 10 + proofs = [gen_proof(node)[1] for _ in range(numof_proof)] + + for proof in proofs: + node.sendavalancheproof(proof.serialize().hex()) + proofids = get_proof_ids(node) + assert all(proof.proofid in proofids for proof in proofs) + + self.log.info("Unsollicited requests are ignored") + + peer = node.add_p2p_connection(ProofStoreP2PInterface()) + peer.send_and_ping(msg_avaproofsreq()) + assert_equal(len(peer.get_proofs()), 0) + + def request_proofs(peer): + peer.send_message(msg_getavaproofs()) + self.wait_until(lambda: self.received_avaproofs(peer)) + + avaproofs = self.received_avaproofs(peer) + assert_equal(len(avaproofs.shortids), numof_proof) + + return avaproofs + + _ = request_proofs(peer) + + self.log.info("Sending an empty request has no effect") + + peer.send_and_ping(msg_avaproofsreq()) + assert_equal(len(peer.get_proofs()), 0) + + self.log.info("Check the requested proofs are sent by the node") + + def check_received_proofs(indices): + requester = node.add_p2p_connection(ProofStoreP2PInterface()) + avaproofs = request_proofs(requester) + + req = msg_avaproofsreq() + req.indices = indices + requester.send_message(req) + + # Check we got the expected number of proofs + self.wait_until( + lambda: len( + requester.get_proofs()) == len(indices)) + + # Check we got the expected proofs + received_shortids = [ + calculate_shortid( + avaproofs.key0, + avaproofs.key1, + proof.proofid) for proof in requester.get_proofs()] + assert_equal(set(received_shortids), + set([avaproofs.shortids[i] for i in indices])) + + # Only the first proof + check_received_proofs([0]) + # Only the last proof + check_received_proofs([numof_proof - 1]) + # Half first + check_received_proofs(range(0, numof_proof // 2)) + # Half last + check_received_proofs(range(numof_proof // 2, numof_proof)) + # Even + check_received_proofs([i for i in range(numof_proof) if i % 2 == 0]) + # Odds + check_received_proofs([i for i in range(numof_proof) if i % 2 == 1]) + # All + check_received_proofs(range(numof_proof)) + + def test_compact_proofs_download_on_connect(self): + self.log.info( + "Check the node get compact proofs upon avalanche outbound discovery") + + requestee = self.nodes[0] + requester = self.nodes[1] + + self.restart_node(0) + + numof_proof = 10 + proofs = [gen_proof(requestee)[1] for _ in range(numof_proof)] + + for proof in proofs: + requestee.sendavalancheproof(proof.serialize().hex()) + proofids = get_proof_ids(requestee) + assert all(proof.proofid in proofids for proof in proofs) + + # Start the requester and check it gets all the proofs + self.start_node(1) + self.connect_nodes(0, 1) + self.wait_until( + lambda: all( + proof.proofid in proofids for proof in get_proof_ids(requester))) + def run_test(self): + # Most if the tests only need a single node, let the other ones start + # the node when required + self.stop_node(1) + self.test_send_outbound_getavaproofs() self.test_send_manual_getavaproofs() self.test_respond_getavaproofs() self.test_request_missing_proofs() + self.test_send_missing_proofs() + self.test_compact_proofs_download_on_connect() if __name__ == '__main__':