diff --git a/src/avalanche/compactproofs.h b/src/avalanche/compactproofs.h --- a/src/avalanche/compactproofs.h +++ b/src/avalanche/compactproofs.h @@ -31,18 +31,8 @@ uint32_t index; avalanche::ProofRef proof; - class Formatter : public DifferenceFormatter { - public: - template void Ser(Stream &s, PrefilledProof pp) { - DifferenceFormatter::Ser(s, pp.index); - s << pp.proof; - } - - template void Unser(Stream &s, PrefilledProof &pp) { - DifferenceFormatter::Unser(s, pp.index); - s >> pp.proof; - } - }; + template void SerData(Stream &s) { s << proof; } + template void UnserData(Stream &s) { s >> proof; } }; class CompactProofs { @@ -73,7 +63,7 @@ obj.shortproofidk0, obj.shortproofidk1, Using>>( obj.shortproofids), - Using>( + Using>( obj.prefilledProofs)); if (ser_action.ForRead() && obj.prefilledProofs.size() > 0) { diff --git a/src/serialize.h b/src/serialize.h --- a/src/serialize.h +++ b/src/serialize.h @@ -782,6 +782,28 @@ } }; +/** + * Helper for a list of items containing a differentially encoded index as their + * first member. See DifferenceFormatter for info about the index encoding. + * + * The index should be a public member of the object. + * SerData()/UnserData() methods must be implemented to serialize/deserialize + * the remaining item data. + * + * To be used with a VectorFormatter. + */ +struct DifferentialIndexedItemFormatter : public DifferenceFormatter { + template void Ser(Stream &s, T v) { + DifferenceFormatter::Ser(s, v.index); + v.SerData(s); + } + + template void Unser(Stream &s, T &v) { + DifferenceFormatter::Unser(s, v.index); + v.UnserData(s); + } +}; + /** * Forward declarations */ diff --git a/src/test/serialize_tests.cpp b/src/test/serialize_tests.cpp --- a/src/test/serialize_tests.cpp +++ b/src/test/serialize_tests.cpp @@ -452,10 +452,67 @@ BOOST_CHECK(methodtest3 == methodtest4); } -BOOST_AUTO_TEST_CASE(difference_formatter) { - VectorFormatter formatter; +namespace { +struct DifferentialIndexedItem { + uint32_t index; + std::string text; + + template void SerData(Stream &s) { s << text; } + template void UnserData(Stream &s) { s >> text; } + + bool operator==(const DifferentialIndexedItem &other) const { + return index == other.index && text == other.text; + } + bool operator!=(const DifferentialIndexedItem &other) const { + return !(*this == other); + } + + // Make boost happy + friend std::ostream &operator<<(std::ostream &os, + const DifferentialIndexedItem &item) { + os << "index: " << item.index << ", text: " << item.text; + return os; + } + + DifferentialIndexedItem() {} + DifferentialIndexedItem(uint32_t indexIn) + : index(indexIn), text(ToString(index)) {} +}; + +template +static void checkDifferentialEncodingRoundtrip() { + Formatter formatter; + + const std::vector indicesIn{0, 1, 2, 5, 10, 20, 50, 100}; + std::vector indicesOut; + + CDataStream ss(SER_DISK, PROTOCOL_VERSION); + formatter.Ser(ss, indicesIn); + formatter.Unser(ss, indicesOut); + BOOST_CHECK_EQUAL_COLLECTIONS(indicesIn.begin(), indicesIn.end(), + indicesOut.begin(), indicesOut.end()); +} + +template +static void checkDifferentialEncodingOverflow() { + Formatter formatter; { + const std::vector indicesIn{1, 0}; + + CDataStream ss(SER_DISK, PROTOCOL_VERSION); + BOOST_CHECK_EXCEPTION(formatter.Ser(ss, indicesIn), + std::ios_base::failure, + HasReason("differential value overflow")); + } +} +} // namespace + +BOOST_AUTO_TEST_CASE(difference_formatter) { + { + // Roundtrip with internals check + VectorFormatter formatter; + std::vector indicesIn{0, 1, 2, 5, 10, 20, 50, 100}; std::vector indicesOut; @@ -474,16 +531,15 @@ indicesOut.begin(), indicesOut.end()); } - { - std::vector indicesIn{1, 0}; - - CDataStream ss(SER_DISK, PROTOCOL_VERSION); - BOOST_CHECK_EXCEPTION(formatter.Ser(ss, indicesIn), - std::ios_base::failure, - HasReason("differential value overflow")); - } + checkDifferentialEncodingRoundtrip, + uint32_t>(); + checkDifferentialEncodingRoundtrip< + VectorFormatter, + DifferentialIndexedItem>(); { + // Checking 32 bits overflow requires to manually create the serialized + // stream, so only do it with uint32_t std::vector indicesOut; // Compute the number of MAX_SIZE increment we need to cause an overflow @@ -512,6 +568,8 @@ return ss; }; + VectorFormatter formatter; + auto noThrowStream = buildStream(remainder - 1); BOOST_CHECK_NO_THROW(formatter.Unser(noThrowStream, indicesOut)); @@ -520,6 +578,12 @@ std::ios_base::failure, HasReason("differential value overflow")); } + + checkDifferentialEncodingOverflow, + uint32_t>(); + checkDifferentialEncodingOverflow< + VectorFormatter, + DifferentialIndexedItem>(); } BOOST_AUTO_TEST_SUITE_END()