Skip to content

Commit 2d44df4

Browse files
Add filter program
1 parent 6d51075 commit 2d44df4

File tree

9 files changed

+277
-1
lines changed

9 files changed

+277
-1
lines changed

build_helper.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#!/bin/sh
22

33
cd native
4-
bazel build -c opt _meds_reader.so meds_reader_convert
4+
bazel build -c opt _meds_reader.so meds_reader_convert meds_reader_filter
55
cd ..
66

77
rm -f src/meds_reader/_meds_reader* src/meds_reader/meds_reader_convert*
88
cp native/bazel-bin/_meds_reader.so src/meds_reader/_meds_reader.so
99
cp native/bazel-bin/meds_reader_convert src/meds_reader/meds_reader_convert
10+
cp native/bazel-bin/meds_reader_filter src/meds_reader/meds_reader_filter

native/BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ cc_library(
2626
],
2727
)
2828

29+
cc_library(
30+
name="filter_database",
31+
srcs=["filter_database.cc", "pdqsort.h"],
32+
hdrs=["filter_database.hh"],
33+
deps=[
34+
":mmap_file",
35+
],
36+
)
37+
38+
2939
cc_library(
3040
name="binary_version",
3141
hdrs=["binary_version.hh"],
@@ -84,6 +94,15 @@ cc_binary(
8494
],
8595
)
8696

