Skip to content

Commit 8dc72a0

Browse files
committed
add CLI integration tests
1 parent 6a4d058 commit 8dc72a0

File tree

12 files changed

+332
-95
lines changed

12 files changed

+332
-95
lines changed

src/gwascatalog/sumstatapp/cli/__main__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,6 @@ def _run_validate(args: argparse.Namespace, parser: argparse.ArgumentParser) ->
219219

220220
workers: int = max(1, args.workers)
221221

222-
# Warn about duplicate stems that would clobber output files
223-
from gwascatalog.sumstatapp.cli._validate import output_stem
224-
225-
stems = [output_stem(f) for f in files]
226-
if len(stems) != len(set(stems)):
227-
print(
228-
"WARNING: Duplicate file stems detected — output files may be overwritten",
229-
file=sys.stderr,
230-
)
231-
232222
print(f"Validating {len(files)} file(s) with {workers} worker(s)")
233223
print(f"Output: {output_dir}\n")
234224

src/gwascatalog/sumstatapp/cli/_validate.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
from dataclasses import dataclass
1414
from pathlib import Path
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Literal
1616

1717
from gwascatalog.sumstatlib import (
1818
CNVSumstatModel,
@@ -60,29 +60,14 @@ def _get_model(variation_type: str) -> type[CNVSumstatModel] | type[GeneSumstatM
6060
raise ValueError(f"Unsupported variation type: {variation_type}")
6161

6262

63-
def output_stem(path: Path) -> str:
64-
"""Derive an output file stem, stripping archive and tabular extensions.
65-
66-
Examples::
67-
68-
output_stem(Path("study.tsv.gz")) == "study"
69-
output_stem(Path("study.tsv")) == "study"
70-
output_stem(Path("study.csv.gz")) == "study"
71-
"""
72-
stem = path.stem
73-
if path.suffix == ".gz":
74-
stem = Path(stem).stem
75-
return stem
76-
77-
7863
def _write_error_report(error_path: Path, errors: list[SumstatError]) -> None:
7964
"""Write validation errors to a human-readable TSV file."""
65+
dict_errors = [dict(e) for e in errors]
66+
8067
with error_path.open("w", encoding="utf-8", newline="") as f:
81-
writer = csv.writer(f, delimiter="\t")
82-
writer.writerow(["row", "column", "message"])
83-
for e in errors:
84-
column = e["loc"] if e["loc"] is not None else ""
85-
writer.writerow([e["row"], column, e["msg"]])
68+
writer = csv.DictWriter(f, delimiter="\t", fieldnames=["row", "column", "msg"])
69+
writer.writeheader()
70+
writer.writerows(dict_errors)
8671

8772

8873
def _compute_md5(path: Path) -> str:
@@ -102,7 +87,8 @@ def validate_file(
10287
output_dir: str,
10388
variation_type: str,
10489
assembly: str | None,
105-
primary_effect_size: str | None,
90+
primary_effect_size: Literal["beta", "odds_ratio", "hazard_ratio", "z_score"]
91+
| None,
10692
allow_zero_pvalues: bool,
10793
) -> FileResult:
10894
"""Validate a single summary statistics file and write results.
@@ -115,9 +101,13 @@ def validate_file(
115101
"""
116102
inp = Path(input_path)
117103
out_dir = Path(output_dir)
118-
stem = output_stem(inp)
119-
output_path = out_dir / f"{stem}.tsv.gz"
120-
error_path = out_dir / f"{stem}.errors.tsv"
104+
output_path = out_dir / f"validated_{inp.stem}.tsv.gz"
105+
error_path = out_dir / f"{inp.stem}.errors.tsv"
106+
107+
if output_path.exists():
108+
raise FileExistsError(output_path)
109+
if error_path.exists():
110+
raise FileExistsError(error_path)
121111

122112
start = time.monotonic()
123113

@@ -133,8 +123,9 @@ def validate_file(
133123

134124
rows_processed = 0
135125
valid_count = 0
126+
writer = table.open_writer(output_path, compress=True)
136127

137-
for row in table.open_writer(output_path, compress=True):
128+
for row in writer:
138129
rows_processed += 1
139130
if row.is_valid:
140131
valid_count += 1
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
romosome,base_pair_start,base_pair_end,neg_log10_p_value,p_value,beta,standard_error,statistical_model_type
2+
1,16600001,,9.45,3.54813E-10,0.048,0.008,additive
3+
,86415001,86425000,13.661,2.18273E-14,-0.035,,additive
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
hgnc_symbol,ensembl_gene_id,p_value
2+
ISG20,ENSG00000172183,0.0001
3+
,ENSG00000128886,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
chromosome,base_pair_start,base_pair_end,neg_log10_p_value,beta,standard_error,statistical_model_type,extra_test_column
2+
1,16600001,16605000,9.45,0.048,0.008,additive,test1
3+
X,86415001,86425000,13.661,-0.035,0.003,additive,test2
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
ensembl_gene_id,p_value,chromosome,base_pair_start,base_pair_end,beta,standard_error,extra_test_column
2+
ENSG00000128886,0.1,15,43772605,43777315,0.420,2,test1
3+
ENSG00000172183,0.0001,15,88635618,88656483,0.048,0.0006,test2

sumstatlib/src/gwascatalog/sumstatlib/cnv/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class CNVSumstatModel(BaseSumstatModel):
4343
- allow_zero_pvalues (bool, optional):
4444
"""
4545

46-
MIN_RECORDS: ClassVar[None] = MIN_CNV_RECORDS
46+
MIN_RECORDS: ClassVar[int] = MIN_CNV_RECORDS
4747
FIELD_MAP: ClassVar[Mapping[str, int]] = CNV_FIELD_INDEX_MAP
4848
VALID_FIELD_NAMES: ClassVar[list[str]] = list(CNV_FIELD_INDEX_MAP.keys())
4949

sumstatlib/src/gwascatalog/sumstatlib/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,4 @@
4747

4848
# see decision docs for justification
4949
MIN_GENE_RECORDS: Final[int] = 10_000
50-
MIN_CNV_RECORDS: Final = None
50+
MIN_CNV_RECORDS: Final[int] = 10_000

sumstatlib/src/gwascatalog/sumstatlib/sumstattable.py

Lines changed: 22 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ class SumstatConfig(TypedDict):
2727
"""Runtime configuration for validating summary stats"""
2828

2929
allow_zero_p_values: bool
30-
assembly: GenomeAssembly
30+
assembly: GenomeAssembly | None
3131
primary_effect_size: Literal["beta", "odds_ratio", "hazard_ratio", "z_score"] | None
3232

3333

3434
class SumstatError(TypedDict):
3535
"""A parsed pydantic ValidationError"""
3636

3737
row: int
38-
loc: int | None
38+
column: str | int | None
3939
msg: str
4040

4141

@@ -59,7 +59,6 @@ def __init__(
5959
data_model: type[CNVSumstatModel | GeneSumstatModel],
6060
input_path: Path,
6161
config: SumstatConfig,
62-
min_records: int | None = None,
6362
):
6463
self._data_model = data_model
6564
self._path = Path(input_path)
@@ -69,35 +68,36 @@ def __init__(
6968
if not self._path.exists():
7069
raise FileNotFoundError(self._path)
7170

72-
if min_records is None:
73-
self._min_records = self._data_model.MIN_RECORDS
74-
else:
75-
self._min_records = min_records
76-
7771
n_rows = self.n_rows
78-
if self._min_records is not None and n_rows < self._min_records:
79-
raise ValueError(f"Not enough rows in file: {n_rows=} {self._min_records=}")
72+
if n_rows < self.data_model.MIN_RECORDS:
73+
warning = f"""
74+
It looks like you only have {n_rows} rows in {self._path}.
75+
{self.data_model} recommends at least {self.data_model.MIN_RECORDS} (before
76+
any QC steps). Please include all results, not just top hits.
77+
The GWAS Catalog inclusion criteria requires studies to be genome-wide.
78+
Please get in touch with gwas-subs@ebi.ac.uk if you have any questions.
79+
"""
80+
logger.warning(warning)
8081

8182
# Validate first row to check column structure — fail fast on bad columns
8283
_ = self.output_fieldnames
8384

8485
def _open_sumstat(self) -> IO[str]:
86+
# don't forget to strip UTF-8 BOM from Excel-exported files
87+
# newline = "" is best for CSV files - let the dictreader parser handle it
8588
if _is_gzip(self._path):
86-
return gzip.open(self._path, "rt", encoding="utf-8", newline=None)
87-
return self._path.open(mode="rt", encoding="utf-8", newline=None)
89+
return gzip.open(self._path, "rt", encoding="utf-8-sig", newline="")
90+
return self._path.open(mode="rt", encoding="utf-8-sig", newline="")
8891

8992
def parse_csv(self, sample_size: int = 4096) -> Generator[dict]:
9093
"""Automatically detect CSV delimiter and yield each row as a dict"""
9194
with self._open_sumstat() as f:
9295
sample = f.read(sample_size)
9396
sniffer = csv.Sniffer()
9497
dialect = sniffer.sniff(sample, delimiters=",\t;| ")
95-
96-
if not sniffer.has_header(sample):
97-
raise ValueError("file doesn't appear to contain a header")
98-
99-
f.seek(0) # reset to start of the file
98+
f.seek(0)
10099
reader = csv.DictReader(f, dialect=dialect)
100+
101101
yield from reader
102102

103103
@cached_property
@@ -119,21 +119,10 @@ def output_fieldnames(self) -> list[str]:
119119
ValidationError: If the first row fails validation, indicating
120120
an invalid column set (e.g. missing required columns).
121121
"""
122-
first_row = next(self.parse_csv())
123-
try:
124-
instance = self._data_model.model_validate(first_row, context=self._config)
125-
except ValidationError as e:
126-
logger.critical(f"First row of {self._path.name} failed validation")
127-
logger.critical(f"{ValidationError}")
128-
msg = (
129-
f"The first row of {self._path.name} failed validation. "
130-
"This usually means the file has missing or incorrectly "
131-
"named columns. Valid column names include: "
132-
f"{self.data_model.VALID_FIELD_NAMES}"
133-
)
134-
raise ValueError(msg) from e
122+
present = next(self.parse_csv(), None)
123+
if present is None:
124+
raise ValueError(f"Can't read anything from {self._path}")
135125

136-
present = list(instance.model_dump(exclude_none=True).keys())
137126
field_map = self._data_model.FIELD_MAP
138127

139128
# get a list fields sorted by their field map index
@@ -157,30 +146,6 @@ def n_rows(self) -> int:
157146
next(f, None) # skip header
158147
return sum(1 for _ in f)
159148

160-
def validate_rows(self) -> Generator[dict]:
161-
"""Validate all rows, storing errors in self._errors and yielding validated
162-
rows.
163-
"""
164-
for i, row in enumerate(self.parse_csv()):
165-
try:
166-
validated = self._data_model.model_validate(
167-
row, context=self._config
168-
).model_dump()
169-
except ValidationError as exc:
170-
for error in exc.errors():
171-
location = int(error["loc"][0])
172-
self._errors.append(
173-
SumstatError(row=i, loc=location, msg=error["msg"])
174-
)
175-
176-
if len(self._errors) >= self.MAX_ERRORS:
177-
logger.critical(
178-
f"Stopped validation after {self.MAX_ERRORS} errors"
179-
)
180-
break
181-
else:
182-
yield validated
183-
184149
@property
185150
def errors(self) -> list[SumstatError]:
186151
"""Return all row errors encountered"""
@@ -259,11 +224,11 @@ def __iter__(self) -> Generator[ValidatedRow]:
259224
except ValidationError as exc:
260225
for error in exc.errors():
261226
try:
262-
location = int(error["loc"][0])
227+
location = error["loc"][0]
263228
except IndexError:
264229
location = None
265230
self._table.add_error(
266-
SumstatError(row=i, loc=location, msg=error["msg"])
231+
SumstatError(row=i, column=location, msg=error["msg"])
267232
)
268233
yield ValidatedRow(row_number=i, is_valid=False)
269234

tests/conftest.py

Whitespace-only changes.

0 commit comments

Comments
 (0)