Skip to content

Commit 4c0f19b

Browse files
razdoburdinDmitry Razdoburdin
andauthored
Native serialization for DynamicFlat index (#281)
This PR introduce native serialization for DynamicFlat index. Main changes are: 1. `auto_dynamic_assemble` now accepts lazy loader. That is mandatory for buffer-free deserialization. 2. new class `Deserializer' is introduced. It is responsible for conditional reading of overhead data (like names of temporary files) in case of legacy models. 3. `IDTranslator` is refactored to cover save and load to/from stream. --------- Co-authored-by: Dmitry Razdoburdin <drazdobu@intel.com>
1 parent c6c42c4 commit 4c0f19b

File tree

10 files changed

+411
-89
lines changed

10 files changed

+411
-89
lines changed

include/svs/core/data/simple.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,12 @@ class GenericSerializer {
136136
}
137137

138138
template <typename T, lib::LazyInvocable<size_t, size_t> F>
139-
static lib::lazy_result_t<F, size_t, size_t>
140-
load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) {
139+
static lib::lazy_result_t<F, size_t, size_t> load(
140+
const lib::ContextFreeLoadTable& table,
141+
const lib::detail::Deserializer& deserializer,
142+
std::istream& is,
143+
const F& lazy
144+
) {
141145
auto datatype = lib::load_at<DataType>(table, "eltype");
142146
if (datatype != datatype_v<T>) {
143147
throw ANNEXCEPTION(
@@ -151,6 +155,10 @@ class GenericSerializer {
151155
size_t num_vectors = lib::load_at<size_t>(table, "num_vectors");
152156
size_t dims = lib::load_at<size_t>(table, "dims");
153157

158+
deserializer.read_name(is);
159+
deserializer.read_size(is);
160+
deserializer.read_binary<io::v1::Header>(is);
161+
154162
return io::load_dataset(is, lazy, num_vectors, dims);
155163
}
156164
};
@@ -474,13 +482,14 @@ class SimpleData {
474482

475483
static SimpleData load(
476484
const lib::ContextFreeLoadTable& table,
485+
const lib::detail::Deserializer& deserializer,
477486
std::istream& is,
478487
const allocator_type& allocator = {}
479488
)
480489
requires(!is_view)
481490
{
482491
return GenericSerializer::load<T>(
483-
table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) {
492+
table, deserializer, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) {
484493
return SimpleData(n_elements, n_dimensions, allocator);
485494
})
486495
);
@@ -879,11 +888,15 @@ class SimpleData<T, Extent, Blocked<Alloc>> {
879888

880889
static SimpleData load(
881890
const lib::ContextFreeLoadTable& table,
891+
const lib::detail::Deserializer& deserializer,
882892
std::istream& is,
883893
const Blocked<Alloc>& allocator = {}
884894
) {
885895
return GenericSerializer::load<T>(
886-
table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) {
896+
table,
897+
deserializer,
898+
is,
899+
lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) {
887900
return SimpleData(n_elements, n_dimensions, allocator);
888901
})
889902
);

include/svs/core/translation.h

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -324,27 +324,36 @@ class IDTranslator {
324324
"external_to_internal_translation";
325325
static constexpr lib::Version save_version = lib::Version(0, 0, 0);
326326

327-
lib::SaveTable save(const lib::SaveContext& ctx) const {
328-
auto filename = ctx.generate_name("id_translation", "binary");
329-
// Save the translations to a file.
330-
auto stream = lib::open_write(filename);
331-
for (auto i = begin(), iend = end(); i != iend; ++i) {
332-
// N.B.: Apparently `std::pair` of integers is not trivially copyable ...
333-
lib::write_binary(stream, i->first);
334-
lib::write_binary(stream, i->second);
335-
}
327+
lib::SaveTable save_table() const {
336328
return lib::SaveTable(
337329
serialization_schema,
338330
save_version,
339331
{{"kind", kind},
340332
{"num_points", lib::save(size())},
341333
{"external_id_type", lib::save(datatype_v<external_id_type>)},
342-
{"internal_id_type", lib::save(datatype_v<internal_id_type>)},
343-
{"filename", lib::save(filename.filename())}}
334+
{"internal_id_type", lib::save(datatype_v<internal_id_type>)}}
344335
);
345336
}
346337

347-
static IDTranslator load(const lib::LoadTable& table) {
338+
void save(std::ostream& os) const {
339+
for (auto i = begin(), iend = end(); i != iend; ++i) {
340+
// N.B.: Apparently `std::pair` of integers is not trivially copyable ...
341+
lib::write_binary(os, i->first);
342+
lib::write_binary(os, i->second);
343+
}
344+
}
345+
346+
lib::SaveTable save(const lib::SaveContext& ctx) const {
347+
auto filename = ctx.generate_name("id_translation", "binary");
348+
// Save the translations to a file.
349+
auto os = lib::open_write(filename);
350+
save(os);
351+
auto table = save_table();
352+
table.insert("filename", lib::save(filename.filename()));
353+
return table;
354+
}
355+
356+
static void validate(const lib::ContextFreeLoadTable& table) {
348357
if (kind != lib::load_at<std::string>(table, "kind")) {
349358
throw ANNEXCEPTION("Mismatched kind!");
350359
}
@@ -357,21 +366,42 @@ class IDTranslator {
357366
if (internal_id_name != lib::load_at<std::string>(table, "internal_id_type")) {
358367
throw ANNEXCEPTION("Mismatched internal id types!");
359368
}
369+
}
360370

361-
// Now that we've more-or-less validated the metadata, time to start loading
362-
// the points.
371+
static IDTranslator load(const lib::ContextFreeLoadTable& table, std::istream& is) {
363372
auto num_points = lib::load_at<size_t>(table, "num_points");
373+
364374
auto translator = IDTranslator{};
365-
auto resolved = table.resolve_at("filename");
366-
auto stream = lib::open_read(resolved);
367375
for (size_t i = 0; i < num_points; ++i) {
368-
auto external_id = lib::read_binary<external_id_type>(stream);
369-
auto internal_id = lib::read_binary<internal_id_type>(stream);
376+
auto external_id = lib::read_binary<external_id_type>(is);
377+
auto internal_id = lib::read_binary<internal_id_type>(is);
370378
translator.insert_translation(external_id, internal_id);
371379
}
372380
return translator;
373381
}
374382

383+
static IDTranslator load(
384+
const lib::ContextFreeLoadTable& table,
385+
const lib::detail::Deserializer& deserializer,
386+
std::istream& is
387+
) {
388+
IDTranslator::validate(table);
389+
deserializer.read_name(is);
390+
deserializer.read_size(is);
391+
392+
return IDTranslator::load(table, is);
393+
}
394+
395+
static IDTranslator load(const lib::LoadTable& table) {
396+
IDTranslator::validate(table);
397+
398+
// Now that we've more-or-less validated the metadata, time to start loading
399+
// the points.
400+
auto resolved = table.resolve_at("filename");
401+
auto is = lib::open_read(resolved);
402+
return IDTranslator::load(table, is);
403+
}
404+
375405
private:
376406
template <class Begin, class End, class Map, class Modifier = lib::identity>
377407
void check(

include/svs/index/flat/dynamic_flat.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "svs/lib/invoke.h"
3737
#include "svs/lib/misc.h"
3838
#include "svs/lib/preprocessor.h"
39+
#include "svs/lib/stream.h"
3940
#include "svs/lib/threads.h"
4041

4142
namespace svs::index::flat {
@@ -403,6 +404,26 @@ template <typename Data, typename Dist> class DynamicFlatIndex {
403404
// Save the dataset in the separate data directory
404405
lib::save_to_disk(data_, data_directory);
405406
}
407+
408+
void save(std::ostream& os) {
409+
compact();
410+
411+
lib::begin_serialization(os);
412+
// Save data structures and translation to config directory
413+
lib::SaveTable save_table = lib::SaveTable(
414+
"dynamic_flat_config",
415+
save_version,
416+
{
417+
{"name", name()},
418+
{"translation", lib::detail::exit_hook(translator_.save_table())},
419+
}
420+
);
421+
lib::save_to_stream(save_table, os);
422+
translator_.save(os);
423+
424+
lib::save_to_stream(data_, os);
425+
}
426+
406427
constexpr std::string_view name() const { return "dynamic flat index"; }
407428

408429
///// Thread Pool Management
@@ -767,4 +788,67 @@ auto auto_dynamic_assemble(
767788
);
768789
}
769790

791+
auto load_translator(const lib::detail::Deserializer& deserializer, std::istream& is) {
792+
auto table = lib::detail::begin_deserialization(deserializer, is);
793+
auto translator = IDTranslator::load(
794+
table.template cast<toml::table>().at("translation").template cast<toml::table>(),
795+
deserializer,
796+
is
797+
);
798+
return translator;
799+
}
800+
801+
template <typename LazyDataLoader, typename Distance, typename ThreadPoolProto>
802+
auto auto_dynamic_assemble(
803+
const lib::detail::Deserializer& deserializer,
804+
std::istream& is,
805+
LazyDataLoader&& data_loader,
806+
Distance distance,
807+
ThreadPoolProto threadpool_proto,
808+
// Set this to `true` to use the identity map for ID translation.
809+
// This allows us to read files generated by the static index construction routines
810+
// to easily benchmark the static versus dynamic implementation.
811+
//
812+
// This is an internal API and should not be considered officially supported nor stable.
813+
bool SVS_UNUSED(debug_load_from_static) = false,
814+
svs::logging::logger_ptr logger = svs::logging::get()
815+
) {
816+
IDTranslator translator;
817+
// In legacy deserialization the order of directories isn't determined.
818+
auto name = deserializer.read_name_in_advance(is);
819+
820+
// We have to hardcode the file_name for legacy mode, since it was hardcoded when legacy
821+
// model was serialized
822+
bool translator_before_data =
823+
(name == "config/svs_config.toml") || deserializer.is_native();
824+
if (translator_before_data) {
825+
translator = load_translator(deserializer, is);
826+
}
827+
828+
// Load the dataset
829+
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
830+
auto data = svs::detail::dispatch_load(data_loader(), threadpool);
831+
auto datasize = data.size();
832+
833+
if (!translator_before_data) {
834+
translator = load_translator(deserializer, is);
835+
}
836+
837+
// Validate the translator
838+
auto translator_size = translator.size();
839+
if (translator_size != datasize) {
840+
throw ANNEXCEPTION(
841+
"Translator has {} IDs but should have {}", translator_size, datasize
842+
);
843+
}
844+
845+
return DynamicFlatIndex(
846+
std::move(data),
847+
std::move(translator),
848+
std::move(distance),
849+
std::move(threadpool),
850+
std::move(logger)
851+
);
852+
}
853+
770854
} // namespace svs::index::flat

include/svs/index/flat/flat.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,10 @@ class FlatIndex {
523523
lib::save_to_disk(data_, data_directory);
524524
}
525525

526-
void save(std::ostream& os) const { lib::save_to_stream(data_, os); }
526+
void save(std::ostream& os) const {
527+
lib::begin_serialization(os);
528+
lib::save_to_stream(data_, os);
529+
}
527530
};
528531

529532
///

0 commit comments

Comments
 (0)