-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathprocess-unprocessed-tars
More file actions
executable file
·144 lines (121 loc) · 5.6 KB
/
process-unprocessed-tars
File metadata and controls
executable file
·144 lines (121 loc) · 5.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python3
"""
Download all unprocessed GISAID tar files from S3, process them to NDJSON,
and output a manifest of successfully processed tars.
"""
import subprocess
import sys
import tempfile
import tarfile
from pathlib import Path
from augur.io.json import dump_ndjson, load_ndjson
import os
def main():
source_path = sys.argv[1] # e.g., s3://nextstrain-ncov-private/gisaid-downloads/unprocessed/ OR local path
output_ndjson = sys.argv[2] # e.g., data/gisaid/unprocessed-combined.ndjson
manifest_file = sys.argv[3] # e.g., data/gisaid/processed-manifest.txt
# Check if source is S3 or local path
if source_path.startswith("s3://"):
# Download all tar files from S3
tmp_dir = Path("tmp/unprocessed-tars")
tmp_dir.mkdir(parents=True, exist_ok=True)
print(f"Downloading tar files from {source_path}...", file=sys.stderr)
subprocess.run([
"aws", "s3", "cp", source_path, str(tmp_dir),
"--recursive", "--exclude", "*", "--include", "*.tar",
"--no-progress"
], check=True)
else:
# Use local path directly
tmp_dir = Path(source_path)
print(f"Using local tar files from {source_path}...", file=sys.stderr)
if not tmp_dir.exists():
print(f"ERROR: Local path does not exist: {source_path}", file=sys.stderr)
sys.exit(1)
tar_files = list(tmp_dir.glob("*.tar"))
if not tar_files:
print("No tar files found", file=sys.stderr)
# Create empty outputs so Snakemake doesn't fail
Path(output_ndjson).touch()
Path(manifest_file).write_text("")
return
print(f"Found {len(tar_files)} tar files to process:", file=sys.stderr)
for tar in sorted(tar_files):
print(f" - {tar.name}", file=sys.stderr)
print("", file=sys.stderr)
processed_tars = []
all_records = []
# Sort by filename to ensure consistent ordering
# Process oldest first so newer records overwrite during deduplication
for tar_path in sorted(tar_files):
print(f"Processing {tar_path.name}...", file=sys.stderr)
try:
# Extract and process
with tempfile.TemporaryDirectory() as extract_dir:
extract_path = Path(extract_dir)
# Extract tar (strip leading slashes from absolute paths)
with tarfile.open(tar_path, 'r') as tar:
# Strip leading slashes to avoid extracting to root filesystem
members = tar.getmembers()
for member in members:
if member.name.startswith('/'):
member.name = member.name.lstrip('/')
tar.extractall(extract_path, members=members)
# Find metadata and sequences
metadata_files = list(extract_path.glob("**/*.metadata.tsv"))
sequence_files = list(extract_path.glob("**/*.sequences.fasta"))
if not metadata_files or not sequence_files:
print(f" WARNING: Missing metadata or sequences in {tar_path.name}", file=sys.stderr)
continue
metadata_file = metadata_files[0]
sequence_file = sequence_files[0]
# Process through augur curate + transform
proc = subprocess.Popen([
"augur", "curate", "passthru",
"--metadata", str(metadata_file),
"--fasta", str(sequence_file),
"--seq-id-column", "strain",
"--seq-field", "sequence"
], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=False)
proc2 = subprocess.Popen([
"./bin/transform-to-gisaid-cache"
], stdin=proc.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=False)
proc.stdout.close()
stdout, stderr = proc2.communicate()
if proc2.returncode != 0:
print(f" WARNING: Transform failed for {tar_path.name}: {stderr.decode()}", file=sys.stderr)
continue
# Parse records
import io
records = list(load_ndjson(io.BytesIO(stdout)))
all_records.extend(records)
processed_tars.append(tar_path.name)
print(f" Processed {len(records)} records", file=sys.stderr)
except Exception as e:
print(f" ERROR processing {tar_path.name}: {e}", file=sys.stderr)
continue
# Write combined NDJSON
print(f"\nWriting {len(all_records)} total records to {output_ndjson}", file=sys.stderr)
lines_written = 0
with open(output_ndjson, 'w') as f:
# dump_ndjson writes to stdout by default, so we need to redirect
import json
for record in all_records:
json.dump(record, f)
f.write('\n')
lines_written += 1
print(f"Wrote {lines_written} lines to {output_ndjson}", file=sys.stderr)
# Write manifest
with open(manifest_file, 'w') as f:
for tar_name in processed_tars:
f.write(f"{tar_name}\n")
print(f"\nSuccessfully processed {len(processed_tars)}/{len(tar_files)} tar files", file=sys.stderr)
if processed_tars:
print("Processed tars:", file=sys.stderr)
for tar_name in processed_tars:
print(f" ✓ {tar_name}", file=sys.stderr)
failed_count = len(tar_files) - len(processed_tars)
if failed_count > 0:
print(f"\nWarning: {failed_count} tar files failed to process", file=sys.stderr)
if __name__ == "__main__":
main()