diff --git a/test/functional/abc-p2p-compactblocks.py b/test/functional/abc-p2p-compactblocks.py --- a/test/functional/abc-p2p-compactblocks.py +++ b/test/functional/abc-p2p-compactblocks.py @@ -39,17 +39,17 @@ self.last_headers = None super().__init__() - def on_sendcmpct(self, conn, message): + def on_sendcmpct(self, message): self.last_sendcmpct = message - def on_cmpctblock(self, conn, message): + def on_cmpctblock(self, message): self.last_cmpctblock = message self.last_cmpctblock.header_and_shortids.header.calc_sha256() - def on_getheaders(self, conn, message): + def on_getheaders(self, message): self.last_getheaders = message - def on_headers(self, conn, message): + def on_headers(self, message): self.last_headers = message for x in self.last_headers.headers: x.calc_sha256() @@ -273,7 +273,7 @@ # Check that compact block also work for big blocks node = self.nodes[0] peer = TestNode() - peer.add_connection(NodeConn('127.0.0.1', p2p_port(0), peer)) + peer.peer_connect('127.0.0.1', p2p_port(0)) # Wait for connection to be etablished peer.wait_for_verack() diff --git a/test/functional/abc-sync-chain.py b/test/functional/abc-sync-chain.py --- a/test/functional/abc-sync-chain.py +++ b/test/functional/abc-sync-chain.py @@ -44,8 +44,7 @@ def run_test(self): node0conn = BaseNode() - node0conn.add_connection( - NodeConn('127.0.0.1', p2p_port(0), node0conn)) + node0conn.peer_connect('127.0.0.1', p2p_port(0)) NetworkThread().start() node0conn.wait_for_verack() diff --git a/test/functional/assumevalid.py b/test/functional/assumevalid.py --- a/test/functional/assumevalid.py +++ b/test/functional/assumevalid.py @@ -69,7 +69,7 @@ def send_blocks_until_disconnected(self, p2p_conn): """Keep sending blocks to the node until we're disconnected.""" for i in range(len(self.blocks)): - if not p2p_conn.connection: + if p2p_conn.state != "connected": break try: p2p_conn.send_message(msg_block(self.blocks[i])) diff --git a/test/functional/example_test.py b/test/functional/example_test.py --- a/test/functional/example_test.py +++ b/test/functional/example_test.py @@ -51,14 +51,14 @@ # Stores a dictionary of all blocks received self.block_receive_map = defaultdict(int) - def on_block(self, conn, message): + def on_block(self, message): """Override the standard on_block callback Store the hash of a received block in the dictionary.""" message.block.calc_sha256() self.block_receive_map[message.block.sha256] += 1 - def on_inv(self, conn, message): + def on_inv(self, message): """Override the standard on_inv callback""" pass diff --git a/test/functional/maxuploadtarget.py b/test/functional/maxuploadtarget.py --- a/test/functional/maxuploadtarget.py +++ b/test/functional/maxuploadtarget.py @@ -28,10 +28,10 @@ super().__init__() self.block_receive_map = defaultdict(int) - def on_inv(self, conn, message): + def on_inv(self, message): pass - def on_block(self, conn, message): + def on_block(self, message): message.block.calc_sha256() self.block_receive_map[message.block.sha256] += 1 diff --git a/test/functional/p2p-compactblocks.py b/test/functional/p2p-compactblocks.py --- a/test/functional/p2p-compactblocks.py +++ b/test/functional/p2p-compactblocks.py @@ -30,22 +30,22 @@ # so we can eg wait until a particular block is announced. self.announced_blockhashes = set() - def on_sendcmpct(self, conn, message): + def on_sendcmpct(self, message): self.last_sendcmpct.append(message) - def on_cmpctblock(self, conn, message): + def on_cmpctblock(self, message): self.block_announced = True self.last_message["cmpctblock"].header_and_shortids.header.calc_sha256() self.announced_blockhashes.add( self.last_message["cmpctblock"].header_and_shortids.header.sha256) - def on_headers(self, conn, message): + def on_headers(self, message): self.block_announced = True for x in self.last_message["headers"].headers: x.calc_sha256() self.announced_blockhashes.add(x.sha256) - def on_inv(self, conn, message): + def on_inv(self, message): for x in self.last_message["inv"].inv: if x.type == 2: self.block_announced = True @@ -66,7 +66,7 @@ msg = msg_getheaders() msg.locator.vHave = locator msg.hashstop = hashstop - self.connection.send_message(msg) + self.send_message(msg) def send_header_for_blocks(self, new_blocks): headers_message = msg_headers() @@ -93,7 +93,7 @@ This is used when we want to send a message into the node that we expect will get us disconnected, eg an invalid block.""" self.send_message(message) - wait_until(lambda: not self.connected, + wait_until(lambda: self.state != "connected", timeout=timeout, lock=mininode_lock) diff --git a/test/functional/p2p-feefilter.py b/test/functional/p2p-feefilter.py --- a/test/functional/p2p-feefilter.py +++ b/test/functional/p2p-feefilter.py @@ -34,7 +34,7 @@ super().__init__() self.txinvs = [] - def on_inv(self, conn, message): + def on_inv(self, message): for i in message.inv: if (i.type == 1): self.txinvs.append(hashToHex(i.hash)) diff --git a/test/functional/p2p-leaktests.py b/test/functional/p2p-leaktests.py --- a/test/functional/p2p-leaktests.py +++ b/test/functional/p2p-leaktests.py @@ -27,51 +27,50 @@ self.unexpected_msg = True self.log.info("should not have received message: %s" % message.command) - def on_open(self, conn): - self.connected = True + def on_open(self): self.ever_connected = True - def on_version(self, conn, message): self.bad_message(message) + def on_version(self, message): self.bad_message(message) - def on_verack(self, conn, message): self.bad_message(message) + def on_verack(self, message): self.bad_message(message) - def on_reject(self, conn, message): self.bad_message(message) + def on_reject(self, message): self.bad_message(message) - def on_inv(self, conn, message): self.bad_message(message) + def on_inv(self, message): self.bad_message(message) - def on_addr(self, conn, message): self.bad_message(message) + def on_addr(self, message): self.bad_message(message) - def on_getdata(self, conn, message): self.bad_message(message) + def on_getdata(self, message): self.bad_message(message) - def on_getblocks(self, conn, message): self.bad_message(message) + def on_getblocks(self, message): self.bad_message(message) - def on_tx(self, conn, message): self.bad_message(message) + def on_tx(self, message): self.bad_message(message) - def on_block(self, conn, message): self.bad_message(message) + def on_block(self, message): self.bad_message(message) - def on_getaddr(self, conn, message): self.bad_message(message) + def on_getaddr(self, message): self.bad_message(message) - def on_headers(self, conn, message): self.bad_message(message) + def on_headers(self, message): self.bad_message(message) - def on_getheaders(self, conn, message): self.bad_message(message) + def on_getheaders(self, message): self.bad_message(message) - def on_ping(self, conn, message): self.bad_message(message) + def on_ping(self, message): self.bad_message(message) - def on_mempool(self, conn): self.bad_message(message) + def on_mempool(self, message): self.bad_message(message) - def on_pong(self, conn, message): self.bad_message(message) + def on_pong(self, message): self.bad_message(message) - def on_feefilter(self, conn, message): self.bad_message(message) + def on_feefilter(self, message): self.bad_message(message) - def on_sendheaders(self, conn, message): self.bad_message(message) + def on_sendheaders(self, message): self.bad_message(message) - def on_sendcmpct(self, conn, message): self.bad_message(message) + def on_sendcmpct(self, message): self.bad_message(message) - def on_cmpctblock(self, conn, message): self.bad_message(message) + def on_cmpctblock(self, message): self.bad_message(message) - def on_getblocktxn(self, conn, message): self.bad_message(message) + def on_getblocktxn(self, message): self.bad_message(message) - def on_blocktxn(self, conn, message): self.bad_message(message) + def on_blocktxn(self, message): self.bad_message(message) # Node that never sends a version. We'll use this to send a bunch of messages # anyway, and eventually get disconnected. @@ -80,12 +79,12 @@ class CNodeNoVersionBan(CLazyNode): # send a bunch of veracks without sending a message. This should get us disconnected. # NOTE: implementation-specific check here. Remove if bitcoind ban behavior changes - def on_open(self, conn): - super().on_open(conn) + def on_open(self): + super().on_open() for i in range(banscore): self.send_message(msg_verack()) - def on_reject(self, conn, message): pass + def on_reject(self, message): pass # Node that never sends a version. This one just sits idle and hopes to receive # any message (it shouldn't!) @@ -103,17 +102,17 @@ self.version_received = False super().__init__() - def on_reject(self, conn, message): pass + def on_reject(self, message): pass - def on_verack(self, conn, message): pass + def on_verack(self, message): pass # When version is received, don't reply with a verack. Instead, see if the # node will give us a message that it shouldn't. This is not an exhaustive # list! - def on_version(self, conn, message): + def on_version(self, message): self.version_received = True - conn.send_message(msg_ping()) - conn.send_message(msg_getaddr()) + self.send_message(msg_ping()) + self.send_message(msg_getaddr()) class P2PLeakTest(BitcoinTestFramework): @@ -145,7 +144,7 @@ time.sleep(5) # This node should have been banned - assert not no_version_bannode.connected + assert no_version_bannode.state != "connected" self.nodes[0].disconnect_p2ps() diff --git a/test/functional/p2p-timeouts.py b/test/functional/p2p-timeouts.py --- a/test/functional/p2p-timeouts.py +++ b/test/functional/p2p-timeouts.py @@ -29,7 +29,7 @@ class TestNode(NodeConnCB): - def on_version(self, conn, message): + def on_version(self, message): # Don't send a verack in response pass diff --git a/test/functional/sendheaders.py b/test/functional/sendheaders.py --- a/test/functional/sendheaders.py +++ b/test/functional/sendheaders.py @@ -110,24 +110,24 @@ msg = msg_getdata() for x in block_hashes: msg.inv.append(CInv(2, x)) - self.connection.send_message(msg) + self.send_message(msg) def get_headers(self, locator, hashstop): msg = msg_getheaders() msg.locator.vHave = locator msg.hashstop = hashstop - self.connection.send_message(msg) + self.send_message(msg) def send_block_inv(self, blockhash): msg = msg_inv() msg.inv = [CInv(2, blockhash)] - self.connection.send_message(msg) + self.send_message(msg) - def on_inv(self, conn, message): + def on_inv(self, message): self.block_announced = True self.last_blockhash_announced = message.inv[-1].hash - def on_headers(self, conn, message): + def on_headers(self, message): if len(message.headers): self.block_announced = True message.headers[-1].calc_sha256() diff --git a/test/functional/test_framework/comptool.py b/test/functional/test_framework/comptool.py --- a/test/functional/test_framework/comptool.py +++ b/test/functional/test_framework/comptool.py @@ -48,7 +48,6 @@ def __init__(self, block_store, tx_store): super().__init__() - self.conn = None self.bestblockhash = None self.block_store = block_store self.block_request_map = {} @@ -63,28 +62,25 @@ self.lastInv = [] self.closed = False - def on_close(self, conn): + def on_close(self): self.closed = True - def add_connection(self, conn): - self.conn = conn - - def on_headers(self, conn, message): + def on_headers(self, message): if len(message.headers) > 0: best_header = message.headers[-1] best_header.calc_sha256() self.bestblockhash = best_header.sha256 - def on_getheaders(self, conn, message): + def on_getheaders(self, message): response = self.block_store.headers_for( message.locator, message.hashstop) if response is not None: - conn.send_message(response) + self.send_message(response) - def on_getdata(self, conn, message): - [conn.send_message(r) + def on_getdata(self, message): + [self.send_message(r) for r in self.block_store.get_blocks(message.inv)] - [conn.send_message(r) + [self.send_message(r) for r in self.tx_store.get_transactions(message.inv)] for i in message.inv: @@ -93,17 +89,17 @@ elif i.type == 2: self.block_request_map[i.hash] = True - def on_inv(self, conn, message): + def on_inv(self, message): self.lastInv = [x.hash for x in message.inv] - def on_pong(self, conn, message): + def on_pong(self, message): try: del self.pingMap[message.nonce] except KeyError: raise AssertionError( "Got pong for unknown ping [%s]" % repr(message)) - def on_reject(self, conn, message): + def on_reject(self, message): if message.message == b'tx': self.tx_reject_map[message.data] = RejectResult( message.code, message.reason) @@ -113,30 +109,30 @@ def send_inv(self, obj): mtype = 2 if isinstance(obj, CBlock) else 1 - self.conn.send_message(msg_inv([CInv(mtype, obj.sha256)])) + self.send_message(msg_inv([CInv(mtype, obj.sha256)])) def send_getheaders(self): # We ask for headers from their last tip. m = msg_getheaders() m.locator = self.block_store.get_locator(self.bestblockhash) - self.conn.send_message(m) + self.send_message(m) def send_header(self, header): m = msg_headers() m.headers.append(header) - self.conn.send_message(m) + self.send_message(m) # This assumes BIP31 def send_ping(self, nonce): self.pingMap[nonce] = True - self.conn.send_message(msg_ping(nonce)) + self.send_message(msg_ping(nonce)) def received_ping_response(self, nonce): return nonce not in self.pingMap def send_mempool(self): self.lastInv = [] - self.conn.send_message(msg_mempool()) + self.send_message(msg_mempool()) # TestInstance: # @@ -179,8 +175,7 @@ def __init__(self, testgen, datadir): self.test_generator = testgen - self.connections = [] - self.test_nodes = [] + self.p2p_connections = [] self.block_store = BlockStore(datadir) self.tx_store = TxStore(datadir) self.ping_counter = 1 @@ -188,29 +183,24 @@ def add_all_connections(self, nodes): for i in range(len(nodes)): # Create a p2p connection to each node - test_node = TestNode(self.block_store, self.tx_store) - self.test_nodes.append(test_node) - self.connections.append( - NodeConn('127.0.0.1', p2p_port(i), test_node)) - # Make sure the TestNode (callback class) has a reference to its - # associated NodeConn - test_node.add_connection(self.connections[-1]) + node = TestNode(self.block_store, self.tx_store) + node.peer_connect('127.0.0.1', p2p_port(i)) + self.p2p_connections.append(node) def clear_all_connections(self): - self.connections = [] - self.test_nodes = [] + self.p2p_connections = [] def wait_for_disconnections(self): def disconnected(): - return all(node.closed for node in self.test_nodes) + return all(node.closed for node in self.p2p_connections) wait_until(disconnected, timeout=10, lock=mininode_lock) def wait_for_verack(self): - return all(node.wait_for_verack() for node in self.test_nodes) + return all(node.wait_for_verack() for node in self.p2p_connections) def wait_for_pings(self, counter): def received_pongs(): - return all(node.received_ping_response(counter) for node in self.test_nodes) + return all(node.received_ping_response(counter) for node in self.p2p_connections) wait_until(received_pongs, lock=mininode_lock) # sync_blocks: Wait for all connections to request the blockhash given @@ -220,7 +210,7 @@ def blocks_requested(): return all( blockhash in node.block_request_map and node.block_request_map[blockhash] - for node in self.test_nodes + for node in self.p2p_connections ) # --> error if not requested @@ -228,10 +218,10 @@ num_blocks, lock=mininode_lock) # Send getheaders message - [c.cb.send_getheaders() for c in self.connections] + [c.send_getheaders() for c in self.p2p_connections] # Send ping and wait for response -- synchronization hack - [c.cb.send_ping(self.ping_counter) for c in self.connections] + [c.send_ping(self.ping_counter) for c in self.p2p_connections] self.wait_for_pings(self.ping_counter) self.ping_counter += 1 @@ -241,7 +231,7 @@ def transaction_requested(): return all( txhash in node.tx_request_map and node.tx_request_map[txhash] - for node in self.test_nodes + for node in self.p2p_connections ) # --> error if not requested @@ -249,38 +239,38 @@ num_events, lock=mininode_lock) # Get the mempool - [c.cb.send_mempool() for c in self.connections] + [c.send_mempool() for c in self.p2p_connections] # Send ping and wait for response -- synchronization hack - [c.cb.send_ping(self.ping_counter) for c in self.connections] + [c.send_ping(self.ping_counter) for c in self.p2p_connections] self.wait_for_pings(self.ping_counter) self.ping_counter += 1 # Sort inv responses from each node with mininode_lock: - [c.cb.lastInv.sort() for c in self.connections] + [c.lastInv.sort() for c in self.p2p_connections] # Verify that the tip of each connection all agree with each other, and # with the expected outcome (if given) def check_results(self, blockhash, outcome): with mininode_lock: - for c in self.connections: + for c in self.p2p_connections: if outcome is None: - if c.cb.bestblockhash != self.connections[0].cb.bestblockhash: + if c.bestblockhash != self.p2p_connections[0].bestblockhash: return False # Check that block was rejected w/ code elif isinstance(outcome, RejectResult): - if c.cb.bestblockhash == blockhash: + if c.bestblockhash == blockhash: return False - if blockhash not in c.cb.block_reject_map: + if blockhash not in c.block_reject_map: logger.error( 'Block not in reject map: %064x' % (blockhash)) return False - if not outcome.match(c.cb.block_reject_map[blockhash]): + if not outcome.match(c.block_reject_map[blockhash]): logger.error('Block rejected with %s instead of expected %s: %064x' % ( - c.cb.block_reject_map[blockhash], outcome, blockhash)) + c.block_reject_map[blockhash], outcome, blockhash)) return False - elif ((c.cb.bestblockhash == blockhash) != outcome): + elif ((c.bestblockhash == blockhash) != outcome): return False return True @@ -292,23 +282,23 @@ # a particular tx's existence in the mempool is the same across all nodes. def check_mempool(self, txhash, outcome): with mininode_lock: - for c in self.connections: + for c in self.p2p_connections: if outcome is None: # Make sure the mempools agree with each other - if c.cb.lastInv != self.connections[0].cb.lastInv: + if c.lastInv != self.p2p_connections[0].lastInv: return False # Check that tx was rejected w/ code elif isinstance(outcome, RejectResult): - if txhash in c.cb.lastInv: + if txhash in c.lastInv: return False - if txhash not in c.cb.tx_reject_map: + if txhash not in c.tx_reject_map: logger.error('Tx not in reject map: %064x' % (txhash)) return False - if not outcome.match(c.cb.tx_reject_map[txhash]): + if not outcome.match(c.tx_reject_map[txhash]): logger.error('Tx rejected with %s instead of expected %s: %064x' % ( - c.cb.tx_reject_map[txhash], outcome, txhash)) + c.tx_reject_map[txhash], outcome, txhash)) return False - elif ((txhash in c.cb.lastInv) != outcome): + elif ((txhash in c.lastInv) != outcome): return False return True @@ -350,27 +340,27 @@ first_block_with_hash = False with mininode_lock: self.block_store.add_block(block) - for c in self.connections: - if first_block_with_hash and block.sha256 in c.cb.block_request_map and c.cb.block_request_map[block.sha256] == True: + for c in self.p2p_connections: + if first_block_with_hash and block.sha256 in c.block_request_map and c.block_request_map[block.sha256] == True: # There was a previous request for this block hash # Most likely, we delivered a header for this block # but never had the block to respond to the getdata c.send_message(msg_block(block)) else: - c.cb.block_request_map[block.sha256] = False + c.block_request_map[block.sha256] = False # Either send inv's to each node and sync, or add # to invqueue for later inv'ing. if (test_instance.sync_every_block): # if we expect success, send inv and sync every block # if we expect failure, just push the block and see what happens. if outcome == True: - [c.cb.send_inv(block) for c in self.connections] + [c.send_inv(block) for c in self.p2p_connections] self.sync_blocks(block.sha256, 1) else: [c.send_message(msg_block(block)) - for c in self.connections] - [c.cb.send_ping(self.ping_counter) - for c in self.connections] + for c in self.p2p_connections] + [c.send_ping(self.ping_counter) + for c in self.p2p_connections] self.wait_for_pings(self.ping_counter) self.ping_counter += 1 if (not self.check_results(tip, outcome)): @@ -381,7 +371,7 @@ elif isinstance(b_or_t, CBlockHeader): block_header = b_or_t self.block_store.add_header(block_header) - [c.cb.send_header(block_header) for c in self.connections] + [c.send_header(block_header) for c in self.p2p_connections] else: # Tx test runner assert(isinstance(b_or_t, CTransaction)) @@ -390,11 +380,11 @@ # Add to shared tx store and clear map entry with mininode_lock: self.tx_store.add_transaction(tx) - for c in self.connections: - c.cb.tx_request_map[tx.sha256] = False + for c in self.p2p_connections: + c.tx_request_map[tx.sha256] = False # Again, either inv to all nodes or save for later if (test_instance.sync_every_tx): - [c.cb.send_inv(tx) for c in self.connections] + [c.send_inv(tx) for c in self.p2p_connections] self.sync_transaction(tx.sha256, 1) if (not self.check_mempool(tx.sha256, outcome)): raise AssertionError( @@ -404,14 +394,14 @@ # Ensure we're not overflowing the inv queue if len(invqueue) == MAX_INV_SZ: [c.send_message(msg_inv(invqueue)) - for c in self.connections] + for c in self.p2p_connections] invqueue = [] # Do final sync if we weren't syncing on every block or every tx. if (not test_instance.sync_every_block and block is not None): if len(invqueue) > 0: [c.send_message(msg_inv(invqueue)) - for c in self.connections] + for c in self.p2p_connections] invqueue = [] self.sync_blocks(block.sha256, len( test_instance.blocks_and_transactions)) @@ -421,7 +411,7 @@ if (not test_instance.sync_every_tx and tx is not None): if len(invqueue) > 0: [c.send_message(msg_inv(invqueue)) - for c in self.connections] + for c in self.p2p_connections] invqueue = [] self.sync_transaction(tx.sha256, len( test_instance.blocks_and_transactions)) @@ -432,7 +422,7 @@ logger.info("Test %d: PASS" % test_number) test_number += 1 - [c.disconnect_node() for c in self.connections] + [c.disconnect_node() for c in self.p2p_connections] self.wait_for_disconnections() self.block_store.close() self.tx_store.close() diff --git a/test/functional/test_framework/mininode.py b/test/functional/test_framework/mininode.py --- a/test/functional/test_framework/mininode.py +++ b/test/functional/test_framework/mininode.py @@ -23,6 +23,7 @@ from threading import RLock, Thread from test_framework.messages import * +from test_framework.util import wait_until logger = logging.getLogger("TestFramework.mininode") @@ -58,12 +59,24 @@ class NodeConn(asyncore.dispatcher): - """The actual NodeConn class + """A low-level connection object to a node's P2P interface. - This class provides an interface for a p2p connection to a specified node.""" + This class is responsible for: - def __init__(self, dstaddr, dstport, callback, net="regtest", services=NODE_NETWORK, send_version=True): - asyncore.dispatcher.__init__(self, map=mininode_socket_map) + - opening and closing the TCP connection to the node + - reading bytes from and writing bytes to the socket + - deserializing and serializing the P2P message header + - logging messages as they are sent and received + + This class contains no logic for handing the P2P message payloads. It must be + sub-classed and the on_message() callback overridden. + + TODO: rename this class P2PConnection.""" + + def __init__(self): + super().__init__(map=mininode_socket_map) + + def peer_connect(self, dstaddr, dstport, net="regtest", services=NODE_NETWORK, send_version=True): self.dstaddr = dstaddr self.dstport = dstport self.create_socket(socket.AF_INET, socket.SOCK_STREAM) @@ -72,9 +85,7 @@ self.recvbuf = b"" self.state = "connecting" self.network = net - self.cb = callback self.disconnect = False - self.nServices = 0 if send_version: # stuff version msg into sendbuf @@ -94,6 +105,11 @@ except: self.handle_close() + def peer_disconnect(self): + # Connection could have already been closed by other end. + if self.state == "connected": + self.disconnect_node() + # Connection and disconnection methods def handle_connect(self): @@ -102,7 +118,7 @@ logger.debug("Connected & Listening: %s:%d" % (self.dstaddr, self.dstport)) self.state = "connected" - self.cb.on_open(self) + self.on_open() def handle_close(self): """asyncore callback when a connection is closed.""" @@ -115,7 +131,7 @@ self.close() except: pass - self.cb.on_close(self) + self.on_close() def disconnect_node(self): """Disconnect the p2p connection. @@ -176,8 +192,8 @@ raise def on_message(self, message): - """Callback for processing a P2P payload. Calls into NodeConnCB.""" - self.cb.on_message(self, message) + """Callback for processing a P2P payload. Must be overridden by derived class.""" + raise NotImplementedError # Socket write methods @@ -249,16 +265,20 @@ logger.debug(log_message) -class NodeConnCB(): - """Callback and helper functions for P2P connection to a bitcoind node. +class NodeConnCB(NodeConn): + """A high-level P2P interface class for communicating with a Bitcoin Cash node. + + This class provides high-level callbacks for processing P2P message + payloads, as well as convenience methods for interacting with the + node over P2P. Individual testcases should subclass this and override the on_* methods - if they want to alter message handling behaviour.""" + if they want to alter message handling behaviour. + + TODO: rename this class P2PInterface""" def __init__(self): - # Track whether we have a P2P connection open to the node - self.connected = False - self.connection = None + super().__init__() # Track number of messages of each type received and the most recent # message of each type @@ -268,9 +288,12 @@ # A count of the number of ping messages we've sent to the node self.ping_counter = 1 + # The network services received from the peer + self.nServices = 0 + # Message receiving methods - def on_message(self, conn, message): + def on_message(self, message): """Receive message and dispatch message to appropriate callback. We keep a count of how many of each message type has been received @@ -280,7 +303,7 @@ command = message.command.decode('ascii') self.message_count[command] += 1 self.last_message[command] = message - getattr(self, 'on_' + command)(conn, message) + getattr(self, 'on_' + command)(message) except: print("ERROR delivering %s (%s)" % (repr(message), sys.exc_info()[0])) @@ -289,74 +312,70 @@ # Callback methods. Can be overridden by subclasses in individual test # cases to provide custom message handling behaviour. - def on_open(self, conn): - self.connected = True + def on_open(self): + pass - def on_close(self, conn): - self.connected = False - self.connection = None + def on_close(self): + pass - def on_addr(self, conn, message): pass + def on_addr(self, message): pass - def on_block(self, conn, message): pass + def on_block(self, message): pass - def on_blocktxn(self, conn, message): pass + def on_blocktxn(self, message): pass - def on_cmpctblock(self, conn, message): pass + def on_cmpctblock(self, message): pass - def on_feefilter(self, conn, message): pass + def on_feefilter(self, message): pass - def on_getaddr(self, conn, message): pass + def on_getaddr(self, message): pass - def on_getblocks(self, conn, message): pass + def on_getblocks(self, message): pass - def on_getblocktxn(self, conn, message): pass + def on_getblocktxn(self, message): pass - def on_getdata(self, conn, message): pass + def on_getdata(self, message): pass - def on_getheaders(self, conn, message): pass + def on_getheaders(self, message): pass - def on_headers(self, conn, message): pass + def on_headers(self, message): pass - def on_mempool(self, conn): pass + def on_mempool(self, message): pass - def on_pong(self, conn, message): pass + def on_pong(self, message): pass - def on_reject(self, conn, message): pass + def on_reject(self, message): pass - def on_sendcmpct(self, conn, message): pass + def on_sendcmpct(self, message): pass - def on_sendheaders(self, conn, message): pass + def on_sendheaders(self, message): pass - def on_tx(self, conn, message): pass + def on_tx(self, message): pass - def on_inv(self, conn, message): + def on_inv(self, message): want = msg_getdata() for i in message.inv: if i.type != 0: want.inv.append(i) if len(want.inv): - conn.send_message(want) + self.send_message(want) - def on_ping(self, conn, message): - conn.send_message(msg_pong(message.nonce)) + def on_ping(self, message): + self.send_message(msg_pong(message.nonce)) - def on_verack(self, conn, message): + def on_verack(self, message): self.verack_received = True - def on_version(self, conn, message): + def on_version(self, message): assert message.nVersion >= MIN_VERSION_SUPPORTED, "Version {} received. Test framework only supports versions greater than {}".format( message.nVersion, MIN_VERSION_SUPPORTED) - conn.send_message(msg_verack()) - conn.nServices = message.nServices + self.send_message(msg_verack()) + self.nServices = message.nServices # Connection helper methods - def add_connection(self, conn): - self.connection = conn - def wait_for_disconnect(self, timeout=60): - def test_function(): return not self.connected + def test_function(): return self.state != "connected" wait_until(test_function, timeout=timeout, lock=mininode_lock) # Message receiving helper methods @@ -391,12 +410,6 @@ # Message sending helper functions - def send_message(self, message): - if self.connection: - self.connection.send_message(message) - else: - logger.error("Cannot send message. No connection to node!") - def send_and_ping(self, message): self.send_message(message) self.sync_with_ping() diff --git a/test/functional/test_framework/test_node.py b/test/functional/test_framework/test_node.py --- a/test/functional/test_framework/test_node.py +++ b/test/functional/test_framework/test_node.py @@ -14,7 +14,7 @@ import time from .authproxy import JSONRPCException -from .mininode import COIN, ToHex, FromHex, CTransaction, NodeConn +from .mininode import COIN, ToHex, FromHex, CTransaction from .util import ( assert_equal, get_rpc_proxy, @@ -188,7 +188,7 @@ ctx = FromHex(CTransaction(), self.getrawtransaction(txid)) return self.calculate_fee(ctx) - def add_p2p_connection(self, p2p_conn, **kwargs): + def add_p2p_connection(self, p2p_conn, *args, **kwargs): """Add a p2p connection to the node. This method adds the p2p connection to the self.p2ps list and also @@ -197,9 +197,9 @@ kwargs['dstport'] = p2p_port(self.index) if 'dstaddr' not in kwargs: kwargs['dstaddr'] = '127.0.0.1' + + p2p_conn.peer_connect(*args, **kwargs) self.p2ps.append(p2p_conn) - kwargs.update({'callback': p2p_conn}) - p2p_conn.add_connection(NodeConn(**kwargs)) return p2p_conn @@ -215,10 +215,8 @@ def disconnect_p2ps(self): """Close all p2p connections to the node.""" for p in self.p2ps: - # Connection could have already been closed by other end. - if p.connection is not None: - p.connection.disconnect_node() - self.p2ps = [] + p.peer_disconnect() + del self.p2ps[:] class TestNodeCLI():