diff --git a/test/functional/p2p_getdata.py b/test/functional/p2p_getdata.py index 56fa08894..aa2a607f9 100755 --- a/test/functional/p2p_getdata.py +++ b/test/functional/p2p_getdata.py @@ -1,58 +1,51 @@ #!/usr/bin/env python3 # Copyright (c) 2020 The Bitcoin Core developers # Distributed under the MIT software license, see the accompanying # file COPYING or http://www.opensource.org/licenses/mit-license.php. """Test GETDATA processing behavior""" from collections import defaultdict from test_framework.messages import ( CInv, msg_getdata, ) -from test_framework.mininode import ( - mininode_lock, - P2PInterface, -) +from test_framework.mininode import P2PInterface from test_framework.test_framework import BitcoinTestFramework -from test_framework.util import wait_until class P2PStoreBlock(P2PInterface): - def __init__(self): super().__init__() self.blocks = defaultdict(int) def on_block(self, message): message.block.calc_sha256() self.blocks[message.block.sha256] += 1 class GetdataTest(BitcoinTestFramework): def set_test_params(self): self.num_nodes = 1 def run_test(self): - self.nodes[0].add_p2p_connection(P2PStoreBlock()) + p2p_block_store = self.nodes[0].add_p2p_connection(P2PStoreBlock()) self.log.info( "test that an invalid GETDATA doesn't prevent processing of future messages") # Send invalid message and verify that node responds to later ping invalid_getdata = msg_getdata() invalid_getdata.inv.append(CInv(t=0, h=0)) # INV type 0 is invalid. - self.nodes[0].p2ps[0].send_and_ping(invalid_getdata) + p2p_block_store.send_and_ping(invalid_getdata) # Check getdata still works by fetching tip block best_block = int(self.nodes[0].getbestblockhash(), 16) good_getdata = msg_getdata() good_getdata.inv.append(CInv(t=2, h=best_block)) - self.nodes[0].p2ps[0].send_and_ping(good_getdata) - wait_until( - lambda: self.nodes[0].p2ps[0].blocks[best_block] == 1, - timeout=30, - lock=mininode_lock) + p2p_block_store.send_and_ping(good_getdata) + p2p_block_store.wait_until( + lambda: self.nodes[0].p2ps[0].blocks[best_block] == 1) if __name__ == '__main__': GetdataTest().main() diff --git a/test/functional/test_framework/mininode.py b/test/functional/test_framework/mininode.py index 81d93dac8..35ab6cff4 100755 --- a/test/functional/test_framework/mininode.py +++ b/test/functional/test_framework/mininode.py @@ -1,777 +1,777 @@ #!/usr/bin/env python3 # Copyright (c) 2010 ArtForz -- public domain half-a-node # Copyright (c) 2012 Jeff Garzik # Copyright (c) 2010-2019 The Bitcoin Core developers # Distributed under the MIT software license, see the accompanying # file COPYING or http://www.opensource.org/licenses/mit-license.php. """Bitcoin P2P network half-a-node. This python code was modified from ArtForz' public domain half-a-node, as found in the mini-node branch of http://github.com/jgarzik/pynode. P2PConnection: A low-level connection object to a node's P2P interface P2PInterface: A high-level interface object for communicating to a node over P2P P2PDataStore: A p2p interface class that keeps a store of transactions and blocks and can respond correctly to getdata and getheaders messages P2PTxInvStore: A p2p interface class that inherits from P2PDataStore, and keeps a count of how many times each txid has been announced.""" import asyncio from collections import defaultdict from io import BytesIO import logging import struct import sys import threading from test_framework.messages import ( CBlockHeader, MIN_VERSION_SUPPORTED, msg_addr, msg_avapoll, msg_tcpavaresponse, msg_avahello, msg_block, MSG_BLOCK, msg_blocktxn, msg_cfcheckpt, msg_cfheaders, msg_cfilter, msg_cmpctblock, msg_feefilter, msg_filteradd, msg_filterclear, msg_filterload, msg_getaddr, msg_getblocks, msg_getblocktxn, msg_getdata, msg_getheaders, msg_headers, msg_inv, msg_mempool, msg_merkleblock, msg_notfound, msg_ping, msg_pong, msg_sendcmpct, msg_sendheaders, msg_tx, MSG_TX, MSG_TYPE_MASK, msg_verack, msg_version, NODE_NETWORK, sha256, ) from test_framework.util import wait_until logger = logging.getLogger("TestFramework.mininode") MESSAGEMAP = { b"addr": msg_addr, b"avapoll": msg_avapoll, b"avaresponse": msg_tcpavaresponse, b"avahello": msg_avahello, b"block": msg_block, b"blocktxn": msg_blocktxn, b"cfcheckpt": msg_cfcheckpt, b"cfheaders": msg_cfheaders, b"cfilter": msg_cfilter, b"cmpctblock": msg_cmpctblock, b"feefilter": msg_feefilter, b"filteradd": msg_filteradd, b"filterclear": msg_filterclear, b"filterload": msg_filterload, b"getaddr": msg_getaddr, b"getblocks": msg_getblocks, b"getblocktxn": msg_getblocktxn, b"getdata": msg_getdata, b"getheaders": msg_getheaders, b"headers": msg_headers, b"inv": msg_inv, b"mempool": msg_mempool, b"merkleblock": msg_merkleblock, b"notfound": msg_notfound, b"ping": msg_ping, b"pong": msg_pong, b"sendcmpct": msg_sendcmpct, b"sendheaders": msg_sendheaders, b"tx": msg_tx, b"verack": msg_verack, b"version": msg_version, } MAGIC_BYTES = { "mainnet": b"\xe3\xe1\xf3\xe8", "testnet3": b"\xf4\xe5\xf3\xf4", "regtest": b"\xda\xb5\xbf\xfa", } class P2PConnection(asyncio.Protocol): """A low-level connection object to a node's P2P interface. This class is responsible for: - opening and closing the TCP connection to the node - reading bytes from and writing bytes to the socket - deserializing and serializing the P2P message header - logging messages as they are sent and received This class contains no logic for handing the P2P message payloads. It must be sub-classed and the on_message() callback overridden.""" def __init__(self): # The underlying transport of the connection. # Should only call methods on this from the NetworkThread, c.f. # call_soon_threadsafe self._transport = None @property def is_connected(self): return self._transport is not None def peer_connect(self, dstaddr, dstport, *, net, timeout_factor): assert not self.is_connected self.timeout_factor = timeout_factor self.dstaddr = dstaddr self.dstport = dstport # The initial message to send after the connection was made: self.on_connection_send_msg = None self.on_connection_send_msg_is_raw = False self.recvbuf = b"" self.magic_bytes = MAGIC_BYTES[net] logger.debug('Connecting to Bitcoin Node: {}:{}'.format( self.dstaddr, self.dstport)) loop = NetworkThread.network_event_loop conn_gen_unsafe = loop.create_connection( lambda: self, host=self.dstaddr, port=self.dstport) def conn_gen(): return loop.call_soon_threadsafe( loop.create_task, conn_gen_unsafe) return conn_gen def peer_disconnect(self): # Connection could have already been closed by other end. NetworkThread.network_event_loop.call_soon_threadsafe( lambda: self._transport and self._transport.abort()) # Connection and disconnection methods def connection_made(self, transport): """asyncio callback when a connection is opened.""" assert not self._transport logger.debug("Connected & Listening: {}:{}".format( self.dstaddr, self.dstport)) self._transport = transport if self.on_connection_send_msg: if self.on_connection_send_msg_is_raw: self.send_raw_message(self.on_connection_send_msg) else: self.send_message(self.on_connection_send_msg) # Never used again self.on_connection_send_msg = None self.on_open() def connection_lost(self, exc): """asyncio callback when a connection is closed.""" if exc: logger.warning("Connection lost to {}:{} due to {}".format( self.dstaddr, self.dstport, exc)) else: logger.debug("Closed connection to: {}:{}".format( self.dstaddr, self.dstport)) self._transport = None self.recvbuf = b"" self.on_close() # Socket read methods def data_received(self, t): """asyncio callback when data is read from the socket.""" with mininode_lock: if len(t) > 0: self.recvbuf += t while True: msg = self._on_data() if msg is None: break self.on_message(msg) def _on_data(self): """Try to read P2P messages from the recv buffer. This method reads data from the buffer in a loop. It deserializes, parses and verifies the P2P header, then passes the P2P payload to the on_message callback for processing.""" try: with mininode_lock: if len(self.recvbuf) < 4: return None if self.recvbuf[:4] != self.magic_bytes: raise ValueError( "magic bytes mismatch: {} != {}".format( repr( self.magic_bytes), repr( self.recvbuf))) if len(self.recvbuf) < 4 + 12 + 4 + 4: return None msgtype = self.recvbuf[4:4 + 12].split(b"\x00", 1)[0] msglen = struct.unpack( " 500: log_message += "... (msg truncated)" logger.debug(log_message) class P2PInterface(P2PConnection): """A high-level P2P interface class for communicating with a Bitcoin Cash node. This class provides high-level callbacks for processing P2P message payloads, as well as convenience methods for interacting with the node over P2P. Individual testcases should subclass this and override the on_* methods if they want to alter message handling behaviour.""" def __init__(self): super().__init__() # Track number of messages of each type received and the most recent # message of each type self.message_count = defaultdict(int) self.last_message = {} # A count of the number of ping messages we've sent to the node self.ping_counter = 1 # The network services received from the peer self.nServices = 0 def peer_connect(self, *args, services=NODE_NETWORK, send_version=True, **kwargs): create_conn = super().peer_connect(*args, **kwargs) if send_version: # Send a version msg vt = msg_version() vt.nServices = services vt.addrTo.ip = self.dstaddr vt.addrTo.port = self.dstport vt.addrFrom.ip = "0.0.0.0" vt.addrFrom.port = 0 # Will be sent soon after connection_made self.on_connection_send_msg = vt return create_conn # Message receiving methods def on_message(self, message): """Receive message and dispatch message to appropriate callback. We keep a count of how many of each message type has been received and the most recent message of each type.""" with mininode_lock: try: msgtype = message.msgtype.decode('ascii') self.message_count[msgtype] += 1 self.last_message[msgtype] = message getattr(self, 'on_' + msgtype)(message) except Exception: print("ERROR delivering {} ({})".format( repr(message), sys.exc_info()[0])) raise # Callback methods. Can be overridden by subclasses in individual test # cases to provide custom message handling behaviour. def on_open(self): pass def on_close(self): pass def on_addr(self, message): pass def on_avapoll(self, message): pass def on_avaresponse(self, message): pass def on_avahello(self, message): pass def on_block(self, message): pass def on_blocktxn(self, message): pass def on_cfcheckpt(self, message): pass def on_cfheaders(self, message): pass def on_cfilter(self, message): pass def on_cmpctblock(self, message): pass def on_feefilter(self, message): pass def on_filteradd(self, message): pass def on_filterclear(self, message): pass def on_filterload(self, message): pass def on_getaddr(self, message): pass def on_getblocks(self, message): pass def on_getblocktxn(self, message): pass def on_getdata(self, message): pass def on_getheaders(self, message): pass def on_headers(self, message): pass def on_mempool(self, message): pass def on_merkleblock(self, message): pass def on_notfound(self, message): pass def on_pong(self, message): pass def on_sendcmpct(self, message): pass def on_sendheaders(self, message): pass def on_tx(self, message): pass def on_inv(self, message): want = msg_getdata() for i in message.inv: if i.type != 0: want.inv.append(i) if len(want.inv): self.send_message(want) def on_ping(self, message): self.send_message(msg_pong(message.nonce)) def on_verack(self, message): pass def on_version(self, message): assert message.nVersion >= MIN_VERSION_SUPPORTED, "Version {} received. Test framework only supports versions greater than {}".format( message.nVersion, MIN_VERSION_SUPPORTED) self.send_message(msg_verack()) self.nServices = message.nServices # Connection helper methods - def wait_until(self, test_function, timeout): + def wait_until(self, test_function, timeout=60): wait_until(test_function, timeout=timeout, lock=mininode_lock, timeout_factor=self.timeout_factor) def wait_for_disconnect(self, timeout=60): def test_function(): return not self.is_connected self.wait_until(test_function, timeout=timeout) # Message receiving helper methods def wait_for_tx(self, txid, timeout=60): def test_function(): assert self.is_connected if not self.last_message.get('tx'): return False return self.last_message['tx'].tx.rehash() == txid self.wait_until(test_function, timeout=timeout) def wait_for_block(self, blockhash, timeout=60): def test_function(): assert self.is_connected return self.last_message.get( "block") and self.last_message["block"].block.rehash() == blockhash self.wait_until(test_function, timeout=timeout) def wait_for_header(self, blockhash, timeout=60): def test_function(): assert self.is_connected last_headers = self.last_message.get('headers') if not last_headers: return False return last_headers.headers[0].rehash() == int(blockhash, 16) self.wait_until(test_function, timeout=timeout) def wait_for_merkleblock(self, blockhash, timeout=60): def test_function(): assert self.is_connected last_filtered_block = self.last_message.get('merkleblock') if not last_filtered_block: return False return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16) self.wait_until(test_function, timeout=timeout) def wait_for_getdata(self, hash_list, timeout=60): """Waits for a getdata message. The object hashes in the inventory vector must match the provided hash_list.""" def test_function(): assert self.is_connected last_data = self.last_message.get("getdata") if not last_data: return False return [x.hash for x in last_data.inv] == hash_list self.wait_until(test_function, timeout=timeout) def wait_for_getheaders(self, timeout=60): """Waits for a getheaders message. Receiving any getheaders message will satisfy the predicate. the last_message["getheaders"] value must be explicitly cleared before calling this method, or this will return immediately with success. TODO: change this method to take a hash value and only return true if the correct block header has been requested.""" def test_function(): assert self.is_connected return self.last_message.get("getheaders") self.wait_until(test_function, timeout=timeout) def wait_for_inv(self, expected_inv, timeout=60): """Waits for an INV message and checks that the first inv object in the message was as expected.""" if len(expected_inv) > 1: raise NotImplementedError( "wait_for_inv() will only verify the first inv object") def test_function(): assert self.is_connected return self.last_message.get("inv") and \ self.last_message["inv"].inv[0].type == expected_inv[0].type and \ self.last_message["inv"].inv[0].hash == expected_inv[0].hash self.wait_until(test_function, timeout=timeout) def wait_for_verack(self, timeout=60): def test_function(): return self.message_count["verack"] self.wait_until(test_function, timeout=timeout) # Message sending helper functions def send_and_ping(self, message, timeout=60): self.send_message(message) self.sync_with_ping(timeout=timeout) # Sync up with the node def sync_with_ping(self, timeout=60): self.send_message(msg_ping(nonce=self.ping_counter)) def test_function(): assert self.is_connected return self.last_message.get( "pong") and self.last_message["pong"].nonce == self.ping_counter self.wait_until(test_function, timeout=timeout) self.ping_counter += 1 # One lock for synchronizing all data access between the networking thread (see # NetworkThread below) and the thread running the test logic. For simplicity, # P2PConnection acquires this lock whenever delivering a message to a P2PInterface. # This lock should be acquired in the thread running the test logic to synchronize # access to any data shared with the P2PInterface or P2PConnection. mininode_lock = threading.RLock() class NetworkThread(threading.Thread): network_event_loop = None def __init__(self): super().__init__(name="NetworkThread") # There is only one event loop and no more than one thread must be # created assert not self.network_event_loop NetworkThread.network_event_loop = asyncio.new_event_loop() def run(self): """Start the network thread.""" self.network_event_loop.run_forever() def close(self, timeout=10): """Close the connections and network event loop.""" self.network_event_loop.call_soon_threadsafe( self.network_event_loop.stop) wait_until(lambda: not self.network_event_loop.is_running(), timeout=timeout) self.network_event_loop.close() self.join(timeout) # Safe to remove event loop. NetworkThread.network_event_loop = None class P2PDataStore(P2PInterface): """A P2P data store class. Keeps a block and transaction store and responds correctly to getdata and getheaders requests.""" def __init__(self): super().__init__() # store of blocks. key is block hash, value is a CBlock object self.block_store = {} self.last_block_hash = '' # store of txs. key is txid, value is a CTransaction object self.tx_store = {} self.getdata_requests = [] def on_getdata(self, message): """Check for the tx/block in our stores and if found, reply with an inv message.""" for inv in message.inv: self.getdata_requests.append(inv.hash) if (inv.type & MSG_TYPE_MASK) == MSG_TX and inv.hash in self.tx_store.keys(): self.send_message(msg_tx(self.tx_store[inv.hash])) elif (inv.type & MSG_TYPE_MASK) == MSG_BLOCK and inv.hash in self.block_store.keys(): self.send_message(msg_block(self.block_store[inv.hash])) else: logger.debug( 'getdata message type {} received.'.format(hex(inv.type))) def on_getheaders(self, message): """Search back through our block store for the locator, and reply with a headers message if found.""" locator, hash_stop = message.locator, message.hashstop # Assume that the most recent block added is the tip if not self.block_store: return headers_list = [self.block_store[self.last_block_hash]] maxheaders = 2000 while headers_list[-1].sha256 not in locator.vHave: # Walk back through the block store, adding headers to headers_list # as we go. prev_block_hash = headers_list[-1].hashPrevBlock if prev_block_hash in self.block_store: prev_block_header = CBlockHeader( self.block_store[prev_block_hash]) headers_list.append(prev_block_header) if prev_block_header.sha256 == hash_stop: # if this is the hashstop header, stop here break else: logger.debug('block hash {} not found in block store'.format( hex(prev_block_hash))) break # Truncate the list if there are too many headers headers_list = headers_list[:-maxheaders - 1:-1] response = msg_headers(headers_list) if response is not None: self.send_message(response) def send_blocks_and_test(self, blocks, node, *, success=True, force_send=False, reject_reason=None, expect_disconnect=False, timeout=60): """Send blocks to test node and test whether the tip advances. - add all blocks to our block_store - send a headers message for the final block - the on_getheaders handler will ensure that any getheaders are responded to - if force_send is False: wait for getdata for each of the blocks. The on_getdata handler will ensure that any getdata messages are responded to. Otherwise send the full block unsolicited. - if success is True: assert that the node's tip advances to the most recent block - if success is False: assert that the node's tip doesn't advance - if reject_reason is set: assert that the correct reject message is logged""" with mininode_lock: for block in blocks: self.block_store[block.sha256] = block self.last_block_hash = block.sha256 def test(): if force_send: for b in blocks: self.send_message(msg_block(block=b)) else: self.send_message( msg_headers([CBlockHeader(block) for block in blocks])) self.wait_until( lambda: blocks[-1].sha256 in self.getdata_requests, timeout=timeout) if expect_disconnect: self.wait_for_disconnect(timeout=timeout) else: self.sync_with_ping(timeout=timeout) if success: self.wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout) else: assert node.getbestblockhash() != blocks[-1].hash if reject_reason: with node.assert_debug_log(expected_msgs=[reject_reason]): test() else: test() def send_txs_and_test(self, txs, node, *, success=True, expect_disconnect=False, reject_reason=None): """Send txs to test node and test whether they're accepted to the mempool. - add all txs to our tx_store - send tx messages for all txs - if success is True/False: assert that the txs are/are not accepted to the mempool - if expect_disconnect is True: Skip the sync with ping - if reject_reason is set: assert that the correct reject message is logged.""" with mininode_lock: for tx in txs: self.tx_store[tx.sha256] = tx def test(): for tx in txs: self.send_message(msg_tx(tx)) if expect_disconnect: self.wait_for_disconnect() else: self.sync_with_ping() raw_mempool = node.getrawmempool() if success: # Check that all txs are now in the mempool for tx in txs: assert tx.hash in raw_mempool, "{} not found in mempool".format( tx.hash) else: # Check that none of the txs are now in the mempool for tx in txs: assert tx.hash not in raw_mempool, "{} tx found in mempool".format( tx.hash) if reject_reason: with node.assert_debug_log(expected_msgs=[reject_reason]): test() else: test() class P2PTxInvStore(P2PInterface): """A P2PInterface which stores a count of how many times each txid has been announced.""" def __init__(self): super().__init__() self.tx_invs_received = defaultdict(int) def on_inv(self, message): # Send getdata in response. super().on_inv(message) # Store how many times invs have been received for each tx. for i in message.inv: if i.type == MSG_TX: # save txid self.tx_invs_received[i.hash] += 1 def get_invs(self): with mininode_lock: return list(self.tx_invs_received.keys()) def wait_for_broadcast(self, txns, timeout=60): """Waits for the txns (list of txids) to complete initial broadcast. The mempool should mark unbroadcast=False for these transactions. """ # Wait until invs have been received (and getdatas sent) for each txid. self.wait_until(lambda: set(self.get_invs()) == set( [int(tx, 16) for tx in txns]), timeout) # Flush messages and wait for the getdatas to be processed self.sync_with_ping()