diff --git a/src/bench/gcs_filter.cpp b/src/bench/gcs_filter.cpp --- a/src/bench/gcs_filter.cpp +++ b/src/bench/gcs_filter.cpp @@ -16,7 +16,7 @@ uint64_t siphash_k0 = 0; while (state.KeepRunning()) { - GCSFilter filter(siphash_k0, 0, 20, 1 << 20, elements); + GCSFilter filter({siphash_k0, 0, 20, 1 << 20}, elements); siphash_k0++; } @@ -30,7 +30,7 @@ element[1] = static_cast(i >> 8); elements.insert(std::move(element)); } - GCSFilter filter(0, 0, 20, 1 << 20, elements); + GCSFilter filter({0, 0, 20, 1 << 20}, elements); while (state.KeepRunning()) { filter.Match(GCSFilter::Element()); diff --git a/src/blockfilter.h b/src/blockfilter.h --- a/src/blockfilter.h +++ b/src/blockfilter.h @@ -23,11 +23,20 @@ typedef std::vector Element; typedef std::set ElementSet; + struct Params { + uint64_t m_siphash_k0; + uint64_t m_siphash_k1; + uint8_t m_P; //!< Golomb-Rice coding parameter + uint32_t m_M; //!< Inverse false positive rate + + Params(uint64_t siphash_k0 = 0, uint64_t siphash_k1 = 0, uint8_t P = 0, + uint32_t M = 1) + : m_siphash_k0(siphash_k0), m_siphash_k1(siphash_k1), m_P(P), + m_M(M) {} + }; + private: - uint64_t m_siphash_k0; - uint64_t m_siphash_k1; - uint8_t m_P; //!< Golomb-Rice coding parameter - uint32_t m_M; //!< Inverse false positive rate + Params m_params; uint32_t m_N; //!< Number of elements in the filter uint64_t m_F; //!< Range of element hashes, F = N * M std::vector m_encoded; @@ -43,20 +52,16 @@ public: /** Constructs an empty filter. */ - GCSFilter(uint64_t siphash_k0 = 0, uint64_t siphash_k1 = 0, uint8_t P = 0, - uint32_t M = 0); + explicit GCSFilter(const Params ¶ms = Params()); /** Reconstructs an already-created filter from an encoding. */ - GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, uint32_t M, - std::vector encoded_filter); + GCSFilter(const Params ¶ms, std::vector encoded_filter); /** Builds a new filter from the params and set of elements. */ - GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, uint32_t M, - const ElementSet &elements); + GCSFilter(const Params ¶ms, const ElementSet &elements); - uint8_t GetP() const { return m_P; } uint32_t GetN() const { return m_N; } - uint32_t GetM() const { return m_M; } + const Params &GetParams() const { return m_params; } const std::vector &GetEncoded() const { return m_encoded; } /** @@ -90,23 +95,31 @@ uint256 m_block_hash; GCSFilter m_filter; + bool BuildParams(GCSFilter::Params ¶ms) const; + public: - // Construct a new BlockFilter of the specified type from a block. + BlockFilter() = default; + + //! Reconstruct a BlockFilter from parts. + BlockFilter(BlockFilterType filter_type, const uint256 &block_hash, + std::vector filter); + + //! Construct a new BlockFilter of the specified type from a block. BlockFilter(BlockFilterType filter_type, const CBlock &block, const CBlockUndo &block_undo); BlockFilterType GetFilterType() const { return m_filter_type; } - + const uint256 &GetBlockHash() const { return m_block_hash; } const GCSFilter &GetFilter() const { return m_filter; } const std::vector &GetEncodedFilter() const { return m_filter.GetEncoded(); } - // Compute the filter hash. + //! Compute the filter hash. uint256 GetHash() const; - // Compute the filter header given the previous one. + //! Compute the filter header given the previous one. uint256 ComputeHeader(const uint256 &prev_header) const; template void Serialize(Stream &s) const { @@ -122,16 +135,11 @@ m_filter_type = static_cast(filter_type); - switch (m_filter_type) { - case BlockFilterType::BASIC: - m_filter = GCSFilter(m_block_hash.GetUint64(0), - m_block_hash.GetUint64(1), BASIC_FILTER_P, - BASIC_FILTER_M, std::move(encoded_filter)); - break; - - default: - throw std::ios_base::failure("unknown filter_type"); + GCSFilter::Params params; + if (!BuildParams(params)) { + throw std::ios_base::failure("unknown filter_type"); } + m_filter = GCSFilter(params, std::move(encoded_filter)); } }; diff --git a/src/blockfilter.cpp b/src/blockfilter.cpp --- a/src/blockfilter.cpp +++ b/src/blockfilter.cpp @@ -80,7 +80,7 @@ } uint64_t GCSFilter::HashToRange(const Element &element) const { - uint64_t hash = CSipHasher(m_siphash_k0, m_siphash_k1) + uint64_t hash = CSipHasher(m_params.m_siphash_k0, m_params.m_siphash_k1) .Write(element.data(), element.size()) .Finalize(); return MapIntoRange(hash, m_F); @@ -97,16 +97,11 @@ return hashed_elements; } -GCSFilter::GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, - uint32_t M) - : m_siphash_k0(siphash_k0), m_siphash_k1(siphash_k1), m_P(P), m_M(M), - m_N(0), m_F(0) {} - -GCSFilter::GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, - uint32_t M, std::vector encoded_filter) - : GCSFilter(siphash_k0, siphash_k1, P, M) { - m_encoded = std::move(encoded_filter); +GCSFilter::GCSFilter(const Params ¶ms) + : m_params(params), m_N(0), m_F(0), m_encoded{0} {} +GCSFilter::GCSFilter(const Params ¶ms, std::vector encoded_filter) + : m_params(params), m_encoded(std::move(encoded_filter)) { VectorReader stream(GCS_SER_TYPE, GCS_SER_VERSION, m_encoded, 0); uint64_t N = ReadCompactSize(stream); @@ -114,29 +109,28 @@ if (m_N != N) { throw std::ios_base::failure("N must be <2^32"); } - m_F = static_cast(m_N) * static_cast(m_M); + m_F = static_cast(m_N) * static_cast(m_params.m_M); // Verify that the encoded filter contains exactly N elements. If it has too // much or too little data, a std::ios_base::failure exception will be // raised. BitStreamReader bitreader(stream); for (uint64_t i = 0; i < m_N; ++i) { - GolombRiceDecode(bitreader, m_P); + GolombRiceDecode(bitreader, m_params.m_P); } if (!stream.empty()) { throw std::ios_base::failure("encoded_filter contains excess data"); } } -GCSFilter::GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, - uint32_t M, const ElementSet &elements) - : GCSFilter(siphash_k0, siphash_k1, P, M) { +GCSFilter::GCSFilter(const Params ¶ms, const ElementSet &elements) + : m_params(params) { size_t N = elements.size(); m_N = static_cast(N); if (m_N != N) { throw std::invalid_argument("N must be <2^32"); } - m_F = static_cast(m_N) * static_cast(m_M); + m_F = static_cast(m_N) * static_cast(m_params.m_M); CVectorWriter stream(GCS_SER_TYPE, GCS_SER_VERSION, m_encoded, 0); @@ -151,7 +145,7 @@ uint64_t last_value = 0; for (uint64_t value : BuildHashedSet(elements)) { uint64_t delta = value - last_value; - GolombRiceEncode(bitwriter, m_P, delta); + GolombRiceEncode(bitwriter, m_params.m_P, delta); last_value = value; } @@ -171,7 +165,7 @@ uint64_t value = 0; size_t hashes_index = 0; for (uint32_t i = 0; i < m_N; ++i) { - uint64_t delta = GolombRiceDecode(bitreader, m_P); + uint64_t delta = GolombRiceDecode(bitreader, m_params.m_P); value += delta; while (true) { @@ -227,20 +221,37 @@ return elements; } +BlockFilter::BlockFilter(BlockFilterType filter_type, const uint256 &block_hash, + std::vector filter) + : m_filter_type(filter_type), m_block_hash(block_hash) { + GCSFilter::Params params; + if (!BuildParams(params)) { + throw std::invalid_argument("unknown filter_type"); + } + m_filter = GCSFilter(params, std::move(filter)); +} + BlockFilter::BlockFilter(BlockFilterType filter_type, const CBlock &block, const CBlockUndo &block_undo) : m_filter_type(filter_type), m_block_hash(block.GetHash()) { + GCSFilter::Params params; + if (!BuildParams(params)) { + throw std::invalid_argument("unknown filter_type"); + } + m_filter = GCSFilter(params, BasicFilterElements(block, block_undo)); +} + +bool BlockFilter::BuildParams(GCSFilter::Params ¶ms) const { switch (m_filter_type) { case BlockFilterType::BASIC: - m_filter = - GCSFilter(m_block_hash.GetUint64(0), m_block_hash.GetUint64(1), - BASIC_FILTER_P, BASIC_FILTER_M, - BasicFilterElements(block, block_undo)); - break; - - default: - throw std::invalid_argument("unknown filter_type"); + params.m_siphash_k0 = m_block_hash.GetUint64(0); + params.m_siphash_k1 = m_block_hash.GetUint64(1); + params.m_P = BASIC_FILTER_P; + params.m_M = BASIC_FILTER_M; + return true; } + + return false; } uint256 BlockFilter::GetHash() const { diff --git a/src/test/blockfilter_tests.cpp b/src/test/blockfilter_tests.cpp --- a/src/test/blockfilter_tests.cpp +++ b/src/test/blockfilter_tests.cpp @@ -30,7 +30,7 @@ excluded_elements.insert(std::move(element2)); } - GCSFilter filter(0, 0, 10, 1 << 10, included_elements); + GCSFilter filter({0, 0, 10, 1 << 10}, included_elements); for (const auto &element : included_elements) { BOOST_CHECK(filter.Match(element)); @@ -40,6 +40,18 @@ } } +BOOST_AUTO_TEST_CASE(gcsfilter_default_constructor) { + GCSFilter filter; + BOOST_CHECK_EQUAL(filter.GetN(), 0); + BOOST_CHECK_EQUAL(filter.GetEncoded().size(), 1); + + const GCSFilter::Params ¶ms = filter.GetParams(); + BOOST_CHECK_EQUAL(params.m_siphash_k0, 0); + BOOST_CHECK_EQUAL(params.m_siphash_k1, 0); + BOOST_CHECK_EQUAL(params.m_P, 0); + BOOST_CHECK_EQUAL(params.m_M, 1); +} + BOOST_AUTO_TEST_CASE(blockfilter_basic_test) { CScript included_scripts[5], excluded_scripts[3]; @@ -96,6 +108,19 @@ BOOST_CHECK( !filter.Match(GCSFilter::Element(script.begin(), script.end()))); } + + // Test serialization/unserialization. + BlockFilter block_filter2; + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << block_filter; + stream >> block_filter2; + + BOOST_CHECK_EQUAL(block_filter.GetFilterType(), + block_filter2.GetFilterType()); + BOOST_CHECK(block_filter.GetBlockHash() == block_filter2.GetBlockHash()); + BOOST_CHECK(block_filter.GetEncodedFilter() == + block_filter2.GetEncodedFilter()); } BOOST_AUTO_TEST_CASE(blockfilters_json_test) {