Skip to content

Commit 1ce2978

Browse files
committed
wip
wip wip wip
1 parent dd82a34 commit 1ce2978

File tree

3 files changed

+604
-2
lines changed

3 files changed

+604
-2
lines changed

fgpyo/io/__init__.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,15 @@
5656
from typing import Iterator
5757
from typing import Set
5858
from typing import TextIO
59+
from typing import TypeAlias
5960
from typing import Union
6061
from typing import cast
6162

6263
COMPRESSED_FILE_EXTENSIONS: Set[str] = {".gz", ".bgz"}
6364

65+
ReadableFileHandle: TypeAlias = Union[io.TextIOWrapper, TextIO, IO[Any]]
66+
WritableFileHandle: TypeAlias = Union[IO[Any], io.TextIOWrapper]
67+
6468

6569
def assert_path_is_readable(path: Path) -> None:
6670
"""Checks that file exists and returns True, else raises AssertionError
@@ -129,7 +133,7 @@ def assert_path_is_writeable(path: Path, parent_must_exist: bool = True) -> None
129133
raise AssertionError(f"No parent directories exist for: {path}")
130134

131135

132-
def to_reader(path: Path) -> Union[io.TextIOWrapper, TextIO, IO[Any]]:
136+
def to_reader(path: Path) -> ReadableFileHandle:
133137
"""Opens a Path for reading and based on extension uses open() or gzip.open()
134138
135139
Args:
@@ -147,7 +151,7 @@ def to_reader(path: Path) -> Union[io.TextIOWrapper, TextIO, IO[Any]]:
147151
return path.open(mode="r")
148152

149153

150-
def to_writer(path: Path, append: bool = False) -> Union[IO[Any], io.TextIOWrapper]:
154+
def to_writer(path: Path, append: bool = False) -> WritableFileHandle:
151155
"""Opens a Path for writing (or appending) and based on extension uses open() or gzip.open()
152156
153157
Args:

fgpyo/util/metric.py

+339
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,27 @@
116116
117117
"""
118118

119+
import dataclasses
119120
from abc import ABC
121+
from csv import DictWriter
122+
from dataclasses import dataclass
120123
from enum import Enum
124+
from inspect import isclass
121125
from pathlib import Path
126+
from types import TracebackType
122127
from typing import Any
123128
from typing import Callable
124129
from typing import Dict
125130
from typing import Generic
131+
from typing import Iterable
126132
from typing import Iterator
127133
from typing import List
134+
from typing import Optional
135+
from typing import Type
128136
from typing import TypeVar
129137

138+
import attr
139+
130140
from fgpyo import io
131141
from fgpyo.util import inspect
132142

