Skip to content

Commit 40bcee8

Browse files
committed
Move input handling to separate merge_inputs.smk
1 parent 4060723 commit 40bcee8

2 files changed

Lines changed: 144 additions & 142 deletions

File tree

rules/main.smk

Lines changed: 2 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -33,148 +33,8 @@ rule test_target:
3333
class InvalidConfigError(Exception):
3434
pass
3535

36-
# ------------- helper functions to collect, merge & download input files ------------------- #
37-
38-
def _parse_config_input(input):
39-
"""
40-
Parses information from an individual config-defined input, i.e. an element within `config.inputs` or `config.additional_inputs`
41-
and returns information snakemake rules can use to obtain the underlying data.
42-
43-
The structure of `input` is a dictionary with keys:
44-
- name:string (required)
45-
- metadata:string (optional) - a s3 URI or a local file path
46-
- sequences:string|dict[string,string] (optional) - either a s3 URI or a local file path, in which case
47-
it must include a '{segment}' wildcard substring, or a dict of segment → s3 URI or local file path,
48-
in which case it must not include the wildcard substring.
49-
50-
Returns a dictionary with optional keys:
51-
- metadata:string - the relative path to the metadata file. If the original data was remote then this represents
52-
the output of a rule which downloads the file
53-
- metadata_location:string - the URI for the remote file if applicable else `None`
54-
- sequences:function. Takes in wildcards and returns the relative path to the sequences FASTA for the provided
55-
segment wildcard, or returns `None` if this input doesn't define sequences for the provided segment.
56-
- sequences_location:function. Takes in wildcards and returns the URI for the remote file, or `None`, where applicable.
57-
58-
Raises InvalidConfigError
59-
"""
60-
name = input['name']
61-
lambda_none = lambda w: None
62-
63-
info = {'metadata': None, 'metadata_location': None, 'sequences': lambda_none, 'sequences_location': lambda_none}
64-
65-
def _source(uri, *, s3, local):
66-
if uri.startswith('s3://'):
67-
return s3
68-
elif uri.lower().startswith('http://') or uri.lower().startswith('https://'):
69-
raise InvalidConfigError("Workflow cannot yet handle HTTP[S] inputs")
70-
return local
71-
72-
if location:=input.get('metadata', False):
73-
info['metadata'] = _source(location, s3=f"data/{name}/metadata.tsv", local=location)
74-
info['metadata_location'] = _source(location, s3=location, local=None)
75-
76-
if location:=input.get('sequences', False):
77-
if isinstance(location, dict):
78-
info['sequences'] = lambda w: _source(location[w.segment], s3=f"data/{name}/sequences_{w.segment}.fasta", local=location[w.segment]) \
79-
if w.segment in location \
80-
else None
81-
info['sequences_location'] = lambda w: _source(location[w.segment], s3=location[w.segment], local=None) \
82-
if w.segment in location \
83-
else None
84-
elif isinstance(location, str):
85-
info['sequences'] = _source(location, s3=lambda w: f"data/{name}/sequences_{w.segment}.fasta", local=lambda w: location.format(segment=w.segment))
86-
info['sequences_location'] = _source(location, s3=lambda w: location.format(segment=w.segment), local=lambda_none)
87-
else:
88-
raise InvalidConfigError(f"Config input for {name} specifies sequences in an unknown format; must be dict or string")
89-
90-
return info
91-
92-
93-
def _gather_inputs():
94-
all_inputs = [*config['inputs'], *config.get('additional_inputs', [])]
95-
96-
if len(all_inputs)==0:
97-
raise InvalidConfigError("Config must define at least one element in config.inputs or config.additional_inputs lists")
98-
if not all([isinstance(i, dict) for i in all_inputs]):
99-
raise InvalidConfigError("All of the elements in config.inputs and config.additional_inputs lists must be dictionaries"
100-
"If you've used a command line '--config' double check your quoting.")
101-
if len({i['name'] for i in all_inputs})!=len(all_inputs):
102-
raise InvalidConfigError("Names of inputs (config.inputs and config.additional_inputs) must be unique")
103-
if not all(['name' in i and ('sequences' in i or 'metadata' in i) for i in all_inputs]):
104-
raise InvalidConfigError("Each input (config.inputs and config.additional_inputs) must have a 'name' and 'metadata' and/or 'sequences'")
105-
106-
return {i['name']: _parse_config_input(i) for i in all_inputs}
107-
108-
input_sources = _gather_inputs()
109-
110-
def input_metadata(wildcards):
111-
inputs = [info['metadata'] for info in input_sources.values() if info.get('metadata', None)]
112-
return inputs[0] if len(inputs)==1 else "results/metadata_merged.tsv"
113-
114-
def input_sequences(wildcards):
115-
inputs = list(filter(None, [info['sequences'](wildcards) for info in input_sources.values() if info.get('sequences', None)]))
116-
return inputs[0] if len(inputs)==1 else "results/sequences_merged_{segment}.fasta"
117-
118-
rule download_s3_sequences:
119-
output:
120-
sequences = "data/{input_name}/sequences_{segment}.fasta",
121-
params:
122-
address = lambda w: input_sources[w.input_name]['sequences_location'](w),
123-
no_sign_request=lambda w: "--no-sign-request" \
124-
if input_sources[w.input_name]['sequences_location'](w).startswith(NEXTSTRAIN_PUBLIC_BUCKET) \
125-
else "",
126-
shell:
127-
"""
128-
aws s3 cp {params.no_sign_request:q} {params.address:q} - | zstd -d > {output.sequences}
129-
"""
130-
131-
rule download_s3_metadata:
132-
output:
133-
metadata = "data/{input_name}/metadata.tsv",
134-
params:
135-
address = lambda w: input_sources[w.input_name]['metadata_location'],
136-
no_sign_request=lambda w: "--no-sign-request" \
137-
if input_sources[w.input_name]['metadata_location'].startswith(NEXTSTRAIN_PUBLIC_BUCKET) \
138-
else "",
139-
shell:
140-
"""
141-
aws s3 cp {params.no_sign_request:q} {params.address:q} - | zstd -d > {output.metadata}
142-
"""
143-
144-
rule merge_metadata:
145-
"""
146-
This rule should only be invoked if there are multiple defined metadata inputs
147-
(config.inputs + config.additional_inputs)
148-
"""
149-
input:
150-
**{name: info['metadata'] for name,info in input_sources.items() if info.get('metadata', None)}
151-
params:
152-
metadata = lambda w, input: list(map("=".join, input.items()))
153-
output:
154-
metadata = "results/metadata_merged.tsv"
155-
shell:
156-
r"""
157-
augur merge \
158-
--metadata {params.metadata:q} \
159-
--source-columns 'input_{{NAME}}' \
160-
--output-metadata {output.metadata}
161-
"""
162-
163-
rule merge_sequences:
164-
"""
165-
This rule should only be invoked if there are multiple defined metadata inputs
166-
(config.inputs + config.additional_inputs) for this particular segment
167-
"""
168-
input:
169-
lambda w: list(filter(None, [info['sequences'](w) for info in input_sources.values()]))
170-
output:
171-
sequences = "results/sequences_merged_{segment}.fasta"
172-
shell:
173-
r"""
174-
seqkit rmdup {input:q} > {output.sequences:q}
175-
"""
176-
177-
# -------------------------------------------------------------------------------------------- #
36+
# This uses the `InvalidConfigError` defined above
37+
include: "merge_inputs.smk"
17838

