Skip to content

Commit 89db955

Browse files
authored
Feature/issue 166/multiple metrics (#233)
* orchestrator input changes * loop over metric * names to function calls. * multipule metrics, pal * multiple metrics, pal working fix for orchestrator * cli fix * revert bug check change * threshold change * threshold parsing. * cli str fix. * working tests * documentation. * post-merge fixes. * Json error * ruff error * reverting old test change. * typer error * mypy errors * test for typer error. * removing validation. * removing redundant test causing error. * multiple metrics tests for cli and orchestrator. * extra test for missmatch error. * review changes. * mypy error demands tuple unpacking
1 parent 2165b3e commit 89db955

11 files changed

Lines changed: 361 additions & 157 deletions

File tree

src/wristpy/core/cli.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,36 @@ def version_check(version: bool) -> None:
6262
raise typer.Exit()
6363

6464

65+
def _parse_thresholds(thresholds: list[str]) -> list[tuple[float, float, float]]:
66+
"""Parse the threshold strings into a list of tuples.
67+
68+
Args:
69+
thresholds: List of threshold strings, each containing three space-separated
70+
floats.
71+
72+
Returns:
73+
List of tuple float triplets containing the parsed threshold values.
74+
75+
Raises:
76+
typer.BadParameter: If any threshold triplet does not contain exactly three
77+
floats
78+
typer.BadParameter: If threshold format is invalid or values cannot be parsed.
79+
"""
80+
parsed = []
81+
for triplet_str in thresholds:
82+
parts = triplet_str.strip().split()
83+
if len(parts) != 3:
84+
raise typer.BadParameter(
85+
f"Threshold triplet must have exactly 3 floats: {triplet_str}"
86+
)
87+
try:
88+
values = [float(part) for part in parts]
89+
except ValueError:
90+
raise typer.BadParameter(f"Invalid float in threshold: {triplet_str}")
91+
parsed.append((values[0], values[1], values[2]))
92+
return parsed
93+
94+
6595
@app.command()
6696
def main(
6797
input: pathlib.Path = typer.Argument(
@@ -87,22 +117,24 @@ def main(
87117
"Must choose one of 'none', 'ggir', or 'gradient'.",
88118
case_sensitive=False,
89119
),
90-
activity_metric: ActivityMetric = typer.Option(
91-
ActivityMetric.enmo,
120+
activity_metric: list[ActivityMetric] = typer.Option(
121+
[ActivityMetric.enmo],
92122
"-a",
93123
"--activity-metric",
94-
help="Metric used for physical activity categorization. "
95-
"Choose from 'enmo', 'mad', 'ag_count', or 'mims'. ",
124+
help="Metric(s) used for physical activity categorization. "
125+
"Choose from 'enmo', 'mad', 'ag_count', or 'mims'. "
126+
"Use multiple times for multiple metrics: '-a enmo -a mad' etc.",
96127
case_sensitive=False,
97128
),
98-
thresholds: tuple[float, float, float] = typer.Option(
129+
# Typer does not support list[tuple[...]]
130+
thresholds: list[str] = typer.Option(
99131
None,
100132
"-t",
101133
"--thresholds",
102134
help="Provide three thresholds for light, moderate, and vigorous activity. "
103-
"Exactly three values must be >= 0, given in ascending order, "
104-
"and separated by a space. (e.g. '-t 0.1 1.0 1.5').",
105-
min=0,
135+
"One threshold set per activity metric, in the same order as metrics. "
136+
"Format: three space-separated values >= 0 in ascending order. "
137+
"Example: -t '0.1 1.0 1.5' or -a enmo -a mad -t '0.1 1.0 1.5' -t '0.2 2.0 3.0'",
106138
),
107139
nonwear_algorithm: list[NonwearAlgorithms] = typer.Option(
108140
[NonwearAlgorithms.ggir],
@@ -146,6 +178,8 @@ def main(
146178
logger.setLevel(log_level)
147179

148180
nonwear_algorithms = [algo.value for algo in nonwear_algorithm]
181+
activity_metrics = [metric.value for metric in activity_metric]
182+
parsed_thresholds = _parse_thresholds(thresholds) if thresholds else None
149183
calibrator_value = None if calibrator == Calibrator.none else calibrator.value
150184

151185
logger.debug("Running wristpy. arguments given: %s", locals())
@@ -154,8 +188,8 @@ def main(
154188
input=input,
155189
output=output,
156190
calibrator=calibrator_value,
157-
activity_metric=activity_metric.value,
158-
thresholds=None if thresholds is None else thresholds,
191+
activity_metric=activity_metrics, # type: ignore[arg-type] # Covered by ActivityMetric Enum class
192+
thresholds=parsed_thresholds,
159193
epoch_length=epoch_length,
160194
nonwear_algorithm=nonwear_algorithms, # type: ignore[arg-type] # Covered by NonwearAlgorithm Enum class
161195
verbosity=log_level,

src/wristpy/core/computations.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def _moving(
1414
*,
1515
centered: bool = False,
1616
continuous: bool = False,
17+
name: str | None = None,
1718
) -> models.Measurement:
1819
"""Internal handler of rolling window functions.
1920
@@ -24,6 +25,7 @@ def _moving(
2425
centered: If true, centers the window. Defaults to False.
2526
continuous: If true, applies the window to every measurement. If false,
2627
groups measurements into chunks of epoch_length. Defaults to False.
28+
name: The name of the Measurement object.
2729
2830
Returns:
2931
The measurement with the rolling function applied to it.
@@ -54,25 +56,26 @@ def _moving(
5456
.agg(aggregator())
5557
)
5658

57-
return models.Measurement.from_data_frame(aggregated_df.collect())
59+
return models.Measurement.from_data_frame(aggregated_df.collect(), name=name)
5860

5961

6062
def moving_mean(
61-
array: models.Measurement, epoch_length: float = 5
63+
array: models.Measurement, epoch_length: float = 5, name: str | None = None
6264
) -> models.Measurement:
6365
"""Calculate the moving mean of the sensor data in array.
6466
6567
Args:
6668
array: The Measurement object with the sensor data we want to take the mean of
6769
epoch_length: The length, in seconds, of the window.
70+
name: The name of the Measurement object.
6871
6972
Returns:
7073
The moving mean of the array in a new Measurement instance.
7174
7275
Raises:
7376
ValueError: If the epoch length is not an integer or is less than 1.
7477
"""
75-
return _moving(array, epoch_length, "mean")
78+
return _moving(array, epoch_length, "mean", name=name)
7679

7780

7881
def moving_std(

src/wristpy/core/models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,23 @@
1616
class Measurement(BaseModel):
1717
"""A single measurement of a sensor and its corresponding time."""
1818

19+
name: str | None = None
1920
measurements: np.ndarray
2021
time: pl.Series
2122

2223
@classmethod
23-
def from_data_frame(cls, data_frame: pl.DataFrame) -> "Measurement":
24+
def from_data_frame(
25+
cls, data_frame: pl.DataFrame, name: str | None = None
26+
) -> "Measurement":
2427
"""Creates a measurement from a Polars DataFrame.
2528
2629
Args:
2730
data_frame: The Polars DataFrame, must have a time column. All
2831
non-time columns will be used as the 'measurements' input.
32+
name: Optional name describing the type of measurement.
2933
"""
3034
return Measurement(
35+
name=name,
3136
measurements=data_frame.drop("time").to_numpy().squeeze(),
3237
time=data_frame["time"],
3338
)

0 commit comments

Comments
 (0)