Skip to content

Commit a770cc7

Browse files
Add TrainGCNV input specifying subset list of samples for training (#294)
1 parent ab8a855 commit a770cc7

File tree

2 files changed

+76
-8
lines changed

2 files changed

+76
-8
lines changed

wdl/TrainGCNV.wdl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ workflow TrainGCNV {
1818
File reference_index # Index (.fai), must be in same dir as fasta
1919
File reference_dict # Dictionary (.dict), must be in same dir as fasta
2020
21+
# Options for subsetting samples for training. Both options require providing sv_pipeline_base_docker
22+
# Assumes all other inputs correspond to the full sample list. Intended for Terra
2123
Int? n_samples_subsample # Number of samples to subsample from provided sample list for trainGCNV (rec: ~100)
2224
Int subsample_seed = 42
25+
# Subset of full sample list on which to train the gCNV model. Overrides n_samples_subsample if both provided
26+
Array[String]? sample_ids_training_subset
2327

2428
# Condense read counts
2529
Int? condense_num_bins
@@ -85,7 +89,7 @@ workflow TrainGCNV {
8589
String linux_docker
8690
String gatk_docker
8791
String condense_counts_docker
88-
String? sv_pipeline_base_docker # required if using n_samples_subsample to select samples
92+
String? sv_pipeline_base_docker # required if using n_samples_subsample or sample_ids_training_subset to subset samples
8993
9094
# Runtime configuration overrides
9195
RuntimeAttr? condense_counts_runtime_attr
@@ -100,20 +104,31 @@ workflow TrainGCNV {
100104
RuntimeAttr? runtime_attr_explode
101105
}
102106

103-
if (defined(n_samples_subsample)) {
107+
if (defined(sample_ids_training_subset)) {
108+
call util.GetSubsampledIndices {
109+
input:
110+
all_strings = write_lines(samples),
111+
subset_strings = write_lines(select_first([sample_ids_training_subset])),
112+
prefix = cohort,
113+
sv_pipeline_base_docker = select_first([sv_pipeline_base_docker])
114+
}
115+
}
116+
117+
if (defined(n_samples_subsample) && !defined(sample_ids_training_subset)) {
104118
call util.RandomSubsampleStringArray {
105119
input:
106-
strings = samples,
120+
strings = write_lines(samples),
107121
seed = subsample_seed,
108122
subset_size = select_first([n_samples_subsample]),
109123
prefix = cohort,
110124
sv_pipeline_base_docker = select_first([sv_pipeline_base_docker])
111125
}
112126
}
113127
114-
Array[Int] sample_indices = select_first([RandomSubsampleStringArray.subsample_indices_array, range(length(samples))])
128+
Array[Int] sample_indices = select_first([GetSubsampledIndices.subsample_indices_array, RandomSubsampleStringArray.subsample_indices_array, range(length(samples))])
115129
116130
scatter (i in sample_indices) {
131+
String sample_ids_ = samples[i]
117132
call cov.CondenseReadCounts as CondenseReadCounts {
118133
input:
119134
counts = count_files[i],
@@ -138,7 +153,7 @@ workflow TrainGCNV {
138153
preprocessed_intervals = CountsToIntervals.out,
139154
filter_intervals = filter_intervals,
140155
counts = CondenseReadCounts.out,
141-
count_entity_ids = select_first([RandomSubsampleStringArray.subsampled_strings_array, samples]),
156+
count_entity_ids = sample_ids_,
142157
cohort_entity_id = cohort,
143158
contig_ploidy_priors = contig_ploidy_priors,
144159
num_intervals_per_scatter = num_intervals_per_scatter,

wdl/Utils.wdl

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ task RunQC {
159159

160160
task RandomSubsampleStringArray {
161161
input {
162-
Array[String] strings
162+
File strings
163163
Int seed
164164
Int subset_size
165165
String prefix
@@ -172,7 +172,7 @@ task RandomSubsampleStringArray {
172172

173173
RuntimeAttr default_attr = object {
174174
cpu_cores: 1,
175-
mem_gb: 3.75,
175+
mem_gb: 1,
176176
disk_gb: 10,
177177
boot_disk_gb: 10,
178178
preemptible_tries: 3,
@@ -185,7 +185,7 @@ task RandomSubsampleStringArray {
185185
set -euo pipefail
186186
python3 <<CODE
187187
import random
188-
string_array = ['~{sep="','" strings}']
188+
string_array = [line.rstrip() for line in open("~{strings}", 'r')]
189189
array_len = len(string_array)
190190
if ~{subset_size} > array_len:
191191
raise ValueError("Subsample quantity ~{subset_size} cannot > array length %d" % array_len)
@@ -218,6 +218,59 @@ task RandomSubsampleStringArray {
218218
}
219219
}
220220

221+
task GetSubsampledIndices {
222+
input {
223+
File all_strings
224+
File subset_strings
225+
String prefix
226+
String sv_pipeline_base_docker
227+
RuntimeAttr? runtime_attr_override
228+
}
229+
230+
String subsample_indices_filename = "~{prefix}.subsample_indices.list"
231+
232+
RuntimeAttr default_attr = object {
233+
cpu_cores: 1,
234+
mem_gb: 1,
235+
disk_gb: 10,
236+
boot_disk_gb: 10,
237+
preemptible_tries: 3,
238+
max_retries: 1
239+
}
240+
RuntimeAttr runtime_attr = select_first([runtime_attr_override, default_attr])
241+
242+
command <<<
243+
244+
set -euo pipefail
245+
python3 <<CODE
246+
all_strings = [line.rstrip() for line in open("~{all_strings}", 'r')]
247+
subset_strings = {line.rstrip() for line in open("~{subset_strings}", 'r')}
248+
if not subset_strings.issubset(set(all_strings)):
249+
raise ValueError("Subset list must be a subset of full list")
250+
with open("~{subsample_indices_filename}", 'w') as indices:
251+
for i, string in enumerate(all_strings):
252+
if string in subset_strings:
253+
indices.write(f"{i}\n")
254+
CODE
255+
256+
>>>
257+
258+
output {
259+
Array[Int] subsample_indices_array = read_lines(subsample_indices_filename)
260+
}
261+
262+
runtime {
263+
cpu: select_first([runtime_attr.cpu_cores, default_attr.cpu_cores])
264+
memory: select_first([runtime_attr.mem_gb, default_attr.mem_gb]) + " GiB"
265+
disks: "local-disk " + select_first([runtime_attr.disk_gb, default_attr.disk_gb]) + " HDD"
266+
bootDiskSizeGb: select_first([runtime_attr.boot_disk_gb, default_attr.boot_disk_gb])
267+
docker: sv_pipeline_base_docker
268+
preemptible: select_first([runtime_attr.preemptible_tries, default_attr.preemptible_tries])
269+
maxRetries: select_first([runtime_attr.max_retries, default_attr.max_retries])
270+
}
271+
}
272+
273+
221274
task SubsetPedFile {
222275
input {
223276
File ped_file

0 commit comments

Comments
 (0)