Skip to content

Commit 6da914c

Browse files
committed
Further typer modifications
1 parent 9bd758e commit 6da914c

2 files changed

Lines changed: 120 additions & 194 deletions

File tree

src/wristpy/core/cli.py

Lines changed: 58 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import logging
44
import pathlib
5+
from enum import Enum
56
from typing import List, Literal, Optional, Tuple, Union, cast
67

78
import typer
89

9-
from wristpy.core import config, orchestrator
10+
from wristpy.core import config
1011

1112
logger = config.get_logger()
1213
app = typer.Typer(
@@ -15,35 +16,36 @@
1516
)
1617

1718

18-
def _none_or_float_list(value: str) -> Optional[List[float]]:
19-
"""Helper function to process thresholds argument."""
20-
if value.lower() == "none":
21-
return None
22-
try:
23-
float_list = [float(v) for v in value.split(",")]
24-
if len(float_list) != 3:
25-
raise typer.BadParameter(
26-
f"Invalid value: {value}."
27-
"Must be a comma-separated list of exactly three numbers or 'None'."
28-
)
29-
return float_list
30-
except ValueError:
31-
raise typer.BadParameter(
32-
f"Invalid value: {value}. Must be a comma-separated list or 'None'."
33-
)
34-
35-
36-
def _parse_nonwear_algorithms(algorithm_name: str) -> List[str]:
37-
"""Parse comma-separated non-wear algorithm names."""
38-
valid_algorithm_names = ["ggir", "cta", "detach"]
39-
algorithms = [algo.strip().lower() for algo in algorithm_name.split(",")]
40-
for algo in algorithms:
41-
if algo not in valid_algorithm_names:
42-
raise typer.BadParameter(
43-
f"Invalid algorithm: '{algo}'. Must be one of: "
44-
f"{', '.join(valid_algorithm_names)}."
45-
)
46-
return algorithms
19+
class Calibrator(str, Enum):
20+
"""Setting a calibrator class for typer.
21+
22+
This class is used to define the literal types that are allowed for
23+
calibration, and parsing the strings for the orchestrator.
24+
"""
25+
26+
none = "none"
27+
ggir = "ggir"
28+
gradient = "gradient"
29+
30+
31+
class ActivityMetric(str, Enum):
32+
"""Valid activity metrics for physical activity categorization."""
33+
34+
enmo = "enmo"
35+
mad = "mad"
36+
ag_count = "ag_count"
37+
38+
39+
class NonwearAlgorithms(str, Enum):
40+
"""Setting a nonwear algorithm class for typer.
41+
42+
This class is used to define the literal types that are allowed for
43+
nonwear algorithms, and parsing the strings for the orchestrator.
44+
"""
45+
46+
ggir = "ggir"
47+
cta = "cta"
48+
detach = "detach"
4749

4850

4951
@app.command()
@@ -64,50 +66,47 @@ def main(
6466
help="Format for save files when processing directories. "
6567
"Leave as None when processing single files.",
6668
),
67-
calibrator: Union[
69+
calibrator: Calibrator = typer.Option(
6870
None,
69-
Literal["ggir", "gradient"],
70-
] = typer.Option(
71-
"none",
7271
"-c",
7372
"--calibrator",
74-
help="Pick which calibrator to use.",
73+
help="Pick which calibrator to use."
74+
"Must choose one of 'none', 'ggir', or 'gradient'.",
7575
case_sensitive=False,
76-
callback=lambda x: x.lower() if x else x,
7776
),
78-
activity_metric: str = typer.Option(
79-
"enmo",
77+
activity_metric: ActivityMetric = typer.Option(
78+
ActivityMetric.enmo,
8079
"-a",
8180
"--activity-metric",
82-
help="Pick which physical activity metric should be used for physical activity categorization.",
81+
help="Metric should be used for physical activity categorization."
82+
"Choose from 'enmo', 'mad', or 'ag_count'.",
8383
case_sensitive=False,
84-
callback=lambda x: x.lower() if x else x,
8584
),
86-
thresholds: Optional[str] = typer.Option(
85+
thresholds: tuple[float, float, float] = typer.Option(
8786
None,
8887
"-t",
8988
"--thresholds",
9089
help="Provide three thresholds for light, moderate, and vigorous activity. "
91-
"Exactly three values must be given in ascending order, and comma seperated.",
92-
callback=_none_or_float_list,
90+
"Exactly three values must be >= 0, given in ascending order,"
91+
" and separated by a space.",
92+
min=0,
9393
),
94-
nonwear_algorithm: List[str] = typer.Option(
95-
["ggir"],
96-
"-nw",
94+
nonwear_algorithm: list[NonwearAlgorithms] = typer.Option(
95+
[NonwearAlgorithms.ggir],
96+
"-n",
9797
"--nonwear-algorithm",
9898
help="Specify the non-wear detection algorithm(s) to use. "
99-
"Specify one or more of 'ggir', 'cta', 'detach' as a comma-separated list "
100-
"(e.g. 'ggir,detach'). "
99+
"Specify one or more of 'ggir', 'cta', 'detach'. "
100+
"(e.g. '-n ggir -n cta'). "
101101
"When multiple algorithms are specified, majority voting will be applied.",
102-
callback=_parse_nonwear_algorithms,
103102
),
104103
epoch_length: int = typer.Option(
105104
5,
106105
"-e",
107106
"--epoch-length",
108107
help="Specify the sampling rate in seconds for all metrics. "
109-
"Must be greater than 0.",
110-
min=0,
108+
"Must be greater than 1.",
109+
min=1,
111110
),
112111
verbosity: int = typer.Option(
113112
0,
@@ -122,6 +121,8 @@ def main(
122121
),
123122
) -> None:
124123
"""Run wristpy orchestrator with command line arguments."""
124+
from wristpy.core import orchestrator
125+
125126
if version:
126127
typer.echo(f"Wristpy version: {config.get_version()}")
127128
raise typer.Exit()
@@ -134,17 +135,17 @@ def main(
134135
log_level = logging.DEBUG
135136
logger.setLevel(log_level)
136137

138+
nonwear_algorithms = [algo.value for algo in nonwear_algorithm]
139+
137140
logger.debug("Running wristpy. arguments given: %s", locals())
138141
orchestrator.run(
139142
input=input,
140143
output=output,
141-
calibrator=None if calibrator == "none" else calibrator,
142-
activity_metric=activity_metric,
143-
thresholds=None
144-
if thresholds is None
145-
else cast(Tuple[float, float, float], tuple(thresholds)),
144+
calibrator=calibrator.value if calibrator else None,
145+
activity_metric=activity_metric.value,
146+
thresholds=None if thresholds is None else thresholds,
146147
epoch_length=epoch_length,
147-
nonwear_algorithm=nonwear_algorithm,
148+
nonwear_algorithm=nonwear_algorithms,
148149
verbosity=log_level,
149150
output_filetype=output_filetype,
150151
)

0 commit comments

Comments
 (0)