diff --git a/src/seeder/test/seeder_tests.cpp b/src/seeder/test/seeder_tests.cpp --- a/src/seeder/test/seeder_tests.cpp +++ b/src/seeder/test/seeder_tests.cpp @@ -4,12 +4,142 @@ #define BOOST_TEST_MODULE Bitcoin Seeder Test Suite +#include + +#include +#include +#include +#include +#include + #include BOOST_AUTO_TEST_SUITE(seeder) -BOOST_AUTO_TEST_CASE(parse_name_simple) { - BOOST_CHECK_EQUAL(true, true); +static const int BUFFER_LENGTH = 512; +static const int QUERY_NAME_BUFFER_LENGTH = 256; +static const int SIZE_OF_QUERY_TYPE = 2; +static const int SIZE_OF_QUERY_CLASS = 2; +// QCLASS = IN +static const uint16_t QUERY_CLASS = 1; +// QTYPE = A query +static const uint16_t QUERY_TYPE = 1; +// End of name field +static const uint8_t END_OF_NAME_FIELD = 0; +static const size_t MAX_LABEL_LENGTH = 63; + +// Builds dummy DNS query message +std::array +CreateDNSQuestion(const std::string &queryName) { + std::stringstream queryhex; + queryhex.clear(); + queryhex << std::hex << std::setfill('0'); + + // Build the DNS message question section + uint8_t i = 0; + uint8_t lastPeriod = 0; + while (i < queryName.size()) { + if (queryName[i] == '.') { + uint8_t labelLength = i - lastPeriod; + queryhex << labelLength; + while (lastPeriod < i) { + queryhex << queryName[lastPeriod]; + lastPeriod++; + } + lastPeriod = i + 1; + } + i++; + } + uint8_t labelLength = i - lastPeriod; + queryhex << labelLength; + while (lastPeriod < i) { + queryhex << queryName[lastPeriod]; + lastPeriod++; + } + queryhex << END_OF_NAME_FIELD; + queryhex << QUERY_TYPE; + queryhex << QUERY_CLASS; + + std::array messageBuffer; + messageBuffer.fill(0); + queryhex >> std::hex >> messageBuffer.data(); + return messageBuffer; +} + +BOOST_AUTO_TEST_CASE(parse_name_happy_path) { + const std::string messageQueryName = "www.mydomain.com"; + std::array dnsMessage = + CreateDNSQuestion(messageQueryName); + std::array queryName; + queryName.fill(0); + size_t writeBufferSize = QUERY_NAME_BUFFER_LENGTH; + const uint8_t *messageBegin = dnsMessage.data(); + // +1 for the last octet ending the field name + const uint8_t *messageEnd = dnsMessage.data() + messageQueryName.size() + + 1 + SIZE_OF_QUERY_TYPE + SIZE_OF_QUERY_CLASS; + + int ret = parse_name(&messageBegin, messageEnd, dnsMessage.data(), + queryName.data(), writeBufferSize); + + BOOST_CHECK_EQUAL(ret, 0); + BOOST_CHECK_EQUAL(queryName.data(), messageQueryName); +} + +// Test for insufficient output buffer size +BOOST_AUTO_TEST_CASE(parse_name_insufficient_output_buffer_size) { + const std::string messageQueryName = "www.mydomain.com"; + std::array dnsMessage = + CreateDNSQuestion(messageQueryName); + std::array queryName; + queryName.fill(0); + const uint8_t *messageBegin = dnsMessage.data(); + // +1 for the last octet ending the field name + const uint8_t *messageEnd = dnsMessage.data() + messageQueryName.size() + + 1 + SIZE_OF_QUERY_TYPE + SIZE_OF_QUERY_CLASS; + + // The size of the buffer being written to is 1 octect too small + int ret = parse_name(&messageBegin, messageEnd, dnsMessage.data(), + queryName.data(), messageQueryName.size()); + BOOST_CHECK_EQUAL(ret, -2); + BOOST_CHECK_EQUAL(queryName.data(), + messageQueryName.substr(0, messageQueryName.size() - 1)); +} + +// Test for premature end of input buffer +BOOST_AUTO_TEST_CASE(parse_name_premature_end_of_input_buffer) { + const std::string messageQueryName = "www.mydomain.com"; + std::array dnsMessage = + CreateDNSQuestion(messageQueryName); + std::array queryName; + queryName.fill(0); + size_t writeBufferSize = QUERY_NAME_BUFFER_LENGTH; + const uint8_t *messageBegin = dnsMessage.data(); + // The end index pointer for the DNS message buffer passed is located two + // octets away from the beginning + int ret = parse_name(&messageBegin, messageBegin + 2, dnsMessage.data(), + queryName.data(), writeBufferSize); + BOOST_CHECK_EQUAL(ret, -1); + BOOST_CHECK_EQUAL(queryName.data(), messageQueryName.substr(0, 1)); +} + +// Test for when name field is too long +BOOST_AUTO_TEST_CASE(parse_name_field_name_too_long) { + std::string tooLongQName = "www."; + for (size_t i = 0; i < MAX_LABEL_LENGTH + 1; i++) { + tooLongQName += 'a'; + } + tooLongQName += ".com"; + std::array dnsMessage = + CreateDNSQuestion(tooLongQName); + std::array queryName; + queryName.fill(0); + size_t writeBufferSize = QUERY_NAME_BUFFER_LENGTH; + const uint8_t *messageBegin = dnsMessage.data(); + const uint8_t *messageEnd = dnsMessage.data() + tooLongQName.size() + 1 + + SIZE_OF_QUERY_TYPE + SIZE_OF_QUERY_CLASS; + int ret = parse_name(&messageBegin, messageEnd, dnsMessage.data(), + queryName.data(), writeBufferSize); + BOOST_CHECK_EQUAL(ret, -1); } BOOST_AUTO_TEST_SUITE_END()