|
| 1 | +import abc |
1 | 2 | import contextlib |
2 | 3 | import inspect |
3 | 4 | import operator |
4 | 5 | import re |
5 | 6 | from collections.abc import Callable, Iterable |
6 | | -from typing import Optional, Union |
| 7 | +from functools import reduce |
| 8 | +from typing import Optional |
7 | 9 |
|
8 | 10 | from more_itertools import always_iterable |
9 | 11 |
|
@@ -59,7 +61,125 @@ def _DeprecatedFieldFunc(field, data): |
59 | 61 | return _DeprecatedFieldFunc |
60 | 62 |
|
61 | 63 |
|
62 | | -class DerivedField: |
| 64 | +class DerivedFieldBase(abc.ABC): |
| 65 | + @abc.abstractmethod |
| 66 | + def __call__(self, field, data): |
| 67 | + pass |
| 68 | + |
| 69 | + @abc.abstractmethod |
| 70 | + def __repr__(self) -> str: |
| 71 | + pass |
| 72 | + |
| 73 | + # Multiplication (left and right side) |
| 74 | + def __mul__(self, other) -> "DerivedFieldCombination": |
| 75 | + return DerivedFieldCombination([self, other], op=operator.mul) |
| 76 | + |
| 77 | + def __rmul__(self, other) -> "DerivedFieldCombination": |
| 78 | + return DerivedFieldCombination([self, other], op=operator.mul) |
| 79 | + |
| 80 | + # Division (left side) |
| 81 | + def __truediv__(self, other) -> "DerivedFieldCombination": |
| 82 | + return DerivedFieldCombination([self, other], op=operator.truediv) |
| 83 | + |
| 84 | + def __rtruediv__(self, other) -> "DerivedFieldCombination": |
| 85 | + return DerivedFieldCombination([other, self], op=operator.truediv) |
| 86 | + |
| 87 | + # Addition (left and right side) |
| 88 | + def __add__(self, other) -> "DerivedFieldCombination": |
| 89 | + return DerivedFieldCombination([self, other], op=operator.add) |
| 90 | + |
| 91 | + def __radd__(self, other) -> "DerivedFieldCombination": |
| 92 | + return DerivedFieldCombination([self, other], op=operator.add) |
| 93 | + |
| 94 | + # Subtraction (left and right side) |
| 95 | + def __sub__(self, other) -> "DerivedFieldCombination": |
| 96 | + return DerivedFieldCombination([self, other], op=operator.sub) |
| 97 | + |
| 98 | + def __rsub__(self, other) -> "DerivedFieldCombination": |
| 99 | + return DerivedFieldCombination([other, self], op=operator.sub) |
| 100 | + |
| 101 | + # Unary minus |
| 102 | + def __neg__(self) -> "DerivedFieldCombination": |
| 103 | + return DerivedFieldCombination([self], op=operator.neg) |
| 104 | + |
| 105 | + # Comparison operators |
| 106 | + def __leq__(self, other) -> "DerivedFieldCombination": |
| 107 | + return DerivedFieldCombination([self, other], op=operator.le) |
| 108 | + |
| 109 | + def __lt__(self, other) -> "DerivedFieldCombination": |
| 110 | + return DerivedFieldCombination([self, other], op=operator.lt) |
| 111 | + |
| 112 | + def __geq__(self, other) -> "DerivedFieldCombination": |
| 113 | + return DerivedFieldCombination([self, other], op=operator.ge) |
| 114 | + |
| 115 | + def __gt__(self, other) -> "DerivedFieldCombination": |
| 116 | + return DerivedFieldCombination([self, other], op=operator.gt) |
| 117 | + |
| 118 | + # def __eq__(self, other) -> "DerivedFieldCombination": |
| 119 | + # return DerivedFieldCombination([self, other], op=operator.eq) |
| 120 | + |
| 121 | + def __ne__(self, other) -> "DerivedFieldCombination": |
| 122 | + return DerivedFieldCombination([self, other], op=operator.ne) |
| 123 | + |
| 124 | + |
| 125 | +class DerivedFieldCombination(DerivedFieldBase): |
| 126 | + sampling_type: str | None |
| 127 | + terms: list |
| 128 | + op: Callable | None |
| 129 | + |
| 130 | + def __init__(self, terms: list, op=None): |
| 131 | + if not terms: |
| 132 | + raise ValueError("DerivedFieldCombination requires at least one term.") |
| 133 | + |
| 134 | + # Make sure all terms have the same sampling type |
| 135 | + sampling_types = set() |
| 136 | + for term in terms: |
| 137 | + if isinstance(term, DerivedField): |
| 138 | + sampling_types.add(term.sampling_type) |
| 139 | + |
| 140 | + if len(sampling_types) > 1: |
| 141 | + raise ValueError( |
| 142 | + "All terms in a DerivedFieldCombination must " |
| 143 | + "have the same sampling type." |
| 144 | + ) |
| 145 | + self.sampling_type = sampling_types.pop() if sampling_types else None |
| 146 | + self.terms = terms |
| 147 | + self.op = op |
| 148 | + |
| 149 | + def __call__(self, field, data): |
| 150 | + """ |
| 151 | + Return the value of the field in a given data object. |
| 152 | + """ |
| 153 | + qties = [] |
| 154 | + for term in self.terms: |
| 155 | + if isinstance(term, DerivedField): |
| 156 | + qties.append(data[term.name]) |
| 157 | + elif isinstance(term, DerivedFieldCombination): |
| 158 | + qties.append(term(field, data)) |
| 159 | + else: |
| 160 | + qties.append(term) |
| 161 | + |
| 162 | + if len(qties) == 1: |
| 163 | + return self.op(qties[0]) |
| 164 | + else: |
| 165 | + return reduce(self.op, qties) |
| 166 | + |
| 167 | + def __repr__(self): |
| 168 | + return f"DerivedFieldCombination(terms={self.terms!r}, op={self.op!r})" |
| 169 | + |
| 170 | + def getDependentFields(self): |
| 171 | + fields = [] |
| 172 | + for term in self.terms: |
| 173 | + if isinstance(term, DerivedField): |
| 174 | + fields.append(term.name) |
| 175 | + elif isinstance(term, DerivedFieldCombination): |
| 176 | + fields.extend(term.getDependentFields()) |
| 177 | + else: |
| 178 | + continue |
| 179 | + return fields |
| 180 | + |
| 181 | + |
| 182 | +class DerivedField(DerivedFieldBase): |
63 | 183 | """ |
64 | 184 | This is the base class used to describe a cell-by-cell derived field. |
65 | 185 |
|
@@ -499,128 +619,6 @@ def __copy__(self): |
499 | 619 | nodal_flag=self.nodal_flag, |
500 | 620 | ) |
501 | 621 |
|
502 | | - def _operator( |
503 | | - self, other: Union["DerivedField", float], op: Callable |
504 | | - ) -> "DerivedField": |
505 | | - my_units = self.ds.get_unit_from_registry(self.units) |
506 | | - if isinstance(other, DerivedField): |
507 | | - if self.sampling_type != other.sampling_type: |
508 | | - raise TypeError( |
509 | | - f"Cannot {op} fields with different sampling types: " |
510 | | - f"{self.sampling_type} and {other.sampling_type}" |
511 | | - ) |
512 | | - |
513 | | - def wrapped(field, data): |
514 | | - return op(self(data), other(data)) |
515 | | - |
516 | | - other_name = other.name[1] |
517 | | - other_units = self.ds.get_unit_from_registry(other.units) |
518 | | - |
519 | | - else: |
520 | | - # Special case when passing (value, "unit") tuple |
521 | | - if isinstance(other, tuple) and len(other) == 2: |
522 | | - other = self.ds.quan(*other) |
523 | | - |
524 | | - def wrapped(field, data): |
525 | | - return op(self(data), other) |
526 | | - |
527 | | - other_name = str(other) |
528 | | - other_units = getattr(other, "units", self.ds.get_unit_from_registry("1")) |
529 | | - |
530 | | - if op in (operator.add, operator.sub, operator.eq): |
531 | | - assert my_units.same_dimensions_as(other_units) |
532 | | - new_units = my_units |
533 | | - elif op in (operator.mul, operator.truediv): |
534 | | - new_units = op(my_units, other_units) |
535 | | - elif op in (operator.le, operator.lt, operator.ge, operator.gt, operator.ne): |
536 | | - # Comparison yield unitless fields |
537 | | - new_units = Unit("1") |
538 | | - else: |
539 | | - raise TypeError(f"Unsupported operator {op} for DerivedField") |
540 | | - |
541 | | - return DerivedField( |
542 | | - name=(self.name[0], f"{self.name[1]}_{op.__name__}_{other_name}"), |
543 | | - sampling_type=self.sampling_type, |
544 | | - function=wrapped, |
545 | | - units=new_units, |
546 | | - ds=self.ds, |
547 | | - ) |
548 | | - |
549 | | - # Multiplication (left and right side) |
550 | | - def __mul__(self, other: Union["DerivedField", float]) -> "DerivedField": |
551 | | - return self._operator(other, op=operator.mul) |
552 | | - |
553 | | - def __rmul__(self, other: Union["DerivedField", float]) -> "DerivedField": |
554 | | - return self._operator(other, op=operator.mul) |
555 | | - |
556 | | - # Division (left side) |
557 | | - def __truediv__(self, other: Union["DerivedField", float]) -> "DerivedField": |
558 | | - return self._operator(other, op=operator.truediv) |
559 | | - |
560 | | - # Addition (left and right side) |
561 | | - def __add__(self, other: Union["DerivedField", float]) -> "DerivedField": |
562 | | - return self._operator(other, op=operator.add) |
563 | | - |
564 | | - def __radd__(self, other: Union["DerivedField", float]) -> "DerivedField": |
565 | | - return self._operator(other, op=operator.add) |
566 | | - |
567 | | - # Subtraction (left and right side) |
568 | | - def __sub__(self, other: Union["DerivedField", float]) -> "DerivedField": |
569 | | - return self._operator(other, op=operator.sub) |
570 | | - |
571 | | - def __rsub__(self, other: Union["DerivedField", float]) -> "DerivedField": |
572 | | - return self._operator(-other, op=operator.add) |
573 | | - |
574 | | - # Unary minus |
575 | | - def __neg__(self) -> "DerivedField": |
576 | | - def wrapped(field, data): |
577 | | - return -self(data) |
578 | | - |
579 | | - return DerivedField( |
580 | | - name=(self.name[0], f"neg_{self.name[1]}"), |
581 | | - sampling_type=self.sampling_type, |
582 | | - function=wrapped, |
583 | | - units=self.units, |
584 | | - ds=self.ds, |
585 | | - ) |
586 | | - |
587 | | - # Division (right side, a bit more complex) |
588 | | - def __rtruediv__(self, other: Union["DerivedField", float]) -> "DerivedField": |
589 | | - units = self.ds.get_unit_from_registry(self.units) |
590 | | - |
591 | | - def wrapped(field, data): |
592 | | - return 1 / self(data) |
593 | | - |
594 | | - inverse_self = DerivedField( |
595 | | - name=(self.name[0], f"inverse_{self.name[1]}"), |
596 | | - sampling_type=self.sampling_type, |
597 | | - function=wrapped, |
598 | | - units=units**-1, |
599 | | - ds=self.ds, |
600 | | - ) |
601 | | - |
602 | | - return inverse_self * other |
603 | | - |
604 | | - # Comparison operators |
605 | | - def __leq__(self, other: Union["DerivedField", float]) -> "DerivedField": |
606 | | - return self._operator(other, op=operator.le) |
607 | | - |
608 | | - def __lt__(self, other: Union["DerivedField", float]) -> "DerivedField": |
609 | | - return self._operator(other, op=operator.lt) |
610 | | - |
611 | | - def __geq__(self, other: Union["DerivedField", float]) -> "DerivedField": |
612 | | - return self._operator(other, op=operator.ge) |
613 | | - |
614 | | - def __gt__(self, other: Union["DerivedField", float]) -> "DerivedField": |
615 | | - return self._operator(other, op=operator.gt) |
616 | | - |
617 | | - # Somehow, makes yt not work? |
618 | | - # def __eq__(self, other: Union["DerivedField", float]) -> "DerivedField": |
619 | | - # return self._operator(other, op=operator.eq) |
620 | | - |
621 | | - def __ne__(self, other: Union["DerivedField", float]) -> "DerivedField": |
622 | | - return self._operator(other, op=operator.ne) |
623 | | - |
624 | 622 |
|
625 | 623 | class FieldValidator: |
626 | 624 | """ |
|
0 commit comments