Changeset View
Changeset View
Standalone View
Standalone View
test/functional/wallet_import_rescan.py
Show All 20 Lines | |||||
import collections | import collections | ||||
import enum | import enum | ||||
import itertools | import itertools | ||||
from test_framework.test_framework import BitcoinTestFramework | from test_framework.test_framework import BitcoinTestFramework | ||||
from test_framework.util import ( | from test_framework.util import ( | ||||
assert_equal, | assert_equal, | ||||
assert_raises_rpc_error, | |||||
connect_nodes, | connect_nodes, | ||||
set_node_times, | set_node_times, | ||||
) | ) | ||||
Call = enum.Enum("Call", "single multiaddress multiscript") | Call = enum.Enum("Call", "single multiaddress multiscript") | ||||
Data = enum.Enum("Data", "address pub priv") | Data = enum.Enum("Data", "address pub priv") | ||||
Rescan = enum.Enum("Rescan", "no yes late_timestamp") | Rescan = enum.Enum("Rescan", "no yes late_timestamp") | ||||
class Variant(collections.namedtuple("Variant", "call data rescan prune")): | class Variant(collections.namedtuple("Variant", "call data rescan prune")): | ||||
"""Helper for importing one key and verifying scanned transactions.""" | """Helper for importing one key and verifying scanned transactions.""" | ||||
def try_rpc(self, func, *args, **kwargs): | |||||
if self.expect_disabled: | |||||
assert_raises_rpc_error(-4, "Rescan is disabled in pruned mode", | |||||
func, *args, **kwargs) | |||||
else: | |||||
return func(*args, **kwargs) | |||||
def do_import(self, timestamp): | def do_import(self, timestamp): | ||||
"""Call one key import RPC.""" | """Call one key import RPC.""" | ||||
rescan = self.rescan == Rescan.yes | rescan = self.rescan == Rescan.yes | ||||
if self.call == Call.single: | if self.call == Call.single: | ||||
if self.data == Data.address: | if self.data == Data.address: | ||||
response = self.try_rpc( | response = self.node.importaddress( | ||||
self.node.importaddress, | address=self.address["address"], label=self.label, rescan=rescan) | ||||
address=self.address["address"], | |||||
label=self.label, | |||||
rescan=rescan) | |||||
elif self.data == Data.pub: | elif self.data == Data.pub: | ||||
response = self.try_rpc( | response = self.node.importpubkey( | ||||
self.node.importpubkey, | pubkey=self.address["pubkey"], label=self.label, rescan=rescan) | ||||
pubkey=self.address["pubkey"], | |||||
label=self.label, | |||||
rescan=rescan) | |||||
elif self.data == Data.priv: | elif self.data == Data.priv: | ||||
response = self.try_rpc( | response = self.node.importprivkey( | ||||
self.node.importprivkey, | privkey=self.key, label=self.label, rescan=rescan) | ||||
privkey=self.key, | |||||
label=self.label, | |||||
rescan=rescan) | |||||
assert_equal(response, None) | assert_equal(response, None) | ||||
elif self.call in (Call.multiaddress, Call.multiscript): | elif self.call in (Call.multiaddress, Call.multiscript): | ||||
response = self.node.importmulti([{ | response = self.node.importmulti([{ | ||||
"scriptPubKey": { | "scriptPubKey": { | ||||
"address": self.address["address"] | "address": self.address["address"] | ||||
} if self.call == Call.multiaddress else self.address["scriptPubKey"], | } if self.call == Call.multiaddress else self.address["scriptPubKey"], | ||||
"timestamp": timestamp + TIMESTAMP_WINDOW + (1 if self.rescan == Rescan.late_timestamp else 0), | "timestamp": timestamp + TIMESTAMP_WINDOW + (1 if self.rescan == Rescan.late_timestamp else 0), | ||||
▲ Show 20 Lines • Show All 109 Lines • ▼ Show 20 Lines | def run_test(self): | ||||
set_node_times(self.nodes, timestamp + TIMESTAMP_WINDOW + 1) | set_node_times(self.nodes, timestamp + TIMESTAMP_WINDOW + 1) | ||||
self.nodes[0].generate(1) | self.nodes[0].generate(1) | ||||
self.sync_all() | self.sync_all() | ||||
# For each variation of wallet key import, invoke the import RPC and | # For each variation of wallet key import, invoke the import RPC and | ||||
# check the results from getbalance and listtransactions. | # check the results from getbalance and listtransactions. | ||||
for variant in IMPORT_VARIANTS: | for variant in IMPORT_VARIANTS: | ||||
self.log.info('Run import for variant {}'.format(variant)) | self.log.info('Run import for variant {}'.format(variant)) | ||||
variant.expect_disabled = variant.rescan == Rescan.yes and variant.prune and variant.call == Call.single | expect_rescan = variant.rescan == Rescan.yes | ||||
expect_rescan = variant.rescan == Rescan.yes and not variant.expect_disabled | variant.node = self.nodes[2 + IMPORT_NODES.index( | ||||
variant.node = self.nodes[ | ImportNode(variant.prune, expect_rescan))] | ||||
2 + IMPORT_NODES.index(ImportNode(variant.prune, expect_rescan))] | |||||
variant.do_import(timestamp) | variant.do_import(timestamp) | ||||
if expect_rescan: | if expect_rescan: | ||||
variant.expected_balance = variant.initial_amount | variant.expected_balance = variant.initial_amount | ||||
variant.expected_txs = 1 | variant.expected_txs = 1 | ||||
variant.check(variant.initial_txid, variant.initial_amount, 2) | variant.check(variant.initial_txid, variant.initial_amount, 2) | ||||
else: | else: | ||||
variant.expected_balance = 0 | variant.expected_balance = 0 | ||||
variant.expected_txs = 0 | variant.expected_txs = 0 | ||||
variant.check() | variant.check() | ||||
# Create new transactions sending to each address. | # Create new transactions sending to each address. | ||||
for i, variant in enumerate(IMPORT_VARIANTS): | for i, variant in enumerate(IMPORT_VARIANTS): | ||||
variant.sent_amount = 1 - (2 * i + 1) / 128 | variant.sent_amount = 1 - (2 * i + 1) / 128 | ||||
variant.sent_txid = self.nodes[0].sendtoaddress( | variant.sent_txid = self.nodes[0].sendtoaddress( | ||||
variant.address["address"], variant.sent_amount) | variant.address["address"], variant.sent_amount) | ||||
# Generate a block containing the new transactions. | # Generate a block containing the new transactions. | ||||
self.nodes[0].generate(1) | self.nodes[0].generate(1) | ||||
assert_equal(self.nodes[0].getrawmempool(), []) | assert_equal(self.nodes[0].getrawmempool(), []) | ||||
self.sync_all() | self.sync_all() | ||||
# Check the latest results from getbalance and listtransactions. | # Check the latest results from getbalance and listtransactions. | ||||
for variant in IMPORT_VARIANTS: | for variant in IMPORT_VARIANTS: | ||||
self.log.info('Run check for variant {}'.format(variant)) | self.log.info('Run check for variant {}'.format(variant)) | ||||
if not variant.expect_disabled: | |||||
variant.expected_balance += variant.sent_amount | variant.expected_balance += variant.sent_amount | ||||
variant.expected_txs += 1 | variant.expected_txs += 1 | ||||
variant.check(variant.sent_txid, variant.sent_amount, 1) | variant.check(variant.sent_txid, variant.sent_amount, 1) | ||||
else: | |||||
variant.check() | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
ImportRescanTest().main() | ImportRescanTest().main() |