Skip to content

Commit 546fad2

Browse files
committed
Randomize barcode downsampling
1 parent 6258dbf commit 546fad2

4 files changed

Lines changed: 32 additions & 29 deletions

File tree

src/common/barcode_index/barcode_index_builder.hpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "alignment/sequence_mapper_notifier.hpp"
1515
#include "alignment/sequence_mapper.hpp"
1616

17+
#include <algorithm>
18+
#include <cmath>
1719
#include <random>
1820
#include <string>
1921
#include <unordered_set>
@@ -199,31 +201,29 @@ class FrameBarcodeIndexBuilder {
199201
template<class ReadType>
200202
void ConstructBarcodeIndex(io::ReadStreamList<ReadType> read_streams,
201203
FrameBarcodeIndex<Graph> &barcode_index,
202-
const io::SequencingLibraryBase &lib,
203204
bool is_tellseq);
204205

205-
void DownsampleBarcodeIndex(FrameBarcodeIndex<Graph> &downsampled_index, FrameBarcodeIndex<Graph> &original_index, double sampling_factor) {
206+
void DownsampleBarcodeIndex(FrameBarcodeIndex<Graph> &downsampled_index,
207+
FrameBarcodeIndex<Graph> &original_index,
208+
double sampling_factor,
209+
int seed) {
206210
std::unordered_set<BarcodeId> barcodes;
207-
std::unordered_set<BarcodeId> passed_barcodes;
208-
BarcodeId min_barcode = std::numeric_limits<BarcodeId>::max();
209-
BarcodeId max_barcode = std::numeric_limits<BarcodeId>::min();
210211
for (auto it = original_index.begin(); it != original_index.end(); ++it) {
211212
const auto &barcode_distribution = it->second.GetDistribution();
212213
for (const auto &entry: barcode_distribution) {
213-
BarcodeId current_barcode = entry.first;
214-
barcodes.insert(current_barcode);
215-
min_barcode = std::min(min_barcode, current_barcode);
216-
max_barcode = std::max(max_barcode, current_barcode);
214+
barcodes.insert(entry.first);
217215
}
218216
}
219217
INFO("Number of encountered barcodes: " << barcodes.size());
220-
INFO("Barcode id range: " << min_barcode << ", " << max_barcode);
221-
double barcode_thr = static_cast<double>(max_barcode - min_barcode) * sampling_factor;
222-
for (const auto &barcode: barcodes) {
223-
if (math::le(static_cast<double>(barcode - min_barcode), barcode_thr)) {
224-
passed_barcodes.insert(barcode);
225-
}
226-
}
218+
size_t target = static_cast<size_t>(std::round(static_cast<double>(barcodes.size()) * sampling_factor));
219+
std::unordered_set<BarcodeId> passed_barcodes;
220+
passed_barcodes.reserve(target);
221+
std::mt19937 rng(seed);
222+
std::sample(barcodes.begin(),
223+
barcodes.end(),
224+
std::inserter(passed_barcodes, passed_barcodes.end()),
225+
target,
226+
rng);
227227
INFO("Passed barcodes: " << passed_barcodes.size());
228228

229229
downsampled_index.InitialFillMap();
@@ -248,7 +248,6 @@ class FrameBarcodeIndexBuilder {
248248
template<class ReadType>
249249
void FrameBarcodeIndexBuilder::ConstructBarcodeIndex(io::ReadStreamList<ReadType> read_streams,
250250
FrameBarcodeIndex<Graph> &barcode_index,
251-
const io::SequencingLibraryBase &lib,
252251
bool is_tellseq) {
253252
{
254253
size_t starting_barcode = 0;

src/projects/splitter/barcode_index_construction.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ void ConstructBarcodeIndex(barcode_index::FrameBarcodeIndex<debruijn_graph::Grap
3535
FrameBarcodeIndexBuilder barcode_index_builder(graph, mapper, barcode_prefices, frame_size, nthreads);
3636
bool is_tellseq = lib.type() == io::LibraryType::TellSeqReads;
3737
if (!is_tellseq) {
38-
barcode_index_builder.ConstructBarcodeIndex(io::paired_easy_readers(lib, false, 0), barcode_index, lib, is_tellseq);
38+
barcode_index_builder.ConstructBarcodeIndex(io::paired_easy_readers(lib, false, 0), barcode_index, is_tellseq);
3939
}
40-
if (is_tellseq) {
40+
else {
4141
INFO("Constructing from tellseq lib");
42-
barcode_index_builder.ConstructBarcodeIndex(io::tellseq_easy_readers(lib, false, 0), barcode_index, lib, is_tellseq);
42+
barcode_index_builder.ConstructBarcodeIndex(io::tellseq_easy_readers(lib, false, 0), barcode_index, is_tellseq);
4343
}
4444
INFO("Barcode index construction finished.");
4545

@@ -69,14 +69,15 @@ void DownsampleBarcodeIndex(const debruijn_graph::Graph &graph,
6969
unsigned nthreads,
7070
barcode_index::FrameBarcodeIndex<debruijn_graph::Graph> &barcode_index,
7171
barcode_index::FrameBarcodeIndex<debruijn_graph::Graph> &downsampled_index,
72-
double sampling_fraction) {
72+
double sampling_fraction,
73+
int seed) {
7374
VERIFY_MSG(math::ls(sampling_fraction, 1.0), "Sampling fraction must be less than 1");
7475
const size_t mapping_k = 31;
7576
const std::vector<std::string> barcode_prefices = {"BC:Z:", "BX:Z:"};
7677
debruijn_graph::Graph empty_graph(mapping_k);
7778
alignment::BWAReadMapper<debruijn_graph::Graph> mapper(empty_graph);
7879
FrameBarcodeIndexBuilder barcode_index_builder(graph, mapper, barcode_prefices, barcode_index.GetFrameSize(), nthreads);
79-
barcode_index_builder.DownsampleBarcodeIndex(downsampled_index, barcode_index, sampling_fraction);
80+
barcode_index_builder.DownsampleBarcodeIndex(downsampled_index, barcode_index, sampling_fraction, seed);
8081
}
8182

8283
}

src/projects/splitter/barcode_index_construction.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ void DownsampleBarcodeIndex(const debruijn_graph::Graph &graph,
3636
unsigned nthreads,
3737
barcode_index::FrameBarcodeIndex<debruijn_graph::Graph> &barcode_index,
3838
barcode_index::FrameBarcodeIndex<debruijn_graph::Graph> &downsampled_index,
39-
double sampling_factor);
39+
double sampling_factor,
40+
int seed);
4041
}

src/projects/splitter/main.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ struct gcfg {
5757
size_t frame_size = 40000;
5858
size_t read_linkage_distance = 40000;
5959
double sampling_factor = 1.0;
60+
int seed = 999;
6061

6162
//graph construction
6263
double graph_score_threshold = 2.0;
@@ -89,6 +90,7 @@ static void process_cmdline(int argc, char** argv, gcfg& cfg) {
8990
(option("-l") & integer("value", cfg.libindex)) % "library index (0-based, default: 0)",
9091
(option("-t") & integer("value", cfg.nthreads)) % "# of threads to use",
9192
(option("--sampling-factor") & value("sampling-factor", cfg.sampling_factor)) % "Sampling factor for read downsampling",
93+
(option("--seed") & value("seed", cfg.seed)) % "Seed for barcode downsampling",
9294
(option("--assembly-info") & value("assembly-info", assembly_info))
9395
% "Path to metaflye assembly_info.txt file (meta mode, metaFlye graphs only)",
9496
(with_prefix("-G",
@@ -162,11 +164,11 @@ struct TimeTracerRAII {
162164
gfa::GFAReader ReadGraph(const gcfg &cfg,
163165
debruijn_graph::Graph &graph,
164166
io::IdMapper<std::string> *id_mapper) {
165-
gfa::GFAReader gfa(cfg.graph);
166-
gfa.to_graph(graph, id_mapper);
167-
INFO("GFA segments: " << gfa.num_edges() << ", links: " << gfa.num_links() << ", paths: "
168-
<< gfa.num_paths());
169-
return gfa;
167+
gfa::GFAReader gfa(cfg.graph);
168+
gfa.to_graph(graph, id_mapper);
169+
INFO("GFA segments: " << gfa.num_edges() << ", links: " << gfa.num_links() << ", paths: "
170+
<< gfa.num_paths());
171+
return gfa;
170172
}
171173

172174
std::unordered_set<debruijn_graph::EdgeId> ParseRepetitiveEdges(const debruijn_graph::Graph &graph,
@@ -401,7 +403,7 @@ int main(int argc, char** argv) {
401403
if (not math::eq(cfg.sampling_factor, 1.0)) {
402404
INFO("Downsampling the barcode index with factor " << cfg.sampling_factor);
403405
cont_index::DownsampleBarcodeIndex(graph, cfg.nthreads, barcode_index, downsampled_index,
404-
cfg.sampling_factor);
406+
cfg.sampling_factor, cfg.seed);
405407
barcode_extractor_ptr = std::make_shared<BarcodeExtractor>(downsampled_index, graph);
406408
}
407409

0 commit comments

Comments
 (0)