diff --git a/src/net_processing.cpp b/src/net_processing.cpp --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -2749,7 +2749,8 @@ bool IsAvalancheMessageType(const std::string &msg_type) { return msg_type == NetMsgType::AVAHELLO || msg_type == NetMsgType::AVAPOLL || - msg_type == NetMsgType::AVARESPONSE; + msg_type == NetMsgType::AVARESPONSE || + msg_type == NetMsgType::AVAPROOF; } void PeerManager::ProcessMessage(const Config &config, CNode &pfrom, @@ -4297,6 +4298,45 @@ return; } + if (msg_type == NetMsgType::AVAPROOF) { + auto proof = std::make_shared(); + vRecv >> *proof; + const avalanche::ProofId &proofid = proof->getId(); + + pfrom.AddKnownProof(proofid); + + const NodeId nodeid = pfrom.GetId(); + + { + LOCK(cs_proofrequest); + m_proofrequest.ReceivedResponse(nodeid, proofid); + + if (AlreadyHaveProof(proofid)) { + m_proofrequest.ForgetInvId(proofid); + return; + } + } + + // addProof should not be called while cs_proofrequest because it holds + // cs_main and that creates a potential deadlock during shutdown + if (g_avalanche->addProof(proof)) { + WITH_LOCK(cs_proofrequest, m_proofrequest.ForgetInvId(proofid)); + RelayProof(proofid, m_connman); + + LogPrint(BCLog::NET, "New avalanche proof: peer=%d, proofid %s\n", + nodeid, proofid.ToString()); + } else { + // If the proof couldn't be added, it can be either orphan or + // invalid. In the latter case we should increase the ban score. + // TODO improve the ban reason by printing the validation state + if (!g_avalanche->getOrphan(proofid)) { + Misbehaving(nodeid, 100, "invalid-avaproof"); + } + } + + return; + } + if (msg_type == NetMsgType::GETADDR) { // This asymmetric behavior for inbound and outbound connections was // introduced to prevent a fingerprinting attack: an attacker can send diff --git a/test/functional/abc_p2p_proof_inventory.py b/test/functional/abc_p2p_proof_inventory.py --- a/test/functional/abc_p2p_proof_inventory.py +++ b/test/functional/abc_p2p_proof_inventory.py @@ -8,6 +8,8 @@ from test_framework.avatools import ( create_coinbase_stakes, + get_proof_ids, + wait_for_proof, ) from test_framework.key import ECKey, bytes_to_wif from test_framework.messages import ( @@ -15,6 +17,7 @@ FromHex, MSG_AVA_PROOF, MSG_TYPE_MASK, + msg_avaproof, ) from test_framework.p2p import ( P2PInterface, @@ -22,6 +25,9 @@ ) from test_framework.test_framework import BitcoinTestFramework from test_framework.util import ( + assert_equal, + assert_greater_than, + connect_nodes, wait_until, ) @@ -39,7 +45,7 @@ class ProofInventoryTest(BitcoinTestFramework): def set_test_params(self): - self.num_nodes = 1 + self.num_nodes = 5 self.extra_args = [['-enableavalanche=1', '-avacooldown=0']] * self.num_nodes @@ -95,8 +101,98 @@ assert all(p.proof_invs_counter == 1 for p in node.p2ps) + def test_receive_proof(self): + self.log.info("Test a peer is created on proof reception") + + node = self.nodes[0] + _, proof = self.gen_proof(node) + + peer = node.add_p2p_connection(P2PInterface()) + + msg = msg_avaproof() + msg.proof = proof + peer.send_message(msg) + + wait_until(lambda: proof.proofid in get_proof_ids(node)) + + self.log.info("Test receiving a proof with missing utxo is orphaned") + + privkey = ECKey() + privkey.generate() + orphan_hex = node.buildavalancheproof( + 42, 2000000000, privkey.get_pubkey().get_bytes().hex(), [{ + 'txid': '0' * 64, + 'vout': 0, + 'amount': 10, + 'height': 42, + 'iscoinbase': False, + 'privatekey': bytes_to_wif(privkey.get_bytes()), + }] + ) + + orphan = FromHex(AvalancheProof(), orphan_hex) + orphan_proofid = "{:064x}".format(orphan.proofid) + + msg = msg_avaproof() + msg.proof = orphan + peer.send_message(msg) + + wait_for_proof(node, orphan_proofid) + raw_proof = node.getrawavalancheproof(orphan_proofid) + assert_equal(raw_proof["proof"], orphan_hex) + assert_equal(raw_proof["orphan"], True) + + def test_ban_invalid_proof(self): + node = self.nodes[0] + _, bad_proof = self.gen_proof(node) + bad_proof.stakes = [] + + peer = node.add_p2p_connection(P2PInterface()) + + msg = msg_avaproof() + msg.proof = bad_proof + with node.assert_debug_log([ + 'Misbehaving', + 'invalid-avaproof', + ]): + peer.send_message(msg) + peer.wait_for_disconnect() + + def test_proof_relay(self): + # This test makes no sense with a single node ! + assert_greater_than(self.num_nodes, 1) + + def restart_nodes_with_proof(nodes=self.nodes): + proofids = set() + for i, node in enumerate(nodes): + privkey, proof = self.gen_proof(node) + proofids.add(proof.proofid) + + self.restart_node(node.index, self.extra_args[node.index] + [ + "-avaproof={}".format(proof.serialize().hex()), + "-avamasterkey={}".format(privkey) + ]) + + # Connect a block to make the proof be added to our pool + node.generate(1) + wait_until(lambda: proof.proofid in get_proof_ids(node)) + + [connect_nodes(node, n) for n in nodes[:i]] + + return proofids + + proofids = restart_nodes_with_proof(self.nodes) + + self.log.info("Nodes should eventually get the proof from their peer") + self.sync_proofs() + for node in self.nodes: + assert_equal(set(get_proof_ids(node)), proofids) + def run_test(self): self.test_send_proof_inv() + self.test_receive_proof() + self.test_ban_invalid_proof() + self.test_proof_relay() if __name__ == '__main__': diff --git a/test/functional/test_framework/test_framework.py b/test/functional/test_framework/test_framework.py --- a/test/functional/test_framework/test_framework.py +++ b/test/functional/test_framework/test_framework.py @@ -18,6 +18,7 @@ from typing import Optional from .authproxy import JSONRPCException +from .avatools import get_proof_ids from . import coverage from .p2p import NetworkThread from .test_node import TestNode @@ -578,6 +579,25 @@ "".join("\n {!r}".format(m) for m in pool), )) + def sync_proofs(self, nodes=None, wait=1, timeout=60): + """ + Wait until everybody has the same proofs in their proof pools + """ + rpc_connections = nodes or self.nodes + timeout = int(timeout * self.options.timeout_factor) + stop_time = time.time() + timeout + while time.time() <= stop_time: + nodes_proofs = [set(get_proof_ids(r)) for r in rpc_connections] + if nodes_proofs.count(nodes_proofs[0]) == len(rpc_connections): + return + # Check that each peer has at least one connection + assert (all([len(x.getpeerinfo()) for x in rpc_connections])) + time.sleep(wait) + raise AssertionError("Proofs sync timed out after {}s:{}".format( + timeout, + "".join("\n {!r}".format(m) for m in nodes_proofs), + )) + def sync_all(self, nodes=None): self.sync_blocks(nodes) self.sync_mempools(nodes)