diff --git a/test/functional/abc-get-invalid-block.py b/test/functional/abc-get-invalid-block.py new file mode 100755 --- /dev/null +++ b/test/functional/abc-get-invalid-block.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# Copyright (c) 2019 The Bitcoin developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Test requesting invalid blocks behaves safely.""" + +import time + +from test_framework.messages import ( + CInv, + msg_getdata, + msg_getheaders, +) +from test_framework.mininode import mininode_lock, P2PInterface +from test_framework.test_framework import BitcoinTestFramework +from test_framework.util import assert_equal + + +# This test is modelled off of p2p_fingerprint. Please apply similar +# improvements made to that test here as well. +class GetInvalidBlockTest(BitcoinTestFramework): + def set_test_params(self): + self.num_nodes = 1 + + # Send a getdata request for a given block hash + def send_block_request(self, block_hash, node_p2p, is_compact_block): + msg = msg_getdata() + inv_type = 2 + if (is_compact_block): + inv_type = 4 + msg.inv.append(CInv(inv_type, block_hash)) + node_p2p.send_message(msg) + + # Send a getheaders request for a given single block hash + def send_header_request(self, block_hash, node_p2p): + msg = msg_getheaders() + msg.hashstop = block_hash + node_p2p.send_message(msg) + + # Check whether last block received from node_p2p has a given hash + def last_block_equals(self, expected_hash, node_p2p): + block_msg = node_p2p.last_message.get("block") + return block_msg and block_msg.block.rehash() == expected_hash + + # Check whether last block received from node_p2p has a given hash + def last_compact_block_equals(self, expected_hash, node_p2p): + compact_block_msg = node_p2p.last_message.get("cmpctblock") + return (compact_block_msg and + compact_block_msg.header_and_shortids and + compact_block_msg.header_and_shortids.header.rehash() == expected_hash) + + # Check whether last block header received from node_p2p has a given hash + def last_header_equals(self, expected_hash, node_p2p): + headers_msg = node_p2p.last_message.get("headers") + return (headers_msg and + headers_msg.headers and + headers_msg.headers[0].rehash() == expected_hash) + + def run_test(self): + node = self.nodes[0] + node_p2p = node.add_p2p_connection(P2PInterface()) + chaintip = node.getbestblockhash() + + # Mine some blocks and invalidate them + blocks = node.generate(nblocks=3) + node.invalidateblock(blocks[0]) + assert_equal(chaintip, node.getbestblockhash()) + + # Clear any old messages + with mininode_lock: + node_p2p.last_message.pop("block", None) + node_p2p.last_message.pop("cmpctblock", None) + node_p2p.last_message.pop("headers", None) + + # Requests for the invalidated block and it's decendants should fail. + # Not doing so is a potential DoS vector. + for b in blocks: + block_hash = int(b, 16) + + self.send_block_request(block_hash, node_p2p, False) + time.sleep(3) + assert not self.last_block_equals(block_hash, node_p2p) + + self.send_block_request(block_hash, node_p2p, True) + time.sleep(3) + assert not self.last_compact_block_equals(block_hash, node_p2p) + + self.send_header_request(block_hash, node_p2p) + time.sleep(3) + assert not self.last_header_equals(block_hash, node_p2p) + + +if __name__ == '__main__': + GetInvalidBlockTest().main()