diff --git a/src/protocol.cpp b/src/protocol.cpp index 118a34bcc..b06bba4b7 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -1,290 +1,291 @@ // Copyright (c) 2009-2010 Satoshi Nakamoto // Copyright (c) 2009-2016 The Bitcoin Core developers // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include #include #include #include #include #ifndef WIN32 #include #endif #include static std::atomic g_initial_block_download_completed(false); namespace NetMsgType { const char *VERSION = "version"; const char *VERACK = "verack"; const char *ADDR = "addr"; const char *ADDRV2 = "addrv2"; const char *SENDADDRV2 = "sendaddrv2"; const char *INV = "inv"; const char *GETDATA = "getdata"; const char *MERKLEBLOCK = "merkleblock"; const char *GETBLOCKS = "getblocks"; const char *GETHEADERS = "getheaders"; const char *TX = "tx"; const char *HEADERS = "headers"; const char *BLOCK = "block"; const char *GETADDR = "getaddr"; const char *MEMPOOL = "mempool"; const char *PING = "ping"; const char *PONG = "pong"; const char *NOTFOUND = "notfound"; const char *FILTERLOAD = "filterload"; const char *FILTERADD = "filteradd"; const char *FILTERCLEAR = "filterclear"; const char *SENDHEADERS = "sendheaders"; const char *FEEFILTER = "feefilter"; const char *SENDCMPCT = "sendcmpct"; const char *CMPCTBLOCK = "cmpctblock"; const char *GETBLOCKTXN = "getblocktxn"; const char *BLOCKTXN = "blocktxn"; const char *GETCFILTERS = "getcfilters"; const char *CFILTER = "cfilter"; const char *GETCFHEADERS = "getcfheaders"; const char *CFHEADERS = "cfheaders"; const char *GETCFCHECKPT = "getcfcheckpt"; const char *CFCHECKPT = "cfcheckpt"; const char *AVAHELLO = "avahello"; const char *AVAPOLL = "avapoll"; const char *AVARESPONSE = "avaresponse"; +const char *AVAPROOF = "avaproof"; bool IsBlockLike(const std::string &strCommand) { return strCommand == NetMsgType::BLOCK || strCommand == NetMsgType::CMPCTBLOCK || strCommand == NetMsgType::BLOCKTXN; } }; // namespace NetMsgType /** * All known message types. Keep this in the same order as the list of messages * above and in protocol.h. */ static const std::string allNetMessageTypes[] = { NetMsgType::VERSION, NetMsgType::VERACK, NetMsgType::ADDR, NetMsgType::ADDRV2, NetMsgType::SENDADDRV2, NetMsgType::INV, NetMsgType::GETDATA, NetMsgType::MERKLEBLOCK, NetMsgType::GETBLOCKS, NetMsgType::GETHEADERS, NetMsgType::TX, NetMsgType::HEADERS, NetMsgType::BLOCK, NetMsgType::GETADDR, NetMsgType::MEMPOOL, NetMsgType::PING, NetMsgType::PONG, NetMsgType::NOTFOUND, NetMsgType::FILTERLOAD, NetMsgType::FILTERADD, NetMsgType::FILTERCLEAR, NetMsgType::SENDHEADERS, NetMsgType::FEEFILTER, NetMsgType::SENDCMPCT, NetMsgType::CMPCTBLOCK, NetMsgType::GETBLOCKTXN, NetMsgType::BLOCKTXN, NetMsgType::GETCFILTERS, NetMsgType::CFILTER, NetMsgType::GETCFHEADERS, NetMsgType::CFHEADERS, NetMsgType::GETCFCHECKPT, NetMsgType::CFCHECKPT, }; static const std::vector allNetMessageTypesVec(allNetMessageTypes, allNetMessageTypes + ARRAYLEN(allNetMessageTypes)); CMessageHeader::CMessageHeader(const MessageMagic &pchMessageStartIn) { memcpy(std::begin(pchMessageStart), std::begin(pchMessageStartIn), MESSAGE_START_SIZE); memset(pchCommand.data(), 0, sizeof(pchCommand)); nMessageSize = -1; memset(pchChecksum, 0, CHECKSUM_SIZE); } CMessageHeader::CMessageHeader(const MessageMagic &pchMessageStartIn, const char *pszCommand, unsigned int nMessageSizeIn) { memcpy(std::begin(pchMessageStart), std::begin(pchMessageStartIn), MESSAGE_START_SIZE); // Copy the command name, zero-padding to COMMAND_SIZE bytes size_t i = 0; for (; i < pchCommand.size() && pszCommand[i] != 0; ++i) { pchCommand[i] = pszCommand[i]; } // Assert that the command name passed in is not longer than COMMAND_SIZE assert(pszCommand[i] == 0); for (; i < pchCommand.size(); ++i) { pchCommand[i] = 0; } nMessageSize = nMessageSizeIn; memset(pchChecksum, 0, CHECKSUM_SIZE); } std::string CMessageHeader::GetCommand() const { // return std::string(pchCommand.begin(), pchCommand.end()); return std::string(pchCommand.data(), pchCommand.data() + strnlen(pchCommand.data(), COMMAND_SIZE)); } static bool CheckHeaderMagicAndCommand(const CMessageHeader &header, const CMessageHeader::MessageMagic &magic) { // Check start string if (memcmp(std::begin(header.pchMessageStart), std::begin(magic), CMessageHeader::MESSAGE_START_SIZE) != 0) { return false; } // Check the command string for errors for (const char *p1 = header.pchCommand.data(); p1 < header.pchCommand.data() + CMessageHeader::COMMAND_SIZE; p1++) { if (*p1 == 0) { // Must be all zeros after the first zero for (; p1 < header.pchCommand.data() + CMessageHeader::COMMAND_SIZE; p1++) { if (*p1 != 0) { return false; } } } else if (*p1 < ' ' || *p1 > 0x7E) { return false; } } return true; } bool CMessageHeader::IsValid(const Config &config) const { // Check start string if (!CheckHeaderMagicAndCommand(*this, config.GetChainParams().NetMagic())) { return false; } // Message size if (IsOversized(config)) { LogPrintf("CMessageHeader::IsValid(): (%s, %u bytes) is oversized\n", GetCommand(), nMessageSize); return false; } return true; } /** * This is a transition method in order to stay compatible with older code that * do not use the config. It assumes message will not get too large. This cannot * be used for any piece of code that will download blocks as blocks may be * bigger than the permitted size. Idealy, code that uses this function should * be migrated toward using the config. */ bool CMessageHeader::IsValidWithoutConfig(const MessageMagic &magic) const { // Check start string if (!CheckHeaderMagicAndCommand(*this, magic)) { return false; } // Message size if (nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) { LogPrintf( "CMessageHeader::IsValidForSeeder(): (%s, %u bytes) is oversized\n", GetCommand(), nMessageSize); return false; } return true; } bool CMessageHeader::IsOversized(const Config &config) const { // If the message doesn't not contain a block content, check against // MAX_PROTOCOL_MESSAGE_LENGTH. if (nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH && !NetMsgType::IsBlockLike(GetCommand())) { return true; } // Scale the maximum accepted size with the block size. if (nMessageSize > 2 * config.GetMaxBlockSize()) { return true; } return false; } ServiceFlags GetDesirableServiceFlags(ServiceFlags services) { if ((services & NODE_NETWORK_LIMITED) && g_initial_block_download_completed) { return ServiceFlags(NODE_NETWORK_LIMITED); } return ServiceFlags(NODE_NETWORK); } void SetServiceFlagsIBDCache(bool state) { g_initial_block_download_completed = state; } std::string CInv::GetCommand() const { std::string cmd; switch (GetKind()) { case MSG_TX: return cmd.append(NetMsgType::TX); case MSG_BLOCK: return cmd.append(NetMsgType::BLOCK); case MSG_FILTERED_BLOCK: return cmd.append(NetMsgType::MERKLEBLOCK); case MSG_CMPCT_BLOCK: return cmd.append(NetMsgType::CMPCTBLOCK); default: throw std::out_of_range( strprintf("CInv::GetCommand(): type=%d unknown type", type)); } } std::string CInv::ToString() const { try { return strprintf("%s %s", GetCommand(), hash.ToString()); } catch (const std::out_of_range &) { return strprintf("0x%08x %s", type, hash.ToString()); } } const std::vector &getAllNetMessageTypes() { return allNetMessageTypesVec; } /** * Convert a service flag (NODE_*) to a human readable string. * It supports unknown service flags which will be returned as "UNKNOWN[...]". * @param[in] bit the service flag is calculated as (1 << bit) */ static std::string serviceFlagToStr(const size_t bit) { const uint64_t service_flag = 1ULL << bit; switch (ServiceFlags(service_flag)) { case NODE_NONE: // impossible abort(); case NODE_NETWORK: return "NETWORK"; case NODE_GETUTXO: return "GETUTXO"; case NODE_BLOOM: return "BLOOM"; case NODE_NETWORK_LIMITED: return "NETWORK_LIMITED"; case NODE_COMPACT_FILTERS: return "COMPACT_FILTERS"; case NODE_AVALANCHE: return "AVALANCHE"; default: std::ostringstream stream; stream.imbue(std::locale::classic()); stream << "UNKNOWN["; stream << "2^" << bit; stream << "]"; return stream.str(); } } std::vector serviceFlagsToStr(const uint64_t flags) { std::vector str_flags; for (size_t i = 0; i < sizeof(flags) * 8; ++i) { if (flags & (1ULL << i)) { str_flags.emplace_back(serviceFlagToStr(i)); } } return str_flags; } diff --git a/src/protocol.h b/src/protocol.h index a332de7e0..e6f8d0eff 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -1,536 +1,543 @@ // Copyright (c) 2009-2010 Satoshi Nakamoto // Copyright (c) 2009-2019 The Bitcoin Core developers // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. #ifndef __cplusplus #error This header can only be compiled as C++. #endif #ifndef BITCOIN_PROTOCOL_H #define BITCOIN_PROTOCOL_H #include #include #include #include #include #include #include class Config; /** * Maximum length of incoming protocol messages (Currently 2MB). * NB: Messages propagating block content are not subject to this limit. */ static const unsigned int MAX_PROTOCOL_MESSAGE_LENGTH = 2 * 1024 * 1024; /** * Message header. * (4) message start. * (12) command. * (4) size. * (4) checksum. */ class CMessageHeader { public: static constexpr size_t MESSAGE_START_SIZE = 4; static constexpr size_t COMMAND_SIZE = 12; static constexpr size_t MESSAGE_SIZE_SIZE = 4; static constexpr size_t CHECKSUM_SIZE = 4; static constexpr size_t MESSAGE_SIZE_OFFSET = MESSAGE_START_SIZE + COMMAND_SIZE; static constexpr size_t CHECKSUM_OFFSET = MESSAGE_SIZE_OFFSET + MESSAGE_SIZE_SIZE; static constexpr size_t HEADER_SIZE = MESSAGE_START_SIZE + COMMAND_SIZE + MESSAGE_SIZE_SIZE + CHECKSUM_SIZE; typedef std::array MessageMagic; explicit CMessageHeader(const MessageMagic &pchMessageStartIn); /** * Construct a P2P message header from message-start characters, a command * and the size of the message. * @note Passing in a `pszCommand` longer than COMMAND_SIZE will result in a * run-time assertion error. */ CMessageHeader(const MessageMagic &pchMessageStartIn, const char *pszCommand, unsigned int nMessageSizeIn); std::string GetCommand() const; bool IsValid(const Config &config) const; bool IsValidWithoutConfig(const MessageMagic &magic) const; bool IsOversized(const Config &config) const; SERIALIZE_METHODS(CMessageHeader, obj) { READWRITE(obj.pchMessageStart, obj.pchCommand, obj.nMessageSize, obj.pchChecksum); } MessageMagic pchMessageStart; std::array pchCommand; uint32_t nMessageSize; uint8_t pchChecksum[CHECKSUM_SIZE]; }; /** * Bitcoin protocol message types. When adding new message types, don't forget * to update allNetMessageTypes in protocol.cpp. */ namespace NetMsgType { /** * The version message provides information about the transmitting node to the * receiving node at the beginning of a connection. * @see https://bitcoin.org/en/developer-reference#version */ extern const char *VERSION; /** * The verack message acknowledges a previously-received version message, * informing the connecting node that it can begin to send other messages. * @see https://bitcoin.org/en/developer-reference#verack */ extern const char *VERACK; /** * The addr (IP address) message relays connection information for peers on the * network. * @see https://bitcoin.org/en/developer-reference#addr */ extern const char *ADDR; /** * The addrv2 message relays connection information for peers on the network * just like the addr message, but is extended to allow gossiping of longer node * addresses (see BIP155). */ extern const char *ADDRV2; /** * The sendaddrv2 message signals support for receiving ADDRV2 messages * (BIP155). It also implies that its sender can encode as ADDRV2 and would send * ADDRV2 instead of ADDR to a peer that has signaled ADDRV2 support by sending * SENDADDRV2. */ extern const char *SENDADDRV2; /** * The inv message (inventory message) transmits one or more inventories of * objects known to the transmitting peer. * @see https://bitcoin.org/en/developer-reference#inv */ extern const char *INV; /** * The getdata message requests one or more data objects from another node. * @see https://bitcoin.org/en/developer-reference#getdata */ extern const char *GETDATA; /** * The merkleblock message is a reply to a getdata message which requested a * block using the inventory type MSG_MERKLEBLOCK. * @since protocol version 70001 as described by BIP37. * @see https://bitcoin.org/en/developer-reference#merkleblock */ extern const char *MERKLEBLOCK; /** * The getblocks message requests an inv message that provides block header * hashes starting from a particular point in the block chain. * @see https://bitcoin.org/en/developer-reference#getblocks */ extern const char *GETBLOCKS; /** * The getheaders message requests a headers message that provides block * headers starting from a particular point in the block chain. * @since protocol version 31800. * @see https://bitcoin.org/en/developer-reference#getheaders */ extern const char *GETHEADERS; /** * The tx message transmits a single transaction. * @see https://bitcoin.org/en/developer-reference#tx */ extern const char *TX; /** * The headers message sends one or more block headers to a node which * previously requested certain headers with a getheaders message. * @since protocol version 31800. * @see https://bitcoin.org/en/developer-reference#headers */ extern const char *HEADERS; /** * The block message transmits a single serialized block. * @see https://bitcoin.org/en/developer-reference#block */ extern const char *BLOCK; /** * The getaddr message requests an addr message from the receiving node, * preferably one with lots of IP addresses of other receiving nodes. * @see https://bitcoin.org/en/developer-reference#getaddr */ extern const char *GETADDR; /** * The mempool message requests the TXIDs of transactions that the receiving * node has verified as valid but which have not yet appeared in a block. * @since protocol version 60002. * @see https://bitcoin.org/en/developer-reference#mempool */ extern const char *MEMPOOL; /** * The ping message is sent periodically to help confirm that the receiving * peer is still connected. * @see https://bitcoin.org/en/developer-reference#ping */ extern const char *PING; /** * The pong message replies to a ping message, proving to the pinging node that * the ponging node is still alive. * @since protocol version 60001 as described by BIP31. * @see https://bitcoin.org/en/developer-reference#pong */ extern const char *PONG; /** * The notfound message is a reply to a getdata message which requested an * object the receiving node does not have available for relay. * @since protocol version 70001. * @see https://bitcoin.org/en/developer-reference#notfound */ extern const char *NOTFOUND; /** * The filterload message tells the receiving peer to filter all relayed * transactions and requested merkle blocks through the provided filter. * @since protocol version 70001 as described by BIP37. * Only available with service bit NODE_BLOOM since protocol version * 70011 as described by BIP111. * @see https://bitcoin.org/en/developer-reference#filterload */ extern const char *FILTERLOAD; /** * The filteradd message tells the receiving peer to add a single element to a * previously-set bloom filter, such as a new public key. * @since protocol version 70001 as described by BIP37. * Only available with service bit NODE_BLOOM since protocol version * 70011 as described by BIP111. * @see https://bitcoin.org/en/developer-reference#filteradd */ extern const char *FILTERADD; /** * The filterclear message tells the receiving peer to remove a previously-set * bloom filter. * @since protocol version 70001 as described by BIP37. * Only available with service bit NODE_BLOOM since protocol version * 70011 as described by BIP111. * @see https://bitcoin.org/en/developer-reference#filterclear */ extern const char *FILTERCLEAR; /** * Indicates that a node prefers to receive new block announcements via a * "headers" message rather than an "inv". * @since protocol version 70012 as described by BIP130. * @see https://bitcoin.org/en/developer-reference#sendheaders */ extern const char *SENDHEADERS; /** * The feefilter message tells the receiving peer not to inv us any txs * which do not meet the specified min fee rate. * @since protocol version 70013 as described by BIP133 */ extern const char *FEEFILTER; /** * Contains a 1-byte bool and 8-byte LE version number. * Indicates that a node is willing to provide blocks via "cmpctblock" messages. * May indicate that a node prefers to receive new block announcements via a * "cmpctblock" message rather than an "inv", depending on message contents. * @since protocol version 70014 as described by BIP 152 */ extern const char *SENDCMPCT; /** * Contains a CBlockHeaderAndShortTxIDs object - providing a header and * list of "short txids". * @since protocol version 70014 as described by BIP 152 */ extern const char *CMPCTBLOCK; /** * Contains a BlockTransactionsRequest * Peer should respond with "blocktxn" message. * @since protocol version 70014 as described by BIP 152 */ extern const char *GETBLOCKTXN; /** * Contains a BlockTransactions. * Sent in response to a "getblocktxn" message. * @since protocol version 70014 as described by BIP 152 */ extern const char *BLOCKTXN; /** * getcfilters requests compact filters for a range of blocks. * Only available with service bit NODE_COMPACT_FILTERS as described by * BIP 157 & 158. */ extern const char *GETCFILTERS; /** * cfilter is a response to a getcfilters request containing a single compact * filter. */ extern const char *CFILTER; /** * getcfheaders requests a compact filter header and the filter hashes for a * range of blocks, which can then be used to reconstruct the filter headers * for those blocks. * Only available with service bit NODE_COMPACT_FILTERS as described by * BIP 157 & 158. */ extern const char *GETCFHEADERS; /** * cfheaders is a response to a getcfheaders request containing a filter header * and a vector of filter hashes for each subsequent block in the requested * range. */ extern const char *CFHEADERS; /** * getcfcheckpt requests evenly spaced compact filter headers, enabling * parallelized download and validation of the headers between them. * Only available with service bit NODE_COMPACT_FILTERS as described by * BIP 157 & 158. */ extern const char *GETCFCHECKPT; /** * cfcheckpt is a response to a getcfcheckpt request containing a vector of * evenly spaced filter headers for blocks on the requested chain. */ extern const char *CFCHECKPT; /** * Contains a delegation and a signature. */ extern const char *AVAHELLO; /** * Contains an avalanche::Poll. * Peer should respond with "avaresponse" message. */ extern const char *AVAPOLL; /** * Contains an avalanche::Response. * Sent in response to a "avapoll" message. */ extern const char *AVARESPONSE; +/** + * Contains an avalanche::Proof. + * Sent in response to a "getdata" message with inventory type + * MSG_AVA_PROOF. + */ +extern const char *AVAPROOF; /** * Indicate if the message is used to transmit the content of a block. * These messages can be significantly larger than usual messages and therefore * may need to be processed differently. */ bool IsBlockLike(const std::string &strCommand); }; // namespace NetMsgType /** Get a vector of all valid message types (see above) */ const std::vector &getAllNetMessageTypes(); /** * nServices flags. */ enum ServiceFlags : uint64_t { // NOTE: When adding here, be sure to update serviceFlagToStr too // Nothing NODE_NONE = 0, // NODE_NETWORK means that the node is capable of serving the complete block // chain. It is currently set by all Bitcoin ABC non pruned nodes, and is // unset by SPV clients or other light clients. NODE_NETWORK = (1 << 0), // NODE_GETUTXO means the node is capable of responding to the getutxo // protocol request. Bitcoin ABC does not support this but a patch set // called Bitcoin XT does. See BIP 64 for details on how this is // implemented. NODE_GETUTXO = (1 << 1), // NODE_BLOOM means the node is capable and willing to handle bloom-filtered // connections. Bitcoin ABC nodes used to support this by default, without // advertising this bit, but no longer do as of protocol version 70011 (= // NO_BLOOM_VERSION) NODE_BLOOM = (1 << 2), // Bit 4 was NODE_XTHIN, removed in v0.22.12 // Bit 5 was NODE_BITCOIN_CASH, removed in v0.22.8 // NODE_COMPACT_FILTERS means the node will service basic block filter // requests. // See BIP157 and BIP158 for details on how this is implemented. NODE_COMPACT_FILTERS = (1 << 6), // NODE_NETWORK_LIMITED means the same as NODE_NETWORK with the limitation // of only serving the last 288 (2 day) blocks // See BIP159 for details on how this is implemented. NODE_NETWORK_LIMITED = (1 << 10), // The last non experimental service bit, helper for looping over the flags NODE_LAST_NON_EXPERIMENTAL_SERVICE_BIT = (1 << 23), // Bits 24-31 are reserved for temporary experiments. Just pick a bit that // isn't getting used, or one not being used much, and notify the // bitcoin-development mailing list. Remember that service bits are just // unauthenticated advertisements, so your code must be robust against // collisions and other cases where nodes may be advertising a service they // do not actually support. Other service bits should be allocated via the // BIP process. // NODE_AVALANCHE means the node supports Bitcoin Cash's avalanche // preconsensus mechanism. NODE_AVALANCHE = (1 << 24), }; /** * Convert service flags (a bitmask of NODE_*) to human readable strings. * It supports unknown service flags which will be returned as "UNKNOWN[...]". * @param[in] flags multiple NODE_* bitwise-OR-ed together */ std::vector serviceFlagsToStr(const uint64_t flags); /** * Gets the set of service flags which are "desirable" for a given peer. * * These are the flags which are required for a peer to support for them * to be "interesting" to us, ie for us to wish to use one of our few * outbound connection slots for or for us to wish to prioritize keeping * their connection around. * * Relevant service flags may be peer- and state-specific in that the * version of the peer may determine which flags are required (eg in the * case of NODE_NETWORK_LIMITED where we seek out NODE_NETWORK peers * unless they set NODE_NETWORK_LIMITED and we are out of IBD, in which * case NODE_NETWORK_LIMITED suffices). * * Thus, generally, avoid calling with peerServices == NODE_NONE, unless * state-specific flags must absolutely be avoided. When called with * peerServices == NODE_NONE, the returned desirable service flags are * guaranteed to not change dependent on state - ie they are suitable for * use when describing peers which we know to be desirable, but for which * we do not have a confirmed set of service flags. * * If the NODE_NONE return value is changed, contrib/seeds/makeseeds.py * should be updated appropriately to filter for the same nodes. */ ServiceFlags GetDesirableServiceFlags(ServiceFlags services); /** * Set the current IBD status in order to figure out the desirable service * flags */ void SetServiceFlagsIBDCache(bool status); /** * A shortcut for (services & GetDesirableServiceFlags(services)) * == GetDesirableServiceFlags(services), ie determines whether the given * set of service flags are sufficient for a peer to be "relevant". */ static inline bool HasAllDesirableServiceFlags(ServiceFlags services) { return !(GetDesirableServiceFlags(services) & (~services)); } /** * Checks if a peer with the given service flags may be capable of having a * robust address-storage DB. */ static inline bool MayHaveUsefulAddressDB(ServiceFlags services) { return (services & NODE_NETWORK) || (services & NODE_NETWORK_LIMITED); } /** * A CService with information about it as peer. */ class CAddress : public CService { static constexpr uint32_t TIME_INIT{100000000}; public: CAddress() : CService{} {}; CAddress(CService ipIn, ServiceFlags nServicesIn) : CService{ipIn}, nServices{nServicesIn} {}; CAddress(CService ipIn, ServiceFlags nServicesIn, uint32_t nTimeIn) : CService{ipIn}, nTime{nTimeIn}, nServices{nServicesIn} {}; void Init(); SERIALIZE_METHODS(CAddress, obj) { SER_READ(obj, obj.nTime = TIME_INIT); int nVersion = s.GetVersion(); if (s.GetType() & SER_DISK) { READWRITE(nVersion); } if ((s.GetType() & SER_DISK) || (nVersion != INIT_PROTO_VERSION && !(s.GetType() & SER_GETHASH))) { // The only time we serialize a CAddress object without nTime is in // the initial VERSION messages which contain two CAddress records. // At that point, the serialization version is INIT_PROTO_VERSION. // After the version handshake, serialization version is >= // MIN_PEER_PROTO_VERSION and all ADDR messages are serialized with // nTime. READWRITE(obj.nTime); } if (nVersion & ADDRV2_FORMAT) { uint64_t services_tmp; SER_WRITE(obj, services_tmp = obj.nServices); READWRITE(Using>(services_tmp)); SER_READ(obj, obj.nServices = static_cast(services_tmp)); } else { READWRITE(Using>(obj.nServices)); } READWRITEAS(CService, obj); } // disk and network only uint32_t nTime{TIME_INIT}; ServiceFlags nServices{NODE_NONE}; }; /** getdata message type flags */ const uint32_t MSG_TYPE_MASK = 0xffffffff >> 3; /** * getdata / inv message types. * These numbers are defined by the protocol. When adding a new value, be sure * to mention it in the respective BIP. */ enum GetDataMsg { UNDEFINED = 0, MSG_TX = 1, MSG_BLOCK = 2, // The following can only occur in getdata. Invs always use TX or BLOCK. //! Defined in BIP37 MSG_FILTERED_BLOCK = 3, //! Defined in BIP152 MSG_CMPCT_BLOCK = 4, + MSG_AVA_PROOF = 0x1f000001, }; /** * Inv(ventory) message data. * Intended as non-ambiguous identifier of objects (eg. transactions, blocks) * held by peers. */ class CInv { public: uint32_t type; uint256 hash; CInv() : type(0), hash() {} CInv(uint32_t typeIn, const uint256 &hashIn) : type(typeIn), hash(hashIn) {} SERIALIZE_METHODS(CInv, obj) { READWRITE(obj.type, obj.hash); } friend bool operator<(const CInv &a, const CInv &b) { return a.type < b.type || (a.type == b.type && a.hash < b.hash); } std::string GetCommand() const; std::string ToString() const; uint32_t GetKind() const { return type & MSG_TYPE_MASK; } bool IsTx() const { auto k = GetKind(); return k == MSG_TX; } bool IsSomeBlock() const { auto k = GetKind(); return k == MSG_BLOCK || k == MSG_FILTERED_BLOCK || k == MSG_CMPCT_BLOCK; } }; #endif // BITCOIN_PROTOCOL_H diff --git a/test/functional/test_framework/messages.py b/test/functional/test_framework/messages.py index 20c8f622a..28f732b2f 100755 --- a/test/functional/test_framework/messages.py +++ b/test/functional/test_framework/messages.py @@ -1,1876 +1,1940 @@ #!/usr/bin/env python3 # Copyright (c) 2010 ArtForz -- public domain half-a-node # Copyright (c) 2012 Jeff Garzik # Copyright (c) 2010-2019 The Bitcoin Core developers # Distributed under the MIT software license, see the accompanying # file COPYING or http://www.opensource.org/licenses/mit-license.php. """Bitcoin test framework primitive and message structures CBlock, CTransaction, CBlockHeader, CTxIn, CTxOut, etc....: data structures that should map to corresponding structures in bitcoin/primitives msg_block, msg_tx, msg_headers, etc.: data structures that represent network messages ser_*, deser_*: functions that handle serialization/deserialization. Classes use __slots__ to ensure extraneous attributes aren't accidentally added by tests, compromising their intended effect. """ from codecs import encode import copy import hashlib from io import BytesIO import random import socket import struct import time +import unittest from typing import List from test_framework.siphash import siphash256 from test_framework.util import hex_str_to_bytes, assert_equal MIN_VERSION_SUPPORTED = 60001 # past bip-31 for ping/pong MY_VERSION = 70014 MY_SUBVERSION = b"/python-mininode-tester:0.0.3/" # from version 70001 onwards, fRelay should be appended to version # messages (BIP37) MY_RELAY = 1 MAX_LOCATOR_SZ = 101 MAX_BLOCK_BASE_SIZE = 1000000 MAX_BLOOM_FILTER_SIZE = 36000 MAX_BLOOM_HASH_FUNCS = 50 # 1 BCH in satoshis COIN = 100000000 MAX_MONEY = 21000000 * COIN NODE_NETWORK = (1 << 0) NODE_GETUTXO = (1 << 1) NODE_BLOOM = (1 << 2) # NODE_WITNESS = (1 << 3) # NODE_XTHIN = (1 << 4) # removed in v0.22.12 NODE_COMPACT_FILTERS = (1 << 6) NODE_NETWORK_LIMITED = (1 << 10) NODE_AVALANCHE = (1 << 24) MSG_TX = 1 MSG_BLOCK = 2 MSG_FILTERED_BLOCK = 3 MSG_CMPCT_BLOCK = 4 +MSG_AVA_PROOF = 0x1f000001 MSG_TYPE_MASK = 0xffffffff >> 2 FILTER_TYPE_BASIC = 0 # Serialization/deserialization tools def sha256(s): return hashlib.new('sha256', s).digest() def hash256(s): return sha256(sha256(s)) def ser_compact_size(size): r = b"" if size < 253: r = struct.pack("B", size) elif size < 0x10000: r = struct.pack(">= 32 return rs def uint256_from_str(s): r = 0 t = struct.unpack("> 24) & 0xFF v = (c & 0xFFFFFF) << (8 * (nbytes - 3)) return v # deser_function_name: Allow for an alternate deserialization function on the # entries in the vector. def deser_vector(f, c, deser_function_name=None): nit = deser_compact_size(f) r = [] for i in range(nit): t = c() if deser_function_name: getattr(t, deser_function_name)(f) else: t.deserialize(f) r.append(t) return r # ser_function_name: Allow for an alternate serialization function on the # entries in the vector. def ser_vector(v, ser_function_name=None): r = ser_compact_size(len(v)) for i in v: if ser_function_name: r += getattr(i, ser_function_name)() else: r += i.serialize() return r def deser_uint256_vector(f): nit = deser_compact_size(f) r = [] for i in range(nit): t = deser_uint256(f) r.append(t) return r def ser_uint256_vector(v): r = ser_compact_size(len(v)) for i in v: r += ser_uint256(i) return r def deser_string_vector(f): nit = deser_compact_size(f) r = [] for i in range(nit): t = deser_string(f) r.append(t) return r def ser_string_vector(v): r = ser_compact_size(len(v)) for sv in v: r += ser_string(sv) return r def FromHex(obj, hex_string): """Deserialize from a hex string representation (eg from RPC)""" obj.deserialize(BytesIO(hex_str_to_bytes(hex_string))) return obj def ToHex(obj): """Convert a binary-serializable object to hex (eg for submission via RPC)""" return obj.serialize().hex() # Objects that map to bitcoind objects, which can be serialized/deserialized class CAddress: __slots__ = ("net", "ip", "nServices", "port", "time") # see https://github.com/bitcoin/bips/blob/master/bip-0155.mediawiki NET_IPV4 = 1 ADDRV2_NET_NAME = { NET_IPV4: "IPv4" } ADDRV2_ADDRESS_LENGTH = { NET_IPV4: 4 } def __init__(self): self.time = 0 self.nServices = 1 self.net = self.NET_IPV4 self.ip = "0.0.0.0" self.port = 0 def deserialize(self, f, *, with_time=True): """Deserialize from addrv1 format (pre-BIP155)""" if with_time: # VERSION messages serialize CAddress objects without time self.time = struct.unpack("H", f.read(2))[0] def serialize(self, *, with_time=True): """Serialize in addrv1 format (pre-BIP155)""" assert self.net == self.NET_IPV4 r = b"" if with_time: # VERSION messages serialize CAddress objects without time r += struct.pack("H", self.port) return r def deserialize_v2(self, f): """Deserialize from addrv2 format (BIP155)""" self.time = struct.unpack("H", f.read(2))[0] def serialize_v2(self): """Serialize in addrv2 format (BIP155)""" assert self.net == self.NET_IPV4 r = b"" r += struct.pack("H", self.port) return r def __repr__(self): return ("CAddress(nServices=%i net=%s addr=%s port=%i)" % (self.nServices, self.ADDRV2_NET_NAME[self.net], self.ip, self.port)) class CInv: __slots__ = ("hash", "type") typemap = { 0: "Error", MSG_TX: "TX", MSG_BLOCK: "Block", MSG_FILTERED_BLOCK: "filtered Block", - MSG_CMPCT_BLOCK: "CompactBlock" + MSG_CMPCT_BLOCK: "CompactBlock", + MSG_AVA_PROOF: "avalanche proof", } def __init__(self, t=0, h=0): self.type = t self.hash = h def deserialize(self, f): self.type = struct.unpack(" 21000000 * COIN: return False return True def __repr__(self): return "CTransaction(nVersion={} vin={} vout={} nLockTime={})".format( self.nVersion, repr(self.vin), repr(self.vout), self.nLockTime) class CBlockHeader: __slots__ = ("hash", "hashMerkleRoot", "hashPrevBlock", "nBits", "nNonce", "nTime", "nVersion", "sha256") def __init__(self, header=None): if header is None: self.set_null() else: self.nVersion = header.nVersion self.hashPrevBlock = header.hashPrevBlock self.hashMerkleRoot = header.hashMerkleRoot self.nTime = header.nTime self.nBits = header.nBits self.nNonce = header.nNonce self.sha256 = header.sha256 self.hash = header.hash self.calc_sha256() def set_null(self): self.nVersion = 1 self.hashPrevBlock = 0 self.hashMerkleRoot = 0 self.nTime = 0 self.nBits = 0 self.nNonce = 0 self.sha256 = None self.hash = None def deserialize(self, f): self.nVersion = struct.unpack(" 1: newhashes = [] for i in range(0, len(hashes), 2): i2 = min(i + 1, len(hashes) - 1) newhashes.append(hash256(hashes[i] + hashes[i2])) hashes = newhashes return uint256_from_str(hashes[0]) def calc_merkle_root(self): hashes = [] for tx in self.vtx: tx.calc_sha256() hashes.append(ser_uint256(tx.sha256)) return self.get_merkle_root(hashes) def is_valid(self): self.calc_sha256() target = uint256_from_compact(self.nBits) if self.sha256 > target: return False for tx in self.vtx: if not tx.is_valid(): return False if self.calc_merkle_root() != self.hashMerkleRoot: return False return True def solve(self): self.rehash() target = uint256_from_compact(self.nBits) while self.sha256 > target: self.nNonce += 1 self.rehash() def __repr__(self): return "CBlock(nVersion={} hashPrevBlock={:064x} hashMerkleRoot={:064x} nTime={} nBits={:08x} nNonce={:08x} vtx={})".format( self.nVersion, self.hashPrevBlock, self.hashMerkleRoot, self.nTime, self.nBits, self.nNonce, repr(self.vtx)) class PrefilledTransaction: __slots__ = ("index", "tx") def __init__(self, index=0, tx=None): self.index = index self.tx = tx def deserialize(self, f): self.index = deser_compact_size(f) self.tx = CTransaction() self.tx.deserialize(f) def serialize(self): r = b"" r += ser_compact_size(self.index) r += self.tx.serialize() return r def __repr__(self): return "PrefilledTransaction(index={}, tx={})".format( self.index, repr(self.tx)) # This is what we send on the wire, in a cmpctblock message. class P2PHeaderAndShortIDs: __slots__ = ("header", "nonce", "prefilled_txn", "prefilled_txn_length", "shortids", "shortids_length") def __init__(self): self.header = CBlockHeader() self.nonce = 0 self.shortids_length = 0 self.shortids = [] self.prefilled_txn_length = 0 self.prefilled_txn = [] def deserialize(self, f): self.header.deserialize(f) self.nonce = struct.unpack(" class msg_headers: __slots__ = ("headers",) msgtype = b"headers" def __init__(self, headers=None): self.headers = headers if headers is not None else [] def deserialize(self, f): # comment in bitcoind indicates these should be deserialized as blocks blocks = deser_vector(f, CBlock) for x in blocks: self.headers.append(CBlockHeader(x)) def serialize(self): blocks = [CBlock(x) for x in self.headers] return ser_vector(blocks) def __repr__(self): return "msg_headers(headers={})".format(repr(self.headers)) class msg_merkleblock: __slots__ = ("merkleblock",) msgtype = b"merkleblock" def __init__(self, merkleblock=None): if merkleblock is None: self.merkleblock = CMerkleBlock() else: self.merkleblock = merkleblock def deserialize(self, f): self.merkleblock.deserialize(f) def serialize(self): return self.merkleblock.serialize() def __repr__(self): return "msg_merkleblock(merkleblock={})".format(repr(self.merkleblock)) class msg_filterload: __slots__ = ("data", "nHashFuncs", "nTweak", "nFlags") msgtype = b"filterload" def __init__(self, data=b'00', nHashFuncs=0, nTweak=0, nFlags=0): self.data = data self.nHashFuncs = nHashFuncs self.nTweak = nTweak self.nFlags = nFlags def deserialize(self, f): self.data = deser_string(f) self.nHashFuncs = struct.unpack(" 0: self.recvbuf += t while True: msg = self._on_data() if msg is None: break self.on_message(msg) def _on_data(self): """Try to read P2P messages from the recv buffer. This method reads data from the buffer in a loop. It deserializes, parses and verifies the P2P header, then passes the P2P payload to the on_message callback for processing.""" try: with mininode_lock: if len(self.recvbuf) < 4: return None if self.recvbuf[:4] != self.magic_bytes: raise ValueError( "magic bytes mismatch: {} != {}".format( repr( self.magic_bytes), repr( self.recvbuf))) if len(self.recvbuf) < 4 + 12 + 4 + 4: return None msgtype = self.recvbuf[4:4 + 12].split(b"\x00", 1)[0] msglen = struct.unpack( " 500: log_message += "... (msg truncated)" logger.debug(log_message) class P2PInterface(P2PConnection): """A high-level P2P interface class for communicating with a Bitcoin Cash node. This class provides high-level callbacks for processing P2P message payloads, as well as convenience methods for interacting with the node over P2P. Individual testcases should subclass this and override the on_* methods if they want to alter message handling behaviour.""" def __init__(self, support_addrv2=False): super().__init__() # Track number of messages of each type received and the most recent # message of each type self.message_count = defaultdict(int) self.last_message = {} # A count of the number of ping messages we've sent to the node self.ping_counter = 1 # The network services received from the peer self.nServices = 0 self.support_addrv2 = support_addrv2 def peer_connect(self, *args, services=NODE_NETWORK, send_version=True, **kwargs): create_conn = super().peer_connect(*args, **kwargs) if send_version: # Send a version msg vt = msg_version() vt.nServices = services vt.addrTo.ip = self.dstaddr vt.addrTo.port = self.dstport vt.addrFrom.ip = "0.0.0.0" vt.addrFrom.port = 0 # Will be sent soon after connection_made self.on_connection_send_msg = vt return create_conn # Message receiving methods def on_message(self, message): """Receive message and dispatch message to appropriate callback. We keep a count of how many of each message type has been received and the most recent message of each type.""" with mininode_lock: try: msgtype = message.msgtype.decode('ascii') self.message_count[msgtype] += 1 self.last_message[msgtype] = message getattr(self, 'on_' + msgtype)(message) except Exception: print("ERROR delivering {} ({})".format( repr(message), sys.exc_info()[0])) raise # Callback methods. Can be overridden by subclasses in individual test # cases to provide custom message handling behaviour. def on_open(self): pass def on_close(self): pass def on_addr(self, message): pass def on_addrv2(self, message): pass def on_avapoll(self, message): pass + def on_avaproof(self, message): pass + def on_avaresponse(self, message): pass def on_avahello(self, message): pass def on_block(self, message): pass def on_blocktxn(self, message): pass def on_cfcheckpt(self, message): pass def on_cfheaders(self, message): pass def on_cfilter(self, message): pass def on_cmpctblock(self, message): pass def on_feefilter(self, message): pass def on_filteradd(self, message): pass def on_filterclear(self, message): pass def on_filterload(self, message): pass def on_getaddr(self, message): pass def on_getblocks(self, message): pass def on_getblocktxn(self, message): pass def on_getdata(self, message): pass def on_getheaders(self, message): pass def on_headers(self, message): pass def on_mempool(self, message): pass def on_merkleblock(self, message): pass def on_notfound(self, message): pass def on_pong(self, message): pass def on_sendaddrv2(self, message): pass def on_sendcmpct(self, message): pass def on_sendheaders(self, message): pass def on_tx(self, message): pass def on_inv(self, message): want = msg_getdata() for i in message.inv: if i.type != 0: want.inv.append(i) if len(want.inv): self.send_message(want) def on_ping(self, message): self.send_message(msg_pong(message.nonce)) def on_verack(self, message): pass def on_version(self, message): assert message.nVersion >= MIN_VERSION_SUPPORTED, "Version {} received. Test framework only supports versions greater than {}".format( message.nVersion, MIN_VERSION_SUPPORTED) self.send_message(msg_verack()) if self.support_addrv2: self.send_message(msg_sendaddrv2()) self.nServices = message.nServices # Connection helper methods def wait_until(self, test_function, timeout=60): wait_until(test_function, timeout=timeout, lock=mininode_lock, timeout_factor=self.timeout_factor) def wait_for_disconnect(self, timeout=60): def test_function(): return not self.is_connected self.wait_until(test_function, timeout=timeout) # Message receiving helper methods def wait_for_tx(self, txid, timeout=60): def test_function(): assert self.is_connected if not self.last_message.get('tx'): return False return self.last_message['tx'].tx.rehash() == txid self.wait_until(test_function, timeout=timeout) def wait_for_block(self, blockhash, timeout=60): def test_function(): assert self.is_connected return self.last_message.get( "block") and self.last_message["block"].block.rehash() == blockhash self.wait_until(test_function, timeout=timeout) def wait_for_header(self, blockhash, timeout=60): def test_function(): assert self.is_connected last_headers = self.last_message.get('headers') if not last_headers: return False return last_headers.headers[0].rehash() == int(blockhash, 16) self.wait_until(test_function, timeout=timeout) def wait_for_merkleblock(self, blockhash, timeout=60): def test_function(): assert self.is_connected last_filtered_block = self.last_message.get('merkleblock') if not last_filtered_block: return False return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16) self.wait_until(test_function, timeout=timeout) def wait_for_getdata(self, hash_list, timeout=60): """Waits for a getdata message. The object hashes in the inventory vector must match the provided hash_list.""" def test_function(): assert self.is_connected last_data = self.last_message.get("getdata") if not last_data: return False return [x.hash for x in last_data.inv] == hash_list self.wait_until(test_function, timeout=timeout) def wait_for_getheaders(self, timeout=60): """Waits for a getheaders message. Receiving any getheaders message will satisfy the predicate. the last_message["getheaders"] value must be explicitly cleared before calling this method, or this will return immediately with success. TODO: change this method to take a hash value and only return true if the correct block header has been requested.""" def test_function(): assert self.is_connected return self.last_message.get("getheaders") self.wait_until(test_function, timeout=timeout) def wait_for_inv(self, expected_inv, timeout=60): """Waits for an INV message and checks that the first inv object in the message was as expected.""" if len(expected_inv) > 1: raise NotImplementedError( "wait_for_inv() will only verify the first inv object") def test_function(): assert self.is_connected return self.last_message.get("inv") and \ self.last_message["inv"].inv[0].type == expected_inv[0].type and \ self.last_message["inv"].inv[0].hash == expected_inv[0].hash self.wait_until(test_function, timeout=timeout) def wait_for_verack(self, timeout=60): def test_function(): return self.message_count["verack"] self.wait_until(test_function, timeout=timeout) # Message sending helper functions def send_and_ping(self, message, timeout=60): self.send_message(message) self.sync_with_ping(timeout=timeout) # Sync up with the node def sync_with_ping(self, timeout=60): self.send_message(msg_ping(nonce=self.ping_counter)) def test_function(): assert self.is_connected return self.last_message.get( "pong") and self.last_message["pong"].nonce == self.ping_counter self.wait_until(test_function, timeout=timeout) self.ping_counter += 1 # One lock for synchronizing all data access between the networking thread (see # NetworkThread below) and the thread running the test logic. For simplicity, # P2PConnection acquires this lock whenever delivering a message to a P2PInterface. # This lock should be acquired in the thread running the test logic to synchronize # access to any data shared with the P2PInterface or P2PConnection. mininode_lock = threading.RLock() class NetworkThread(threading.Thread): network_event_loop = None def __init__(self): super().__init__(name="NetworkThread") # There is only one event loop and no more than one thread must be # created assert not self.network_event_loop NetworkThread.network_event_loop = asyncio.new_event_loop() def run(self): """Start the network thread.""" self.network_event_loop.run_forever() def close(self, timeout=10): """Close the connections and network event loop.""" self.network_event_loop.call_soon_threadsafe( self.network_event_loop.stop) wait_until(lambda: not self.network_event_loop.is_running(), timeout=timeout) self.network_event_loop.close() self.join(timeout) # Safe to remove event loop. NetworkThread.network_event_loop = None class P2PDataStore(P2PInterface): """A P2P data store class. Keeps a block and transaction store and responds correctly to getdata and getheaders requests.""" def __init__(self): super().__init__() # store of blocks. key is block hash, value is a CBlock object self.block_store = {} self.last_block_hash = '' # store of txs. key is txid, value is a CTransaction object self.tx_store = {} self.getdata_requests = [] def on_getdata(self, message): """Check for the tx/block in our stores and if found, reply with an inv message.""" for inv in message.inv: self.getdata_requests.append(inv.hash) if (inv.type & MSG_TYPE_MASK) == MSG_TX and inv.hash in self.tx_store.keys(): self.send_message(msg_tx(self.tx_store[inv.hash])) elif (inv.type & MSG_TYPE_MASK) == MSG_BLOCK and inv.hash in self.block_store.keys(): self.send_message(msg_block(self.block_store[inv.hash])) else: logger.debug( 'getdata message type {} received.'.format(hex(inv.type))) def on_getheaders(self, message): """Search back through our block store for the locator, and reply with a headers message if found.""" locator, hash_stop = message.locator, message.hashstop # Assume that the most recent block added is the tip if not self.block_store: return headers_list = [self.block_store[self.last_block_hash]] maxheaders = 2000 while headers_list[-1].sha256 not in locator.vHave: # Walk back through the block store, adding headers to headers_list # as we go. prev_block_hash = headers_list[-1].hashPrevBlock if prev_block_hash in self.block_store: prev_block_header = CBlockHeader( self.block_store[prev_block_hash]) headers_list.append(prev_block_header) if prev_block_header.sha256 == hash_stop: # if this is the hashstop header, stop here break else: logger.debug('block hash {} not found in block store'.format( hex(prev_block_hash))) break # Truncate the list if there are too many headers headers_list = headers_list[:-maxheaders - 1:-1] response = msg_headers(headers_list) if response is not None: self.send_message(response) def send_blocks_and_test(self, blocks, node, *, success=True, force_send=False, reject_reason=None, expect_disconnect=False, timeout=60): """Send blocks to test node and test whether the tip advances. - add all blocks to our block_store - send a headers message for the final block - the on_getheaders handler will ensure that any getheaders are responded to - if force_send is False: wait for getdata for each of the blocks. The on_getdata handler will ensure that any getdata messages are responded to. Otherwise send the full block unsolicited. - if success is True: assert that the node's tip advances to the most recent block - if success is False: assert that the node's tip doesn't advance - if reject_reason is set: assert that the correct reject message is logged""" with mininode_lock: for block in blocks: self.block_store[block.sha256] = block self.last_block_hash = block.sha256 def test(): if force_send: for b in blocks: self.send_message(msg_block(block=b)) else: self.send_message( msg_headers([CBlockHeader(block) for block in blocks])) self.wait_until( lambda: blocks[-1].sha256 in self.getdata_requests, timeout=timeout) if expect_disconnect: self.wait_for_disconnect(timeout=timeout) else: self.sync_with_ping(timeout=timeout) if success: self.wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout) else: assert node.getbestblockhash() != blocks[-1].hash if reject_reason: with node.assert_debug_log(expected_msgs=[reject_reason]): test() else: test() def send_txs_and_test(self, txs, node, *, success=True, expect_disconnect=False, reject_reason=None): """Send txs to test node and test whether they're accepted to the mempool. - add all txs to our tx_store - send tx messages for all txs - if success is True/False: assert that the txs are/are not accepted to the mempool - if expect_disconnect is True: Skip the sync with ping - if reject_reason is set: assert that the correct reject message is logged.""" with mininode_lock: for tx in txs: self.tx_store[tx.sha256] = tx def test(): for tx in txs: self.send_message(msg_tx(tx)) if expect_disconnect: self.wait_for_disconnect() else: self.sync_with_ping() raw_mempool = node.getrawmempool() if success: # Check that all txs are now in the mempool for tx in txs: assert tx.hash in raw_mempool, "{} not found in mempool".format( tx.hash) else: # Check that none of the txs are now in the mempool for tx in txs: assert tx.hash not in raw_mempool, "{} tx found in mempool".format( tx.hash) if reject_reason: with node.assert_debug_log(expected_msgs=[reject_reason]): test() else: test() class P2PTxInvStore(P2PInterface): """A P2PInterface which stores a count of how many times each txid has been announced.""" def __init__(self): super().__init__() self.tx_invs_received = defaultdict(int) def on_inv(self, message): # Send getdata in response. super().on_inv(message) # Store how many times invs have been received for each tx. for i in message.inv: if i.type == MSG_TX: # save txid self.tx_invs_received[i.hash] += 1 def get_invs(self): with mininode_lock: return list(self.tx_invs_received.keys()) def wait_for_broadcast(self, txns, timeout=60): """Waits for the txns (list of txids) to complete initial broadcast. The mempool should mark unbroadcast=False for these transactions. """ # Wait until invs have been received (and getdatas sent) for each txid. self.wait_until(lambda: set(self.get_invs()) == set( [int(tx, 16) for tx in txns]), timeout) # Flush messages and wait for the getdatas to be processed self.sync_with_ping() diff --git a/test/functional/test_runner.py b/test/functional/test_runner.py index 744a21db5..7c0072f21 100755 --- a/test/functional/test_runner.py +++ b/test/functional/test_runner.py @@ -1,899 +1,900 @@ #!/usr/bin/env python3 # Copyright (c) 2014-2019 The Bitcoin Core developers # Copyright (c) 2017 The Bitcoin developers # Distributed under the MIT software license, see the accompanying # file COPYING or http://www.opensource.org/licenses/mit-license.php. """Run regression test suite. This module calls down into individual test cases via subprocess. It will forward all unrecognized arguments onto the individual test scripts. For a description of arguments recognized by test scripts, see `test/functional/test_framework/test_framework.py:BitcoinTestFramework.main`. """ import argparse from collections import deque import configparser import datetime import os import time import shutil import sys import subprocess import tempfile import re import logging import xml.etree.ElementTree as ET import json import threading import multiprocessing from queue import Queue, Empty import unittest # Formatting. Default colors to empty strings. BOLD, GREEN, RED, GREY = ("", ""), ("", ""), ("", ""), ("", "") try: # Make sure python thinks it can write unicode to its stdout "\u2713".encode("utf_8").decode(sys.stdout.encoding) TICK = "✓ " CROSS = "✖ " CIRCLE = "○ " except UnicodeDecodeError: TICK = "P " CROSS = "x " CIRCLE = "o " if os.name != 'nt' or sys.getwindowsversion() >= (10, 0, 14393): # type: ignore if os.name == 'nt': import ctypes kernel32 = ctypes.windll.kernel32 # type: ignore ENABLE_VIRTUAL_TERMINAL_PROCESSING = 4 STD_OUTPUT_HANDLE = -11 STD_ERROR_HANDLE = -12 # Enable ascii color control to stdout stdout = kernel32.GetStdHandle(STD_OUTPUT_HANDLE) stdout_mode = ctypes.c_int32() kernel32.GetConsoleMode(stdout, ctypes.byref(stdout_mode)) kernel32.SetConsoleMode( stdout, stdout_mode.value | ENABLE_VIRTUAL_TERMINAL_PROCESSING) # Enable ascii color control to stderr stderr = kernel32.GetStdHandle(STD_ERROR_HANDLE) stderr_mode = ctypes.c_int32() kernel32.GetConsoleMode(stderr, ctypes.byref(stderr_mode)) kernel32.SetConsoleMode( stderr, stderr_mode.value | ENABLE_VIRTUAL_TERMINAL_PROCESSING) # primitive formatting on supported # terminal via ANSI escape sequences: BOLD = ('\033[0m', '\033[1m') GREEN = ('\033[0m', '\033[0;32m') RED = ('\033[0m', '\033[0;31m') GREY = ('\033[0m', '\033[1;30m') TEST_EXIT_PASSED = 0 TEST_EXIT_SKIPPED = 77 TEST_FRAMEWORK_MODULES = [ "address", "blocktools", + "messages", "script", ] NON_SCRIPTS = [ # These are python files that live in the functional tests directory, but # are not test scripts. "combine_logs.py", "create_cache.py", "test_runner.py", ] TEST_PARAMS = { # Some test can be run with additional parameters. # When a test is listed here, the it will be run without parameters # as well as with additional parameters listed here. # This: # example "testName" : [["--param1", "--param2"] , ["--param3"]] # will run the test 3 times: # testName # testName --param1 --param2 # testname --param3 "rpc_bind.py": [["--ipv4"], ["--ipv6"], ["--nonloopback"]], "rpc_deriveaddresses.py": [["--usecli"]], "wallet_txn_doublespend.py": [["--mineblock"]], "wallet_txn_clone.py": [["--mineblock"]], "wallet_createwallet.py": [["--usecli"]], "wallet_multiwallet.py": [["--usecli"]], "wallet_watchonly.py": [["--usecli"]], } # Used to limit the number of tests, when list of tests is not provided on command line # When --extended is specified, we run all tests, otherwise # we only run a test if its execution time in seconds does not exceed # EXTENDED_CUTOFF DEFAULT_EXTENDED_CUTOFF = 40 DEFAULT_JOBS = (multiprocessing.cpu_count() // 3) + 1 class TestCase(): """ Data structure to hold and run information necessary to launch a test case. """ def __init__(self, test_num, test_case, tests_dir, tmpdir, failfast_event, flags=None): self.tests_dir = tests_dir self.tmpdir = tmpdir self.test_case = test_case self.test_num = test_num self.failfast_event = failfast_event self.flags = flags def run(self, portseed_offset): if self.failfast_event.is_set(): return TestResult(self.test_num, self.test_case, "", "Skipped", 0, "", "") portseed = self.test_num + portseed_offset portseed_arg = ["--portseed={}".format(portseed)] log_stdout = tempfile.SpooledTemporaryFile(max_size=2**16) log_stderr = tempfile.SpooledTemporaryFile(max_size=2**16) test_argv = self.test_case.split() testdir = os.path.join("{}", "{}_{}").format( self.tmpdir, re.sub(".py$", "", test_argv[0]), portseed) tmpdir_arg = ["--tmpdir={}".format(testdir)] start_time = time.time() process = subprocess.Popen([sys.executable, os.path.join(self.tests_dir, test_argv[0])] + test_argv[1:] + self.flags + portseed_arg + tmpdir_arg, universal_newlines=True, stdout=log_stdout, stderr=log_stderr) process.wait() log_stdout.seek(0), log_stderr.seek(0) [stdout, stderr] = [log.read().decode('utf-8') for log in (log_stdout, log_stderr)] log_stdout.close(), log_stderr.close() if process.returncode == TEST_EXIT_PASSED and stderr == "": status = "Passed" elif process.returncode == TEST_EXIT_SKIPPED: status = "Skipped" else: status = "Failed" return TestResult(self.test_num, self.test_case, testdir, status, time.time() - start_time, stdout, stderr) def on_ci(): return os.getenv('TRAVIS') == 'true' or os.getenv( 'TEAMCITY_VERSION') is not None def main(): # Read config generated by configure. config = configparser.ConfigParser() configfile = os.path.join(os.path.abspath( os.path.dirname(__file__)), "..", "config.ini") config.read_file(open(configfile, encoding="utf8")) src_dir = config["environment"]["SRCDIR"] build_dir = config["environment"]["BUILDDIR"] tests_dir = os.path.join(src_dir, 'test', 'functional') # SRCDIR must be set for cdefs.py to find and parse consensus.h os.environ["SRCDIR"] = src_dir # Parse arguments and pass through unrecognised args parser = argparse.ArgumentParser(add_help=False, usage='%(prog)s [test_runner.py options] [script options] [scripts]', description=__doc__, epilog=''' Help text and arguments for individual test script:''', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--combinedlogslen', '-c', type=int, default=0, metavar='n', help='On failure, print a log (of length n lines) to ' 'the console, combined from the test framework ' 'and all test nodes.') parser.add_argument('--coverage', action='store_true', help='generate a basic coverage report for the RPC interface') parser.add_argument( '--exclude', '-x', help='specify a comma-separated-list of scripts to exclude.') parser.add_argument('--extended', action='store_true', help='run the extended test suite in addition to the basic tests') parser.add_argument('--cutoff', type=int, default=DEFAULT_EXTENDED_CUTOFF, help='set the cutoff runtime for what tests get run') parser.add_argument('--help', '-h', '-?', action='store_true', help='print help text and exit') parser.add_argument('--jobs', '-j', type=int, default=DEFAULT_JOBS, help='how many test scripts to run in parallel.') parser.add_argument('--keepcache', '-k', action='store_true', help='the default behavior is to flush the cache directory on startup. --keepcache retains the cache from the previous testrun.') parser.add_argument('--quiet', '-q', action='store_true', help='only print results summary and failure logs') parser.add_argument('--tmpdirprefix', '-t', default=os.path.join(build_dir, 'test', 'tmp'), help="Root directory for datadirs") parser.add_argument( '--failfast', action='store_true', help='stop execution after the first test failure') parser.add_argument('--junitoutput', '-J', help="File that will store JUnit formatted test results. If no absolute path is given it is treated as relative to the temporary directory.") parser.add_argument('--testsuitename', '-n', default='Bitcoin ABC functional tests', help="Name of the test suite, as it will appear in the logs and in the JUnit report.") args, unknown_args = parser.parse_known_args() # args to be passed on always start with two dashes; tests are the # remaining unknown args tests = [arg for arg in unknown_args if arg[:2] != "--"] passon_args = [arg for arg in unknown_args if arg[:2] == "--"] passon_args.append("--configfile={}".format(configfile)) # Set up logging logging_level = logging.INFO if args.quiet else logging.DEBUG logging.basicConfig(format='%(message)s', level=logging_level) logging.info("Starting {}".format(args.testsuitename)) # Create base test directory tmpdir = os.path.join("{}", "test_runner_₿₵_🏃_{:%Y%m%d_%H%M%S}").format( args.tmpdirprefix, datetime.datetime.now()) os.makedirs(tmpdir) logging.debug("Temporary test directory at {}".format(tmpdir)) if args.junitoutput and not os.path.isabs(args.junitoutput): args.junitoutput = os.path.join(tmpdir, args.junitoutput) enable_bitcoind = config["components"].getboolean("ENABLE_BITCOIND") if not enable_bitcoind: print("No functional tests to run.") print("Rerun ./configure with --with-daemon and then make") sys.exit(0) # Build list of tests all_scripts = get_all_scripts_from_disk(tests_dir, NON_SCRIPTS) # Check all tests with parameters actually exist for test in TEST_PARAMS: if test not in all_scripts: print("ERROR: Test with parameter {} does not exist, check it has " "not been renamed or deleted".format(test)) sys.exit(1) if tests: # Individual tests have been specified. Run specified tests that exist # in the all_scripts list. Accept the name with or without .py # extension. individual_tests = [ re.sub(r"\.py$", "", test) + ".py" for test in tests if not test.endswith('*')] test_list = [] for test in individual_tests: if test in all_scripts: test_list.append(test) else: print("{}WARNING!{} Test '{}' not found in full test list.".format( BOLD[1], BOLD[0], test)) # Allow for wildcard at the end of the name, so a single input can # match multiple tests for test in tests: if test.endswith('*'): test_list.extend( [t for t in all_scripts if t.startswith(test[:-1])]) # do not cut off explicitly specified tests cutoff = sys.maxsize else: # Run base tests only test_list = all_scripts cutoff = sys.maxsize if args.extended else args.cutoff # Remove the test cases that the user has explicitly asked to exclude. if args.exclude: exclude_tests = [re.sub(r"\.py$", "", test) + (".py" if ".py" not in test else "") for test in args.exclude.split(',')] for exclude_test in exclude_tests: if exclude_test in test_list: test_list.remove(exclude_test) else: print("{}WARNING!{} Test '{}' not found in current test list.".format( BOLD[1], BOLD[0], exclude_test)) # Update timings from build_dir only if separate build directory is used. # We do not want to pollute source directory. build_timings = None if (src_dir != build_dir): build_timings = Timings(os.path.join(build_dir, 'timing.json')) # Always use timings from scr_dir if present src_timings = Timings(os.path.join( src_dir, "test", "functional", 'timing.json')) # Add test parameters and remove long running tests if needed test_list = get_tests_to_run( test_list, TEST_PARAMS, cutoff, src_timings) if not test_list: print("No valid test scripts specified. Check that your test is in one " "of the test lists in test_runner.py, or run test_runner.py with no arguments to run all tests") sys.exit(0) if args.help: # Print help for test_runner.py, then print help of the first script # and exit. parser.print_help() subprocess.check_call( [sys.executable, os.path.join(tests_dir, test_list[0]), '-h']) sys.exit(0) check_script_prefixes(all_scripts) if not args.keepcache: shutil.rmtree(os.path.join(build_dir, "test", "cache"), ignore_errors=True) run_tests( test_list, build_dir, tests_dir, args.junitoutput, tmpdir, num_jobs=args.jobs, test_suite_name=args.testsuitename, enable_coverage=args.coverage, args=passon_args, combined_logs_len=args.combinedlogslen, build_timings=build_timings, failfast=args.failfast ) def run_tests(test_list, build_dir, tests_dir, junitoutput, tmpdir, num_jobs, test_suite_name, enable_coverage=False, args=None, combined_logs_len=0, build_timings=None, failfast=False): args = args or [] # Warn if bitcoind is already running (unix only) try: pidofOutput = subprocess.check_output(["pidof", "bitcoind"]) if pidofOutput is not None and pidofOutput != b'': print("{}WARNING!{} There is already a bitcoind process running on this system. Tests may fail unexpectedly due to resource contention!".format( BOLD[1], BOLD[0])) except (OSError, subprocess.SubprocessError): pass # Warn if there is a cache directory cache_dir = os.path.join(build_dir, "test", "cache") if os.path.isdir(cache_dir): print("{}WARNING!{} There is a cache directory here: {}. If tests fail unexpectedly, try deleting the cache directory.".format( BOLD[1], BOLD[0], cache_dir)) # Test Framework Tests print("Running Unit Tests for Test Framework Modules") test_framework_tests = unittest.TestSuite() for module in TEST_FRAMEWORK_MODULES: test_framework_tests.addTest( unittest.TestLoader().loadTestsFromName( "test_framework.{}".format(module))) result = unittest.TextTestRunner( verbosity=1, failfast=True).run(test_framework_tests) if not result.wasSuccessful(): logging.debug( "Early exiting after failure in TestFramework unit tests") sys.exit(False) flags = ['--cachedir={}'.format(cache_dir)] + args if enable_coverage: coverage = RPCCoverage() flags.append(coverage.flag) logging.debug( "Initializing coverage directory at {}".format(coverage.dir)) else: coverage = None if len(test_list) > 1 and num_jobs > 1: # Populate cache try: subprocess.check_output([sys.executable, os.path.join( tests_dir, 'create_cache.py')] + flags + [os.path.join("--tmpdir={}", "cache") .format(tmpdir)]) except subprocess.CalledProcessError as e: sys.stdout.buffer.write(e.output) raise # Run Tests start_time = time.time() test_results = execute_test_processes( num_jobs, test_list, tests_dir, tmpdir, flags, failfast) runtime = time.time() - start_time max_len_name = len(max(test_list, key=len)) print_results(test_results, tests_dir, max_len_name, runtime, combined_logs_len) if junitoutput is not None: save_results_as_junit( test_results, junitoutput, runtime, test_suite_name) if (build_timings is not None): build_timings.save_timings(test_results) if coverage: coverage_passed = coverage.report_rpc_coverage() logging.debug("Cleaning up coverage data") coverage.cleanup() else: coverage_passed = True # Clear up the temp directory if all subdirectories are gone if not os.listdir(tmpdir): os.rmdir(tmpdir) all_passed = all(map( lambda test_result: test_result.was_successful, test_results)) and coverage_passed sys.exit(not all_passed) def execute_test_processes( num_jobs, test_list, tests_dir, tmpdir, flags, failfast=False): update_queue = Queue() job_queue = Queue() failfast_event = threading.Event() test_results = [] poll_timeout = 10 # seconds # In case there is a graveyard of zombie bitcoinds, we can apply a # pseudorandom offset to hopefully jump over them. # (625 is PORT_RANGE/MAX_NODES) portseed_offset = int(time.time() * 1000) % 625 ## # Define some helper functions we will need for threading. ## def handle_message(message, running_jobs): """ handle_message handles a single message from handle_test_cases """ if isinstance(message, TestCase): running_jobs.append((message.test_num, message.test_case)) print("{}{}{} started".format(BOLD[1], message.test_case, BOLD[0])) return if isinstance(message, TestResult): test_result = message running_jobs.remove((test_result.num, test_result.name)) test_results.append(test_result) if test_result.status == "Passed": print("{}{}{} passed, Duration: {} s".format( BOLD[1], test_result.name, BOLD[0], TimeResolution.seconds(test_result.time))) elif test_result.status == "Skipped": print("{}{}{} skipped".format( BOLD[1], test_result.name, BOLD[0])) else: print("{}{}{} failed, Duration: {} s\n".format( BOLD[1], test_result.name, BOLD[0], TimeResolution.seconds(test_result.time))) print(BOLD[1] + 'stdout:' + BOLD[0]) print(test_result.stdout) print(BOLD[1] + 'stderr:' + BOLD[0]) print(test_result.stderr) if failfast: logging.debug("Early exiting after test failure") failfast_event.set() return assert False, "we should not be here" def handle_update_messages(): """ handle_update_messages waits for messages to be sent from handle_test_cases via the update_queue. It serializes the results so we can print nice status update messages. """ printed_status = False running_jobs = [] while True: message = None try: message = update_queue.get(True, poll_timeout) if message is None: break # We printed a status message, need to kick to the next line # before printing more. if printed_status: print() printed_status = False handle_message(message, running_jobs) update_queue.task_done() except Empty: if not on_ci(): print("Running jobs: {}".format( ", ".join([j[1] for j in running_jobs])), end="\r") sys.stdout.flush() printed_status = True def handle_test_cases(): """ job_runner represents a single thread that is part of a worker pool. It waits for a test, then executes that test. It also reports start and result messages to handle_update_messages """ while True: test = job_queue.get() if test is None: break # Signal that the test is starting to inform the poor waiting # programmer update_queue.put(test) result = test.run(portseed_offset) update_queue.put(result) job_queue.task_done() ## # Setup our threads, and start sending tasks ## # Start our result collection thread. resultCollector = threading.Thread(target=handle_update_messages) resultCollector.daemon = True resultCollector.start() # Start some worker threads for job in range(num_jobs): t = threading.Thread(target=handle_test_cases) t.daemon = True t.start() # Push all our test cases into the job queue. for i, t in enumerate(test_list): job_queue.put(TestCase(i, t, tests_dir, tmpdir, failfast_event, flags)) # Wait for all the jobs to be completed job_queue.join() # Wait for all the results to be compiled update_queue.join() # Flush our queues so the threads exit update_queue.put(None) for job in range(num_jobs): job_queue.put(None) return test_results def print_results(test_results, tests_dir, max_len_name, runtime, combined_logs_len): results = "\n" + BOLD[1] + "{} | {} | {}\n\n".format( "TEST".ljust(max_len_name), "STATUS ", "DURATION") + BOLD[0] test_results.sort(key=TestResult.sort_key) all_passed = True time_sum = 0 for test_result in test_results: all_passed = all_passed and test_result.was_successful time_sum += test_result.time test_result.padding = max_len_name results += str(test_result) testdir = test_result.testdir if combined_logs_len and os.path.isdir(testdir): # Print the final `combinedlogslen` lines of the combined logs print('{}Combine the logs and print the last {} lines ...{}'.format( BOLD[1], combined_logs_len, BOLD[0])) print('\n============') print('{}Combined log for {}:{}'.format(BOLD[1], testdir, BOLD[0])) print('============\n') combined_logs_args = [ sys.executable, os.path.join( tests_dir, 'combine_logs.py'), testdir] if BOLD[0]: combined_logs_args += ['--color'] combined_logs, _ = subprocess.Popen( combined_logs_args, universal_newlines=True, stdout=subprocess.PIPE).communicate() print( "\n".join( deque( combined_logs.splitlines(), combined_logs_len))) status = TICK + "Passed" if all_passed else CROSS + "Failed" if not all_passed: results += RED[1] results += BOLD[1] + "\n{} | {} | {} s (accumulated) \n".format( "ALL".ljust(max_len_name), status.ljust(9), TimeResolution.seconds(time_sum)) + BOLD[0] if not all_passed: results += RED[0] results += "Runtime: {} s\n".format(TimeResolution.seconds(runtime)) print(results) class TestResult(): """ Simple data structure to store test result values and print them properly """ def __init__(self, num, name, testdir, status, time, stdout, stderr): self.num = num self.name = name self.testdir = testdir self.status = status self.time = time self.padding = 0 self.stdout = stdout self.stderr = stderr def sort_key(self): if self.status == "Passed": return 0, self.name.lower() elif self.status == "Failed": return 2, self.name.lower() elif self.status == "Skipped": return 1, self.name.lower() def __repr__(self): if self.status == "Passed": color = GREEN glyph = TICK elif self.status == "Failed": color = RED glyph = CROSS elif self.status == "Skipped": color = GREY glyph = CIRCLE return color[1] + "{} | {}{} | {} s\n".format( self.name.ljust(self.padding), glyph, self.status.ljust(7), TimeResolution.seconds(self.time)) + color[0] @property def was_successful(self): return self.status != "Failed" def get_all_scripts_from_disk(test_dir, non_scripts): """ Return all available test script from script directory (excluding NON_SCRIPTS) """ python_files = set([t for t in os.listdir(test_dir) if t[-3:] == ".py"]) return list(python_files - set(non_scripts)) def check_script_prefixes(all_scripts): """Check that no more than `EXPECTED_VIOLATION_COUNT` of the test scripts don't start with one of the allowed name prefixes.""" EXPECTED_VIOLATION_COUNT = 16 # LEEWAY is provided as a transition measure, so that pull-requests # that introduce new tests that don't conform with the naming # convention don't immediately cause the tests to fail. LEEWAY = 0 good_prefixes_re = re.compile( "(abc_)?(example|feature|interface|mempool|mining|p2p|rpc|wallet|tool)_") bad_script_names = [ script for script in all_scripts if good_prefixes_re.match(script) is None] if len(bad_script_names) < EXPECTED_VIOLATION_COUNT: print( "{}HURRAY!{} Number of functional tests violating naming convention reduced!".format( BOLD[1], BOLD[0])) print("Consider reducing EXPECTED_VIOLATION_COUNT from {} to {}".format( EXPECTED_VIOLATION_COUNT, len(bad_script_names))) elif len(bad_script_names) > EXPECTED_VIOLATION_COUNT: print( "INFO: {} tests not meeting naming conventions (expected {}):".format(len(bad_script_names), EXPECTED_VIOLATION_COUNT)) print(" {}".format("\n ".join(sorted(bad_script_names)))) assert len(bad_script_names) <= EXPECTED_VIOLATION_COUNT + \ LEEWAY, "Too many tests not following naming convention! ({} found, expected: <= {})".format( len(bad_script_names), EXPECTED_VIOLATION_COUNT) def get_tests_to_run(test_list, test_params, cutoff, src_timings): """ Returns only test that will not run longer that cutoff. Long running tests are returned first to favor running tests in parallel Timings from build directory override those from src directory """ def get_test_time(test): # Return 0 if test is unknown to always run it return next( (x['time'] for x in src_timings.existing_timings if x['name'] == test), 0) # Some tests must also be run with additional parameters. Add them to the # list. tests_with_params = [] for test_name in test_list: # always execute a test without parameters tests_with_params.append(test_name) params = test_params.get(test_name) if params is not None: tests_with_params.extend( [test_name + " " + " ".join(parameter) for parameter in params]) result = [ test for test in tests_with_params if get_test_time(test) <= cutoff] result.sort(key=lambda x: (-get_test_time(x), x)) return result class RPCCoverage(): """ Coverage reporting utilities for test_runner. Coverage calculation works by having each test script subprocess write coverage files into a particular directory. These files contain the RPC commands invoked during testing, as well as a complete listing of RPC commands per `bitcoin-cli help` (`rpc_interface.txt`). After all tests complete, the commands run are combined and diff'd against the complete list to calculate uncovered RPC commands. See also: test/functional/test_framework/coverage.py """ def __init__(self): self.dir = tempfile.mkdtemp(prefix="coverage") self.flag = '--coveragedir={}'.format(self.dir) def report_rpc_coverage(self): """ Print out RPC commands that were unexercised by tests. """ uncovered = self._get_uncovered_rpc_commands() if uncovered: print("Uncovered RPC commands:") print("".join((" - {}\n".format(i)) for i in sorted(uncovered))) return False else: print("All RPC commands covered.") return True def cleanup(self): return shutil.rmtree(self.dir) def _get_uncovered_rpc_commands(self): """ Return a set of currently untested RPC commands. """ # This is shared from `test/functional/test-framework/coverage.py` reference_filename = 'rpc_interface.txt' coverage_file_prefix = 'coverage.' coverage_ref_filename = os.path.join(self.dir, reference_filename) coverage_filenames = set() all_cmds = set() covered_cmds = set() if not os.path.isfile(coverage_ref_filename): raise RuntimeError("No coverage reference found") with open(coverage_ref_filename, 'r', encoding="utf8") as file: all_cmds.update([line.strip() for line in file.readlines()]) for root, _, files in os.walk(self.dir): for filename in files: if filename.startswith(coverage_file_prefix): coverage_filenames.add(os.path.join(root, filename)) for filename in coverage_filenames: with open(filename, 'r', encoding="utf8") as file: covered_cmds.update([line.strip() for line in file.readlines()]) return all_cmds - covered_cmds def save_results_as_junit(test_results, file_name, time, test_suite_name): """ Save tests results to file in JUnit format See http://llg.cubic.org/docs/junit/ for specification of format """ e_test_suite = ET.Element("testsuite", {"name": "{}".format(test_suite_name), "tests": str(len(test_results)), # "errors": "failures": str(len([t for t in test_results if t.status == "Failed"])), "id": "0", "skipped": str(len([t for t in test_results if t.status == "Skipped"])), "time": str(TimeResolution.milliseconds(time)), "timestamp": datetime.datetime.now().isoformat('T') }) for test_result in test_results: e_test_case = ET.SubElement(e_test_suite, "testcase", {"name": test_result.name, "classname": test_result.name, "time": str(TimeResolution.milliseconds(test_result.time)) } ) if test_result.status == "Skipped": ET.SubElement(e_test_case, "skipped") elif test_result.status == "Failed": ET.SubElement(e_test_case, "failure") # no special element for passed tests ET.SubElement(e_test_case, "system-out").text = test_result.stdout ET.SubElement(e_test_case, "system-err").text = test_result.stderr ET.ElementTree(e_test_suite).write( file_name, "UTF-8", xml_declaration=True) class Timings(): """ Takes care of loading, merging and saving tests execution times. """ def __init__(self, timing_file): self.timing_file = timing_file self.existing_timings = self.load_timings() def load_timings(self): if os.path.isfile(self.timing_file): with open(self.timing_file, encoding="utf8") as file: return json.load(file) else: return [] def get_merged_timings(self, new_timings): """ Return new list containing existing timings updated with new timings Tests that do not exists are not removed """ key = 'name' merged = {} for item in self.existing_timings + new_timings: if item[key] in merged: merged[item[key]].update(item) else: merged[item[key]] = item # Sort the result to preserve test ordering in file merged = list(merged.values()) merged.sort(key=lambda t, key=key: t[key]) return merged def save_timings(self, test_results): # we only save test that have passed - timings for failed test might be # wrong (timeouts or early fails) passed_results = [ test for test in test_results if test.status == 'Passed'] new_timings = list(map(lambda test: {'name': test.name, 'time': TimeResolution.seconds(test.time)}, passed_results)) merged_timings = self.get_merged_timings(new_timings) with open(self.timing_file, 'w', encoding="utf8") as file: json.dump(merged_timings, file, indent=True) class TimeResolution: @staticmethod def seconds(time_fractional_second): return round(time_fractional_second) @staticmethod def milliseconds(time_fractional_second): return round(time_fractional_second, 3) if __name__ == '__main__': main()