diff --git a/src/arith_uint256.cpp b/src/arith_uint256.cpp index b3a305ba99..ca8330005e 100644 --- a/src/arith_uint256.cpp +++ b/src/arith_uint256.cpp @@ -1,239 +1,239 @@ // Copyright (c) 2009-2010 Satoshi Nakamoto // Copyright (c) 2009-2016 The Bitcoin Core developers // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include #include #include #include #include #include template base_uint::base_uint(const std::string &str) { static_assert(BITS / 32 > 0 && BITS % 32 == 0, "Template parameter BITS must be a positive multiple of 32."); SetHex(str); } template base_uint &base_uint::operator<<=(unsigned int shift) { base_uint a(*this); for (int i = 0; i < WIDTH; i++) pn[i] = 0; int k = shift / 32; shift = shift % 32; for (int i = 0; i < WIDTH; i++) { if (i + k + 1 < WIDTH && shift != 0) pn[i + k + 1] |= (a.pn[i] >> (32 - shift)); if (i + k < WIDTH) pn[i + k] |= (a.pn[i] << shift); } return *this; } template base_uint &base_uint::operator>>=(unsigned int shift) { base_uint a(*this); for (int i = 0; i < WIDTH; i++) pn[i] = 0; int k = shift / 32; shift = shift % 32; for (int i = 0; i < WIDTH; i++) { if (i - k - 1 >= 0 && shift != 0) pn[i - k - 1] |= (a.pn[i] << (32 - shift)); if (i - k >= 0) pn[i - k] |= (a.pn[i] >> shift); } return *this; } template base_uint &base_uint::operator*=(uint32_t b32) { uint64_t carry = 0; for (int i = 0; i < WIDTH; i++) { uint64_t n = carry + (uint64_t)b32 * pn[i]; pn[i] = n & 0xffffffff; carry = n >> 32; } return *this; } template base_uint &base_uint::operator*=(const base_uint &b) { - base_uint a = *this; - *this = 0; + base_uint a; for (int j = 0; j < WIDTH; j++) { uint64_t carry = 0; for (int i = 0; i + j < WIDTH; i++) { - uint64_t n = carry + pn[i + j] + (uint64_t)a.pn[j] * b.pn[i]; - pn[i + j] = n & 0xffffffff; + uint64_t n = carry + a.pn[i + j] + (uint64_t)pn[j] * b.pn[i]; + a.pn[i + j] = n & 0xffffffff; carry = n >> 32; } } + *this = a; return *this; } template base_uint &base_uint::operator/=(const base_uint &b) { // make a copy, so we can shift. base_uint div = b; // make a copy, so we can subtract. base_uint num = *this; // the quotient. *this = 0; int num_bits = num.bits(); int div_bits = div.bits(); if (div_bits == 0) throw uint_error("Division by zero"); // the result is certainly 0. if (div_bits > num_bits) return *this; int shift = num_bits - div_bits; // shift so that div and num align. div <<= shift; while (shift >= 0) { if (num >= div) { num -= div; // set a bit of the result. pn[shift / 32] |= (1 << (shift & 31)); } // shift back. div >>= 1; shift--; } // num now contains the remainder of the division. return *this; } template int base_uint::CompareTo(const base_uint &b) const { for (int i = WIDTH - 1; i >= 0; i--) { if (pn[i] < b.pn[i]) return -1; if (pn[i] > b.pn[i]) return 1; } return 0; } template bool base_uint::EqualTo(uint64_t b) const { for (int i = WIDTH - 1; i >= 2; i--) { if (pn[i]) return false; } if (pn[1] != (b >> 32)) return false; if (pn[0] != (b & 0xfffffffful)) return false; return true; } template double base_uint::getdouble() const { double ret = 0.0; double fact = 1.0; for (int i = 0; i < WIDTH; i++) { ret += fact * pn[i]; fact *= 4294967296.0; } return ret; } template std::string base_uint::GetHex() const { return ArithToUint256(*this).GetHex(); } template void base_uint::SetHex(const char *psz) { *this = UintToArith256(uint256S(psz)); } template void base_uint::SetHex(const std::string &str) { SetHex(str.c_str()); } template std::string base_uint::ToString() const { return (GetHex()); } template unsigned int base_uint::bits() const { for (int pos = WIDTH - 1; pos >= 0; pos--) { if (pn[pos]) { for (int nbits = 31; nbits > 0; nbits--) { if (pn[pos] & 1U << nbits) { return 32 * pos + nbits + 1; } } return 32 * pos + 1; } } return 0; } // Explicit instantiations for base_uint<256> template base_uint<256>::base_uint(const std::string &); template base_uint<256> &base_uint<256>::operator<<=(unsigned int); template base_uint<256> &base_uint<256>::operator>>=(unsigned int); template base_uint<256> &base_uint<256>::operator*=(uint32_t b32); template base_uint<256> &base_uint<256>::operator*=(const base_uint<256> &b); template base_uint<256> &base_uint<256>::operator/=(const base_uint<256> &b); template int base_uint<256>::CompareTo(const base_uint<256> &) const; template bool base_uint<256>::EqualTo(uint64_t) const; template double base_uint<256>::getdouble() const; template std::string base_uint<256>::GetHex() const; template std::string base_uint<256>::ToString() const; template void base_uint<256>::SetHex(const char *); template void base_uint<256>::SetHex(const std::string &); template unsigned int base_uint<256>::bits() const; // This implementation directly uses shifts instead of going through an // intermediate MPI representation. arith_uint256 &arith_uint256::SetCompact(uint32_t nCompact, bool *pfNegative, bool *pfOverflow) { int nSize = nCompact >> 24; uint32_t nWord = nCompact & 0x007fffff; if (nSize <= 3) { nWord >>= 8 * (3 - nSize); *this = nWord; } else { *this = nWord; *this <<= 8 * (nSize - 3); } if (pfNegative) *pfNegative = nWord != 0 && (nCompact & 0x00800000) != 0; if (pfOverflow) *pfOverflow = nWord != 0 && ((nSize > 34) || (nWord > 0xff && nSize > 33) || (nWord > 0xffff && nSize > 32)); return *this; } uint32_t arith_uint256::GetCompact(bool fNegative) const { int nSize = (bits() + 7) / 8; uint32_t nCompact = 0; if (nSize <= 3) { nCompact = GetLow64() << 8 * (3 - nSize); } else { arith_uint256 bn = *this >> 8 * (nSize - 3); nCompact = bn.GetLow64(); } // The 0x00800000 bit denotes the sign. // Thus, if it is already set, divide the mantissa by 256 and increase the // exponent. if (nCompact & 0x00800000) { nCompact >>= 8; nSize++; } assert((nCompact & ~0x007fffff) == 0); assert(nSize < 256); nCompact |= nSize << 24; nCompact |= (fNegative && (nCompact & 0x007fffff) ? 0x00800000 : 0); return nCompact; } uint256 ArithToUint256(const arith_uint256 &a) { uint256 b; for (int x = 0; x < a.WIDTH; ++x) WriteLE32(b.begin() + x * 4, a.pn[x]); return b; } arith_uint256 UintToArith256(const uint256 &a) { arith_uint256 b; for (int x = 0; x < b.WIDTH; ++x) b.pn[x] = ReadLE32(a.begin() + x * 4); return b; } diff --git a/src/test/uint256_tests.cpp b/src/test/uint256_tests.cpp index 54a5a63f7e..8ea52efa22 100644 --- a/src/test/uint256_tests.cpp +++ b/src/test/uint256_tests.cpp @@ -1,283 +1,295 @@ // Copyright (c) 2011-2016 The Bitcoin Core developers // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include #include #include #include #include #include #include #include #include #include #include #include #include BOOST_FIXTURE_TEST_SUITE(uint256_tests, BasicTestingSetup) const uint8_t R1Array[] = "\x9c\x52\x4a\xdb\xcf\x56\x11\x12\x2b\x29\x12\x5e\x5d\x35\xd2\xd2" "\x22\x81\xaa\xb5\x33\xf0\x08\x32\xd5\x56\xb1\xf9\xea\xe5\x1d\x7d"; const char R1ArrayHex[] = "7D1DE5EAF9B156D53208F033B5AA8122D2d2355d5e12292b121156cfdb4a529c"; const uint256 R1L = uint256(std::vector(R1Array, R1Array + 32)); const uint160 R1S = uint160(std::vector(R1Array, R1Array + 20)); const uint8_t R2Array[] = "\x70\x32\x1d\x7c\x47\xa5\x6b\x40\x26\x7e\x0a\xc3\xa6\x9c\xb6\xbf" "\x13\x30\x47\xa3\x19\x2d\xda\x71\x49\x13\x72\xf0\xb4\xca\x81\xd7"; const uint256 R2L = uint256(std::vector(R2Array, R2Array + 32)); const uint160 R2S = uint160(std::vector(R2Array, R2Array + 20)); const uint8_t ZeroArray[] = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; const uint256 ZeroL = uint256(std::vector(ZeroArray, ZeroArray + 32)); const uint160 ZeroS = uint160(std::vector(ZeroArray, ZeroArray + 20)); const uint8_t OneArray[] = "\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; const uint256 OneL = uint256(std::vector(OneArray, OneArray + 32)); const uint160 OneS = uint160(std::vector(OneArray, OneArray + 20)); const uint8_t MaxArray[] = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"; const uint256 MaxL = uint256(std::vector(MaxArray, MaxArray + 32)); const uint160 MaxS = uint160(std::vector(MaxArray, MaxArray + 20)); static std::string ArrayToString(const uint8_t A[], unsigned int width) { std::stringstream Stream; Stream << std::hex; for (unsigned int i = 0; i < width; ++i) { Stream << std::setw(2) << std::setfill('0') << (unsigned int)A[width - i - 1]; } return Stream.str(); } // constructors, equality, inequality BOOST_AUTO_TEST_CASE(basics) { BOOST_CHECK(1 == 0 + 1); // constructor uint256(vector): BOOST_CHECK(R1L.ToString() == ArrayToString(R1Array, 32)); BOOST_CHECK(R1S.ToString() == ArrayToString(R1Array, 20)); BOOST_CHECK(R2L.ToString() == ArrayToString(R2Array, 32)); BOOST_CHECK(R2S.ToString() == ArrayToString(R2Array, 20)); BOOST_CHECK(ZeroL.ToString() == ArrayToString(ZeroArray, 32)); BOOST_CHECK(ZeroS.ToString() == ArrayToString(ZeroArray, 20)); BOOST_CHECK(OneL.ToString() == ArrayToString(OneArray, 32)); BOOST_CHECK(OneS.ToString() == ArrayToString(OneArray, 20)); BOOST_CHECK(MaxL.ToString() == ArrayToString(MaxArray, 32)); BOOST_CHECK(MaxS.ToString() == ArrayToString(MaxArray, 20)); BOOST_CHECK(OneL.ToString() != ArrayToString(ZeroArray, 32)); BOOST_CHECK(OneS.ToString() != ArrayToString(ZeroArray, 20)); // == and != BOOST_CHECK(R1L != R2L && R1S != R2S); BOOST_CHECK(ZeroL != OneL && ZeroS != OneS); BOOST_CHECK(OneL != ZeroL && OneS != ZeroS); BOOST_CHECK(MaxL != ZeroL && MaxS != ZeroS); // String Constructor and Copy Constructor BOOST_CHECK(uint256S("0x" + R1L.ToString()) == R1L); BOOST_CHECK(uint256S("0x" + R2L.ToString()) == R2L); BOOST_CHECK(uint256S("0x" + ZeroL.ToString()) == ZeroL); BOOST_CHECK(uint256S("0x" + OneL.ToString()) == OneL); BOOST_CHECK(uint256S("0x" + MaxL.ToString()) == MaxL); BOOST_CHECK(uint256S(R1L.ToString()) == R1L); BOOST_CHECK(uint256S(" 0x" + R1L.ToString() + " ") == R1L); BOOST_CHECK(uint256S("") == ZeroL); BOOST_CHECK(R1L == uint256S(R1ArrayHex)); BOOST_CHECK(uint256(R1L) == R1L); BOOST_CHECK(uint256(ZeroL) == ZeroL); BOOST_CHECK(uint256(OneL) == OneL); BOOST_CHECK(uint160S("0x" + R1S.ToString()) == R1S); BOOST_CHECK(uint160S("0x" + R2S.ToString()) == R2S); BOOST_CHECK(uint160S("0x" + ZeroS.ToString()) == ZeroS); BOOST_CHECK(uint160S("0x" + OneS.ToString()) == OneS); BOOST_CHECK(uint160S("0x" + MaxS.ToString()) == MaxS); BOOST_CHECK(uint160S(R1S.ToString()) == R1S); BOOST_CHECK(uint160S(" 0x" + R1S.ToString() + " ") == R1S); BOOST_CHECK(uint160S("") == ZeroS); BOOST_CHECK(R1S == uint160S(R1ArrayHex)); BOOST_CHECK(uint160(R1S) == R1S); BOOST_CHECK(uint160(ZeroS) == ZeroS); BOOST_CHECK(uint160(OneS) == OneS); } static void CheckComparison(const uint256 &a, const uint256 &b) { BOOST_CHECK(a < b); BOOST_CHECK(a <= b); BOOST_CHECK(b > a); BOOST_CHECK(b >= a); } static void CheckComparison(const uint160 &a, const uint160 &b) { BOOST_CHECK(a < b); BOOST_CHECK(a <= b); BOOST_CHECK(b > a); BOOST_CHECK(b >= a); } // <= >= < > BOOST_AUTO_TEST_CASE(comparison) { uint256 LastL; for (int i = 0; i < 256; i++) { uint256 TmpL; *(TmpL.begin() + (i >> 3)) |= 1 << (i & 7); CheckComparison(LastL, TmpL); LastL = TmpL; BOOST_CHECK(LastL <= LastL); BOOST_CHECK(LastL >= LastL); } CheckComparison(ZeroL, R1L); CheckComparison(R1L, R2L); CheckComparison(ZeroL, OneL); CheckComparison(OneL, MaxL); CheckComparison(R1L, MaxL); CheckComparison(R2L, MaxL); uint160 LastS; for (int i = 0; i < 160; i++) { uint160 TmpS; *(TmpS.begin() + (i >> 3)) |= 1 << (i & 7); CheckComparison(LastS, TmpS); LastS = TmpS; BOOST_CHECK(LastS <= LastS); BOOST_CHECK(LastS >= LastS); } CheckComparison(ZeroS, R1S); CheckComparison(R2S, R1S); CheckComparison(ZeroS, OneS); CheckComparison(OneS, MaxS); CheckComparison(R1S, MaxS); CheckComparison(R2S, MaxS); } // GetHex SetHex begin() end() size() GetLow64 GetSerializeSize, Serialize, // Unserialize BOOST_AUTO_TEST_CASE(methods) { BOOST_CHECK(R1L.GetHex() == R1L.ToString()); BOOST_CHECK(R2L.GetHex() == R2L.ToString()); BOOST_CHECK(OneL.GetHex() == OneL.ToString()); BOOST_CHECK(MaxL.GetHex() == MaxL.ToString()); uint256 TmpL(R1L); BOOST_CHECK(TmpL == R1L); TmpL.SetHex(R2L.ToString()); BOOST_CHECK(TmpL == R2L); TmpL.SetHex(ZeroL.ToString()); BOOST_CHECK(TmpL == uint256()); TmpL.SetHex(R1L.ToString()); BOOST_CHECK(memcmp(R1L.begin(), R1Array, 32) == 0); BOOST_CHECK(memcmp(TmpL.begin(), R1Array, 32) == 0); BOOST_CHECK(memcmp(R2L.begin(), R2Array, 32) == 0); BOOST_CHECK(memcmp(ZeroL.begin(), ZeroArray, 32) == 0); BOOST_CHECK(memcmp(OneL.begin(), OneArray, 32) == 0); BOOST_CHECK(R1L.size() == sizeof(R1L)); BOOST_CHECK(sizeof(R1L) == 32); BOOST_CHECK(R1L.size() == 32); BOOST_CHECK(R2L.size() == 32); BOOST_CHECK(ZeroL.size() == 32); BOOST_CHECK(MaxL.size() == 32); BOOST_CHECK(R1L.begin() + 32 == R1L.end()); BOOST_CHECK(R2L.begin() + 32 == R2L.end()); BOOST_CHECK(OneL.begin() + 32 == OneL.end()); BOOST_CHECK(MaxL.begin() + 32 == MaxL.end()); BOOST_CHECK(TmpL.begin() + 32 == TmpL.end()); BOOST_CHECK(GetSerializeSize(R1L, 0, PROTOCOL_VERSION) == 32); BOOST_CHECK(GetSerializeSize(ZeroL, 0, PROTOCOL_VERSION) == 32); CDataStream ss(0, PROTOCOL_VERSION); ss << R1L; BOOST_CHECK(ss.str() == std::string(R1Array, R1Array + 32)); ss >> TmpL; BOOST_CHECK(R1L == TmpL); ss.clear(); ss << ZeroL; BOOST_CHECK(ss.str() == std::string(ZeroArray, ZeroArray + 32)); ss >> TmpL; BOOST_CHECK(ZeroL == TmpL); ss.clear(); ss << MaxL; BOOST_CHECK(ss.str() == std::string(MaxArray, MaxArray + 32)); ss >> TmpL; BOOST_CHECK(MaxL == TmpL); ss.clear(); BOOST_CHECK(R1S.GetHex() == R1S.ToString()); BOOST_CHECK(R2S.GetHex() == R2S.ToString()); BOOST_CHECK(OneS.GetHex() == OneS.ToString()); BOOST_CHECK(MaxS.GetHex() == MaxS.ToString()); uint160 TmpS(R1S); BOOST_CHECK(TmpS == R1S); TmpS.SetHex(R2S.ToString()); BOOST_CHECK(TmpS == R2S); TmpS.SetHex(ZeroS.ToString()); BOOST_CHECK(TmpS == uint160()); TmpS.SetHex(R1S.ToString()); BOOST_CHECK(memcmp(R1S.begin(), R1Array, 20) == 0); BOOST_CHECK(memcmp(TmpS.begin(), R1Array, 20) == 0); BOOST_CHECK(memcmp(R2S.begin(), R2Array, 20) == 0); BOOST_CHECK(memcmp(ZeroS.begin(), ZeroArray, 20) == 0); BOOST_CHECK(memcmp(OneS.begin(), OneArray, 20) == 0); BOOST_CHECK(R1S.size() == sizeof(R1S)); BOOST_CHECK(sizeof(R1S) == 20); BOOST_CHECK(R1S.size() == 20); BOOST_CHECK(R2S.size() == 20); BOOST_CHECK(ZeroS.size() == 20); BOOST_CHECK(MaxS.size() == 20); BOOST_CHECK(R1S.begin() + 20 == R1S.end()); BOOST_CHECK(R2S.begin() + 20 == R2S.end()); BOOST_CHECK(OneS.begin() + 20 == OneS.end()); BOOST_CHECK(MaxS.begin() + 20 == MaxS.end()); BOOST_CHECK(TmpS.begin() + 20 == TmpS.end()); BOOST_CHECK(GetSerializeSize(R1S, 0, PROTOCOL_VERSION) == 20); BOOST_CHECK(GetSerializeSize(ZeroS, 0, PROTOCOL_VERSION) == 20); ss << R1S; BOOST_CHECK(ss.str() == std::string(R1Array, R1Array + 20)); ss >> TmpS; BOOST_CHECK(R1S == TmpS); ss.clear(); ss << ZeroS; BOOST_CHECK(ss.str() == std::string(ZeroArray, ZeroArray + 20)); ss >> TmpS; BOOST_CHECK(ZeroS == TmpS); ss.clear(); ss << MaxS; BOOST_CHECK(ss.str() == std::string(MaxArray, MaxArray + 20)); ss >> TmpS; BOOST_CHECK(MaxS == TmpS); ss.clear(); } BOOST_AUTO_TEST_CASE(conversion) { BOOST_CHECK(ArithToUint256(UintToArith256(ZeroL)) == ZeroL); BOOST_CHECK(ArithToUint256(UintToArith256(OneL)) == OneL); BOOST_CHECK(ArithToUint256(UintToArith256(R1L)) == R1L); BOOST_CHECK(ArithToUint256(UintToArith256(R2L)) == R2L); BOOST_CHECK(UintToArith256(ZeroL) == 0); BOOST_CHECK(UintToArith256(OneL) == 1); BOOST_CHECK(ArithToUint256(0) == ZeroL); BOOST_CHECK(ArithToUint256(1) == OneL); BOOST_CHECK(arith_uint256(R1L.GetHex()) == UintToArith256(R1L)); BOOST_CHECK(arith_uint256(R2L.GetHex()) == UintToArith256(R2L)); BOOST_CHECK(R1L.GetHex() == UintToArith256(R1L).GetHex()); BOOST_CHECK(R2L.GetHex() == UintToArith256(R2L).GetHex()); } +BOOST_AUTO_TEST_CASE(operator_with_self) { + arith_uint256 v = UintToArith256(uint256S("02")); + v *= v; + BOOST_CHECK(v == UintToArith256(uint256S("04"))); + v /= v; + BOOST_CHECK(v == UintToArith256(uint256S("01"))); + v += v; + BOOST_CHECK(v == UintToArith256(uint256S("02"))); + v -= v; + BOOST_CHECK(v == UintToArith256(uint256S("0"))); +} + BOOST_AUTO_TEST_SUITE_END()