diff --git a/CMakeLists.txt b/CMakeLists.txt index fef2de0..fef73c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -186,7 +186,7 @@ if (GDSB_TEST) # GDSB test target add_executable(gdsb_test - test/batcher.cpp + test/batcher_tests.cpp test/experiment_tests.cpp test/graph_input_tests.cpp test/graph_test.cpp diff --git a/include/gdsb/batcher.h b/include/gdsb/batcher.h index 40511df..aa414f9 100644 --- a/include/gdsb/batcher.h +++ b/include/gdsb/batcher.h @@ -83,8 +83,41 @@ template class Batcher EIt m_end; }; +constexpr uint32_t count_of_batches(uint64_t const edge_count, uint64_t expected_batch_size) +{ + return edge_count / expected_batch_size; +} + +constexpr size_t fair_batch_size(uint64_t const edge_count, uint64_t max_batch_size) +{ + if (edge_count <= max_batch_size) + { + return edge_count; + } + + uint64_t const batch_count = edge_count / max_batch_size; + uint64_t const remaining = edge_count % max_batch_size; + uint64_t fair_size = max_batch_size + (remaining / batch_count); + return fair_size; +} + +constexpr std::pair +fair_batch_offset(uint64_t const fair_size, uint64_t current_batch_num, uint64_t batches_count, uint64_t edge_count) +{ + uint64_t begin = current_batch_num * fair_size; + + bool const is_last_batch = (current_batch_num + 1) == batches_count; + if (is_last_batch) + { + uint64_t const count = edge_count - begin; + return std::make_pair(begin, count); + } + + uint64_t const count = fair_size; + return std::make_pair(begin, count); +} -inline uint64_t partition_batch_count(uint64_t const batch_count, uint32_t const partition_id, uint32_t const partition_size) +constexpr uint64_t partition_batch_count(uint64_t const batch_count, uint32_t const partition_id, uint32_t const partition_size) { uint64_t partition_batch_count = batch_count / partition_size; if (partition_id == partition_size - 1) diff --git a/include/gdsb/mpi_graph_io.h b/include/gdsb/mpi_graph_io.h index 4ba6c68..badc440 100644 --- a/include/gdsb/mpi_graph_io.h +++ b/include/gdsb/mpi_graph_io.h @@ -132,6 +132,67 @@ std::tuple all_read_binary_graph_partition(MPI_File const in return std::make_tuple(data.vertex_count, edge_count); } +struct ReadBatch +{ + // Set this to the desired batch size. + uint64_t batch_size{ 1 }; + // Set this to the expected size of edge in bytes. + size_t edge_size_in_bytes{ 4 }; + // Leave this as is, will be used by all_read_binary_graph_batch(). + uint64_t count_read_in_edges{ 0 }; +}; + +// Parameter edges must be a pointer to the first element of the array/vector +// containing edges. +// +// We do not read from file exceeding the edge count of 'data' (therefore, not +// reading past EOF). Please make sure to handle 'read_batch' correctly which is +// an in-out parameter. We set the 'read_batch.count_read_in_edges' according to +// the edge count we could read from file. +template +void all_read_binary_graph_batch(MPI_File const input, BinaryGraphHeader const& data, Edges* const edges, ReadBatch& read_batch, MPI_Datatype const mpi_datatype) +{ + if (read_batch.batch_size > std::numeric_limits::max()) + { + // Have a look at the count type of MPI_File_read_all() + throw std::runtime_error("Count of edges exceeds MPI read count type (int) maximum."); + } + + if (read_batch.count_read_in_edges == data.edge_count) + { + return; + } + + uint64_t potential_count = read_batch.batch_size; + uint64_t potential_count_read_in_edges = read_batch.count_read_in_edges + potential_count; + if (potential_count_read_in_edges > data.edge_count) + { + potential_count = data.edge_count - read_batch.count_read_in_edges; + } + + int count = static_cast(potential_count); + read_batch.count_read_in_edges += potential_count; + + MPI_Status status; + int const read_all_error = MPI_File_read_all(input, edges, count, mpi_datatype, &status); + if (read_all_error != MPI_SUCCESS) + { + throw std::runtime_error("Could not successfully read all edges from MPI file."); + } + + // It's not really clear from several MPI documentations I've read which + // field actually contains the number of elements read from file given the + // count one wants to read. The status.MPI_SOURCE field seems to + // surprisingly provide that data point.. However, this seems to read beyond + // EOF and still be equal to the requested count. + // + // Still we won't do the following since that leads to a really unclear API. + // return status.MPI_SOURCE == count; + // + // TODO: Find a way to check if we read exceeding EOF. Perhaps we could + // better model the view on the file? +} + namespace binary { diff --git a/test/batcher.cpp b/test/batcher.cpp deleted file mode 100644 index ed22858..0000000 --- a/test/batcher.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include - -#include "test_graph.h" - -#include -#include -#include - -using namespace gdsb; - -TEST_CASE("Batcher") -{ - WeightedEdges32 edges{ { 0, { 1, 1.f } }, { 0, { 2, 1.f } }, { 0, { 3, 1.f } }, { 1, { 2, 1.f } }, - { 1, { 3, 1.f } }, { 2, { 5, 1.f } }, { 2, { 6, 1.f } } }; - - CHECK(edges.size() == 7); - - SECTION("Unsorted batch") - { - Batcher batcher(std::begin(edges), std::end(edges), 3); - Batch batch = batcher.next(1); - - CHECK(batch.begin->source == edges[0].source); - CHECK(batch.begin->target.vertex == edges[0].target.vertex); - } -} - -TEST_CASE("partition_batch_count, on enzymes graph") -{ - std::ifstream binary_graph(graph_path + directed_unweighted_graph_enzymes_bin); - BinaryGraphHeader header = read_binary_graph_header(binary_graph); - - SECTION("partition size 2") - { - uint32_t partition_size = 2; - - uint32_t partition_id = 0; - uint64_t edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); - CHECK(edge_count == 84); - - partition_id = 1; - edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); - CHECK(edge_count == 84); - } - - SECTION("partition size 5") - { - uint32_t partition_size = 5; - - uint32_t partition_id = 0; - uint64_t edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); - CHECK(edge_count == 33); - - partition_id = 1; - edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); - CHECK(edge_count == 33); - - partition_id = 2; - edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); - CHECK(edge_count == 33); - - partition_id = 3; - edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); - CHECK(edge_count == 33); - - partition_id = 4; - edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); - CHECK(edge_count == 36); - } -} diff --git a/test/batcher_tests.cpp b/test/batcher_tests.cpp new file mode 100644 index 0000000..0647441 --- /dev/null +++ b/test/batcher_tests.cpp @@ -0,0 +1,273 @@ +#include + +#include "test_graph.h" + +#include +#include +#include + +using namespace gdsb; + +TEST_CASE("Batcher") +{ + WeightedEdges32 edges{ { 0, { 1, 1.f } }, { 0, { 2, 1.f } }, { 0, { 3, 1.f } }, { 1, { 2, 1.f } }, + { 1, { 3, 1.f } }, { 2, { 5, 1.f } }, { 2, { 6, 1.f } } }; + + CHECK(edges.size() == 7); + + SECTION("Unsorted batch") + { + Batcher batcher(std::begin(edges), std::end(edges), 3); + Batch batch = batcher.next(1); + + CHECK(batch.begin->source == edges[0].source); + CHECK(batch.begin->target.vertex == edges[0].target.vertex); + } +} + +TEST_CASE("partition_batch_count, on enzymes graph") +{ + std::ifstream binary_graph(graph_path + directed_unweighted_graph_enzymes_bin); + BinaryGraphHeader header = read_binary_graph_header(binary_graph); + + SECTION("partition size 2") + { + uint32_t partition_size = 2; + + uint32_t partition_id = 0; + uint64_t edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); + CHECK(edge_count == 84); + + partition_id = 1; + edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); + CHECK(edge_count == 84); + } + + SECTION("partition size 5") + { + uint32_t partition_size = 5; + + uint32_t partition_id = 0; + uint64_t edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); + CHECK(edge_count == 33); + + partition_id = 1; + edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); + CHECK(edge_count == 33); + + partition_id = 2; + edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); + CHECK(edge_count == 33); + + partition_id = 3; + edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); + CHECK(edge_count == 33); + + partition_id = 4; + edge_count = partition_batch_count(header.edge_count, partition_id, partition_size); + CHECK(edge_count == 36); + } +} + +TEST_CASE("count_of_batches()") +{ + SECTION("simple") + { + uint64_t constexpr edge_count = 30; + uint64_t constexpr max_batch_size = 10; + uint32_t cob = count_of_batches(edge_count, max_batch_size); + + CHECK(cob == 3); + } + + SECTION("simple with rest") + { + uint64_t constexpr edge_count = 33; + uint64_t constexpr max_batch_size = 10; + uint32_t cob = count_of_batches(edge_count, max_batch_size); + + CHECK(cob == 3); + } + + SECTION("simple with max rest") + { + uint64_t constexpr edge_count = 39; + uint64_t constexpr max_batch_size = 10; + uint32_t cob = count_of_batches(edge_count, max_batch_size); + + CHECK(cob == 3); + } +} + +TEST_CASE("fair_batch_size()") +{ + SECTION("simple") + { + uint64_t constexpr edge_count = 30; + uint64_t constexpr max_batch_size = 10; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + + CHECK(fbs == 10); + } + + SECTION("simple with rest") + { + uint64_t constexpr edge_count = 33; + uint64_t constexpr max_batch_size = 10; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + + CHECK(fbs == 11); + } + + SECTION("simple with max rest") + { + uint64_t constexpr edge_count = 39; + uint64_t constexpr max_batch_size = 10; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + + CHECK(fbs == 13); + } + + SECTION("worst case, 2 remaining unincluded ") + { + uint64_t constexpr edge_count = 38; + uint64_t constexpr max_batch_size = 10; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + + CHECK(fbs == 12); + } + + SECTION("safe") + { + uint64_t constexpr edge_count = 30u; + uint64_t constexpr max_batch_size = 100u; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + REQUIRE(fbs == 30u); + } +} + +TEST_CASE("fair_batch_offset()") +{ + SECTION("simple") + { + uint64_t constexpr edge_count = 30u; + uint64_t constexpr max_batch_size = 10u; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + REQUIRE(fbs == 10u); + + uint32_t cob = count_of_batches(edge_count, fbs); + REQUIRE(cob == 3); + + uint64_t current_batch_num = 0u; + auto [begin, count] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin == 0u); + CHECK(count == 10u); + } + + SECTION("simple with rest") + { + uint64_t constexpr edge_count = 33; + uint64_t constexpr max_batch_size = 10; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + REQUIRE(fbs == 11); + + uint32_t cob = count_of_batches(edge_count, fbs); + REQUIRE(cob == 3); + + uint64_t current_batch_num = 0u; + auto [begin_0, count_0] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin_0 == 0u); + CHECK(count_0 == fbs); + + ++current_batch_num; + auto [begin_1, count_1] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin_1 == 11u); + CHECK(count_1 == fbs); + + ++current_batch_num; + auto [begin_2, count_2] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin_2 == 22u); + CHECK(count_2 == fbs); + } + + SECTION("simple with max rest") + { + uint64_t constexpr edge_count = 39; + uint64_t constexpr max_batch_size = 10; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + + CHECK(fbs == 13); + + uint32_t cob = count_of_batches(edge_count, fbs); + REQUIRE(cob == 3); + + uint64_t current_batch_num = 0u; + auto [begin_0, count_0] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin_0 == 0u); + CHECK(count_0 == fbs); + + ++current_batch_num; + auto [begin_1, count_1] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin_1 == 13u); + CHECK(count_1 == fbs); + + ++current_batch_num; + auto [begin_2, count_2] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin_2 == 26u); + CHECK(count_2 == fbs); + } + + SECTION("worst case, 2 remaining unincluded ") + { + uint64_t constexpr edge_count = 38; + uint64_t constexpr max_batch_size = 10; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + + CHECK(fbs == 12); + + uint32_t cob = count_of_batches(edge_count, fbs); + REQUIRE(cob == 3); + + uint64_t current_batch_num = 0u; + auto [begin_0, count_0] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin_0 == 0u); + CHECK(count_0 == fbs); + + ++current_batch_num; + auto [begin_1, count_1] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin_1 == 12u); + CHECK(count_1 == fbs); + + ++current_batch_num; + auto [begin_2, count_2] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin_2 == 24u); + CHECK(count_2 == fbs + 2); + } + + SECTION("safe") + { + uint64_t constexpr edge_count = 30u; + uint64_t constexpr max_batch_size = 100u; + uint32_t fbs = fair_batch_size(edge_count, max_batch_size); + REQUIRE(fbs == 30u); + + uint32_t cob = count_of_batches(edge_count, fbs); + REQUIRE(cob == 1); + + uint64_t current_batch_num = 0u; + auto [begin, count] = fair_batch_offset(fbs, current_batch_num, cob, edge_count); + + CHECK(begin == 0u); + CHECK(count == 30u); + } +} \ No newline at end of file diff --git a/test/mpi_graph_io_tests.cpp b/test/mpi_graph_io_tests.cpp index 02fc5df..b631f5d 100644 --- a/test/mpi_graph_io_tests.cpp +++ b/test/mpi_graph_io_tests.cpp @@ -622,4 +622,420 @@ TEST_CASE("MPI", "read") CHECK(e.edge.target.weight == float(1.)); CHECK(e.timestamp == 1u); } +} + +TEST_CASE("MPI, all_read_binary_graph_batch, batch for batch") +{ + std::filesystem::path file_path(graph_path + directed_unweighted_graph_enzymes_bin); + mpi::FileWrapper binary_graph{ file_path }; + + BinaryGraphHeader header = mpi::read_binary_graph_header(binary_graph.get()); + REQUIRE(header.vertex_id_byte_size == sizeof(Vertex32)); + REQUIRE(header.directed); + REQUIRE(!header.weighted); + REQUIRE(!header.dynamic); + + mpi::MPIEdge32 mpi_edge_t; + + uint32_t constexpr max_batch_size = 10u; + mpi::ReadBatch read_batch; + read_batch.edge_size_in_bytes = sizeof(Edge32); + read_batch.batch_size = fair_batch_size(header.edge_count, max_batch_size); + + REQUIRE(read_batch.batch_size == 10u); + + uint32_t const cob = count_of_batches(header.edge_count, read_batch.batch_size); + REQUIRE(cob == 16u); + + uint32_t current_batch = 0; + auto const [offset, count] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + + Edges32 edges(count); + REQUIRE(edges.size() == count); + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + + size_t idx = 0; + CHECK(edges[idx].source == 2); + CHECK(edges[idx++].target == 1); + + CHECK(edges[idx].source == 3); + CHECK(edges[idx++].target == 1); + + CHECK(edges[idx].source == 4); + CHECK(edges[idx++].target == 1); + + CHECK(edges[idx].source == 1); + CHECK(edges[idx++].target == 2); + + CHECK(edges[idx].source == 3); + CHECK(edges[idx++].target == 2); + + CHECK(edges[idx].source == 4); + CHECK(edges[idx++].target == 2); + + CHECK(edges[idx].source == 25); + CHECK(edges[idx++].target == 2); + + CHECK(edges[idx].source == 28); + CHECK(edges[idx++].target == 2); + + CHECK(edges[idx].source == 1); + CHECK(edges[idx++].target == 3); + + CHECK(edges[idx].source == 2); + CHECK(edges[idx++].target == 3); + + CHECK(idx == edges.size()); + + ++current_batch; + auto const [offset_1, count_1] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + + edges.resize(count_1); + REQUIRE(edges.size() == count_1); + read_batch.batch_size = count_1; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + + idx = 0; + + CHECK(edges[idx].source == 4); + CHECK(edges[idx++].target == 3); + + CHECK(edges[idx].source == 28); + CHECK(edges[idx++].target == 3); + + CHECK(edges[idx].source == 29); + CHECK(edges[idx++].target == 3); + + CHECK(edges[idx].source == 1); + CHECK(edges[idx++].target == 4); + + CHECK(edges[idx].source == 2); + CHECK(edges[idx++].target == 4); + + CHECK(edges[idx].source == 3); + CHECK(edges[idx++].target == 4); + + CHECK(edges[idx].source == 5); + CHECK(edges[idx++].target == 4); + + CHECK(edges[idx].source == 6); + CHECK(edges[idx++].target == 4); + + CHECK(edges[idx].source == 29); + CHECK(edges[idx++].target == 4); + + CHECK(edges[idx].source == 4); + CHECK(edges[idx++].target == 5); + + CHECK(idx == edges.size()); + + ++current_batch; + auto const [offset_2, count_2] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + + edges.resize(count_2); + REQUIRE(edges.size() == count_2); + read_batch.batch_size = count_2; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + + idx = 0; + + CHECK((edges[idx].source == 6 && edges[idx++].target == 5)); + CHECK((edges[idx].source == 7 && edges[idx++].target == 5)); + CHECK((edges[idx].source == 30 && edges[idx++].target == 5)); + CHECK((edges[idx].source == 4 && edges[idx++].target == 6)); + CHECK((edges[idx].source == 5 && edges[idx++].target == 6)); + CHECK((edges[idx].source == 7 && edges[idx++].target == 6)); + CHECK((edges[idx].source == 8 && edges[idx++].target == 6)); + CHECK((edges[idx].source == 30 && edges[idx++].target == 6)); + CHECK((edges[idx].source == 5 && edges[idx++].target == 7)); + CHECK((edges[idx].source == 6 && edges[idx++].target == 7)); + + ++current_batch; + auto const [offset_3, count_3] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + + edges.resize(count_3); + REQUIRE(edges.size() == count_3); + read_batch.batch_size = count_3; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 8 && edges[idx++].target == 7)); + CHECK((edges[idx].source == 9 && edges[idx++].target == 7)); + CHECK((edges[idx].source == 6 && edges[idx++].target == 8)); + CHECK((edges[idx].source == 7 && edges[idx++].target == 8)); + CHECK((edges[idx].source == 9 && edges[idx++].target == 8)); + CHECK((edges[idx].source == 10 && edges[idx++].target == 8)); + CHECK((edges[idx].source == 11 && edges[idx++].target == 8)); + CHECK((edges[idx].source == 7 && edges[idx++].target == 9)); + CHECK((edges[idx].source == 8 && edges[idx++].target == 9)); + CHECK((edges[idx].source == 10 && edges[idx++].target == 9)); + + ++current_batch; + auto const [offset_4, count_4] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_4); + read_batch.batch_size = count_4; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 8 && edges[idx++].target == 10)); + CHECK((edges[idx].source == 9 && edges[idx++].target == 10)); + CHECK((edges[idx].source == 11 && edges[idx++].target == 10)); + CHECK((edges[idx].source == 12 && edges[idx++].target == 10)); + CHECK((edges[idx].source == 13 && edges[idx++].target == 10)); + CHECK((edges[idx].source == 8 && edges[idx++].target == 11)); + CHECK((edges[idx].source == 10 && edges[idx++].target == 11)); + CHECK((edges[idx].source == 12 && edges[idx++].target == 11)); + CHECK((edges[idx].source == 13 && edges[idx++].target == 11)); + CHECK((edges[idx].source == 10 && edges[idx++].target == 12)); + + ++current_batch; + auto const [offset_5, count_5] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_5); + read_batch.batch_size = count_5; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 11 && edges[idx++].target == 12)); + CHECK((edges[idx].source == 13 && edges[idx++].target == 12)); + CHECK((edges[idx].source == 27 && edges[idx++].target == 12)); + CHECK((edges[idx].source == 10 && edges[idx++].target == 13)); + CHECK((edges[idx].source == 11 && edges[idx++].target == 13)); + CHECK((edges[idx].source == 12 && edges[idx++].target == 13)); + CHECK((edges[idx].source == 26 && edges[idx++].target == 13)); + CHECK((edges[idx].source == 27 && edges[idx++].target == 13)); + CHECK((edges[idx].source == 15 && edges[idx++].target == 14)); + CHECK((edges[idx].source == 16 && edges[idx++].target == 14)); + + ++current_batch; + auto const [offset_6, count_6] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_6); + read_batch.batch_size = count_6; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 17 && edges[idx++].target == 14)); + CHECK((edges[idx].source == 26 && edges[idx++].target == 14)); + CHECK((edges[idx].source == 14 && edges[idx++].target == 15)); + CHECK((edges[idx].source == 16 && edges[idx++].target == 15)); + CHECK((edges[idx].source == 17 && edges[idx++].target == 15)); + CHECK((edges[idx].source == 26 && edges[idx++].target == 15)); + CHECK((edges[idx].source == 14 && edges[idx++].target == 16)); + CHECK((edges[idx].source == 15 && edges[idx++].target == 16)); + CHECK((edges[idx].source == 17 && edges[idx++].target == 16)); + CHECK((edges[idx].source == 18 && edges[idx++].target == 16)); + + ++current_batch; + auto const [offset_7, count_7] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_7); + read_batch.batch_size = count_7; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 14 && edges[idx++].target == 17)); + CHECK((edges[idx].source == 15 && edges[idx++].target == 17)); + CHECK((edges[idx].source == 16 && edges[idx++].target == 17)); + CHECK((edges[idx].source == 18 && edges[idx++].target == 17)); + CHECK((edges[idx].source == 16 && edges[idx++].target == 18)); + CHECK((edges[idx].source == 17 && edges[idx++].target == 18)); + CHECK((edges[idx].source == 19 && edges[idx++].target == 18)); + CHECK((edges[idx].source == 20 && edges[idx++].target == 18)); + CHECK((edges[idx].source == 18 && edges[idx++].target == 19)); + CHECK((edges[idx].source == 20 && edges[idx++].target == 19)); + + ++current_batch; + auto const [offset_8, count_8] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_8); + read_batch.batch_size = count_8; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 21 && edges[idx++].target == 19)); + CHECK((edges[idx].source == 18 && edges[idx++].target == 20)); + CHECK((edges[idx].source == 19 && edges[idx++].target == 20)); + CHECK((edges[idx].source == 21 && edges[idx++].target == 20)); + CHECK((edges[idx].source == 19 && edges[idx++].target == 21)); + CHECK((edges[idx].source == 20 && edges[idx++].target == 21)); + CHECK((edges[idx].source == 22 && edges[idx++].target == 21)); + CHECK((edges[idx].source == 23 && edges[idx++].target == 21)); + CHECK((edges[idx].source == 24 && edges[idx++].target == 21)); + CHECK((edges[idx].source == 31 && edges[idx++].target == 21)); + + ++current_batch; + auto const [offset_9, count_9] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_9); + read_batch.batch_size = count_9; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 21 && edges[idx++].target == 22)); + CHECK((edges[idx].source == 23 && edges[idx++].target == 22)); + CHECK((edges[idx].source == 24 && edges[idx++].target == 22)); + CHECK((edges[idx].source == 31 && edges[idx++].target == 22)); + CHECK((edges[idx].source == 36 && edges[idx++].target == 22)); + CHECK((edges[idx].source == 21 && edges[idx++].target == 23)); + CHECK((edges[idx].source == 22 && edges[idx++].target == 23)); + CHECK((edges[idx].source == 24 && edges[idx++].target == 23)); + CHECK((edges[idx].source == 36 && edges[idx++].target == 23)); + CHECK((edges[idx].source == 21 && edges[idx++].target == 24)); + + ++current_batch; + auto const [offset_10, count_10] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_10); + read_batch.batch_size = count_10; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 22 && edges[idx++].target == 24)); + CHECK((edges[idx].source == 23 && edges[idx++].target == 24)); + CHECK((edges[idx].source == 34 && edges[idx++].target == 24)); + CHECK((edges[idx].source == 2 && edges[idx++].target == 25)); + CHECK((edges[idx].source == 28 && edges[idx++].target == 25)); + CHECK((edges[idx].source == 29 && edges[idx++].target == 25)); + CHECK((edges[idx].source == 30 && edges[idx++].target == 25)); + CHECK((edges[idx].source == 13 && edges[idx++].target == 26)); + CHECK((edges[idx].source == 14 && edges[idx++].target == 26)); + CHECK((edges[idx].source == 15 && edges[idx++].target == 26)); + + ++current_batch; + auto const [offset_11, count_11] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_11); + read_batch.batch_size = count_11; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 27 && edges[idx++].target == 26)); + CHECK((edges[idx].source == 30 && edges[idx++].target == 26)); + CHECK((edges[idx].source == 12 && edges[idx++].target == 27)); + CHECK((edges[idx].source == 13 && edges[idx++].target == 27)); + CHECK((edges[idx].source == 26 && edges[idx++].target == 27)); + CHECK((edges[idx].source == 29 && edges[idx++].target == 27)); + CHECK((edges[idx].source == 30 && edges[idx++].target == 27)); + CHECK((edges[idx].source == 2 && edges[idx++].target == 28)); + CHECK((edges[idx].source == 3 && edges[idx++].target == 28)); + CHECK((edges[idx].source == 25 && edges[idx++].target == 28)); + + ++current_batch; + auto const [offset_12, count_12] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_12); + read_batch.batch_size = count_12; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 29 && edges[idx++].target == 28)); + CHECK((edges[idx].source == 30 && edges[idx++].target == 28)); + CHECK((edges[idx].source == 3 && edges[idx++].target == 29)); + CHECK((edges[idx].source == 4 && edges[idx++].target == 29)); + CHECK((edges[idx].source == 25 && edges[idx++].target == 29)); + CHECK((edges[idx].source == 27 && edges[idx++].target == 29)); + CHECK((edges[idx].source == 28 && edges[idx++].target == 29)); + CHECK((edges[idx].source == 30 && edges[idx++].target == 29)); + CHECK((edges[idx].source == 5 && edges[idx++].target == 30)); + CHECK((edges[idx].source == 6 && edges[idx++].target == 30)); + + ++current_batch; + auto const [offset_13, count_13] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_13); + read_batch.batch_size = count_13; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 25 && edges[idx++].target == 30)); + CHECK((edges[idx].source == 26 && edges[idx++].target == 30)); + CHECK((edges[idx].source == 27 && edges[idx++].target == 30)); + CHECK((edges[idx].source == 28 && edges[idx++].target == 30)); + CHECK((edges[idx].source == 29 && edges[idx++].target == 30)); + CHECK((edges[idx].source == 21 && edges[idx++].target == 31)); + CHECK((edges[idx].source == 22 && edges[idx++].target == 31)); + CHECK((edges[idx].source == 34 && edges[idx++].target == 31)); + CHECK((edges[idx].source == 35 && edges[idx++].target == 31)); + CHECK((edges[idx].source == 36 && edges[idx++].target == 31)); + + ++current_batch; + auto const [offset_14, count_14] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + edges.resize(count_14); + read_batch.batch_size = count_14; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 33 && edges[idx++].target == 32)); + CHECK((edges[idx].source == 35 && edges[idx++].target == 32)); + CHECK((edges[idx].source == 37 && edges[idx++].target == 32)); + CHECK((edges[idx].source == 32 && edges[idx++].target == 33)); + CHECK((edges[idx].source == 34 && edges[idx++].target == 33)); + CHECK((edges[idx].source == 35 && edges[idx++].target == 33)); + CHECK((edges[idx].source == 37 && edges[idx++].target == 33)); + CHECK((edges[idx].source == 24 && edges[idx++].target == 34)); + CHECK((edges[idx].source == 31 && edges[idx++].target == 34)); + CHECK((edges[idx].source == 33 && edges[idx++].target == 34)); + + ++current_batch; + auto const [offset_15, count_15] = fair_batch_offset(read_batch.batch_size, current_batch, cob, header.edge_count); + uint64_t remaining_edges = 8u; + REQUIRE(count_15 == read_batch.batch_size + remaining_edges); + edges.resize(count_15); + read_batch.batch_size = count_15; + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + idx = 0; + + CHECK((edges[idx].source == 35 && edges[idx++].target == 34)); + CHECK((edges[idx].source == 36 && edges[idx++].target == 34)); + CHECK((edges[idx].source == 37 && edges[idx++].target == 34)); + CHECK((edges[idx].source == 31 && edges[idx++].target == 35)); + CHECK((edges[idx].source == 32 && edges[idx++].target == 35)); + CHECK((edges[idx].source == 33 && edges[idx++].target == 35)); + CHECK((edges[idx].source == 34 && edges[idx++].target == 35)); + CHECK((edges[idx].source == 36 && edges[idx++].target == 35)); + CHECK((edges[idx].source == 37 && edges[idx++].target == 35)); + CHECK((edges[idx].source == 22 && edges[idx++].target == 36)); + CHECK((edges[idx].source == 23 && edges[idx++].target == 36)); + CHECK((edges[idx].source == 31 && edges[idx++].target == 36)); + CHECK((edges[idx].source == 34 && edges[idx++].target == 36)); + CHECK((edges[idx].source == 35 && edges[idx++].target == 36)); + CHECK((edges[idx].source == 32 && edges[idx++].target == 37)); + CHECK((edges[idx].source == 33 && edges[idx++].target == 37)); + CHECK((edges[idx].source == 34 && edges[idx++].target == 37)); + CHECK((edges[idx].source == 35 && edges[idx++].target == 37)); + + CHECK(read_batch.count_read_in_edges == header.edge_count); +} + +TEST_CASE("MPI, all_read_binary_graph_batch, read issues") +{ + std::filesystem::path file_path(graph_path + directed_unweighted_graph_enzymes_bin); + mpi::FileWrapper binary_graph{ file_path }; + + BinaryGraphHeader header = mpi::read_binary_graph_header(binary_graph.get()); + REQUIRE(header.vertex_id_byte_size == sizeof(Vertex32)); + REQUIRE(header.directed); + REQUIRE(!header.weighted); + REQUIRE(!header.dynamic); + + mpi::MPIEdge32 mpi_edge_t; + + mpi::ReadBatch read_batch; + read_batch.edge_size_in_bytes = sizeof(Edge32); + + SECTION("Read Success") + { + read_batch.batch_size = header.edge_count; + Edges32 edges(header.edge_count); + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + CHECK(read_batch.count_read_in_edges == header.edge_count); + } + + SECTION("Read exceeds counter") + { + Edges32 edges(header.edge_count); + read_batch.batch_size = std::numeric_limits::max(); + CHECK_THROWS(mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get())); + } + + SECTION("Read exceeds counter") + { + read_batch.batch_size = header.edge_count + 1; + Edges32 edges(read_batch.batch_size, Edge32{ 0u, 0u }); + mpi::all_read_binary_graph_batch(binary_graph.get(), header, &(edges[0]), read_batch, mpi_edge_t.get()); + CHECK(edges[header.edge_count].source == 0u); + CHECK(read_batch.count_read_in_edges == header.edge_count); + } } \ No newline at end of file