|
1 | 1 | """Internal data model.""" |
2 | 2 |
|
| 3 | +import pathlib |
3 | 4 | from typing import Optional |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import polars as pl |
| 8 | +import pydantic |
7 | 9 | from pydantic import BaseModel, field_validator |
8 | 10 |
|
| 11 | +from wristpy.core import config, exceptions |
| 12 | + |
| 13 | +VALID_FILE_TYPES = (".csv", ".parquet") |
| 14 | + |
| 15 | +logger = config.get_logger() |
| 16 | + |
9 | 17 |
|
10 | 18 | class Measurement(BaseModel): |
11 | 19 | """A single measurement of a sensor and its corresponding time.""" |
@@ -125,3 +133,59 @@ def validate_acceleration(cls, v: Measurement) -> Measurement: |
125 | 133 | if v.measurements.ndim != 2 or v.measurements.shape[1] != 3: |
126 | 134 | raise ValueError("acceleration must be a 2D array with 3 columns") |
127 | 135 | return v |
| 136 | + |
| 137 | + |
| 138 | +class OrchestratorResults(pydantic.BaseModel): |
| 139 | + """Dataclass containing results of orchestrator.run().""" |
| 140 | + |
| 141 | + enmo: Measurement |
| 142 | + anglez: Measurement |
| 143 | + physical_activity_levels: Measurement |
| 144 | + nonwear_epoch: Measurement |
| 145 | + sleep_windows_epoch: Measurement |
| 146 | + |
| 147 | + def save_results(self, output: pathlib.Path) -> None: |
| 148 | + """Convert to polars and save the dataframe as a csv or parquet file. |
| 149 | +
|
| 150 | + Args: |
| 151 | + output: The path and file name of the data to be saved. as either a csv or |
| 152 | + parquet files. |
| 153 | +
|
| 154 | + """ |
| 155 | + logger.debug("Saving results.") |
| 156 | + self.validate_output(output=output) |
| 157 | + output.parent.mkdir(parents=True, exist_ok=True) |
| 158 | + |
| 159 | + results_dataframe = pl.DataFrame( |
| 160 | + {"time": self.enmo.time} |
| 161 | + | {name: value.measurements for name, value in self} |
| 162 | + ) |
| 163 | + |
| 164 | + if output.suffix == ".csv": |
| 165 | + results_dataframe.write_csv(output, separator=",") |
| 166 | + elif output.suffix == ".parquet": |
| 167 | + results_dataframe.write_parquet(output) |
| 168 | + else: |
| 169 | + raise exceptions.InvalidFileTypeError( |
| 170 | + f"File type must be one of {VALID_FILE_TYPES}" |
| 171 | + ) |
| 172 | + |
| 173 | + logger.debug("results saved in: %s", output) |
| 174 | + |
| 175 | + @classmethod |
| 176 | + def validate_output(cls, output: pathlib.Path) -> None: |
| 177 | + """Validates that the output path exists and is a valid format. |
| 178 | +
|
| 179 | + Args: |
| 180 | + output: the name of the file to be saved, and the directory it will |
| 181 | + be saved in. Must be a .csv or .parquet file. |
| 182 | +
|
| 183 | + Raises: |
| 184 | + InvalidFileTypeError:If the output file path ends with any extension other |
| 185 | + than csv or parquet. |
| 186 | + """ |
| 187 | + if output.suffix not in VALID_FILE_TYPES: |
| 188 | + raise exceptions.InvalidFileTypeError( |
| 189 | + f"The extension: {output.suffix} is not supported." |
| 190 | + "Please save the file as .csv or .parquet", |
| 191 | + ) |
0 commit comments