Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ Useful controls:
- Click a signal in the right-hand signal list or in the channel overview heatmap to switch channels.
- Use the Matplotlib zoom and pan tools on the main plot to inspect parts of the signal in detail.
- Click `reset` or press `home` to reset the zoom.
- Use the overlay buttons to toggle `trend`, `spike`, `drop`, and `nonwear` overlays.
- Use the overlay buttons to toggle `trend`, `spike`, and `watch` overlays.
- Use the `stats`, `events`, `captions`, and `help` tabs in the details panel to switch what metadata is shown.
- Scroll inside the details panel with the mouse wheel or the `^` / `v` buttons.
2 changes: 2 additions & 0 deletions captionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def run(
from mhc.dataset import MHCDataset
from mhc.transformer import MHCTransformer
from mhc.constants import MHC_CHANNEL_CONFIG
from extractors.device_wear import DeviceWearExtractor
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from models.local import LocalConfig, LocalModel
Expand All @@ -65,6 +66,7 @@ def run(
annotator = Annotator([
StatisticalExtractor(MHC_CHANNEL_CONFIG),
StructuralExtractor(MHC_CHANNEL_CONFIG),
DeviceWearExtractor(MHC_CHANNEL_CONFIG),
SemanticExtractor(MHC_CHANNEL_CONFIG),
])

Expand Down
76 changes: 53 additions & 23 deletions explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from annotator import Annotator
from extractors import ChannelConfig
from extractors.device_wear import DeviceWearExtractor
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from mhc.constants import MHC_CHANNEL_CONFIG
Expand All @@ -36,31 +37,13 @@ def _parse_args() -> argparse.Namespace:
return parser.parse_args()


def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]]:
regions = []
in_region = False
for i, val in enumerate(np.isnan(arr)):
if val and not in_region:
start = i
in_region = True
elif not val and in_region:
if i - start >= min_length:
regions.append((start, i - 1))
in_region = False
if in_region and len(arr) - start >= min_length:
regions.append((start, len(arr) - 1))
return regions


def _format_detector_event(detector_name: str, result: object) -> str:
event_type = getattr(result, "event_type", "event")
score = getattr(result, "score", None)
score_suffix = "" if score in (None, 0, 0.0) else f" score={float(score):.2f}"

if event_type == "trend":
return (
f"{detector_name}: {result.direction} {result.start_minute}-{result.end_minute}{score_suffix}"
)
return f"{detector_name}: {result.direction} {result.start_minute}-{result.end_minute}{score_suffix}"
if event_type == "spike":
return f"{detector_name}: {event_type} @{result.spike_minute}{score_suffix}"
return f"{detector_name}: {event_type}{score_suffix}"
Expand All @@ -72,6 +55,8 @@ def _truncate(text: str, max_len: int = 34) -> str:


