diff --git a/src/script/descriptor.h b/src/script/descriptor.h --- a/src/script/descriptor.h +++ b/src/script/descriptor.h @@ -117,25 +117,25 @@ * @param[out] output_scripts: The expanded scriptPubKeys. * @param[out] out: Scripts and public keys necessary for solving the * expanded scriptPubKeys (may be equal to `provider`). - * @param[out] cache: Cache data necessary to evaluate the descriptor at - * this point without access to private keys. + * @param[out] write_cache: Cache data necessary to evaluate the descriptor + * at this point without access to private keys. */ virtual bool Expand(int pos, const SigningProvider &provider, std::vector &output_scripts, FlatSigningProvider &out, - std::vector *cache = nullptr) const = 0; + DescriptorCache *write_cache = nullptr) const = 0; /** * Expand a descriptor at a specified position using cached expansion data. * * @param[in] pos: The position at which to expand the descriptor. If * IsRange() is false, this is ignored. - * @param[in] cache: Cached expansion data. + * @param[in] read_cache: Cached expansion data. * @param[out] output_scripts: The expanded scriptPubKeys. * @param[out] out: Scripts and public keys necessary for solving the * expanded scriptPubKeys (may be equal to `provider`). */ - virtual bool ExpandFromCache(int pos, const std::vector &cache, + virtual bool ExpandFromCache(int pos, const DescriptorCache &read_cache, std::vector &output_scripts, FlatSigningProvider &out) const = 0; diff --git a/src/script/descriptor.cpp b/src/script/descriptor.cpp --- a/src/script/descriptor.cpp +++ b/src/script/descriptor.cpp @@ -190,9 +190,17 @@ virtual ~PubkeyProvider() = default; - /** Derive a public key. If key==nullptr, only info is desired. */ - virtual bool GetPubKey(int pos, const SigningProvider &arg, CPubKey *key, - KeyOriginInfo &info) const = 0; + /** + * Derive a public key. + * read_cache is the cache to read keys from (if not nullptr) + * write_cache is the cache to write keys to (if not nullptr) + * Caches are not exclusive but this is not tested. Currently we use them + * exclusively + */ + virtual bool GetPubKey(int pos, const SigningProvider &arg, CPubKey &key, + KeyOriginInfo &info, + const DescriptorCache *read_cache = nullptr, + DescriptorCache *write_cache = nullptr) const = 0; /** Whether this represent multiple public keys at different positions. */ virtual bool IsRange() const = 0; @@ -230,9 +238,12 @@ std::unique_ptr provider) : PubkeyProvider(exp_index), m_origin(std::move(info)), m_provider(std::move(provider)) {} - bool GetPubKey(int pos, const SigningProvider &arg, CPubKey *key, - KeyOriginInfo &info) const override { - if (!m_provider->GetPubKey(pos, arg, key, info)) { + bool GetPubKey(int pos, const SigningProvider &arg, CPubKey &key, + KeyOriginInfo &info, + const DescriptorCache *read_cache = nullptr, + DescriptorCache *write_cache = nullptr) const override { + if (!m_provider->GetPubKey(pos, arg, key, info, read_cache, + write_cache)) { return false; } std::copy(std::begin(m_origin.fingerprint), @@ -268,11 +279,11 @@ public: ConstPubkeyProvider(uint32_t exp_index, const CPubKey &pubkey) : PubkeyProvider(exp_index), m_pubkey(pubkey) {} - bool GetPubKey(int pos, const SigningProvider &arg, CPubKey *key, - KeyOriginInfo &info) const override { - if (key) { - *key = m_pubkey; - } + bool GetPubKey(int pos, const SigningProvider &arg, CPubKey &key, + KeyOriginInfo &info, + const DescriptorCache *read_cache = nullptr, + DescriptorCache *write_cache = nullptr) const override { + key = m_pubkey; info.path.clear(); CKeyID keyid = m_pubkey.GetID(); std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), @@ -327,6 +338,17 @@ return true; } + // Derives the last xprv + bool GetDerivedExtKey(const SigningProvider &arg, CExtKey &xprv) const { + if (!GetExtKey(arg, xprv)) { + return false; + } + for (auto entry : m_path) { + xprv.Derive(xprv, entry); + } + return true; + } + bool IsHardened() const { if (m_derive == DeriveType::HARDENED) { return true; @@ -346,38 +368,67 @@ m_path(std::move(path)), m_derive(derive) {} bool IsRange() const override { return m_derive != DeriveType::NO; } size_t GetSize() const override { return 33; } - bool GetPubKey(int pos, const SigningProvider &arg, CPubKey *key, - KeyOriginInfo &info) const override { - if (key) { - if (IsHardened()) { - CKey priv_key; - if (!GetPrivKey(pos, arg, priv_key)) { - return false; - } - *key = priv_key.GetPubKey(); - } else { - // TODO: optimize by caching - CExtPubKey extkey = m_root_extkey; - for (auto entry : m_path) { - extkey.Derive(extkey, entry); - } - if (m_derive == DeriveType::UNHARDENED) { - extkey.Derive(extkey, pos); - } - assert(m_derive != DeriveType::HARDENED); - *key = extkey.pubkey; - } - } + bool GetPubKey(int pos, const SigningProvider &arg, CPubKey &key_out, + KeyOriginInfo &final_info_out, + const DescriptorCache *read_cache = nullptr, + DescriptorCache *write_cache = nullptr) const override { + // Info of parent of the to be derived pubkey + KeyOriginInfo parent_info; CKeyID keyid = m_root_extkey.pubkey.GetID(); - std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), - info.fingerprint); - info.path = m_path; + std::copy(keyid.begin(), + keyid.begin() + sizeof(parent_info.fingerprint), + parent_info.fingerprint); + parent_info.path = m_path; + + // Info of the derived key itself which is copied out upon successful + // completion + KeyOriginInfo final_info_out_tmp = parent_info; if (m_derive == DeriveType::UNHARDENED) { - info.path.push_back(uint32_t(pos)); + final_info_out_tmp.path.push_back((uint32_t)pos); } if (m_derive == DeriveType::HARDENED) { - info.path.push_back(uint32_t(pos) | 0x80000000L); + final_info_out_tmp.path.push_back(((uint32_t)pos) | 0x80000000L); } + + // Derive keys or fetch them from cache + CExtPubKey final_extkey = m_root_extkey; + bool der = true; + if (read_cache) { + if (!read_cache->GetCachedDerivedExtPubKey(m_expr_index, pos, + final_extkey)) { + return false; + } + } else if (IsHardened()) { + CExtKey xprv; + if (!GetDerivedExtKey(arg, xprv)) { + return false; + } + if (m_derive == DeriveType::UNHARDENED) { + der = xprv.Derive(xprv, pos); + } + if (m_derive == DeriveType::HARDENED) { + der = xprv.Derive(xprv, pos | 0x80000000UL); + } + final_extkey = xprv.Neuter(); + } else { + for (auto entry : m_path) { + der = final_extkey.Derive(final_extkey, entry); + assert(der); + } + if (m_derive == DeriveType::UNHARDENED) { + der = final_extkey.Derive(final_extkey, pos); + } + assert(m_derive != DeriveType::HARDENED); + } + assert(der); + + final_info_out = final_info_out_tmp; + key_out = final_extkey.pubkey; + + if (write_cache) { + write_cache->CacheDerivedExtPubKey(m_expr_index, pos, final_extkey); + } + return true; } std::string ToString() const override { @@ -409,12 +460,9 @@ bool GetPrivKey(int pos, const SigningProvider &arg, CKey &key) const override { CExtKey extkey; - if (!GetExtKey(arg, extkey)) { + if (!GetDerivedExtKey(arg, extkey)) { return false; } - for (auto entry : m_path) { - extkey.Derive(extkey, entry); - } if (m_derive == DeriveType::UNHARDENED) { extkey.Derive(extkey, pos); } @@ -540,10 +588,10 @@ } bool ExpandHelper(int pos, const SigningProvider &arg, - Span *cache_read, + const DescriptorCache *read_cache, std::vector &output_scripts, FlatSigningProvider &out, - std::vector *cache_write) const { + DescriptorCache *write_cache) const { std::vector> entries; entries.reserve(m_pubkey_args.size()); @@ -551,43 +599,17 @@ // producing output in case of failure. for (const auto &p : m_pubkey_args) { entries.emplace_back(); - // If we have a cache, we don't need GetPubKey to compute the public - // key. Pass in nullptr to signify only origin info is desired. - if (!p->GetPubKey(pos, arg, - cache_read ? nullptr : &entries.back().first, - entries.back().second)) { + if (!p->GetPubKey(pos, arg, entries.back().first, + entries.back().second, read_cache, write_cache)) { return false; } - if (cache_read) { - // Cached expanded public key exists, use it. - if (cache_read->size() == 0) { - return false; - } - bool compressed = - ((*cache_read)[0] == 0x02 || (*cache_read)[0] == 0x03) && - cache_read->size() >= 33; - bool uncompressed = - ((*cache_read)[0] == 0x04) && cache_read->size() >= 65; - if (!(compressed || uncompressed)) { - return false; - } - CPubKey pubkey(cache_read->begin(), - cache_read->begin() + (compressed ? 33 : 65)); - entries.back().first = pubkey; - *cache_read = cache_read->subspan(compressed ? 33 : 65); - } - if (cache_write) { - cache_write->insert(cache_write->end(), - entries.back().first.begin(), - entries.back().first.end()); - } } std::vector subscripts; if (m_subdescriptor_arg) { FlatSigningProvider subprovider; - if (!m_subdescriptor_arg->ExpandHelper(pos, arg, cache_read, + if (!m_subdescriptor_arg->ExpandHelper(pos, arg, read_cache, subscripts, subprovider, - cache_write)) { + write_cache)) { return false; } out = Merge(out, subprovider); @@ -619,17 +641,16 @@ bool Expand(int pos, const SigningProvider &provider, std::vector &output_scripts, FlatSigningProvider &out, - std::vector *cache = nullptr) const final { - return ExpandHelper(pos, provider, nullptr, output_scripts, out, cache); + DescriptorCache *write_cache = nullptr) const final { + return ExpandHelper(pos, provider, nullptr, output_scripts, out, + write_cache); } - bool ExpandFromCache(int pos, const std::vector &cache, + bool ExpandFromCache(int pos, const DescriptorCache &read_cache, std::vector &output_scripts, FlatSigningProvider &out) const final { - Span span = MakeSpan(cache); - return ExpandHelper(pos, DUMMY_SIGNING_PROVIDER, &span, output_scripts, - out, nullptr) && - span.size() == 0; + return ExpandHelper(pos, DUMMY_SIGNING_PROVIDER, &read_cache, + output_scripts, out, nullptr); } void ExpandPrivate(int pos, const SigningProvider &provider, diff --git a/src/test/descriptor_tests.cpp b/src/test/descriptor_tests.cpp --- a/src/test/descriptor_tests.cpp +++ b/src/test/descriptor_tests.cpp @@ -160,16 +160,16 @@ // Evaluate the descriptor selected by `t` in position `i`. FlatSigningProvider script_provider, script_provider_cached; std::vector spks, spks_cached; - std::vector cache; - BOOST_CHECK( - (t ? parse_priv : parse_pub) - ->Expand(i, key_provider, spks, script_provider, &cache)); + DescriptorCache desc_cache; + BOOST_CHECK((t ? parse_priv : parse_pub) + ->Expand(i, key_provider, spks, script_provider, + &desc_cache)); // Compare the output with the expected result. BOOST_CHECK_EQUAL(spks.size(), ref.size()); // Try to expand again using cached data, and compare. - BOOST_CHECK(parse_pub->ExpandFromCache(i, cache, spks_cached, + BOOST_CHECK(parse_pub->ExpandFromCache(i, desc_cache, spks_cached, script_provider_cached)); BOOST_CHECK(spks == spks_cached); BOOST_CHECK(script_provider.pubkeys ==