Skip to content

Commit 3996556

Browse files
committed
reviewed changes and new case handling for error at start
1 parent eb317a8 commit 3996556

1 file changed

Lines changed: 20 additions & 18 deletions

File tree

  • src/graphomotor/features/trails

src/graphomotor/features/trails/time.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,44 +7,46 @@ def calculate_total_error_time(drawing: models.Drawing) -> dict[str, float]:
77
"""Calculate the total time spent making errors.
88
99
A contiguous "error chunk" is any sequence of rows where df["error"] != "E0".
10-
For each chunk, we find the midpoint time when the error started and the midpoint
11-
time when the error ended. The total error time is the sum of the durations of all
12-
error chunks.
10+
The start of each chunk is defined as the midpoint between the last timestamp with
11+
a "correct" entry and the first timestamp of an "error". The total error time is
12+
the sum of the durations of all error chunks.
1313
1414
Args:
1515
drawing: Drawing object containing drawing data.
1616
1717
Returns:
18-
Dictionary containing the total time spent in error states.
18+
Dictionary containing the total time (s) spent in error states.
1919
"""
2020
mask = drawing.data["error"] != "E0"
2121
if not mask.any():
2222
return {"total_error_time": 0.0}
2323

24-
chunk_start = (~mask.shift(fill_value=False) & mask).to_numpy().nonzero()[0]
25-
chunk_end = (mask.shift(fill_value=False) & ~mask).to_numpy().nonzero()[0]
24+
error_change = mask.astype(int).diff()
25+
chunk_starts = error_change[error_change == 1].index.tolist()
26+
chunk_ends = error_change[error_change == -1].index.tolist()
27+
28+
if mask.iloc[0]:
29+
chunk_starts = [0] + chunk_starts
2630

2731
if mask.iloc[-1]:
28-
chunk_end = list(chunk_end) + [len(drawing.data) - 1]
32+
chunk_ends = chunk_ends + [len(drawing.data)]
2933

3034
seconds = drawing.data["seconds"].to_numpy()
3135
total_error_time = 0.0
3236

33-
for start_idx, end_idx in zip(chunk_start, chunk_end):
34-
start_mid = (
37+
for start_idx, end_idx in zip(chunk_starts, chunk_ends):
38+
start_time = (
3539
(seconds[start_idx - 1] + seconds[start_idx]) / 2
3640
if start_idx > 0
3741
else seconds[0]
3842
)
3943

40-
if end_idx + 1 < len(drawing.data):
41-
end_mid = (seconds[end_idx] + seconds[end_idx - 1]) / 2
42-
else:
43-
if mask.iloc[end_idx]:
44-
end_mid = seconds[end_idx]
45-
else:
46-
end_mid = (seconds[end_idx] + seconds[end_idx - 1]) / 2
47-
print(start_mid, end_mid)
48-
total_error_time += end_mid - start_mid
44+
end_time = (
45+
(seconds[end_idx - 1] + seconds[end_idx]) / 2
46+
if end_idx < len(seconds)
47+
else seconds[-1]
48+
)
49+
50+
total_error_time += end_time - start_time
4951

5052
return {"total_error_time": float(total_error_time)}

0 commit comments

Comments
 (0)