def _hit_target_label(hit_target: str) -> str:
if hit_target == "watch_wear":
return "watch non-wear"
if ":" in hit_target:
_, event_type = hit_target.split(":", 1)
return event_type.replace("_", " ")
Expand All @@ -93,13 +78,15 @@ def __init__(
self.annotator = Annotator([
StatisticalExtractor(channel_config),
StructuralExtractor(channel_config),
DeviceWearExtractor(channel_config),
])

self.row_index = min(max(0, row_index), len(self.dataset) - 1)
self.signal_index = min(max(0, signal_index), len(channel_config.names) - 1)

self.show_trends = True
self.show_spikes = True
self.show_watch_wear = True
self.detail_mode = "events"
self.details_scroll = 0
self.details_page_lines = 12
Expand Down Expand Up @@ -147,7 +134,7 @@ def __init__(
x0 = 0.805 + i * (width + 0.004)
ax = self.fig.add_axes([x0, 0.698, width, 0.038])
self.hit_target_buttons[detector_name] = Button(ax, _hit_target_label(detector_name))
overlay_labels = ["trend", "spike"]
overlay_labels = ["trend", "spike", "watch"]
self.overlay_buttons: dict[str, Button] = {}
start_x = 0.83
button_width = 0.035
Expand Down Expand Up @@ -208,15 +195,30 @@ def _available_detector_names(self) -> list[str]:
for detectors in self.channel_config.detectors.values()
for detector in detectors
}
names.add("watch_wear")
return sorted(names)

@staticmethod
def _matches_hit_target(hit_target: str, detector_name: str, result: object) -> bool:
if hit_target == "watch_wear":
return False
if ":" not in hit_target:
return detector_name == hit_target
target_detector_name, target_event_type = hit_target.split(":", 1)
return detector_name == target_detector_name and getattr(result, "event_type", None) == target_event_type

def _signal_has_hit_target(self, recording: Recording, signal_index: int, hit_target: str) -> bool:
if hit_target == "watch_wear":
return any(
annotation.caption_type == "watch_wear"
for annotation in recording.annotations_for_signal(signal_index)
)
row_signal_events = self._detector_events(recording.signal(signal_index))
return any(
self._matches_hit_target(hit_target, detector_name, result)
for detector_name, result in row_signal_events
)

def _set_row(self, row_index: int) -> None:
row_index = min(max(0, int(row_index)), len(self.dataset) - 1)
if row_index == self.row_index:
Expand Down Expand Up @@ -250,6 +252,8 @@ def _on_toggle(self, label: str) -> None:
self.show_trends = not self.show_trends
elif label == "spike":
self.show_spikes = not self.show_spikes
elif label == "watch":
self.show_watch_wear = not self.show_watch_wear
self._update_overlay_button_styles()
self.render(reset_zoom=False)

Expand Down Expand Up @@ -345,8 +349,8 @@ def _jump_to_hit(self, step: int) -> None:
for offset in range(1, n_rows * n_signals):
candidate = (flat_index + step * offset) % (n_rows * n_signals)
row_index, signal_index = divmod(candidate, n_signals)
row_signal_events = self._row_detector_events(row_index)[signal_index]
if any(self._matches_hit_target(self.hit_target, detector_name, result) for detector_name, result in row_signal_events):
recording = self._load_row_bundle(row_index)
if self._signal_has_hit_target(recording, signal_index, self.hit_target):
self.row_index = row_index
self.signal_index = signal_index
self.details_scroll = 0
Expand Down Expand Up @@ -385,6 +389,15 @@ def _overview_matrix(recording: Recording) -> np.ma.MaskedArray:
rows.append(normalized)
return np.ma.masked_invalid(np.vstack(rows))

@staticmethod
def _device_wear_windows(recording: Recording, signal_idx: int) -> list[tuple[int, int]]:
windows: list[tuple[int, int]] = []
for annotation in recording.annotations_for_signal(signal_idx):
if annotation.caption_type != "watch_wear" or annotation.window is None:
continue
windows.append(annotation.window)
return windows

def _style_widgets(self) -> None:
self.row_slider.label.set_visible(False)
self.row_slider.valtext.set_visible(False)
Expand Down Expand Up @@ -415,6 +428,7 @@ def _overlay_state(self, label: str) -> bool:
return {
"trend": self.show_trends,
"spike": self.show_spikes,
"watch": self.show_watch_wear,
}[label]

def _update_overlay_button_styles(self) -> None:
Expand Down Expand Up @@ -483,6 +497,7 @@ def render(self, reset_zoom: bool = False) -> None:
n_signals = recording.values.shape[0]
signal = recording.signal(self.signal_index)
detector_events = self._detector_events(signal)
wear_windows = self._device_wear_windows(recording, self.signal_index)
spike_labels = self._spike_labels(detector_events)
captions = self._captions_for_signal(recording, self.signal_index)
display_name = signal.display_name
Expand All @@ -506,6 +521,20 @@ def render(self, reset_zoom: bool = False) -> None:
ax.axis("off")

self.ax_main.plot(x[valid], y[valid], color="steelblue", linewidth=1.0, label="signal")
if self.show_watch_wear:
for start_minute, end_minute in wear_windows:
self.ax_main.axvspan(
start_minute,
end_minute,
ymin=0.82,
ymax=0.96,
facecolor="#7b3294",
edgecolor="#5c1f72",
linewidth=1.4,
alpha=0.22,
hatch="\\\\",
label="WatchNonWear",
)

for detector_name, result in detector_events:
if result.event_type == "trend" and self.show_trends:
Expand Down Expand Up @@ -587,8 +616,9 @@ def render(self, reset_zoom: bool = False) -> None:

caption_lines = []
for caption_type, values in captions.items():
display_type = "watch_nonwear" if caption_type == "watch_wear" else caption_type
for value in values[:3]:
caption_lines.append(f"{caption_type}: {value}")
caption_lines.append(f"{display_type}: {value}")
if not caption_lines:
caption_lines = ["No captions for this signal."]

Expand Down
2 changes: 1 addition & 1 deletion extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

DEFAULT_TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates.json"

VALID_CAPTION_TYPES = ("statistical", "structural", "semantic")
VALID_CAPTION_TYPES = ("statistical", "structural", "semantic", "watch_wear")


_ACTIVITY_RE = re.compile(r"HKWorkoutActivityType(.+)$")
Expand Down
96 changes: 96 additions & 0 deletions extractors/device_wear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations

import numpy as np

from extractors import CaptionExtractor, ChannelConfig
from timef.schema import Annotation, Recording


class DeviceWearExtractor(CaptionExtractor):
caption_type = "watch_wear"

def __init__(
self,
config: ChannelConfig,
min_overlapping_channels: int = 2,
min_nonwear_duration: int = 30,
):
super().__init__(config)
self.min_overlapping_channels = min_overlapping_channels
self.min_nonwear_duration = min_nonwear_duration

def extract(self, row: Recording) -> list[Annotation]:
group_channels = self.config.groups.get("watch_device")
if not group_channels:
return []

channel_indices = tuple(
idx for idx, name in enumerate(row.channel_names)
if name in group_channels
)
if len(channel_indices) < self.min_overlapping_channels:
return []

hr_index = _heart_rate_channel_index(row, channel_indices)
active_energy_index = _active_energy_channel_index(row, channel_indices)
if hr_index is None or active_energy_index is None:
return []

hr_gap_mask = _gap_mask(row.signal(hr_index).data)
active_energy_gap_mask = _gap_mask(row.signal(active_energy_index).data)
nonwear_windows = _mask_to_windows(
hr_gap_mask & active_energy_gap_mask,
min_duration=self.min_nonwear_duration,
)
return [
Annotation(
caption_type=self.caption_type,
text=f"Potential Apple Watch non-wear from minute {start_minute} to {end_minute}.",
channel_idxs=channel_indices,
window=(start_minute, end_minute),
label="watch_device",
)
for start_minute, end_minute in nonwear_windows
]


def _gap_mask(series: np.ndarray) -> np.ndarray:
arr = np.asarray(series, dtype=float)
return np.isnan(arr) | (arr == 0)


def _mask_to_windows(mask: np.ndarray, min_duration: int) -> list[tuple[int, int]]:
if mask.size == 0 or not mask.any():
return []

padded = np.concatenate(([False], mask, [False]))
diffs = np.diff(padded.astype(np.int8))
starts = np.where(diffs == 1)[0]
ends = np.where(diffs == -1)[0] - 1
return [
(int(start), int(end))
for start, end in zip(starts.tolist(), ends.tolist())
if end - start + 1 >= min_duration
]


def _heart_rate_channel_index(row: Recording, channel_indices: tuple[int, ...]) -> int | None:
for idx in channel_indices:
name = row.channel_names[idx]
if name.endswith("HKQuantityTypeIdentifierHeartRate") or "heart_rate" in name:
return idx
return None


def _active_energy_channel_index(row: Recording, channel_indices: tuple[int, ...]) -> int | None:
for idx in channel_indices:
name = row.channel_names[idx]
if name.endswith("HKQuantityTypeIdentifierActiveEnergyBurned") or "active_energy" in name:
return idx
return None
1 change: 0 additions & 1 deletion extractors/structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def extract(self, row: Recording) -> list[Annotation]:
j += 1

signal_events.sort(key=lambda item: item[0], reverse=True)

results.extend(annotation for _, annotation in signal_events)

return results
7 changes: 7 additions & 0 deletions mhc/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
]

