Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ SLURM_PARTITIONS = [
PARAM_SEARCH = parameter_search.ParameterSearch()

# Where is a large temp directory?
LARGE_TEMP_DIR = config.get("large_temp_dir", "/data/tmp")
LARGE_TEMP_DIR = config.get("large_temp_dir", "/data/tmp/")

#Different phoenix nodes seem to run at different speeds, so we can specify which node to run
#This gets added as a slurm_extra for all the real read runs
Expand Down Expand Up @@ -3229,7 +3229,7 @@ rule speed_from_log_giraffe_stats:
mapper="giraffe.*"
threads: 1
resources:
mem_mb=200,
mem_mb=600,
runtime=5,
slurm_partition=choose_partition(5)
shell:
Expand Down Expand Up @@ -3286,7 +3286,7 @@ rule memory_from_log_giraffe_stat:
mapper="giraffe.*"
threads: 1
resources:
mem_mb=200,
mem_mb=600,
runtime=5,
slurm_partition=choose_partition(5)
shell:
Expand Down Expand Up @@ -4841,14 +4841,14 @@ rule mapped_names:
# Because GraphAligner doesn't output unmapped reads, we need to get the lengths from the original FASTQ
# So we get all the mapped reads and then all the unmapped reads. See <https://unix.stackexchange.com/a/588652>
"""
vg filter --only-mapped -t {threads} -T "name" {input.gam} | grep -v "#" | sed 's/\/[1-2]$//g' | sort -s -k 1b,1 > {output.mapped_names}
vg filter --only-mapped -t {threads} -T "name" {input.gam} | grep -v "#" | sed 's/\/[1-2]$//g' | LC_ALL=C sort -k 1b,1 > {output.mapped_names}
"""

# tsv of "mapped"/"unmapped" and the length of the original read
# This only does the mapped portion
rule length_by_mapping_mapped:
input:
length_by_name=os.path.join(READS_DIR, "{realness}/{tech}/stats/{sample}{trimmedness}.{subset}.read_length_by_name.tsv"),
read_length_by_name=os.path.join(READS_DIR, "{realness}/{tech}/stats/{sample}{trimmedness}.{subset}.read_length_by_name.tsv"),
mapped_names="{root}/stats/{reference}/{refgraph}/{mapper}/{realness}/{tech}/{sample}{trimmedness}.{subset}.mapped_names.tsv"
output:
tsv="{root}/stats/{reference}/{refgraph}/{mapper}/{realness}/{tech}/{sample}{trimmedness}.{subset}.length_by_mapping.mapped.tsv"
Expand All @@ -4861,15 +4861,16 @@ rule length_by_mapping_mapped:
# Because GraphAligner doesn't output unmapped reads, we need to get the lengths from the original FASTQ
# So we get all the mapped reads and then all the unmapped reads. See <https://unix.stackexchange.com/a/588652>
"""
join -j 1 {input.length_by_name} {input.mapped_names} | awk -v OFS='\t' '{{print "mapped",$2}}' > {output.tsv}
export LC_ALL=C;
join -j 1 <(sort -k 1b,1 {input.read_length_by_name}) <(sort -k 1b,1 {input.mapped_names}) | awk -v OFS='\t' '{{print "mapped", $2}}' > {output.tsv}
"""


# tsv of "mapped"/"unmapped" and the length of the original read
# This only does the unmapped portion
rule length_by_mapping_unmapped:
input:
length_by_name=os.path.join(READS_DIR, "{realness}/{tech}/stats/{sample}{trimmedness}.{subset}.read_length_by_name.tsv"),
read_length_by_name=os.path.join(READS_DIR, "{realness}/{tech}/stats/{sample}{trimmedness}.{subset}.read_length_by_name.tsv"),
mapped_names="{root}/stats/{reference}/{refgraph}/{mapper}/{realness}/{tech}/{sample}{trimmedness}.{subset}.mapped_names.tsv"
output:
tsv="{root}/stats/{reference}/{refgraph}/{mapper}/{realness}/{tech}/{sample}{trimmedness}.{subset}.length_by_mapping.unmapped.tsv"
Expand All @@ -4882,7 +4883,8 @@ rule length_by_mapping_unmapped:
# Because GraphAligner doesn't output unmapped reads, we need to get the lengths from the original FASTQ
# So we get all the mapped reads and then all the unmapped reads. See <https://unix.stackexchange.com/a/588652>
"""
join -j 1 -a 1 -v 2 {input.length_by_name} {input.mapped_names} | awk -v OFS='\t' '{{print "unmapped",$2}}' >> {output.tsv}
export LC_ALL=C;
join -j 1 -a 1 -v 2 <(sort -k 1b,1 {input.read_length_by_name}) <(sort -k 1b,1 {input.mapped_names}) | awk -v OFS='\t' '{{print "unmapped",$2}}' >> {output.tsv}
"""


Expand Down Expand Up @@ -4917,7 +4919,7 @@ rule read_length_by_name:
slurm_partition=choose_partition(720)
shell:
"""
seqkit fx2tab -n -l {input.fastq} | sort -k 1b,1 | awk -v OFS='\t' '{{print $1,$NF}}' >{output.tsv}
seqkit fx2tab -n -l {input.fastq} | LC_ALL=C sort -k 1b,1 | awk -v OFS='\t' '{{print $1,$NF}}' >{output.tsv}
"""
#How many base pairs are in the read file
rule read_bases_total:
Expand Down Expand Up @@ -4989,8 +4991,9 @@ rule mapq_by_correctness:

rule softclips_by_name_gam:
input:
fastq=fastq,
gam="{root}/aligned/{reference}/{refgraph}/{mapper}/{realness}/{tech}/{sample}{trimmedness}.{subset}.gam",
length_by_name="{root}/stats/{reference}/{refgraph}/{mapper}/{realness}/{tech}/{sample}{trimmedness}.{subset}.length_by_name.tsv",
read_length_by_name=os.path.join(READS_DIR, "{realness}/{tech}/stats/{sample}{trimmedness}.{subset}.read_length_by_name.tsv")
output:
mapped_tsv=temp("{root}/stats/{reference}/{refgraph}/{mapper}/{realness}/{tech}/{sample}{trimmedness}.{subset}.mapped_softclips_by_name.tsv"),
sorted_tsv=temp("{root}/stats/{reference}/{refgraph}/{mapper}/{realness}/{tech}/{sample}{trimmedness}.{subset}.sorted_mapped_softclips_by_name.tsv"),
Expand All @@ -5008,12 +5011,12 @@ rule softclips_by_name_gam:
shell:
# We need to make sure we have 0 values for softclips for reads that
# weren't mapped, so we can compute correct average softclips. We know
# the length_by_name TSV contains all read names.
# the read_length_by_name TSV contains all read names.
"""
set -x
time vg filter -t {threads} -T \"name;softclip_start;softclip_end\" {input.gam} | grep -v \"#\" > {output.mapped_tsv}
time sort -k 1b,1 {output.mapped_tsv} | sed 's/\/[1-2]\$//g' > {output.sorted_tsv}
time join -a 1 {input.length_by_name} {output.sorted_tsv} | awk -v OFS='\t' '{{ if ($NF < 3) {{ print $1,0,0 }} else {{ print $1,$3,$4 }} }}' > {output.tsv}
time LC_ALL=C sort -k 1b,1 {output.mapped_tsv} | sed 's/\/[1-2]\$//g' > {output.sorted_tsv}
time join -a 1 {input.read_length_by_name} {output.sorted_tsv} | awk -v OFS='\t' '{{ if ($NF < 3) {{ print $1,0,0 }} else {{ print $1,$3,$4 }} }}' > {output.tsv}
"""

rule softclips_by_name_other:
Expand Down Expand Up @@ -5068,7 +5071,6 @@ ruleorder: softclips_by_name_gam > softclips_by_name_other
# read name and the number of bases that didn't get put in the final alignment.
rule hardclips_by_name_gam:
input:
fastq=fastq,
gam="{root}/aligned/{reference}/{refgraph}/{mapper}/{realness}/{tech}/{sample}{trimmedness}.{subset}.gam",
read_length_by_name=os.path.join(READS_DIR, "{realness}/{tech}/stats/{sample}{trimmedness}.{subset}.read_length_by_name.tsv")
output:
Expand All @@ -5093,7 +5095,7 @@ rule hardclips_by_name_gam:
# length_by_mapping above).
# We use -a 2 on the join to exclude fully-unmapped reads.
"""
vg filter -t {threads} -T "name;length" {input.gam} | grep -v "#" | sort -k 1b,1 | sed 's/\/[1-2]\$//g' > {output.mapped_length_by_name}
vg filter -t {threads} -T "name;length" {input.gam} | grep -v "#" | LC_ALL=C sort -k 1b,1 | sed 's/\/[1-2]\$//g' > {output.mapped_length_by_name}
join -a 2 {input.read_length_by_name} {output.mapped_length_by_name} | awk -v OFS='\t' '{{print $1,$2-$3}}' > {output.tsv}
"""

Expand Down
63 changes: 57 additions & 6 deletions parameter_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import numpy as np
import scipy.stats as stats
from os.path import exists
from bidict import bidict
import argparse
Expand All @@ -17,13 +18,17 @@
value_range is a tuple of [range_start, range_end)
'''
class Parameter:
def __init__(self, name, datatype, min_val, max_val, default, sampling_strategy):
def __init__(self, name, datatype, min_val, max_val, default, sampling_strategy, mean):
self.name = name
self.datatype = datatype.lower()
self.min_val = min_val
self.max_val = max_val
self.default = default
self.sampling_strategy = sampling_strategy.lower()
if mean=="none":
self.mean = default
else:
self.mean = mean

def sample(self):
if self.min_val == self.max_val:
Expand All @@ -34,6 +39,15 @@ def sample(self):
elif self.sampling_strategy == "log":
log_sample = np.random.uniform(np.log(self.min_val), np.log(self.max_val))
return int(np.exp(log_sample))
elif self.sampling_strategy == "lognormal":
ln_sample = np.random.lognormal(np.log(self.mean), 1)
return int(ln_sample)
elif self.sampling_strategy == "truncated_normal":
mu = self.mean
sigma = 1
truncated_normal = stats.truncnorm((self.min_val - mu) / sigma, (self.max_val - mu) / sigma, loc=mu, scale=sigma)
tn_sample = truncated_normal.rvs(size=1)
return int(tn_sample)
else:
print("No sampling strategy " + self.sampling_strategy + " for type " + self.datatype)
elif self.datatype=="float":
Expand All @@ -43,6 +57,15 @@ def sample(self):
elif self.sampling_strategy == "log":
log_sample = np.random.uniform(np.log(self.min_val), np.log(self.max_val))
return np.exp(log_sample)
elif self.sampling_strategy == "lognormal":
ln_sample = np.random.lognormal(np.log(self.mean), 1)
return round(ln_sample, decimal_places)
elif self.sampling_strategy == "truncated_normal":
mu = self.mean
sigma = 1
truncated_normal = stats.truncnorm((self.min_val - mu) / sigma, (self.max_val - mu) / sigma, loc=mu, scale=sigma)
tn_sample = truncated_normal.rvs(size=1)
return round(tn_sample, decimal_places)
else:
print("No sampling strategy " + self.sampling_strategy + " for type " + self.datatype)

Expand All @@ -56,12 +79,14 @@ def __str__(self):
ParameterSearch is used to store information about a set of parameters.
Define the parameters to be searched in the parameter config file (default is CONFIG_FILE)
This must be a tsv with values:
#name type min_val max_val default sampling_strategy
#name type min_val max_val default sampling_strategy mean(OPTIONAL)
Where name is the name of the flag that giraffe uses
type is the data type (int or float)
min and max val are the range of values that the parameter can take
default is the default value from giraffe. This is used to unify old runs missing parameters
sampling_strategy is how we sample the values from the range ("uniform", "log")
you can also sample from the truncated normal or lognormal distributions ("lognormal", "truncated_normal")
if desired, you can add a 7th column to the tsv titled mean for use in normal distributions

Randomly sample the parameter space with sample_parameter_space(), giving it the number of sets to return.
This will write the sampled parameters to hash_to_parameters_file
Expand All @@ -88,11 +113,21 @@ def __init__(self, config=CONFIG_FILE, hash_to_parameters_file = HASH_TO_PARAMET
for line in f:
if line[0] != "#":
l = line.split()
if len(l) == 7:
if l[6] != "none":
mean_value = int(l[6]) if l[1] == "int" else float(l[6])
else:
mean_value = "none"
else:
mean_value = "none"

self.parameters.append(Parameter(l[0], l[1],
int(l[2]) if l[1] == "int" else float(l[2]),
int(l[3]) if l[1] == "int" else float(l[3]),
int(l[4]) if l[1] == "int" else float(l[4]),
l[5]) )
l[5],
mean_value),
)
f.close()

#This maps a hash string to the set of parameters it represents, as a list of parameter values,
Expand Down Expand Up @@ -181,8 +216,23 @@ def hash_to_parameter_string(self, hash_val):


#Sample the parameter space and write the new parameters to HASH_TO_PARAMETERS
def sample_parameter_space(self, count):
def sample_parameter_space(self, count, benchmark_default, benchmark_mean):
f = open(self.hash_to_parameters_file, "a")

if benchmark_default:
#add benchmark of default values
benchmark_tuple = tuple(param.default for param in self.parameters)
hash_val = self.parameter_tuple_to_hash(benchmark_tuple)
self.hash_to_parameters[hash_val] = benchmark_tuple
f.write("\n" + hash_val + "\t" + '\t'.join([str(x) for x in benchmark_tuple]))

if benchmark_mean:
#add benchmark of mean values
mean_tuple = tuple(param.mean for param in self.parameters)
hash_val = self.parameter_tuple_to_hash(mean_tuple)
self.hash_to_parameters[hash_val] = mean_tuple
f.write("\n" + hash_val + "\t" + '\t'.join([str(x) for x in mean_tuple]))

for i in range(count):
parameter_tuple = tuple([param.sample() for param in self.parameters])
hash_val = self.parameter_tuple_to_hash(parameter_tuple)
Expand All @@ -201,12 +251,13 @@ def main():
parser.add_argument('--config_file', default=CONFIG_FILE, help="Config file for which parameters to sample and how")
parser.add_argument('--output_file', default=HASH_TO_PARAMETERS_FILE, help="File holding the parameter sets to search and their identifying hash value")
parser.add_argument('--count', type=int, default=1000, help="How many parameters sets to sample [1000]")
parser.add_argument('--benchmark_default', type=bool, default=False, help="Whether or not to additonally run a benchmark of default parameters")
parser.add_argument('--benchmark_mean', type=bool, default=False, help="Whether or not to additionally run a benchmark of the mean parameters")

args = parser.parse_args()

param_search = ParameterSearch(args.config_file, args.output_file)
param_search.sample_parameter_space(args.count)

param_search.sample_parameter_space(args.count, args.benchmark_default, args.benchmark_mean)

if __name__ == "__main__":
main()