97+
cc_binary(
98+
name="meds_reader_filter",
99+
srcs=["meds_reader_filter.cc"],
100+
deps=[
101+
":filter_database",
102+
"@CLI11//:cli11",
103+
],
104+
)
105+
87106
pybind_extension(
88107
name="_meds_reader",
89108
srcs=[

native/filter_database.cc

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#include "filter_database.hh"
2+
3+
#include <fcntl.h>
4+
#include <sys/mman.h>
5+
#include <sys/stat.h>
6+
#include <unistd.h>
7+
8+
#include <cstring>
9+
#include <filesystem>
10+
#include <fstream>
11+
#include <iostream>
12+
#include <string>
13+
#include <thread>
14+
15+
#include "mmap_file.hh"
16+
#include "pdqsort.h"
17+
18+
namespace {
19+
20+
void copy_subset(const std::filesystem::path& source_path,
21+
const std::filesystem::path& destination_path,
22+
const std::vector<size_t>& offsets) {
23+
MmapFile source(source_path);
24+
25+
absl::Span<const uint64_t> byte_offsets = source.data<uint64_t>();
26+
27+
std::string_view data = source.bytes();
28+
29+
std::ofstream destination(destination_path, std::ios_base::out |
30+
std::ios_base::binary |
31+
std::ios_base::trunc);
32+
33+
std::vector<const char*> result_pointers;
34+
result_pointers.reserve(offsets.size());
35+
36+
std::vector<uint64_t> result_byte_offsets;
37+
result_byte_offsets.reserve(offsets.size() + 1);
38+
39+
uint64_t current_offset = sizeof(uint64_t) * (offsets.size() + 1);
40+
41+
for (size_t offset : offsets) {
42+
uint64_t start = byte_offsets[offset];
43+
uint64_t end = byte_offsets[offset + 1];
44+
45+
result_byte_offsets.push_back(current_offset);
46+
result_pointers.push_back(data.data() + start);
47+
48+
current_offset += (end - start);
49+
}
50+
51+
result_byte_offsets.push_back(current_offset);
52+
53+
destination.write((const char*)result_byte_offsets.data(),
54+
sizeof(uint64_t) * result_byte_offsets.size());
55+
56+
for (size_t i = 0; i < result_pointers.size(); i++) {
57+
uint64_t length = result_byte_offsets[i + 1] - result_byte_offsets[i];
58+
destination.write(result_pointers[i], length);
59+
}
60+
}
61+
62+
void filter_database_property(const std::filesystem::path& source_path,
63+
const std::filesystem::path& destination_path,
64+
const std::vector<size_t>& offsets,
65+
const std::string& property_name) {
66+
std::filesystem::path source_property_path = source_path / property_name;
67+
std::filesystem::path destination_property_path =
68+
destination_path / property_name;
69+
70+
std::filesystem::create_directory(destination_property_path);
71+
72+
std::filesystem::path source_dictionary_path =
73+
source_property_path / "dictionary";
74+
75+
if (std::filesystem::exists(source_dictionary_path)) {
76+
std::filesystem::copy(source_dictionary_path,
77+
destination_property_path / "dictionary");
78+
}
79+
80+
std::filesystem::copy(source_property_path / "zdict",
81+
destination_property_path / "zdict");
82+
83+
copy_subset(source_property_path / "data",
84+
destination_property_path / "data", offsets);
85+
}
86+
87+
} // namespace
88+
89+
void filter_database(const char* source, const char* destination,
90+
const char* subject_ids_file, int num_threads) {
91+
std::filesystem::path source_path(source);
92+
std::filesystem::path destination_path(destination);
93+
94+
MmapFile subject_ids_data{std::string(subject_ids_file)};
95+
96+
absl::Span<const int64_t> unsorted_subject_ids =
97+
subject_ids_data.data<int64_t>();
98+
99+
std::vector<int64_t> sorted_subject_ids(std::begin(unsorted_subject_ids),
100+
std::end(unsorted_subject_ids));
101+
102+
pdqsort(std::begin(sorted_subject_ids), std::end(sorted_subject_ids));
103+
104+
absl::Span<const int64_t> subject_ids(sorted_subject_ids.data(),
105+
sorted_subject_ids.size());
106+
107+
std::filesystem::create_directory(destination_path);
108+
109+
std::filesystem::copy(source_path / "metadata",
110+
destination_path / "metadata");
111+
112+
{
113+
std::ofstream subject_ids_file(
114+
destination_path / "subject_id",
115+
std::ios_base::out | std::ios_base::binary | std::ios_base::trunc);
116+
117+
subject_ids_file.write((const char*)subject_ids.data(),
118+
sizeof(int64_t) * subject_ids.size());
119+
}
120+
121+
std::vector<size_t> offsets;
122+
offsets.reserve(subject_ids.size());
123+
124+
{
125+
MmapFile source_subject_ids_file(source_path / "subject_id");
126+
absl::Span<const int64_t> source_subject_ids =
127+
source_subject_ids_file.data<int64_t>();
128+
129+
auto first =
130+
std::lower_bound(std::begin(source_subject_ids),
131+
std::end(source_subject_ids), subject_ids.front());
132+
auto last =
133+
std::upper_bound(std::begin(source_subject_ids),
134+
std::end(source_subject_ids), subject_ids.back());
135+
136+
for (int64_t subject_id : subject_ids) {
137+
auto iter = std::lower_bound(first, last, subject_id);
138+
if (*iter != subject_id) {
139+
throw std::runtime_error(
140+
std::string("Could not find subject_id ") +
141+
std::to_string(subject_id) + " in database " +
142+
std::to_string(*iter));
143+
}
144+
145+
offsets.push_back(iter - std::begin(source_subject_ids));
146+
147+
first = ++iter;
148+
}
149+
}
150+
151+
{
152+
MmapFile source_subject_lengths_file(source_path /
153+
"meds_reader.length");
154+
absl::Span<const uint32_t> source_subject_lengths =
155+
source_subject_lengths_file.data<uint32_t>();
156+
std::vector<uint32_t> subject_lengths;
157+
subject_lengths.reserve(subject_ids.size());
158+
159+
for (size_t offset : offsets) {
160+
subject_lengths.push_back(source_subject_lengths[offset]);
161+
}
162+
163+
std::ofstream subject_lengths_file(
164+
destination_path / "meds_reader.length",
165+
std::ios_base::out | std::ios_base::binary | std::ios_base::trunc);
166+
167+
subject_lengths_file.write((const char*)subject_lengths.data(),
168+
sizeof(uint32_t) * subject_lengths.size());
169+
}
170+
171+
std::vector<std::string> properties;
172+
173+
{
174+
MmapFile property_file(source_path / "meds_reader.properties");
175+
176+
const char* data = property_file.bytes().data();
177+
const char* end =
178+
property_file.bytes().data() + property_file.bytes().size();
179+
180+
while (data != end) {
181+
size_t name_length = *(const size_t*)data;
182+
data += sizeof(size_t);
183+
properties.push_back(std::string(data, name_length));
184+
data += name_length;
185+
data += sizeof(int64_t);
186+
187+
std::cout << "Got property " << properties.back() << std::endl;
188+
}
189+
}
190+
191+
properties.push_back("meds_reader.null_map");
192+
193+
for (const auto& property : properties) {
194+
filter_database_property(source_path, destination_path, offsets,
195+
property);
196+
}
197+
198+
std::filesystem::copy(source_path / "meds_reader.properties",
199+
destination_path / "meds_reader.properties");
200+
201+
std::filesystem::copy(source_path / "meds_reader.version",
202+
destination_path / "meds_reader.version");
203+
}

native/filter_database.hh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#pragma once
2+
3+
void filter_database(const char* source, const char* destination,
4+
const char* subject_ids_file, int num_threads);

native/meds_reader_filter.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include <CLI/CLI.hpp>
2+
#include <iostream>
3+
4+
#include "filter_database.hh"
5+
6+
int main(int argc, char** argv) {
7+
CLI::App app{
8+
"meds_reader_filter is a program for converting a MEDS dataset to a "
9+
"meds_reader SubjectDatabase.",
10+
"meds_reader_convert"};
11+
argv = app.ensure_utf8(argv);
12+
13+
std::string source_dataset, destination_database, subject_ids_file;
14+
app.add_option("source_dataset", source_dataset,
15+
"A path to the source MEDS dataset")
16+
->required();
17+
app.add_option(
18+
"destination_database", destination_database,
19+
"A path of where to write the resulting meds_reader database.")
20+
->required();
21+
22+
app.add_option(
23+
"subject_ids_file", subject_ids_file,
24+
"A path of where to write the resulting meds_reader database.")
25+
->required();
26+
27+
int num_threads = 1;
28+
app.add_option("--num_threads", num_threads,
29+
"The number of threads to use when processing")
30+
->capture_default_str();
31+
32+
CLI11_PARSE(app, argc, argv);
33+
34+
filter_database(source_dataset.c_str(), destination_database.c_str(),
35+
subject_ids_file.c_str(), num_threads);
36+
37+
return 0;
38+
}

native/mmap_file.hh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <sys/stat.h>
66
#include <unistd.h>
77

8+
#include <cstring>
89
#include <filesystem>
910
#include <string_view>
1011

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Tracker = "https://github.com/som-shahlab/meds_reader/issues"
2626

2727
[project.scripts]
2828
meds_reader_convert = "meds_reader:meds_reader_convert"
29+
meds_reader_filter = "meds_reader:meds_reader_filter"
2930
meds_reader_verify = "meds_reader:meds_reader_verify"
3031

3132
[tool.isort]

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def build_extensions(self) -> None:
119119
ext_modules=[
120120
BazelExtension("meds_reader._meds_reader", "_meds_reader.so", "native"),
121121
BazelExtension("meds_reader.meds_reader_convert", "meds_reader_convert", "native"),
122+
BazelExtension("meds_reader.meds_reader_filter", "meds_reader_filter", "native"),
122123
],
123124
cmdclass={"build_ext": cmake_build_ext},
124125
zip_safe=False,

src/meds_reader/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ def meds_reader_convert():
105105
os.execv(executible, sys.argv)
106106

107107

108+
def meds_reader_filter():
109+
submodules = importlib.resources.files("meds_reader")
110+
for module in submodules.iterdir():
111+
if module.name.startswith("meds_reader_filter"):
112+
with importlib.resources.as_file(module) as executible:
113+
os.execv(executible, sys.argv)
114+
115+
108116
def _row_generator(database: _meds_reader.SubjectDatabase, data: pd.DataFrame):
109117
current_index = None
110118
current_rows: List[Any] = []

0 commit comments

Comments
 (0)