SLEEP_CHANNELS = ["sleep:asleep", "sleep:inbed"]
WATCH_DEVICE_CHANNELS = [
"hk_watch:HKQuantityTypeIdentifierStepCount",
"hk_watch:HKQuantityTypeIdentifierDistanceWalkingRunning",
"hk_watch:HKQuantityTypeIdentifierHeartRate",
"hk_watch:HKQuantityTypeIdentifierActiveEnergyBurned",
]


MHC_CHANNEL_CONFIG = ChannelConfig(
Expand All @@ -67,6 +73,7 @@
groups={
"activity": frozenset(ACTIVITY_CHANNELS),
"sleep": frozenset(SLEEP_CHANNELS),
"watch_device": frozenset(WATCH_DEVICE_CHANNELS),
},
aggregators={"hk_watch:HKQuantityTypeIdentifierHeartRate": NonZeroAggregator()},
detectors={
Expand Down
3 changes: 2 additions & 1 deletion mhc_weekly/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aggregators import NonZeroAggregator
from detectors.spike import SpikeDetector
from detectors.trend import TrendDetector
from mhc.constants import ACTIVITY_CHANNELS, CHANNEL_NAMES, CONTINUOUS_CHANNELS, SLEEP_CHANNELS
from mhc.constants import ACTIVITY_CHANNELS, CHANNEL_NAMES, CONTINUOUS_CHANNELS, SLEEP_CHANNELS, WATCH_DEVICE_CHANNELS

HOURLY_TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates_hourly.json"

Expand All @@ -32,6 +32,7 @@
groups={
"activity": frozenset(ACTIVITY_CHANNELS),
"sleep": frozenset(SLEEP_CHANNELS),
"watch_device": frozenset(WATCH_DEVICE_CHANNELS),
},
templates_path=HOURLY_TEMPLATES_PATH,
time_unit="hours",
Expand Down
Loading
Loading