Skip to content

Commit c6c42c4

Browse files
razdoburdinDmitry RazdoburdinCopilot
authored
Native serialization to a stream for FlatIndex (#280)
Reopening of #275 for developer branch --------- Co-authored-by: Dmitry Razdoburdin <drazdobu@intel.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 2a9f81c commit c6c42c4

File tree

11 files changed

+493
-101
lines changed

11 files changed

+493
-101
lines changed

include/svs/core/data/io.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ void populate_impl(
7979
}
8080
}
8181

82+
template <data::MemoryDataset Data> void populate(std::istream& is, Data& data) {
83+
auto accessor = DefaultWriteAccessor();
84+
85+
size_t num_vectors = data.size();
86+
size_t dims = data.dimensions();
87+
88+
auto max_lines = Dynamic;
89+
auto nvectors = std::min(num_vectors, max_lines);
90+
91+
auto reader = lib::VectorReader<typename Data::element_type>(dims);
92+
for (size_t i = 0; i < nvectors; ++i) {
93+
reader.read(is);
94+
accessor.set(data, i, reader.data());
95+
}
96+
}
97+
8298
// Intercept the native file to perform dispatch on the actual file type.
8399
template <data::MemoryDataset Data, typename WriteAccessor>
84100
void populate_impl(
@@ -120,6 +136,15 @@ void save(const Dataset& data, const File& file, const lib::UUID& uuid = lib::Ze
120136
return save(data, accessor, file, uuid);
121137
}
122138

139+
template <data::ImmutableMemoryDataset Dataset>
140+
void save(const Dataset& data, std::ostream& os) {
141+
auto accessor = DefaultReadAccessor();
142+
auto writer = svs::io::v1::StreamWriter<void>(os);
143+
for (size_t i = 0; i < data.size(); ++i) {
144+
writer << accessor.get(data, i);
145+
}
146+
}
147+
123148
///
124149
/// @brief Save the dataset as a "*vecs" file.
125150
///
@@ -169,6 +194,14 @@ lib::lazy_result_t<F, size_t, size_t> load_dataset(const File& file, const F& la
169194
return load_impl(detail::to_native(file), default_accessor, lazy);
170195
}
171196

197+
template <lib::LazyInvocable<size_t, size_t> F>
198+
lib::lazy_result_t<F, size_t, size_t>
199+
load_dataset(std::istream& is, const F& lazy, size_t num_vectors, size_t dims) {
200+
auto data = lazy(num_vectors, dims);
201+
populate(is, data);
202+
return data;
203+
}
204+
172205
// Return whether or not a file is directly loadable via file-extension.
173206
inline bool special_by_file_extension(std::string_view path) {
174207
return (path.ends_with("svs") || path.ends_with("vecs") || path.ends_with("bin"));

include/svs/core/data/simple.h

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,42 @@ class GenericSerializer {
7575
}
7676

7777
template <data::ImmutableMemoryDataset Data>
78-
static lib::SaveTable save(const Data& data, const lib::SaveContext& ctx) {
78+
static lib::SaveTable save_table(const Data& data) {
7979
using T = typename Data::element_type;
80-
// UUID used to identify the file.
81-
auto uuid = lib::UUID{};
82-
auto filename = ctx.generate_name("data");
83-
io::save(data, io::NativeFile(filename), uuid);
84-
return lib::SaveTable(
80+
auto table = lib::SaveTable(
8581
serialization_schema,
8682
save_version,
8783
{
8884
{"name", "uncompressed"},
89-
{"binary_file", lib::save(filename.filename())},
9085
{"dims", lib::save(data.dimensions())},
9186
{"num_vectors", lib::save(data.size())},
92-
{"uuid", uuid.str()},
9387
{"eltype", lib::save(datatype_v<T>)},
9488
}
9589
);
90+
return table;
91+
}
92+
93+
template <data::ImmutableMemoryDataset Data, class FileName_t>
94+
static lib::SaveTable
95+
save_table(const Data& data, const FileName_t& filename, const lib::UUID& uuid) {
96+
auto table = save_table(data);
97+
table.insert("binary_file", filename);
98+
table.insert("uuid", uuid.str());
99+
return table;
100+
}
101+
102+
template <data::ImmutableMemoryDataset Data>
103+
static lib::SaveTable save(const Data& data, const lib::SaveContext& ctx) {
104+
// UUID used to identify the file.
105+
auto uuid = lib::UUID{};
106+
auto filename = ctx.generate_name("data");
107+
io::save(data, io::NativeFile(filename), uuid);
108+
return save_table(data, lib::save(filename.filename()), uuid);
109+
}
110+
111+
template <data::ImmutableMemoryDataset Data>
112+
static void save(const Data& data, std::ostream& os) {
113+
io::save(data, os);
96114
}
97115

98116
template <typename T, lib::LazyInvocable<size_t, size_t> F>
@@ -116,6 +134,25 @@ class GenericSerializer {
116134
}
117135
return io::load_dataset(binaryfile.value(), lazy);
118136
}
137+
138+
template <typename T, lib::LazyInvocable<size_t, size_t> F>
139+
static lib::lazy_result_t<F, size_t, size_t>
140+
load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) {
141+
auto datatype = lib::load_at<DataType>(table, "eltype");
142+
if (datatype != datatype_v<T>) {
143+
throw ANNEXCEPTION(
144+
"Trying to load an uncompressed dataset with element types {} to a dataset "
145+
"with element types {}.",
146+
name(datatype),
147+
name<datatype_v<T>>()
148+
);
149+
}
150+
151+
size_t num_vectors = lib::load_at<size_t>(table, "num_vectors");
152+
size_t dims = lib::load_at<size_t>(table, "dims");
153+
154+
return io::load_dataset(is, lazy, num_vectors, dims);
155+
}
119156
};
120157

121158
struct Matcher {
@@ -405,6 +442,10 @@ class SimpleData {
405442
return GenericSerializer::save(*this, ctx);
406443
}
407444

445+
void save(std::ostream& os) const { return GenericSerializer::save(*this, os); }
446+
447+
lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); }
448+
408449
static bool check_load_compatibility(std::string_view schema, lib::Version version) {
409450
return GenericSerializer::check_compatibility(schema, version);
410451
}
@@ -431,6 +472,20 @@ class SimpleData {
431472
);
432473
}
433474