17939
rule filter_sequences_by_subtype:
18040
input:

rules/merge_inputs.smk

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# ------------- helper functions to collect, merge & download input files ------------------- #
2+
3+
def _parse_config_input(input):
4+
"""
5+
Parses information from an individual config-defined input, i.e. an element within `config.inputs` or `config.additional_inputs`
6+
and returns information snakemake rules can use to obtain the underlying data.
7+
8+
The structure of `input` is a dictionary with keys:
9+
- name:string (required)
10+
- metadata:string (optional) - a s3 URI or a local file path
11+
- sequences:string|dict[string,string] (optional) - either a s3 URI or a local file path, in which case
12+
it must include a '{segment}' wildcard substring, or a dict of segment → s3 URI or local file path,
13+
in which case it must not include the wildcard substring.
14+
15+
Returns a dictionary with optional keys:
16+
- metadata:string - the relative path to the metadata file. If the original data was remote then this represents
17+
the output of a rule which downloads the file
18+
- metadata_location:string - the URI for the remote file if applicable else `None`
19+
- sequences:function. Takes in wildcards and returns the relative path to the sequences FASTA for the provided
20+
segment wildcard, or returns `None` if this input doesn't define sequences for the provided segment.
21+
- sequences_location:function. Takes in wildcards and returns the URI for the remote file, or `None`, where applicable.
22+
23+
Raises InvalidConfigError
24+
"""
25+
name = input['name']
26+
lambda_none = lambda w: None
27+
28+
info = {'metadata': None, 'metadata_location': None, 'sequences': lambda_none, 'sequences_location': lambda_none}
29+
30+
def _source(uri, *, s3, local):
31+
if uri.startswith('s3://'):
32+
return s3
33+
elif uri.lower().startswith('http://') or uri.lower().startswith('https://'):
34+
raise InvalidConfigError("Workflow cannot yet handle HTTP[S] inputs")
35+
return local
36+
37+
if location:=input.get('metadata', False):
38+
info['metadata'] = _source(location, s3=f"data/{name}/metadata.tsv", local=location)
39+
info['metadata_location'] = _source(location, s3=location, local=None)
40+
41+
if location:=input.get('sequences', False):
42+
if isinstance(location, dict):
43+
info['sequences'] = lambda w: _source(location[w.segment], s3=f"data/{name}/sequences_{w.segment}.fasta", local=location[w.segment]) \
44+
if w.segment in location \
45+
else None
46+
info['sequences_location'] = lambda w: _source(location[w.segment], s3=location[w.segment], local=None) \
47+
if w.segment in location \
48+
else None
49+
elif isinstance(location, str):
50+
info['sequences'] = _source(location, s3=lambda w: f"data/{name}/sequences_{w.segment}.fasta", local=lambda w: location.format(segment=w.segment))
51+
info['sequences_location'] = _source(location, s3=lambda w: location.format(segment=w.segment), local=lambda_none)
52+
else:
53+
raise InvalidConfigError(f"Config input for {name} specifies sequences in an unknown format; must be dict or string")
54+
55+
return info
56+
57+
58+
def _gather_inputs():
59+
all_inputs = [*config['inputs'], *config.get('additional_inputs', [])]
60+
61+
if len(all_inputs)==0:
62+
raise InvalidConfigError("Config must define at least one element in config.inputs or config.additional_inputs lists")
63+
if not all([isinstance(i, dict) for i in all_inputs]):
64+
raise InvalidConfigError("All of the elements in config.inputs and config.additional_inputs lists must be dictionaries"
65+
"If you've used a command line '--config' double check your quoting.")
66+
if len({i['name'] for i in all_inputs})!=len(all_inputs):
67+
raise InvalidConfigError("Names of inputs (config.inputs and config.additional_inputs) must be unique")
68+
if not all(['name' in i and ('sequences' in i or 'metadata' in i) for i in all_inputs]):
69+
raise InvalidConfigError("Each input (config.inputs and config.additional_inputs) must have a 'name' and 'metadata' and/or 'sequences'")
70+
71+
return {i['name']: _parse_config_input(i) for i in all_inputs}
72+
73+
input_sources = _gather_inputs()
74+
75+
def input_metadata(wildcards):
76+
inputs = [info['metadata'] for info in input_sources.values() if info.get('metadata', None)]
77+
return inputs[0] if len(inputs)==1 else "results/metadata_merged.tsv"
78+
79+
def input_sequences(wildcards):
80+
inputs = list(filter(None, [info['sequences'](wildcards) for info in input_sources.values() if info.get('sequences', None)]))
81+
return inputs[0] if len(inputs)==1 else "results/sequences_merged_{segment}.fasta"
82+
83+
rule download_s3_sequences:
84+
output:
85+
sequences = "data/{input_name}/sequences_{segment}.fasta",
86+
params:
87+
address = lambda w: input_sources[w.input_name]['sequences_location'](w),
88+
no_sign_request=lambda w: "--no-sign-request" \
89+
if input_sources[w.input_name]['sequences_location'](w).startswith(NEXTSTRAIN_PUBLIC_BUCKET) \
90+
else "",
91+
shell:
92+
"""
93+
aws s3 cp {params.no_sign_request:q} {params.address:q} - | zstd -d > {output.sequences}
94+
"""
95+
96+
rule download_s3_metadata:
97+
output:
98+
metadata = "data/{input_name}/metadata.tsv",
99+
params:
100+
address = lambda w: input_sources[w.input_name]['metadata_location'],
101+
no_sign_request=lambda w: "--no-sign-request" \
102+
if input_sources[w.input_name]['metadata_location'].startswith(NEXTSTRAIN_PUBLIC_BUCKET) \
103+
else "",
104+
shell:
105+
"""
106+
aws s3 cp {params.no_sign_request:q} {params.address:q} - | zstd -d > {output.metadata}
107+
"""
108+
109+
rule merge_metadata:
110+
"""
111+
This rule should only be invoked if there are multiple defined metadata inputs
112+
(config.inputs + config.additional_inputs)
113+
"""
114+
input:
115+
**{name: info['metadata'] for name,info in input_sources.items() if info.get('metadata', None)}
116+
params:
117+
metadata = lambda w, input: list(map("=".join, input.items()))
118+
output:
119+
metadata = "results/metadata_merged.tsv"
120+
shell:
121+
r"""
122+
augur merge \
123+
--metadata {params.metadata:q} \
124+
--source-columns 'input_{{NAME}}' \
125+
--output-metadata {output.metadata}
126+
"""
127+
128+
rule merge_sequences:
129+
"""
130+
This rule should only be invoked if there are multiple defined metadata inputs
131+
(config.inputs + config.additional_inputs) for this particular segment
132+
"""
133+
input:
134+
lambda w: list(filter(None, [info['sequences'](w) for info in input_sources.values()]))
135+
output:
136+
sequences = "results/sequences_merged_{segment}.fasta"
137+
shell:
138+
r"""
139+
seqkit rmdup {input:q} > {output.sequences:q}
140+
"""
141+
142+
# -------------------------------------------------------------------------------------------- #

0 commit comments

Comments
 (0)