Changeset View
Changeset View
Standalone View
Standalone View
test/functional/test_framework/p2p.py
Show First 20 Lines • Show All 172 Lines • ▼ Show 20 Lines | def peer_connect_helper(self, dstaddr, dstport, net, timeout_factor): | ||||
self.on_connection_send_msg_is_raw = False | self.on_connection_send_msg_is_raw = False | ||||
self.recvbuf = b"" | self.recvbuf = b"" | ||||
self.magic_bytes = MAGIC_BYTES[net] | self.magic_bytes = MAGIC_BYTES[net] | ||||
def peer_connect(self, dstaddr, dstport, *, net, timeout_factor): | def peer_connect(self, dstaddr, dstport, *, net, timeout_factor): | ||||
self.peer_connect_helper(dstaddr, dstport, net, timeout_factor) | self.peer_connect_helper(dstaddr, dstport, net, timeout_factor) | ||||
loop = NetworkThread.network_event_loop | loop = NetworkThread.network_event_loop | ||||
logger.debug( | logger.debug(f"Connecting to Bitcoin ABC Node: {self.dstaddr}:{self.dstport}") | ||||
f'Connecting to Bitcoin ABC Node: {self.dstaddr}:{self.dstport}') | |||||
coroutine = loop.create_connection( | coroutine = loop.create_connection( | ||||
lambda: self, host=self.dstaddr, port=self.dstport) | lambda: self, host=self.dstaddr, port=self.dstport | ||||
) | |||||
return lambda: loop.call_soon_threadsafe(loop.create_task, coroutine) | return lambda: loop.call_soon_threadsafe(loop.create_task, coroutine) | ||||
def peer_accept_connection( | def peer_accept_connection( | ||||
self, connect_id, connect_cb=lambda: None, *, net, timeout_factor): | self, connect_id, connect_cb=lambda: None, *, net, timeout_factor | ||||
self.peer_connect_helper('0', 0, net, timeout_factor) | ): | ||||
self.peer_connect_helper("0", 0, net, timeout_factor) | |||||
logger.debug( | logger.debug(f"Listening for Bitcoin ABC Node with id: {connect_id}") | ||||
f'Listening for Bitcoin ABC Node with id: {connect_id}') | |||||
return lambda: NetworkThread.listen(self, connect_cb, idx=connect_id) | return lambda: NetworkThread.listen(self, connect_cb, idx=connect_id) | ||||
def peer_disconnect(self): | def peer_disconnect(self): | ||||
# Connection could have already been closed by other end. | # Connection could have already been closed by other end. | ||||
NetworkThread.network_event_loop.call_soon_threadsafe( | NetworkThread.network_event_loop.call_soon_threadsafe( | ||||
lambda: self._transport and self._transport.abort()) | lambda: self._transport and self._transport.abort() | ||||
) | |||||
# Connection and disconnection methods | # Connection and disconnection methods | ||||
def connection_made(self, transport): | def connection_made(self, transport): | ||||
"""asyncio callback when a connection is opened.""" | """asyncio callback when a connection is opened.""" | ||||
assert not self._transport | assert not self._transport | ||||
logger.debug(f"Connected & Listening: {self.dstaddr}:{self.dstport}") | logger.debug(f"Connected & Listening: {self.dstaddr}:{self.dstport}") | ||||
self._transport = transport | self._transport = transport | ||||
if self.on_connection_send_msg: | if self.on_connection_send_msg: | ||||
if self.on_connection_send_msg_is_raw: | if self.on_connection_send_msg_is_raw: | ||||
self.send_raw_message(self.on_connection_send_msg) | self.send_raw_message(self.on_connection_send_msg) | ||||
else: | else: | ||||
self.send_message(self.on_connection_send_msg) | self.send_message(self.on_connection_send_msg) | ||||
# Never used again | # Never used again | ||||
self.on_connection_send_msg = None | self.on_connection_send_msg = None | ||||
self.on_open() | self.on_open() | ||||
def connection_lost(self, exc): | def connection_lost(self, exc): | ||||
"""asyncio callback when a connection is closed.""" | """asyncio callback when a connection is closed.""" | ||||
if exc: | if exc: | ||||
logger.warning( | logger.warning( | ||||
f"Connection lost to {self.dstaddr}:{self.dstport} due to {exc}") | f"Connection lost to {self.dstaddr}:{self.dstport} due to {exc}" | ||||
) | |||||
else: | else: | ||||
logger.debug(f"Closed connection to: {self.dstaddr}:{self.dstport}") | logger.debug(f"Closed connection to: {self.dstaddr}:{self.dstport}") | ||||
self._transport = None | self._transport = None | ||||
self.recvbuf = b"" | self.recvbuf = b"" | ||||
self.on_close() | self.on_close() | ||||
# Socket read methods | # Socket read methods | ||||
Show All 16 Lines | def _on_data(self): | ||||
parses and verifies the P2P header, then passes the P2P payload to | parses and verifies the P2P header, then passes the P2P payload to | ||||
the on_message callback for processing.""" | the on_message callback for processing.""" | ||||
try: | try: | ||||
with p2p_lock: | with p2p_lock: | ||||
if len(self.recvbuf) < 4: | if len(self.recvbuf) < 4: | ||||
return None | return None | ||||
if self.recvbuf[:4] != self.magic_bytes: | if self.recvbuf[:4] != self.magic_bytes: | ||||
raise ValueError( | raise ValueError( | ||||
f"magic bytes mismatch: " | "magic bytes mismatch: " | ||||
f"{self.magic_bytes!r} != {self.recvbuf!r}") | f"{self.magic_bytes!r} != {self.recvbuf!r}" | ||||
) | |||||
if len(self.recvbuf) < 4 + 12 + 4 + 4: | if len(self.recvbuf) < 4 + 12 + 4 + 4: | ||||
return None | return None | ||||
msgtype = self.recvbuf[4:4 + 12].split(b"\x00", 1)[0] | msgtype = self.recvbuf[4 : 4 + 12].split(b"\x00", 1)[0] | ||||
msglen = struct.unpack( | msglen = struct.unpack("<i", self.recvbuf[4 + 12 : 4 + 12 + 4])[0] | ||||
"<i", self.recvbuf[4 + 12:4 + 12 + 4])[0] | |||||
checksum = self.recvbuf[4 + 12 + 4:4 + 12 + 4 + 4] | checksum = self.recvbuf[4 + 12 + 4 : 4 + 12 + 4 + 4] | ||||
if len(self.recvbuf) < 4 + 12 + 4 + 4 + msglen: | if len(self.recvbuf) < 4 + 12 + 4 + 4 + msglen: | ||||
return None | return None | ||||
msg = self.recvbuf[4 + 12 + 4 + 4:4 + 12 + 4 + 4 + msglen] | msg = self.recvbuf[4 + 12 + 4 + 4 : 4 + 12 + 4 + 4 + msglen] | ||||
h = sha256(sha256(msg)) | h = sha256(sha256(msg)) | ||||
if checksum != h[:4]: | if checksum != h[:4]: | ||||
raise ValueError(f"got bad checksum {repr(self.recvbuf)}") | raise ValueError(f"got bad checksum {repr(self.recvbuf)}") | ||||
self.recvbuf = self.recvbuf[4 + 12 + 4 + 4 + msglen:] | self.recvbuf = self.recvbuf[4 + 12 + 4 + 4 + msglen :] | ||||
if msgtype not in MESSAGEMAP: | if msgtype not in MESSAGEMAP: | ||||
raise ValueError( | raise ValueError( | ||||
f"Received unknown msgtype from {self.dstaddr}:{self.dstport}:" | f"Received unknown msgtype from {self.dstaddr}:{self.dstport}:" | ||||
f" '{msgtype}' {msg!r}") | f" '{msgtype}' {msg!r}" | ||||
) | |||||
f = BytesIO(msg) | f = BytesIO(msg) | ||||
m = MESSAGEMAP[msgtype]() | m = MESSAGEMAP[msgtype]() | ||||
m.deserialize(f) | m.deserialize(f) | ||||
self._log_message("receive", m) | self._log_message("receive", m) | ||||
return m | return m | ||||
except Exception as e: | except Exception as e: | ||||
logger.exception('Error reading message:', repr(e)) | logger.exception("Error reading message:", repr(e)) | ||||
raise | raise | ||||
def on_message(self, message): | def on_message(self, message): | ||||
"""Callback for processing a P2P payload. Must be overridden by derived class.""" | """Callback for processing a P2P payload. Must be overridden by derived class.""" | ||||
raise NotImplementedError | raise NotImplementedError | ||||
# Socket write methods | # Socket write methods | ||||
def send_message(self, message): | def send_message(self, message): | ||||
"""Send a P2P message over the socket. | """Send a P2P message over the socket. | ||||
This method takes a P2P payload, builds the P2P header and adds | This method takes a P2P payload, builds the P2P header and adds | ||||
the message to the send buffer to be sent over the socket.""" | the message to the send buffer to be sent over the socket.""" | ||||
if not self.is_connected: | if not self.is_connected: | ||||
raise IOError('Not connected') | raise IOError("Not connected") | ||||
tmsg = self.build_message(message) | tmsg = self.build_message(message) | ||||
self._log_message("send", message) | self._log_message("send", message) | ||||
return self.send_raw_message(tmsg) | return self.send_raw_message(tmsg) | ||||
def send_raw_message(self, raw_message_bytes): | def send_raw_message(self, raw_message_bytes): | ||||
"""Send any raw message over the socket. | """Send any raw message over the socket. | ||||
This method adds a raw message to the send buffer to be sent over the | This method adds a raw message to the send buffer to be sent over the | ||||
socket.""" | socket.""" | ||||
if not self.is_connected: | if not self.is_connected: | ||||
raise IOError('Not connected') | raise IOError("Not connected") | ||||
def maybe_write(): | def maybe_write(): | ||||
if not self._transport: | if not self._transport: | ||||
return | return | ||||
if self._transport.is_closing(): | if self._transport.is_closing(): | ||||
return | return | ||||
self._transport.write(raw_message_bytes) | self._transport.write(raw_message_bytes) | ||||
NetworkThread.network_event_loop.call_soon_threadsafe(maybe_write) | NetworkThread.network_event_loop.call_soon_threadsafe(maybe_write) | ||||
# Class utility methods | # Class utility methods | ||||
def build_message(self, message): | def build_message(self, message): | ||||
"""Build a serialized P2P message""" | """Build a serialized P2P message""" | ||||
msgtype = message.msgtype | msgtype = message.msgtype | ||||
data = message.serialize() | data = message.serialize() | ||||
▲ Show 20 Lines • Show All 59 Lines • ▼ Show 20 Lines | def peer_connect_send_version(self, services): | ||||
vt.addrTo.ip = self.dstaddr | vt.addrTo.ip = self.dstaddr | ||||
vt.addrTo.port = self.dstport | vt.addrTo.port = self.dstport | ||||
vt.addrFrom.ip = "0.0.0.0" | vt.addrFrom.ip = "0.0.0.0" | ||||
vt.addrFrom.port = 0 | vt.addrFrom.port = 0 | ||||
# Will be sent in connection_made callback | # Will be sent in connection_made callback | ||||
self.on_connection_send_msg = vt | self.on_connection_send_msg = vt | ||||
def peer_connect(self, *args, services=P2P_SERVICES, | def peer_connect(self, *args, services=P2P_SERVICES, send_version=True, **kwargs): | ||||
send_version=True, **kwargs): | |||||
create_conn = super().peer_connect(*args, **kwargs) | create_conn = super().peer_connect(*args, **kwargs) | ||||
if send_version: | if send_version: | ||||
self.peer_connect_send_version(services) | self.peer_connect_send_version(services) | ||||
return create_conn | return create_conn | ||||
def peer_accept_connection(self, *args, services=NODE_NETWORK, **kwargs): | def peer_accept_connection(self, *args, services=NODE_NETWORK, **kwargs): | ||||
create_conn = super().peer_accept_connection(*args, **kwargs) | create_conn = super().peer_accept_connection(*args, **kwargs) | ||||
self.peer_connect_send_version(services) | self.peer_connect_send_version(services) | ||||
return create_conn | return create_conn | ||||
# Message receiving methods | # Message receiving methods | ||||
def on_message(self, message): | def on_message(self, message): | ||||
"""Receive message and dispatch message to appropriate callback. | """Receive message and dispatch message to appropriate callback. | ||||
We keep a count of how many of each message type has been received | We keep a count of how many of each message type has been received | ||||
and the most recent message of each type.""" | and the most recent message of each type.""" | ||||
with p2p_lock: | with p2p_lock: | ||||
try: | try: | ||||
msgtype = message.msgtype.decode('ascii') | msgtype = message.msgtype.decode("ascii") | ||||
self.message_count[msgtype] += 1 | self.message_count[msgtype] += 1 | ||||
self.last_message[msgtype] = message | self.last_message[msgtype] = message | ||||
getattr(self, f"on_{msgtype}")(message) | getattr(self, f"on_{msgtype}")(message) | ||||
except Exception: | except Exception: | ||||
print(f"ERROR delivering {repr(message)} ({sys.exc_info()[0]})") | print(f"ERROR delivering {repr(message)} ({sys.exc_info()[0]})") | ||||
raise | raise | ||||
# Callback methods. Can be overridden by subclasses in individual test | # Callback methods. Can be overridden by subclasses in individual test | ||||
# cases to provide custom message handling behaviour. | # cases to provide custom message handling behaviour. | ||||
def on_open(self): | def on_open(self): | ||||
pass | pass | ||||
def on_close(self): | def on_close(self): | ||||
pass | pass | ||||
def on_addr(self, message): pass | def on_addr(self, message): | ||||
pass | |||||
def on_addrv2(self, message): pass | def on_addrv2(self, message): | ||||
pass | |||||
def on_avapoll(self, message): pass | def on_avapoll(self, message): | ||||
pass | |||||
def on_avaproof(self, message): pass | def on_avaproof(self, message): | ||||
pass | |||||
def on_avaproofs(self, message): pass | def on_avaproofs(self, message): | ||||
pass | |||||
def on_avaproofsreq(self, message): pass | def on_avaproofsreq(self, message): | ||||
pass | |||||
def on_avaresponse(self, message): pass | def on_avaresponse(self, message): | ||||
pass | |||||
def on_avahello(self, message): pass | def on_avahello(self, message): | ||||
pass | |||||
def on_block(self, message): pass | def on_block(self, message): | ||||
pass | |||||
def on_blocktxn(self, message): pass | def on_blocktxn(self, message): | ||||
pass | |||||
def on_cfcheckpt(self, message): pass | def on_cfcheckpt(self, message): | ||||
pass | |||||
def on_cfheaders(self, message): pass | def on_cfheaders(self, message): | ||||
pass | |||||
def on_cfilter(self, message): pass | def on_cfilter(self, message): | ||||
pass | |||||
def on_cmpctblock(self, message): pass | def on_cmpctblock(self, message): | ||||
pass | |||||
def on_feefilter(self, message): pass | def on_feefilter(self, message): | ||||
pass | |||||
def on_filteradd(self, message): pass | def on_filteradd(self, message): | ||||
pass | |||||
def on_filterclear(self, message): pass | def on_filterclear(self, message): | ||||
pass | |||||
def on_filterload(self, message): pass | def on_filterload(self, message): | ||||
pass | |||||
def on_getaddr(self, message): pass | def on_getaddr(self, message): | ||||
pass | |||||
def on_getavaaddr(self, message): pass | def on_getavaaddr(self, message): | ||||
pass | |||||
def on_getavaproofs(self, message): pass | def on_getavaproofs(self, message): | ||||
pass | |||||
def on_getblocks(self, message): pass | def on_getblocks(self, message): | ||||
pass | |||||
def on_getblocktxn(self, message): pass | def on_getblocktxn(self, message): | ||||
pass | |||||
def on_getdata(self, message): pass | def on_getdata(self, message): | ||||
pass | |||||
def on_getheaders(self, message): pass | def on_getheaders(self, message): | ||||
pass | |||||
def on_headers(self, message): pass | def on_headers(self, message): | ||||
pass | |||||
def on_mempool(self, message): pass | def on_mempool(self, message): | ||||
pass | |||||
def on_merkleblock(self, message): pass | def on_merkleblock(self, message): | ||||
pass | |||||
def on_notfound(self, message): pass | def on_notfound(self, message): | ||||
pass | |||||
def on_pong(self, message): pass | def on_pong(self, message): | ||||
pass | |||||
def on_sendaddrv2(self, message): pass | def on_sendaddrv2(self, message): | ||||
pass | |||||
def on_sendcmpct(self, message): pass | def on_sendcmpct(self, message): | ||||
pass | |||||
def on_sendheaders(self, message): pass | def on_sendheaders(self, message): | ||||
pass | |||||
def on_tx(self, message): pass | def on_tx(self, message): | ||||
pass | |||||
def on_inv(self, message): | def on_inv(self, message): | ||||
want = msg_getdata() | want = msg_getdata() | ||||
for i in message.inv: | for i in message.inv: | ||||
if i.type != 0: | if i.type != 0: | ||||
want.inv.append(i) | want.inv.append(i) | ||||
if len(want.inv): | if len(want.inv): | ||||
self.send_message(want) | self.send_message(want) | ||||
def on_ping(self, message): | def on_ping(self, message): | ||||
self.send_message(msg_pong(message.nonce)) | self.send_message(msg_pong(message.nonce)) | ||||
def on_verack(self, message): | def on_verack(self, message): | ||||
pass | pass | ||||
def on_version(self, message): | def on_version(self, message): | ||||
assert message.nVersion >= MIN_P2P_VERSION_SUPPORTED, \ | assert message.nVersion >= MIN_P2P_VERSION_SUPPORTED, ( | ||||
f"Version {message.nVersion} received. Test framework only supports " \ | f"Version {message.nVersion} received. Test framework only supports " | ||||
f"versions greater than {MIN_P2P_VERSION_SUPPORTED}" | f"versions greater than {MIN_P2P_VERSION_SUPPORTED}" | ||||
) | |||||
self.send_message(msg_verack()) | self.send_message(msg_verack()) | ||||
if self.support_addrv2: | if self.support_addrv2: | ||||
self.send_message(msg_sendaddrv2()) | self.send_message(msg_sendaddrv2()) | ||||
self.nServices = message.nServices | self.nServices = message.nServices | ||||
self.send_message(msg_getaddr()) | self.send_message(msg_getaddr()) | ||||
# Connection helper methods | # Connection helper methods | ||||
def wait_until(self, test_function_in, *, timeout=60, | def wait_until(self, test_function_in, *, timeout=60, check_connected=True): | ||||
check_connected=True): | |||||
def test_function(): | def test_function(): | ||||
if check_connected: | if check_connected: | ||||
assert self.is_connected | assert self.is_connected | ||||
return test_function_in() | return test_function_in() | ||||
wait_until_helper(test_function, timeout=timeout, lock=p2p_lock, | wait_until_helper( | ||||
timeout_factor=self.timeout_factor) | test_function, | ||||
timeout=timeout, | |||||
lock=p2p_lock, | |||||
timeout_factor=self.timeout_factor, | |||||
) | |||||
def wait_for_connect(self, timeout=60): | def wait_for_connect(self, timeout=60): | ||||
def test_function(): return self.is_connected | def test_function(): | ||||
return self.is_connected | |||||
self.wait_until(test_function, timeout=timeout, check_connected=False) | self.wait_until(test_function, timeout=timeout, check_connected=False) | ||||
def wait_for_disconnect(self, timeout=60): | def wait_for_disconnect(self, timeout=60): | ||||
def test_function(): return not self.is_connected | def test_function(): | ||||
return not self.is_connected | |||||
self.wait_until(test_function, timeout=timeout, check_connected=False) | self.wait_until(test_function, timeout=timeout, check_connected=False) | ||||
# Message receiving helper methods | # Message receiving helper methods | ||||
def wait_for_tx(self, txid, timeout=60): | def wait_for_tx(self, txid, timeout=60): | ||||
def test_function(): | def test_function(): | ||||
if not self.last_message.get('tx'): | if not self.last_message.get("tx"): | ||||
return False | return False | ||||
return self.last_message['tx'].tx.rehash() == txid | return self.last_message["tx"].tx.rehash() == txid | ||||
self.wait_until(test_function, timeout=timeout) | self.wait_until(test_function, timeout=timeout) | ||||
def wait_for_block(self, blockhash, timeout=60): | def wait_for_block(self, blockhash, timeout=60): | ||||
def test_function(): | def test_function(): | ||||
return self.last_message.get( | return ( | ||||
"block") and self.last_message["block"].block.rehash() == blockhash | self.last_message.get("block") | ||||
and self.last_message["block"].block.rehash() == blockhash | |||||
) | |||||
self.wait_until(test_function, timeout=timeout) | self.wait_until(test_function, timeout=timeout) | ||||
def wait_for_header(self, blockhash, timeout=60): | def wait_for_header(self, blockhash, timeout=60): | ||||
def test_function(): | def test_function(): | ||||
last_headers = self.last_message.get('headers') | last_headers = self.last_message.get("headers") | ||||
if not last_headers: | if not last_headers: | ||||
return False | return False | ||||
return last_headers.headers[0].rehash() == int(blockhash, 16) | return last_headers.headers[0].rehash() == int(blockhash, 16) | ||||
self.wait_until(test_function, timeout=timeout) | self.wait_until(test_function, timeout=timeout) | ||||
def wait_for_merkleblock(self, blockhash, timeout=60): | def wait_for_merkleblock(self, blockhash, timeout=60): | ||||
def test_function(): | def test_function(): | ||||
last_filtered_block = self.last_message.get('merkleblock') | last_filtered_block = self.last_message.get("merkleblock") | ||||
if not last_filtered_block: | if not last_filtered_block: | ||||
return False | return False | ||||
return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16) | return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16) | ||||
self.wait_until(test_function, timeout=timeout) | self.wait_until(test_function, timeout=timeout) | ||||
def wait_for_getdata(self, hash_list, timeout=60): | def wait_for_getdata(self, hash_list, timeout=60): | ||||
"""Waits for a getdata message. | """Waits for a getdata message. | ||||
The object hashes in the inventory vector must match the provided hash_list.""" | The object hashes in the inventory vector must match the provided hash_list.""" | ||||
def test_function(): | def test_function(): | ||||
last_data = self.last_message.get("getdata") | last_data = self.last_message.get("getdata") | ||||
if not last_data: | if not last_data: | ||||
return False | return False | ||||
return [x.hash for x in last_data.inv] == hash_list | return [x.hash for x in last_data.inv] == hash_list | ||||
self.wait_until(test_function, timeout=timeout) | self.wait_until(test_function, timeout=timeout) | ||||
def wait_for_getheaders(self, timeout=60): | def wait_for_getheaders(self, timeout=60): | ||||
"""Waits for a getheaders message. | """Waits for a getheaders message. | ||||
Receiving any getheaders message will satisfy the predicate. the last_message["getheaders"] | Receiving any getheaders message will satisfy the predicate. the last_message["getheaders"] | ||||
value must be explicitly cleared before calling this method, or this will return | value must be explicitly cleared before calling this method, or this will return | ||||
immediately with success. TODO: change this method to take a hash value and only | immediately with success. TODO: change this method to take a hash value and only | ||||
return true if the correct block header has been requested.""" | return true if the correct block header has been requested.""" | ||||
def test_function(): | def test_function(): | ||||
return self.last_message.get("getheaders") | return self.last_message.get("getheaders") | ||||
self.wait_until(test_function, timeout=timeout) | self.wait_until(test_function, timeout=timeout) | ||||
def wait_for_inv(self, expected_inv, timeout=60): | def wait_for_inv(self, expected_inv, timeout=60): | ||||
"""Waits for an INV message and checks that the first inv object in the message was as expected.""" | """Waits for an INV message and checks that the first inv object in the message was as expected.""" | ||||
if len(expected_inv) > 1: | if len(expected_inv) > 1: | ||||
raise NotImplementedError( | raise NotImplementedError( | ||||
"wait_for_inv() will only verify the first inv object") | "wait_for_inv() will only verify the first inv object" | ||||
) | |||||
def test_function(): | def test_function(): | ||||
return self.last_message.get("inv") and \ | return ( | ||||
self.last_message["inv"].inv[0].type == expected_inv[0].type and \ | self.last_message.get("inv") | ||||
self.last_message["inv"].inv[0].hash == expected_inv[0].hash | and self.last_message["inv"].inv[0].type == expected_inv[0].type | ||||
and self.last_message["inv"].inv[0].hash == expected_inv[0].hash | |||||
) | |||||
self.wait_until(test_function, timeout=timeout) | self.wait_until(test_function, timeout=timeout) | ||||
def wait_for_verack(self, timeout=60): | def wait_for_verack(self, timeout=60): | ||||
def test_function(): | def test_function(): | ||||
return "verack" in self.last_message | return "verack" in self.last_message | ||||
self.wait_until(test_function, timeout=timeout) | self.wait_until(test_function, timeout=timeout) | ||||
Show All 12 Lines | def sync_send_with_ping(self, timeout=60): | ||||
self.sync_with_ping() | self.sync_with_ping() | ||||
self.sync_with_ping() | self.sync_with_ping() | ||||
def sync_with_ping(self, timeout=60): | def sync_with_ping(self, timeout=60): | ||||
"""Ensure ProcessMessages is called on this connection""" | """Ensure ProcessMessages is called on this connection""" | ||||
self.send_message(msg_ping(nonce=self.ping_counter)) | self.send_message(msg_ping(nonce=self.ping_counter)) | ||||
def test_function(): | def test_function(): | ||||
return self.last_message.get( | return ( | ||||
"pong") and self.last_message["pong"].nonce == self.ping_counter | self.last_message.get("pong") | ||||
and self.last_message["pong"].nonce == self.ping_counter | |||||
) | |||||
self.wait_until(test_function, timeout=timeout) | self.wait_until(test_function, timeout=timeout) | ||||
self.ping_counter += 1 | self.ping_counter += 1 | ||||
# One lock for synchronizing all data access between the networking thread (see | # One lock for synchronizing all data access between the networking thread (see | ||||
# NetworkThread below) and the thread running the test logic. For simplicity, | # NetworkThread below) and the thread running the test logic. For simplicity, | ||||
# P2PConnection acquires this lock whenever delivering a message to a P2PInterface. | # P2PConnection acquires this lock whenever delivering a message to a P2PInterface. | ||||
Show All 16 Lines | def __init__(self): | ||||
NetworkThread.network_event_loop = asyncio.new_event_loop() | NetworkThread.network_event_loop = asyncio.new_event_loop() | ||||
def run(self): | def run(self): | ||||
"""Start the network thread.""" | """Start the network thread.""" | ||||
self.network_event_loop.run_forever() | self.network_event_loop.run_forever() | ||||
def close(self, timeout=10): | def close(self, timeout=10): | ||||
"""Close the connections and network event loop.""" | """Close the connections and network event loop.""" | ||||
self.network_event_loop.call_soon_threadsafe( | self.network_event_loop.call_soon_threadsafe(self.network_event_loop.stop) | ||||
self.network_event_loop.stop) | wait_until_helper( | ||||
wait_until_helper(lambda: not self.network_event_loop.is_running(), | lambda: not self.network_event_loop.is_running(), timeout=timeout | ||||
timeout=timeout) | ) | ||||
self.network_event_loop.close() | self.network_event_loop.close() | ||||
self.join(timeout) | self.join(timeout) | ||||
# Safe to remove event loop. | # Safe to remove event loop. | ||||
NetworkThread.network_event_loop = None | NetworkThread.network_event_loop = None | ||||
@classmethod | @classmethod | ||||
def listen(cls, p2p, callback, port=None, addr=None, idx=1): | def listen(cls, p2p, callback, port=None, addr=None, idx=1): | ||||
""" Ensure a listening server is running on the given port, and run the | """Ensure a listening server is running on the given port, and run the | ||||
protocol specified by `p2p` on the next connection to it. Once ready | protocol specified by `p2p` on the next connection to it. Once ready | ||||
for connections, call `callback`.""" | for connections, call `callback`.""" | ||||
if port is None: | if port is None: | ||||
assert 0 < idx <= MAX_NODES | assert 0 < idx <= MAX_NODES | ||||
port = p2p_port(MAX_NODES - idx) | port = p2p_port(MAX_NODES - idx) | ||||
if addr is None: | if addr is None: | ||||
addr = '127.0.0.1' | addr = "127.0.0.1" | ||||
coroutine = cls.create_listen_server(addr, port, callback, p2p) | coroutine = cls.create_listen_server(addr, port, callback, p2p) | ||||
cls.network_event_loop.call_soon_threadsafe( | cls.network_event_loop.call_soon_threadsafe( | ||||
cls.network_event_loop.create_task, coroutine) | cls.network_event_loop.create_task, coroutine | ||||
) | |||||
@classmethod | @classmethod | ||||
async def create_listen_server(cls, addr, port, callback, proto): | async def create_listen_server(cls, addr, port, callback, proto): | ||||
def peer_protocol(): | def peer_protocol(): | ||||
"""Returns a function that does the protocol handling for a new | """Returns a function that does the protocol handling for a new | ||||
connection. To allow different connections to have different | connection. To allow different connections to have different | ||||
behaviors, the protocol function is first put in the cls.protos | behaviors, the protocol function is first put in the cls.protos | ||||
dict. When the connection is made, the function removes the | dict. When the connection is made, the function removes the | ||||
protocol function from that dict, and returns it so the event loop | protocol function from that dict, and returns it so the event loop | ||||
can start executing it.""" | can start executing it.""" | ||||
response = cls.protos.get((addr, port)) | response = cls.protos.get((addr, port)) | ||||
cls.protos[(addr, port)] = None | cls.protos[(addr, port)] = None | ||||
return response | return response | ||||
if (addr, port) not in cls.listeners: | if (addr, port) not in cls.listeners: | ||||
# When creating a listener on a given (addr, port) we only need to | # When creating a listener on a given (addr, port) we only need to | ||||
# do it once. If we want different behaviors for different | # do it once. If we want different behaviors for different | ||||
# connections, we can accomplish this by providing different | # connections, we can accomplish this by providing different | ||||
# `proto` functions | # `proto` functions | ||||
listener = await cls.network_event_loop.create_server(peer_protocol, addr, port) | listener = await cls.network_event_loop.create_server( | ||||
logger.debug( | peer_protocol, addr, port | ||||
f"Listening server on {addr}:{port} should be started") | ) | ||||
logger.debug(f"Listening server on {addr}:{port} should be started") | |||||
cls.listeners[(addr, port)] = listener | cls.listeners[(addr, port)] = listener | ||||
cls.protos[(addr, port)] = proto | cls.protos[(addr, port)] = proto | ||||
callback(addr, port) | callback(addr, port) | ||||
class P2PDataStore(P2PInterface): | class P2PDataStore(P2PInterface): | ||||
"""A P2P data store class. | """A P2P data store class. | ||||
Keeps a block and transaction store and responds correctly to getdata and getheaders requests.""" | Keeps a block and transaction store and responds correctly to getdata and getheaders requests. | ||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
# store of blocks. key is block hash, value is a CBlock object | # store of blocks. key is block hash, value is a CBlock object | ||||
self.block_store = {} | self.block_store = {} | ||||
self.last_block_hash = '' | self.last_block_hash = "" | ||||
# store of txs. key is txid, value is a CTransaction object | # store of txs. key is txid, value is a CTransaction object | ||||
self.tx_store = {} | self.tx_store = {} | ||||
self.getdata_requests = [] | self.getdata_requests = [] | ||||
def on_getdata(self, message): | def on_getdata(self, message): | ||||
"""Check for the tx/block in our stores and if found, reply with an inv message.""" | """Check for the tx/block in our stores and if found, reply with an inv message.""" | ||||
for inv in message.inv: | for inv in message.inv: | ||||
self.getdata_requests.append(inv.hash) | self.getdata_requests.append(inv.hash) | ||||
if (inv.type & MSG_TYPE_MASK) == MSG_TX and inv.hash in self.tx_store.keys(): | if ( | ||||
inv.type & MSG_TYPE_MASK | |||||
) == MSG_TX and inv.hash in self.tx_store.keys(): | |||||
self.send_message(msg_tx(self.tx_store[inv.hash])) | self.send_message(msg_tx(self.tx_store[inv.hash])) | ||||
elif (inv.type & MSG_TYPE_MASK) == MSG_BLOCK and inv.hash in self.block_store.keys(): | elif ( | ||||
inv.type & MSG_TYPE_MASK | |||||
) == MSG_BLOCK and inv.hash in self.block_store.keys(): | |||||
self.send_message(msg_block(self.block_store[inv.hash])) | self.send_message(msg_block(self.block_store[inv.hash])) | ||||
else: | else: | ||||
logger.debug( | logger.debug(f"getdata message type {hex(inv.type)} received.") | ||||
f'getdata message type {hex(inv.type)} received.') | |||||
def on_getheaders(self, message): | def on_getheaders(self, message): | ||||
"""Search back through our block store for the locator, and reply with a headers message if found.""" | """Search back through our block store for the locator, and reply with a headers message if found.""" | ||||
locator, hash_stop = message.locator, message.hashstop | locator, hash_stop = message.locator, message.hashstop | ||||
# Assume that the most recent block added is the tip | # Assume that the most recent block added is the tip | ||||
if not self.block_store: | if not self.block_store: | ||||
return | return | ||||
headers_list = [self.block_store[self.last_block_hash]] | headers_list = [self.block_store[self.last_block_hash]] | ||||
while headers_list[-1].sha256 not in locator.vHave: | while headers_list[-1].sha256 not in locator.vHave: | ||||
# Walk back through the block store, adding headers to headers_list | # Walk back through the block store, adding headers to headers_list | ||||
# as we go. | # as we go. | ||||
prev_block_hash = headers_list[-1].hashPrevBlock | prev_block_hash = headers_list[-1].hashPrevBlock | ||||
if prev_block_hash in self.block_store: | if prev_block_hash in self.block_store: | ||||
prev_block_header = CBlockHeader( | prev_block_header = CBlockHeader(self.block_store[prev_block_hash]) | ||||
self.block_store[prev_block_hash]) | |||||
headers_list.append(prev_block_header) | headers_list.append(prev_block_header) | ||||
if prev_block_header.sha256 == hash_stop: | if prev_block_header.sha256 == hash_stop: | ||||
# if this is the hashstop header, stop here | # if this is the hashstop header, stop here | ||||
break | break | ||||
else: | else: | ||||
logger.debug( | logger.debug( | ||||
f'block hash {hex(prev_block_hash)} not found in block store') | f"block hash {hex(prev_block_hash)} not found in block store" | ||||
) | |||||
break | break | ||||
# Truncate the list if there are too many headers | # Truncate the list if there are too many headers | ||||
headers_list = headers_list[:-MAX_HEADERS_RESULTS - 1:-1] | headers_list = headers_list[: -MAX_HEADERS_RESULTS - 1 : -1] | ||||
response = msg_headers(headers_list) | response = msg_headers(headers_list) | ||||
if response is not None: | if response is not None: | ||||
self.send_message(response) | self.send_message(response) | ||||
def send_blocks_and_test(self, blocks, node, *, success=True, force_send=False, | def send_blocks_and_test( | ||||
reject_reason=None, expect_disconnect=False, timeout=60): | self, | ||||
blocks, | |||||
node, | |||||
*, | |||||
success=True, | |||||
force_send=False, | |||||
reject_reason=None, | |||||
expect_disconnect=False, | |||||
timeout=60, | |||||
): | |||||
"""Send blocks to test node and test whether the tip advances. | """Send blocks to test node and test whether the tip advances. | ||||
- add all blocks to our block_store | - add all blocks to our block_store | ||||
- send a headers message for the final block | - send a headers message for the final block | ||||
- the on_getheaders handler will ensure that any getheaders are responded to | - the on_getheaders handler will ensure that any getheaders are responded to | ||||
- if force_send is False: wait for getdata for each of the blocks. The on_getdata handler will | - if force_send is False: wait for getdata for each of the blocks. The on_getdata handler will | ||||
ensure that any getdata messages are responded to. Otherwise send the full block unsolicited. | ensure that any getdata messages are responded to. Otherwise send the full block unsolicited. | ||||
- if success is True: assert that the node's tip advances to the most recent block | - if success is True: assert that the node's tip advances to the most recent block | ||||
- if success is False: assert that the node's tip doesn't advance | - if success is False: assert that the node's tip doesn't advance | ||||
- if reject_reason is set: assert that the correct reject message is logged""" | - if reject_reason is set: assert that the correct reject message is logged""" | ||||
with p2p_lock: | with p2p_lock: | ||||
for block in blocks: | for block in blocks: | ||||
self.block_store[block.sha256] = block | self.block_store[block.sha256] = block | ||||
self.last_block_hash = block.sha256 | self.last_block_hash = block.sha256 | ||||
def test(): | def test(): | ||||
if force_send: | if force_send: | ||||
for b in blocks: | for b in blocks: | ||||
self.send_message(msg_block(block=b)) | self.send_message(msg_block(block=b)) | ||||
else: | else: | ||||
self.send_message( | self.send_message( | ||||
msg_headers([CBlockHeader(block) for block in blocks])) | msg_headers([CBlockHeader(block) for block in blocks]) | ||||
) | |||||
self.wait_until( | self.wait_until( | ||||
lambda: blocks[-1].sha256 in self.getdata_requests, | lambda: blocks[-1].sha256 in self.getdata_requests, | ||||
timeout=timeout, | timeout=timeout, | ||||
check_connected=success, | check_connected=success, | ||||
) | ) | ||||
if expect_disconnect: | if expect_disconnect: | ||||
self.wait_for_disconnect(timeout=timeout) | self.wait_for_disconnect(timeout=timeout) | ||||
else: | else: | ||||
self.sync_with_ping(timeout=timeout) | self.sync_with_ping(timeout=timeout) | ||||
if success: | if success: | ||||
self.wait_until(lambda: node.getbestblockhash() == | self.wait_until( | ||||
blocks[-1].hash, timeout=timeout) | lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout | ||||
) | |||||
else: | else: | ||||
assert node.getbestblockhash() != blocks[-1].hash | assert node.getbestblockhash() != blocks[-1].hash | ||||
if reject_reason: | if reject_reason: | ||||
with node.assert_debug_log(expected_msgs=[reject_reason]): | with node.assert_debug_log(expected_msgs=[reject_reason]): | ||||
test() | test() | ||||
else: | else: | ||||
test() | test() | ||||
def send_txs_and_test(self, txs, node, *, success=True, | def send_txs_and_test( | ||||
expect_disconnect=False, reject_reason=None): | self, txs, node, *, success=True, expect_disconnect=False, reject_reason=None | ||||
): | |||||
"""Send txs to test node and test whether they're accepted to the mempool. | """Send txs to test node and test whether they're accepted to the mempool. | ||||
- add all txs to our tx_store | - add all txs to our tx_store | ||||
- send tx messages for all txs | - send tx messages for all txs | ||||
- if success is True/False: assert that the txs are/are not accepted to the mempool | - if success is True/False: assert that the txs are/are not accepted to the mempool | ||||
- if expect_disconnect is True: Skip the sync with ping | - if expect_disconnect is True: Skip the sync with ping | ||||
- if reject_reason is set: assert that the correct reject message is logged.""" | - if reject_reason is set: assert that the correct reject message is logged.""" | ||||
with p2p_lock: | with p2p_lock: | ||||
for tx in txs: | for tx in txs: | ||||
self.tx_store[tx.sha256] = tx | self.tx_store[tx.sha256] = tx | ||||
def test(): | def test(): | ||||
for tx in txs: | for tx in txs: | ||||
self.send_message(msg_tx(tx)) | self.send_message(msg_tx(tx)) | ||||
▲ Show 20 Lines • Show All 42 Lines • ▼ Show 20 Lines | class P2PTxInvStore(P2PInterface): | ||||
def wait_for_broadcast(self, txns, timeout=60): | def wait_for_broadcast(self, txns, timeout=60): | ||||
"""Waits for the txns (list of txids) to complete initial broadcast. | """Waits for the txns (list of txids) to complete initial broadcast. | ||||
The mempool should mark unbroadcast=False for these transactions. | The mempool should mark unbroadcast=False for these transactions. | ||||
""" | """ | ||||
# Wait until invs have been received (and getdatas sent) for each txid. | # Wait until invs have been received (and getdatas sent) for each txid. | ||||
self.wait_until( | self.wait_until( | ||||
lambda: set(self.tx_invs_received.keys()) == {int(tx, 16) for tx in txns}, | lambda: set(self.tx_invs_received.keys()) == {int(tx, 16) for tx in txns}, | ||||
timeout=timeout) | timeout=timeout, | ||||
) | |||||
# Flush messages and wait for the getdatas to be processed | # Flush messages and wait for the getdatas to be processed | ||||
self.sync_with_ping() | self.sync_with_ping() |