diff --git a/src/net.h b/src/net.h --- a/src/net.h +++ b/src/net.h @@ -1000,6 +1000,7 @@ AvalancheState() {} avalanche::Delegation delegation; + SchnorrSig sig; }; // m_avalanche_state == nullptr if we're not using avalanche with this peer diff --git a/src/net_processing.cpp b/src/net_processing.cpp --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -2105,6 +2105,32 @@ } } + // Process avalanche proof items. For now, we only respond to requests + // for our local proof and ignore all other requests. + while (it != pfrom.vRecvGetData.end() && it->type == MSG_AVA_PROOF) { + if (interruptMsgProc) { + return; + } + + const CInv &inv = *it++; + + if (!g_avalanche || + !gArgs.GetBoolArg("-enableavalanche", AVALANCHE_DEFAULT_ENABLED) || + !gArgs.IsArgSet("-avaproof")) { + vNotFound.push_back(inv); + } else { + const avalanche::ProofId proofid{inv.hash}; + const avalanche::Proof proof = g_avalanche->getProof(); + if (proofid == proof.getId()) { + connman.PushMessage(&pfrom, + msgMaker.Make(NetMsgType::AVAPROOF, proof)); + } else { + // TODO: relay all proofs that we have. + vNotFound.push_back(inv); + } + } + } + // Only process one BLOCK item per call, since they're uncommon and can be // expensive to process. if (it != pfrom.vRecvGetData.end() && !pfrom.fPauseSend) { @@ -4023,10 +4049,31 @@ } CHashVerifier verifier(&vRecv); - avalanche::Delegation &delegation = pfrom.m_avalanche_state->delegation; - verifier >> delegation; + verifier >> pfrom.m_avalanche_state->delegation; + verifier >> pfrom.m_avalanche_state->sig; + + // Request the proof (TODO: do it only if we don't already have it). + std::vector vGetData; + vGetData.emplace_back(CInv( + MSG_AVA_PROOF, pfrom.m_avalanche_state->delegation.getProofId())); + m_connman.PushMessage(&pfrom, + msgMaker.Make(NetMsgType::GETDATA, vGetData)); + } + if (msg_type == NetMsgType::AVAPROOF) { avalanche::Proof proof; + vRecv >> proof; + + // Get the delegated pubkey. + const avalanche::Delegation &delegation = + pfrom.m_avalanche_state->delegation; + + if (proof.getId() != delegation.getProofId()) { + // For now we don't support reception of proofs that do not + // belong to the sender. + return; + } + avalanche::DelegationState state; CPubKey pubkey; if (!delegation.verify(state, proof, pubkey)) { @@ -4034,8 +4081,21 @@ return; } - SchnorrSig sig; - verifier >> sig; + // Use the delegated pubkey to verify the AVAHELLO signature. + const uint256 hash = g_avalanche->buildRemoteSighash(&pfrom); + if (!pubkey.VerifySchnorr(hash, pfrom.m_avalanche_state->sig)) { + Misbehaving(pfrom, 100, "invalid-avalanche-signature"); + return; + } + + // Add the node to the avalanche peers. + if (g_avalanche->addNode(pfrom.GetId(), proof, delegation)) { + LogPrint(BCLog::NET, "added avalanche node=%d\n", pfrom.GetId()); + } else { + // TODO: figure out if the proof is bad, and if so add it to a + // recentRejects filter + } + return; } if (msg_type == NetMsgType::AVAPOLL) { diff --git a/test/functional/abc_p2p_avalanche.py b/test/functional/abc_p2p_avalanche.py --- a/test/functional/abc_p2p_avalanche.py +++ b/test/functional/abc_p2p_avalanche.py @@ -4,18 +4,29 @@ # file COPYING or http://www.opensource.org/licenses/mit-license.php. """Test the resolution of forks via avalanche.""" import random +import struct +from typing import List from test_framework.avatools import create_coinbase_stakes from test_framework.key import ( + bytes_to_wif, ECKey, ECPubKey, ) from test_framework.mininode import P2PInterface, mininode_lock from test_framework.messages import ( + AvalancheDelegation, + AvalancheProof, AvalancheResponse, AvalancheVote, CInv, + FromHex, + hash256, + MSG_AVA_PROOF, msg_avapoll, + msg_avahello, + msg_avaproof, + msg_getdata, msg_tcpavaresponse, NODE_AVALANCHE, NODE_NETWORK, @@ -36,6 +47,7 @@ BLOCK_PENDING = -3 QUORUM_NODE_COUNT = 16 +DUMMY_PROOFID = 1337 class TestNode(P2PInterface): @@ -45,6 +57,7 @@ self.avahello = None self.avaresponses = [] self.avapolls = [] + self.avaproof = None super().__init__() def peer_connect(self, *args, **kwargs): @@ -110,6 +123,44 @@ with mininode_lock: return self.avahello + def send_avahello(self, delegation_hex: str, privkey: ECKey): + msg = msg_avahello() + msg.hello.delegation = FromHex(AvalancheDelegation(), delegation_hex) + + def get_sighash(): + b = msg.hello.delegation.getid() + b += struct.pack(" P2PInterface)") + interface = get_node() - avahello = poll_node.wait_for_avahello().hello + avahello = interface.wait_for_avahello().hello avakey.set(bytes.fromhex(node.getavalanchekey())) assert avakey.verify_schnorr( - avahello.sig, avahello.get_sighash(poll_node)) + avahello.sig, avahello.get_sighash(interface)) + + self.log.info("Ask for the proof") + interface.send_getdata( + [CInv(MSG_AVA_PROOF, avahello.delegation.proofid)]) + avaproof = interface.wait_for_avaproof() + assert avaproof.proof == FromHex(AvalancheProof(), proof) + + self.log.info("Test the avalanche handshake (P2PInterface -> node)") + # Create a different valid proof + stakes = create_coinbase_stakes(node, [blockhashes[1]], addrkey0.key) + interface_proof_hex = node.buildavalancheproof( + proof_sequence, proof_expiration, pubkey.get_bytes().hex(), + stakes) + # delegate + delegated_privkey = ECKey() + delegated_privkey.generate() + interface_delegation_hex = node.delegateavalancheproof( + interface_proof_hex, + bytes_to_wif(privkey.get_bytes()), + delegated_privkey.get_pubkey().get_bytes().hex(), + None + ) + + interface.send_avahello(interface_delegation_hex, delegated_privkey) + expected_proofid = FromHex( + AvalancheProof(), + interface_proof_hex).proofid + interface.wait_for_getdata([expected_proofid]) + + self.log.info("Test that node adds an avalanche peer") + interface.send_avaproof(interface_proof_hex) + + wait_until( + lambda: len(node.getavalanchepeerinfo()) > 0, + timeout=5, + lock=mininode_lock) if __name__ == '__main__': diff --git a/test/functional/test_framework/messages.py b/test/functional/test_framework/messages.py --- a/test/functional/test_framework/messages.py +++ b/test/functional/test_framework/messages.py @@ -373,6 +373,9 @@ def __repr__(self): return "COutPoint(hash={:064x} n={})".format(self.hash, self.n) + def __eq__(self, other): + return self.hash == other.hash and self.n == other.n + class CTxIn: __slots__ = ("nSequence", "prevout", "scriptSig") @@ -874,6 +877,14 @@ f" height={self.height}, " \ f"pubkey={self.pubkey.hex()})" + def __eq__(self, other): + return ( + self.utxo == other.utxo and + self.amount == other.amount and + self.height == other.height and + self.pubkey == other.pubkey + ) + class AvalancheSignedStake: def __init__(self, stake=None, sig=b""): @@ -889,6 +900,9 @@ def serialize(self) -> bytes: return self.stake.serialize() + self.sig + def __eq__(self, other): + return self.stake == other.stake and self.sig == other.sig + class AvalancheProof: __slots__ = ("sequence", "expiration", "master", "stakes", "proofid") @@ -939,6 +953,19 @@ f"master={self.master.hex()}, " \ f"stakes={self.stakes})" + def __eq__(self, other): + if len(self.stakes) != len(other.stakes): + return False + is_equal = ( + self.sequence == other.sequence and + self.expiration == other.expiration and + self.master == other.master and + self.proofid == other.proofid + ) + for s0, s1 in zip(self.stakes, other.stakes): + is_equal = is_equal and s0 == s1 + return is_equal + class AvalanchePoll(): __slots__ = ("round", "invs")