|
116 | 116 |
|
117 | 117 | """
|
118 | 118 |
|
| 119 | +import dataclasses |
119 | 120 | from abc import ABC
|
| 121 | +from csv import DictWriter |
| 122 | +from dataclasses import dataclass |
120 | 123 | from enum import Enum
|
| 124 | +from inspect import isclass |
121 | 125 | from pathlib import Path
|
| 126 | +from types import TracebackType |
122 | 127 | from typing import Any
|
123 | 128 | from typing import Callable
|
124 | 129 | from typing import Dict
|
125 | 130 | from typing import Generic
|
| 131 | +from typing import Iterable |
126 | 132 | from typing import Iterator
|
127 | 133 | from typing import List
|
| 134 | +from typing import Optional |
| 135 | +from typing import Type |
128 | 136 | from typing import TypeVar
|
129 | 137 |
|
| 138 | +import attr |
| 139 | + |
130 | 140 | from fgpyo import io
|
131 | 141 | from fgpyo.util import inspect
|
132 | 142 |
|
@@ -334,3 +344,332 @@ def fast_concat(*inputs: Path, output: Path) -> None:
|
334 | 344 | io.write_lines(
|
335 | 345 | path=output, lines_to_write=list(io.read_lines(input_path))[1:], append=True
|
336 | 346 | )
|
| 347 | + |
| 348 | + |
| 349 | +def is_metric(cls: Any) -> bool: |
| 350 | + """True if the given class is a Metric.""" |
| 351 | + |
| 352 | + return ( |
| 353 | + isclass(cls) |
| 354 | + and issubclass(cls, Metric) |
| 355 | + and (dataclasses.is_dataclass(cls) or attr.has(cls)) |
| 356 | + ) |
| 357 | + |
| 358 | + |
| 359 | +@dataclass(kw_only=True) |
| 360 | +class MetricFileFormat: |
| 361 | + """ |
| 362 | + Parameters describing the format and configuration of a delimited Metric file. |
| 363 | +
|
| 364 | + Most of these parameters, if specified, are passed through to `csv.DictReader`/`csv.DictWriter`. |
| 365 | + """ |
| 366 | + |
| 367 | + delimiter: str = "\t" |
| 368 | + comment: str = "#" |
| 369 | + |
| 370 | + |
| 371 | +@dataclass(frozen=True, kw_only=True) |
| 372 | +class MetricFileHeader: |
| 373 | + """ |
| 374 | + Header of a file. |
| 375 | +
|
| 376 | + A file's header contains an optional preface, consisting of lines prefixed by a comment |
| 377 | + character and/or empty lines, and a required row of fieldnames before the data rows begin. |
| 378 | +
|
| 379 | + Attributes: |
| 380 | + preface: A list of any lines preceding the fieldnames. |
| 381 | + fieldnames: The field names specified in the final line of the header. |
| 382 | + """ |
| 383 | + |
| 384 | + preface: list[str] |
| 385 | + fieldnames: list[str] |
| 386 | + |
| 387 | + |
| 388 | +def get_header( |
| 389 | + reader: io.ReadableFileHandle, |
| 390 | + file_format: MetricFileFormat, |
| 391 | +) -> Optional[MetricFileHeader]: |
| 392 | + """ |
| 393 | + Read the header from an open file. |
| 394 | +
|
| 395 | + The first row after any commented or empty lines will be used as the fieldnames. |
| 396 | +
|
| 397 | + Lines preceding the fieldnames will be returned in the `preface.` |
| 398 | +
|
| 399 | + NB: This function returns `Optional` instead of raising an error because the name of the |
| 400 | + source file is not in scope, making it difficult to provide a helpful error message. It is |
| 401 | + the responsibility of the caller to raise an error if the file is empty. |
| 402 | +
|
| 403 | + See original proof-of-concept here: https://github.com/fulcrumgenomics/fgpyo/pull/103 |
| 404 | +
|
| 405 | + Args: |
| 406 | + reader: An open, readable file handle. |
| 407 | + file_format: A dataclass containing (at minimum) the file's delimiter and the string |
| 408 | + prefixing any comment lines. |
| 409 | +
|
| 410 | + Returns: |
| 411 | + A `FileHeader` containing the field names and any preceding lines. |
| 412 | + None if the file was empty or contained only comments or empty lines. |
| 413 | + """ |
| 414 | + |
| 415 | + preface: list[str] = [] |
| 416 | + |
| 417 | + for line in reader: |
| 418 | + if line.startswith(file_format.comment) or line.strip() == "": |
| 419 | + preface.append(line.strip()) |
| 420 | + else: |
| 421 | + break |
| 422 | + else: |
| 423 | + return None |
| 424 | + |
| 425 | + fieldnames = line.strip().split(file_format.delimiter) |
| 426 | + |
| 427 | + return MetricFileHeader(preface=preface, fieldnames=fieldnames) |
| 428 | + |
| 429 | + |
| 430 | +class MetricWriter: |
| 431 | + _metric_class: type[Metric] |
| 432 | + _fieldnames: list[str] |
| 433 | + _fout: io.WritableFileHandle |
| 434 | + _writer: DictWriter |
| 435 | + |
| 436 | + def __init__( |
| 437 | + self, |
| 438 | + filename: Path | str, |
| 439 | + metric_class: type[Metric], |
| 440 | + append: bool = False, |
| 441 | + delimiter: str = "\t", |
| 442 | + include_fields: list[str] | None = None, |
| 443 | + exclude_fields: list[str] | None = None, |
| 444 | + **kwds: Any, |
| 445 | + ) -> None: |
| 446 | + """ |
| 447 | + Args: |
| 448 | + path: Path to the file to write. |
| 449 | + metric_class: Metric class. |
| 450 | + append: If `True`, the file will be appended to. Otherwise, the specified file will be |
| 451 | + overwritten. |
| 452 | + delimiter: The output file delimiter. |
| 453 | + include_fields: If specified, only the listed fieldnames will be included when writing |
| 454 | + records to file. Fields will be written in the order provided. |
| 455 | + May not be used together with `exclude_fields`. |
| 456 | + exclude_fields: If specified, any listed fieldnames will be excluded when writing |
| 457 | + records to file. |
| 458 | + May not be used together with `include_fields`. |
| 459 | +
|
| 460 | + Raises: |
| 461 | + AssertionError: If the provided metric class is not a dataclass- or attr-decorated |
| 462 | + subclass of `Metric`. |
| 463 | + AssertionError: If the provided filepath is not writable. (Or readable, if |
| 464 | + `append=True`.) |
| 465 | + """ |
| 466 | + |
| 467 | + filepath: Path = filename if isinstance(filename, Path) else Path(filename) |
| 468 | + file_format = MetricFileFormat(delimiter=delimiter) |
| 469 | + |
| 470 | + assert is_metric( |
| 471 | + metric_class |
| 472 | + ), "Metric class must be a dataclass- or attr-decorated subclass of `Metric`." |
| 473 | + io.assert_path_is_writeable(filepath) |
| 474 | + if append: |
| 475 | + io.assert_path_is_readable(filepath) |
| 476 | + assert_file_header_matches_metric(filepath, metric_class, file_format) |
| 477 | + |
| 478 | + self._metric_class = metric_class |
| 479 | + self._fieldnames = _validate_output_fieldnames( |
| 480 | + metric_class=metric_class, |
| 481 | + include_fields=include_fields, |
| 482 | + exclude_fields=exclude_fields, |
| 483 | + ) |
| 484 | + self._fout = io.to_writer(filepath, append=append) |
| 485 | + self._writer = DictWriter( |
| 486 | + f=self._fout, |
| 487 | + fieldnames=self._fieldnames, |
| 488 | + delimiter=delimiter, |
| 489 | + ) |
| 490 | + |
| 491 | + # If we aren't appending to an existing file, write the header before any rows |
| 492 | + if not append: |
| 493 | + self._writer.writeheader() |
| 494 | + |
| 495 | + def __enter__(self) -> "MetricWriter": |
| 496 | + return self |
| 497 | + |
| 498 | + def __exit__( |
| 499 | + self, |
| 500 | + exc_type: Type[BaseException], |
| 501 | + exc_value: BaseException, |
| 502 | + traceback: TracebackType, |
| 503 | + ) -> None: |
| 504 | + self.close() |
| 505 | + |
| 506 | + def close(self) -> None: |
| 507 | + """Close the underlying file handle.""" |
| 508 | + self._fout.close() |
| 509 | + |
| 510 | + def write(self, metric: Metric) -> None: |
| 511 | + """ |
| 512 | + Write a single Metric instance to file. |
| 513 | +
|
| 514 | + The Metric is converted to a dictionary and then written using the underlying |
| 515 | + `csv.DictWriter`. If the `MetricWriter` was created using the `include_fields` or |
| 516 | + `exclude_fields` arguments, the attributes of the dataclass are subset and/or reordered |
| 517 | + accordingly before writing. |
| 518 | +
|
| 519 | + Args: |
| 520 | + metric: An instance of the specified Metric. |
| 521 | + """ |
| 522 | + if not isinstance(metric, self._metric_class): |
| 523 | + raise ValueError(f"Must provide instances of {self._metric_class.__name__}") |
| 524 | + |
| 525 | + # Serialize the Metric to a dict for writing by the underlying `DictWriter` |
| 526 | + row = asdict(metric) |
| 527 | + |
| 528 | + # Filter and/or re-order output fields if necessary |
| 529 | + row = {fieldname: row[fieldname] for fieldname in self._fieldnames} |
| 530 | + |
| 531 | + self._writer.writerow(row) |
| 532 | + |
| 533 | + def writeall(self, metrics: Iterable[Metric]) -> None: |
| 534 | + """ |
| 535 | + Write multiple Metric instances to file. |
| 536 | +
|
| 537 | + Each Metric is converted to a dictionary and then written using the underlying |
| 538 | + `csv.DictWriter`. If the `MetricWriter` was created using the `include_fields` or |
| 539 | + `exclude_fields` arguments, the attributes of each Metric are subset and/or reordered |
| 540 | + accordingly before writing. |
| 541 | +
|
| 542 | + Args: |
| 543 | + metrics: A sequence of instances of the specified Metric. |
| 544 | + """ |
| 545 | + for metric in metrics: |
| 546 | + self.write(metric) |
| 547 | + |
| 548 | + |
| 549 | +def assert_is_metric(cls: type[Metric]) -> None: |
| 550 | + """ |
| 551 | + Assert that the given class is a Metric. |
| 552 | +
|
| 553 | + Args: |
| 554 | + cls: A class object. |
| 555 | +
|
| 556 | + Raises: |
| 557 | + TypeError: If the given class is not a Metric. |
| 558 | + """ |
| 559 | + if not is_metric(cls): |
| 560 | + raise TypeError(f"Not a dataclass or attr decorated Metric: {cls}") |
| 561 | + |
| 562 | + |
| 563 | +def asdict(metric: Metric) -> dict[str, Any]: |
| 564 | + """Convert a Metric instance to a dictionary.""" |
| 565 | + assert_is_metric(type(metric)) |
| 566 | + |
| 567 | + if dataclasses.is_dataclass(metric): |
| 568 | + return dataclasses.asdict(metric) |
| 569 | + elif attr.has(metric): |
| 570 | + return attr.asdict(metric) |
| 571 | + else: |
| 572 | + assert False, "Unreachable" |
| 573 | + |
| 574 | + |
| 575 | +def get_fieldnames(metric_class: type[Metric]) -> list[str]: |
| 576 | + """ |
| 577 | + Get the fieldnames of the specified metric class. |
| 578 | +
|
| 579 | + Args: |
| 580 | + metric_class: A Metric class. |
| 581 | +
|
| 582 | + Returns: |
| 583 | + A list of fieldnames. |
| 584 | +
|
| 585 | + Raises: |
| 586 | + TypeError: If the given class is not a Metric. |
| 587 | + """ |
| 588 | + assert_is_metric(metric_class) |
| 589 | + |
| 590 | + if dataclasses.is_dataclass(metric_class): |
| 591 | + return [f.name for f in dataclasses.fields(metric_class)] |
| 592 | + elif attr.has(metric_class): |
| 593 | + return [f.name for f in attr.fields(metric_class)] |
| 594 | + else: |
| 595 | + assert False, "Unreachable" |
| 596 | + |
| 597 | + |
| 598 | +def assert_file_header_matches_metric( |
| 599 | + path: Path, |
| 600 | + metric_class: type[Metric], |
| 601 | + file_format: MetricFileFormat, |
| 602 | +) -> None: |
| 603 | + """ |
| 604 | + Check that the specified file has a header and its fields match those of the provided Metric. |
| 605 | + """ |
| 606 | + with path.open("r") as fin: |
| 607 | + header: MetricFileHeader = get_header(fin, file_format=file_format) |
| 608 | + |
| 609 | + if header is None: |
| 610 | + raise ValueError(f"Could not find a header in the provided file: {path}") |
| 611 | + |
| 612 | + if header.fieldnames != get_fieldnames(metric_class): |
| 613 | + raise ValueError( |
| 614 | + "The provided file does not have the same field names as the provided dataclass:\n" |
| 615 | + f"\tDataclass: {metric_class.__name__}\n" |
| 616 | + f"\tFile: {path}\n" |
| 617 | + f"\tDataclass fields: {', '.join(get_fieldnames(metric_class))}\n" |
| 618 | + f"\tFile: {', '.join(header.fieldnames)}\n" |
| 619 | + ) |
| 620 | + |
| 621 | + |
| 622 | +def assert_fieldnames_are_metric_attributes( |
| 623 | + specified_fieldnames: list[str], |
| 624 | + metric_class: type[MetricType], |
| 625 | +) -> None: |
| 626 | + """ |
| 627 | + Check that all of the specified fields are attributes on the given Metric. |
| 628 | +
|
| 629 | + Raises: |
| 630 | + ValueError: if any of the specified fieldnames are not an attribute on the given Metric. |
| 631 | + """ |
| 632 | + invalid_fieldnames = [f for f in specified_fieldnames if f not in get_fieldnames(metric_class)] |
| 633 | + |
| 634 | + if len(invalid_fieldnames) > 0: |
| 635 | + raise ValueError( |
| 636 | + "One or more of the specified fields are not attributes on the Metric " |
| 637 | + + f"{metric_class.__name__}: " |
| 638 | + + ", ".join(invalid_fieldnames) |
| 639 | + ) |
| 640 | + |
| 641 | + |
| 642 | +def _validate_output_fieldnames( |
| 643 | + metric_class: type[MetricType], |
| 644 | + include_fields: list[str] | None = None, |
| 645 | + exclude_fields: list[str] | None = None, |
| 646 | +) -> list[str]: |
| 647 | + """ |
| 648 | + Subset and/or re-order the dataclass's fieldnames based on the specified include/exclude lists. |
| 649 | +
|
| 650 | + * Only one of `include_fields` and `exclude_fields` may be specified. |
| 651 | + * All fieldnames specified in `include_fields` must be fields on `dataclass_type`. If this |
| 652 | + argument is specified, fields will be returned in the order they appear in the list. |
| 653 | + * All fieldnames specified in `exclude_fields` must be fields on `dataclass_type`. (This is |
| 654 | + technically unnecessary, but is a safeguard against passing an incorrect list.) |
| 655 | + * If neither `include_fields` or `exclude_fields` are specified, return the `dataclass_type`'s |
| 656 | + fieldnames. |
| 657 | +
|
| 658 | + Raises: |
| 659 | + ValueError: If both `include_fields` and `exclude_fields` are specified. |
| 660 | + """ |
| 661 | + |
| 662 | + if include_fields is not None and exclude_fields is not None: |
| 663 | + raise ValueError( |
| 664 | + "Only one of `include_fields` and `exclude_fields` may be specified, not both." |
| 665 | + ) |
| 666 | + elif exclude_fields is not None: |
| 667 | + assert_fieldnames_are_metric_attributes(exclude_fields, metric_class) |
| 668 | + output_fieldnames = [f for f in get_fieldnames(metric_class) if f not in exclude_fields] |
| 669 | + elif include_fields is not None: |
| 670 | + assert_fieldnames_are_metric_attributes(include_fields, metric_class) |
| 671 | + output_fieldnames = include_fields |
| 672 | + else: |
| 673 | + output_fieldnames = get_fieldnames(metric_class) |
| 674 | + |
| 675 | + return output_fieldnames |
0 commit comments