@@ -334,3 +344,332 @@ def fast_concat(*inputs: Path, output: Path) -> None:
334344
io.write_lines(
335345
path=output, lines_to_write=list(io.read_lines(input_path))[1:], append=True
336346
)
347+
348+
349+
def is_metric(cls: Any) -> bool:
350+
"""True if the given class is a Metric."""
351+
352+
return (
353+
isclass(cls)
354+
and issubclass(cls, Metric)
355+
and (dataclasses.is_dataclass(cls) or attr.has(cls))
356+
)
357+
358+
359+
@dataclass(kw_only=True)
360+
class MetricFileFormat:
361+
"""
362+
Parameters describing the format and configuration of a delimited Metric file.
363+
364+
Most of these parameters, if specified, are passed through to `csv.DictReader`/`csv.DictWriter`.
365+
"""
366+
367+
delimiter: str = "\t"
368+
comment: str = "#"
369+
370+
371+
@dataclass(frozen=True, kw_only=True)
372+
class MetricFileHeader:
373+
"""
374+
Header of a file.
375+
376+
A file's header contains an optional preface, consisting of lines prefixed by a comment
377+
character and/or empty lines, and a required row of fieldnames before the data rows begin.
378+
379+
Attributes:
380+
preface: A list of any lines preceding the fieldnames.
381+
fieldnames: The field names specified in the final line of the header.
382+
"""
383+
384+
preface: list[str]
385+
fieldnames: list[str]
386+
387+
388+
def get_header(
389+
reader: io.ReadableFileHandle,
390+
file_format: MetricFileFormat,
391+
) -> Optional[MetricFileHeader]:
392+
"""
393+
Read the header from an open file.
394+
395+
The first row after any commented or empty lines will be used as the fieldnames.
396+
397+
Lines preceding the fieldnames will be returned in the `preface.`
398+
399+
NB: This function returns `Optional` instead of raising an error because the name of the
400+
source file is not in scope, making it difficult to provide a helpful error message. It is
401+
the responsibility of the caller to raise an error if the file is empty.
402+
403+
See original proof-of-concept here: https://github.com/fulcrumgenomics/fgpyo/pull/103
404+
405+
Args:
406+
reader: An open, readable file handle.
407+
file_format: A dataclass containing (at minimum) the file's delimiter and the string
408+
prefixing any comment lines.
409+
410+
Returns:
411+
A `FileHeader` containing the field names and any preceding lines.
412+
None if the file was empty or contained only comments or empty lines.
413+
"""
414+
415+
preface: list[str] = []
416+
417+
for line in reader:
418+
if line.startswith(file_format.comment) or line.strip() == "":
419+
preface.append(line.strip())
420+
else:
421+
break
422+
else:
423+
return None
424+
425+
fieldnames = line.strip().split(file_format.delimiter)
426+
427+
return MetricFileHeader(preface=preface, fieldnames=fieldnames)
428+
429+
430+
class MetricWriter:
431+
_metric_class: type[Metric]
432+
_fieldnames: list[str]
433+
_fout: io.WritableFileHandle
434+
_writer: DictWriter
435+
436+
def __init__(
437+
self,
438+
filename: Path | str,
439+
metric_class: type[Metric],
440+
append: bool = False,
441+
delimiter: str = "\t",
442+
include_fields: list[str] | None = None,
443+
exclude_fields: list[str] | None = None,
444+
**kwds: Any,
445+
) -> None:
446+
"""
447+
Args:
448+
path: Path to the file to write.
449+
metric_class: Metric class.
450+
append: If `True`, the file will be appended to. Otherwise, the specified file will be
451+
overwritten.
452+
delimiter: The output file delimiter.
453+
include_fields: If specified, only the listed fieldnames will be included when writing
454+
records to file. Fields will be written in the order provided.
455+
May not be used together with `exclude_fields`.
456+
exclude_fields: If specified, any listed fieldnames will be excluded when writing
457+
records to file.
458+
May not be used together with `include_fields`.
459+
460+
Raises:
461+
AssertionError: If the provided metric class is not a dataclass- or attr-decorated
462+
subclass of `Metric`.
463+
AssertionError: If the provided filepath is not writable. (Or readable, if
464+
`append=True`.)
465+
"""
466+
467+
filepath: Path = filename if isinstance(filename, Path) else Path(filename)
468+
file_format = MetricFileFormat(delimiter=delimiter)
469+
470+
assert is_metric(
471+
metric_class
472+
), "Metric class must be a dataclass- or attr-decorated subclass of `Metric`."
473+
io.assert_path_is_writeable(filepath)
474+
if append:
475+
io.assert_path_is_readable(filepath)
476+
assert_file_header_matches_metric(filepath, metric_class, file_format)
477+
478+
self._metric_class = metric_class
479+
self._fieldnames = _validate_output_fieldnames(
480+
metric_class=metric_class,
481+
include_fields=include_fields,
482+
exclude_fields=exclude_fields,
483+
)
484+
self._fout = io.to_writer(filepath, append=append)
485+
self._writer = DictWriter(
486+
f=self._fout,
487+
fieldnames=self._fieldnames,
488+
delimiter=delimiter,
489+
)
490+
491+
# If we aren't appending to an existing file, write the header before any rows
492+
if not append:
493+
self._writer.writeheader()
494+
495+
def __enter__(self) -> "MetricWriter":
496+
return self
497+
498+
def __exit__(
499+
self,
500+
exc_type: Type[BaseException],
501+
exc_value: BaseException,
502+
traceback: TracebackType,
503+
) -> None:
504+
self.close()
505+
506+
def close(self) -> None:
507+
"""Close the underlying file handle."""
508+
self._fout.close()
509+
510+
def write(self, metric: Metric) -> None:
511+
"""
512+
Write a single Metric instance to file.
513+
514+
The Metric is converted to a dictionary and then written using the underlying
515+
`csv.DictWriter`. If the `MetricWriter` was created using the `include_fields` or
516+
`exclude_fields` arguments, the attributes of the dataclass are subset and/or reordered
517+
accordingly before writing.
518+
519+
Args:
520+
metric: An instance of the specified Metric.
521+
"""
522+
if not isinstance(metric, self._metric_class):
523+
raise ValueError(f"Must provide instances of {self._metric_class.__name__}")
524+
525+
# Serialize the Metric to a dict for writing by the underlying `DictWriter`
526+
row = asdict(metric)
527+
528+
# Filter and/or re-order output fields if necessary
529+
row = {fieldname: row[fieldname] for fieldname in self._fieldnames}
530+
531+
self._writer.writerow(row)
532+
533+
def writeall(self, metrics: Iterable[Metric]) -> None:
534+
"""
535+
Write multiple Metric instances to file.
536+
537+
Each Metric is converted to a dictionary and then written using the underlying
538+
`csv.DictWriter`. If the `MetricWriter` was created using the `include_fields` or
539+
`exclude_fields` arguments, the attributes of each Metric are subset and/or reordered
540+
accordingly before writing.
541+
542+
Args:
543+
metrics: A sequence of instances of the specified Metric.
544+
"""
545+
for metric in metrics:
546+
self.write(metric)
547+
548+
549+
def assert_is_metric(cls: type[Metric]) -> None:
550+
"""
551+
Assert that the given class is a Metric.
552+
553+
Args:
554+
cls: A class object.
555+
556+
Raises:
557+
TypeError: If the given class is not a Metric.
558+
"""
559+
if not is_metric(cls):
560+
raise TypeError(f"Not a dataclass or attr decorated Metric: {cls}")
561+
562+
563+
def asdict(metric: Metric) -> dict[str, Any]:
564+
"""Convert a Metric instance to a dictionary."""
565+
assert_is_metric(type(metric))
566+
567+
if dataclasses.is_dataclass(metric):
568+
return dataclasses.asdict(metric)
569+
elif attr.has(metric):
570+
return attr.asdict(metric)
571+
else:
572+
assert False, "Unreachable"
573+
574+
575+
def get_fieldnames(metric_class: type[Metric]) -> list[str]:
576+
"""
577+
Get the fieldnames of the specified metric class.
578+
579+
Args:
580+
metric_class: A Metric class.
581+
582+
Returns:
583+
A list of fieldnames.
584+
585+
Raises:
586+
TypeError: If the given class is not a Metric.
587+
"""
588+
assert_is_metric(metric_class)
589+
590+
if dataclasses.is_dataclass(metric_class):
591+
return [f.name for f in dataclasses.fields(metric_class)]
592+
elif attr.has(metric_class):
593+
return [f.name for f in attr.fields(metric_class)]
594+
else:
595+
assert False, "Unreachable"
596+
597+
598+
def assert_file_header_matches_metric(
599+
path: Path,
600+
metric_class: type[Metric],
601+
file_format: MetricFileFormat,
602+
) -> None:
603+
"""
604+
Check that the specified file has a header and its fields match those of the provided Metric.
605+
"""
606+
with path.open("r") as fin:
607+
header: MetricFileHeader = get_header(fin, file_format=file_format)
608+
609+
if header is None:
610+
raise ValueError(f"Could not find a header in the provided file: {path}")
611+
612+
if header.fieldnames != get_fieldnames(metric_class):
613+
raise ValueError(
614+
"The provided file does not have the same field names as the provided dataclass:\n"
615+
f"\tDataclass: {metric_class.__name__}\n"
616+
f"\tFile: {path}\n"
617+
f"\tDataclass fields: {', '.join(get_fieldnames(metric_class))}\n"
618+
f"\tFile: {', '.join(header.fieldnames)}\n"
619+
)
620+
621+
622+
def assert_fieldnames_are_metric_attributes(
623+
specified_fieldnames: list[str],
624+
metric_class: type[MetricType],
625+
) -> None:
626+
"""
627+
Check that all of the specified fields are attributes on the given Metric.
628+
629+
Raises:
630+
ValueError: if any of the specified fieldnames are not an attribute on the given Metric.
631+
"""
632+
invalid_fieldnames = [f for f in specified_fieldnames if f not in get_fieldnames(metric_class)]
633+
634+
if len(invalid_fieldnames) > 0:
635+
raise ValueError(
636+
"One or more of the specified fields are not attributes on the Metric "
637+
+ f"{metric_class.__name__}: "
638+
+ ", ".join(invalid_fieldnames)
639+
)
640+
641+
642+
def _validate_output_fieldnames(
643+
metric_class: type[MetricType],
644+
include_fields: list[str] | None = None,
645+
exclude_fields: list[str] | None = None,
646+
) -> list[str]:
647+
"""
648+
Subset and/or re-order the dataclass's fieldnames based on the specified include/exclude lists.
649+
650+
* Only one of `include_fields` and `exclude_fields` may be specified.
651+
* All fieldnames specified in `include_fields` must be fields on `dataclass_type`. If this
652+
argument is specified, fields will be returned in the order they appear in the list.
653+
* All fieldnames specified in `exclude_fields` must be fields on `dataclass_type`. (This is
654+
technically unnecessary, but is a safeguard against passing an incorrect list.)
655+
* If neither `include_fields` or `exclude_fields` are specified, return the `dataclass_type`'s
656+
fieldnames.
657+
658+
Raises:
659+
ValueError: If both `include_fields` and `exclude_fields` are specified.
660+
"""
661+
662+
if include_fields is not None and exclude_fields is not None:
663+
raise ValueError(
664+
"Only one of `include_fields` and `exclude_fields` may be specified, not both."
665+
)
666+
elif exclude_fields is not None:
667+
assert_fieldnames_are_metric_attributes(exclude_fields, metric_class)
668+
output_fieldnames = [f for f in get_fieldnames(metric_class) if f not in exclude_fields]
669+
elif include_fields is not None:
670+
assert_fieldnames_are_metric_attributes(include_fields, metric_class)
671+
output_fieldnames = include_fields
672+
else:
673+
output_fieldnames = get_fieldnames(metric_class)
674+
675+
return output_fieldnames

0 commit comments

Comments
 (0)