Skip to content

Commit 4c82368

Browse files
committed
feat[DEI-263]: year filter added for multi-year aggregation
1 parent c440e69 commit 4c82368

3 files changed

Lines changed: 60 additions & 2 deletions

File tree

decoimpact/business/entities/rules/time_aggregation_rule.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
TimeAggregationRule
1212
"""
1313

14-
from typing import List
14+
from typing import List, Optional
1515

1616
import numpy as _np
1717
import xarray as _xr
@@ -35,18 +35,32 @@ def __init__(
3535
name: str,
3636
input_variable_names: List[str],
3737
operation_type: TimeOperationType,
38+
start_year: Optional[int] = None,
39+
end_year: Optional[int] = None,
3840
):
3941
super().__init__(name, input_variable_names)
4042
self._settings = TimeOperationSettings({"month": "ME", "year": "YE"})
4143
self._settings.percentile_value = 0
4244
self._settings.operation_type = operation_type
4345
self._settings.time_scale = "year"
46+
self._start_year = start_year
47+
self._end_year = end_year
4448

4549
@property
4650
def settings(self):
4751
"""Time operation settings"""
4852
return self._settings
4953

54+
@property
55+
def start_year(self) -> Optional[int]:
56+
"""Start year for the aggregation (inclusive)"""
57+
return self._start_year
58+
59+
@property
60+
def end_year(self) -> Optional[int]:
61+
"""End year for the aggregation (inclusive)"""
62+
return self._end_year
63+
5064
def validate(self, logger: ILogger) -> bool:
5165
"""Validates if the rule is valid
5266
@@ -86,6 +100,10 @@ def execute(self, value_array: _xr.DataArray, logger: ILogger) -> _xr.DataArray:
86100
result = _xr.DataArray(data=_np.empty(0), dims=("time",), coords={"time": []})
87101
# perform aggregations in case of multi-year monthly average
88102
if TimeOperationType.MULTI_YEAR_MONTHLY_AVERAGE == settings.operation_type:
103+
start = str(self._start_year) if self._start_year is not None else None
104+
end = str(self._end_year) if self._end_year is not None else None
105+
slice_obj = slice(start, end)
106+
value_array = value_array.sel({time_dim_name: slice_obj})
89107
grouped_values = value_array.groupby(f"{time_dim_name}.month")
90108
result = self._perform_grouping_operation(
91109
grouped_values, settings.operation_type, time_dim_name

decoimpact/data/api/i_time_aggregation_rule_data.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from abc import ABC, abstractmethod
1717

18+
from pyparsing import Optional
19+
1820
from decoimpact.data.api.i_rule_data import IRuleData
1921
from decoimpact.data.api.time_operation_type import TimeOperationType
2022

@@ -41,3 +43,13 @@ def percentile_value(self) -> float:
4143
@abstractmethod
4244
def time_scale(self) -> str:
4345
"""Time scale"""
46+
47+
@property
48+
@abstractmethod
49+
def start_year(self) -> Optional[int]:
50+
"""Start year for the aggregation (inclusive)"""
51+
52+
@property
53+
@abstractmethod
54+
def end_year(self) -> Optional[int]:
55+
"""End year for the aggregation (inclusive)"""

tests/business/entities/rules/test_time_aggregation_rule.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def test_execute_value_array_aggregate_time_months_percentile():
438438
)
439439

440440

441-
def test_execute_value_array_aggregate_time_multi_yearly_month_average():
441+
def test_multi_year_monthly_average():
442442
"""Aggregate input_variable_names of a TimeAggregationRule (average, months)"""
443443

444444
# create test set
@@ -463,6 +463,34 @@ def test_execute_value_array_aggregate_time_multi_yearly_month_average():
463463
)
464464

465465

466+
def test_multi_yearly_month_average_with_year_range():
467+
"""Aggregate input_variable_names of a TimeAggregationRule (average, months) with year range"""
468+
469+
# create test set
470+
logger = Mock(ILogger)
471+
rule = TimeAggregationRule(
472+
name="test",
473+
input_variable_names=["foo"],
474+
operation_type=TimeOperationType.MULTI_YEAR_MONTHLY_AVERAGE,
475+
start_year=2020,
476+
end_year=2020,
477+
)
478+
rule.settings.time_scale = "month"
479+
480+
time_aggregation = rule.execute(value_array_multi_year_monthly, logger)
481+
482+
# result_data = [0.15, 0.45, 0.3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
483+
result_data = [0.1, 0.7, 0.2, 4, 5, 6, 7, 8, 9, 10, 11, 12]
484+
result_array = _xr.DataArray(
485+
result_data, coords=[result_time_multi_year_monthly], dims=["time_monthly"]
486+
)
487+
488+
# Assert
489+
assert (
490+
_xr.testing.assert_allclose(time_aggregation, result_array, atol=1e-11) is None
491+
)
492+
493+
466494
def test_operation_type_not_implemented():
467495
"""Test that the time aggregation rule gives an error
468496
if no operation_type is given"""

0 commit comments

Comments
 (0)