475+
static SimpleData load(
476+
const lib::ContextFreeLoadTable& table,
477+
std::istream& is,
478+
const allocator_type& allocator = {}
479+
)
480+
requires(!is_view)
481+
{
482+
return GenericSerializer::load<T>(
483+
table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) {
484+
return SimpleData(n_elements, n_dimensions, allocator);
485+
})
486+
);
487+
}
488+
434489
///
435490
/// @brief Try to automatically load the dataset.
436491
///
@@ -805,6 +860,10 @@ class SimpleData<T, Extent, Blocked<Alloc>> {
805860
return GenericSerializer::save(*this, ctx);
806861
}
807862

863+
void save(std::ostream& os) const { return GenericSerializer::save(*this, os); }
864+
865+
lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); }
866+
808867
static bool check_load_compatibility(std::string_view schema, lib::Version version) {
809868
return GenericSerializer::check_compatibility(schema, version);
810869
}
@@ -818,6 +877,18 @@ class SimpleData<T, Extent, Blocked<Alloc>> {
818877
);
819878
}
820879

880+
static SimpleData load(
881+
const lib::ContextFreeLoadTable& table,
882+
std::istream& is,
883+
const Blocked<Alloc>& allocator = {}
884+
) {
885+
return GenericSerializer::load<T>(
886+
table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) {
887+
return SimpleData(n_elements, n_dimensions, allocator);
888+
})
889+
);
890+
}
891+
821892
static SimpleData
822893
load(const std::filesystem::path& path, const Blocked<Alloc>& allocator = {}) {
823894
if (detail::is_likely_reload(path)) {

include/svs/core/io/native.h

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -344,28 +344,16 @@ struct Header {
344344
static_assert(sizeof(Header) == header_size, "Mismatch in Native io::v1 header sizes!");
345345
static_assert(std::is_trivially_copyable_v<Header>, "Header must be trivially copyable!");
346346

347-
template <typename T = void> class Writer {
347+
// CRTP
348+
template <typename T, class Derived> class Writer {
348349
public:
349-
Writer(
350-
const std::string& path,
351-
size_t dimension,
352-
lib::UUID uuid = lib::UUID(lib::ZeroInitializer())
353-
)
354-
: dimension_{dimension}
355-
, uuid_{uuid}
356-
, stream_{lib::open_write(path, std::ofstream::out | std::ofstream::binary)} {
357-
// Write a temporary header.
358-
stream_.seekp(0, std::ofstream::beg);
359-
lib::write_binary(stream_, Header());
360-
}
361-
362-
size_t dimensions() const { return dimension_; }
363350
void overwrite_num_vectors(size_t num_vectors) { vectors_written_ = num_vectors; }
364351

365352
// TODO: Error checking to make sure the length is correct.
366353
template <typename U> Writer& append(U&& v) {
354+
std::ostream& os = static_cast<Derived*>(this)->stream();
367355
for (const auto& i : v) {
368-
lib::write_binary(stream_, lib::io_convert<T>(i));
356+
lib::write_binary(os, lib::io_convert<T>(i));
369357
}
370358
++vectors_written_;
371359
return *this;
@@ -374,21 +362,45 @@ template <typename T = void> class Writer {
374362
template <typename... Ts>
375363
requires std::is_same_v<T, void>
376364
Writer& append(std::tuple<Ts...>&& v) {
377-
lib::foreach (v, [&](const auto& x) { lib::write_binary(stream_, x); });
365+
std::ostream& os = static_cast<Derived*>(this)->stream();
366+
lib::foreach (v, [&](const auto& x) { lib::write_binary(os, x); });
378367
++vectors_written_;
379368
return *this;
380369
}
381370

382371
template <typename U> Writer& operator<<(U&& v) { return append(std::forward<U>(v)); }
383372

373+
protected:
374+
size_t vectors_written_ = 0;
375+
};
376+
377+
template <typename T = void> class FileWriter : public Writer<T, FileWriter<T>> {
378+
public:
379+
FileWriter(
380+
const std::string& path,
381+
size_t dimension,
382+
lib::UUID uuid = lib::UUID(lib::ZeroInitializer())
383+
)
384+
: dimension_{dimension}
385+
, uuid_{uuid}
386+
, stream_{lib::open_write(path, std::ofstream::out | std::ofstream::binary)} {
387+
// Write a temporary header.
388+
stream_.seekp(0, std::ofstream::beg);
389+
lib::write_binary(stream_, Header());
390+
}
391+
392+
std::ostream& stream() { return stream_; }
393+
394+
size_t dimensions() const { return dimension_; }
395+
384396
void flush() { stream_.flush(); }
385397

386398
void writeheader(bool resume = true) {
387399
auto position = stream_.tellp();
388400
// Write to the header the number of vectors actually written.
389401
stream_.seekp(0);
390402
assert(stream_.good());
391-
lib::write_binary(stream_, Header(vectors_written_, dimension_, uuid_));
403+
lib::write_binary(stream_, Header(this->vectors_written_, dimension_, uuid_));
392404
if (resume) {
393405
stream_.seekp(position, std::ofstream::beg);
394406
}
@@ -402,20 +414,30 @@ template <typename T = void> class Writer {
402414
//
403415
// We delete the copy constructor and copy assignment operators because
404416
// `std::ofstream` isn't copyable anyways.
405-
Writer(const Writer&) = delete;
406-
Writer& operator=(const Writer&) = delete;
407-
Writer(Writer&&) = delete;
408-
Writer& operator=(Writer&&) = delete;
417+
FileWriter(const FileWriter&) = delete;
418+
FileWriter& operator=(const FileWriter&) = delete;
419+
FileWriter(FileWriter&&) = delete;
420+
FileWriter& operator=(FileWriter&&) = delete;
409421

410422
// Write the header for the file.
411-
~Writer() noexcept { writeheader(); }
423+
~FileWriter() noexcept { writeheader(); }
412424

413425
private:
414426
size_t dimension_;
415427
lib::UUID uuid_;
416428
std::ofstream stream_;
417429
size_t writes_this_vector_ = 0;
418-
size_t vectors_written_ = 0;
430+
};
431+
432+
template <typename T = void> class StreamWriter : public Writer<T, StreamWriter<T>> {
433+
public:
434+
StreamWriter(std::ostream& os)
435+
: stream_{os} {}
436+
437+
std::ostream& stream() { return stream_; }
438+
439+
private:
440+
std::ostream& stream_;
419441
};
420442

421443
///
@@ -449,13 +471,13 @@ class NativeFile {
449471
}
450472

451473
template <typename T>
452-
Writer<T> writer(
474+
FileWriter<T> writer(
453475
lib::Type<T> SVS_UNUSED(type), size_t dimension, lib::UUID uuid = lib::ZeroUUID
454476
) const {
455-
return Writer<T>(path_, dimension, uuid);
477+
return FileWriter<T>(path_, dimension, uuid);
456478
}
457479

458-
Writer<> writer(size_t dimensions, lib::UUID uuid = lib::ZeroUUID) const {
480+
FileWriter<> writer(size_t dimensions, lib::UUID uuid = lib::ZeroUUID) const {
459481
return writer(lib::Type<void>(), dimensions, uuid);
460482
}
461483

@@ -715,7 +737,7 @@ class NativeFile {
715737
public:
716738
using compatible_file_types = lib::Types<vtest::NativeFile, v1::NativeFile>;
717739

718-
template <typename T> using Writer = v1::Writer<T>;
740+
template <typename T> using Writer = v1::FileWriter<T>;
719741

720742
explicit NativeFile(std::filesystem::path path)
721743
: path_{std::move(path)} {}

include/svs/index/flat/flat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,8 @@ class FlatIndex {
522522
void save(const std::filesystem::path& data_directory) const {
523523
lib::save_to_disk(data_, data_directory);
524524
}
525+
526+
void save(std::ostream& os) const { lib::save_to_stream(data_, os); }
525527
};
526528

527529
///

0 commit comments

Comments
 (0)