Skip to content

Commit 8228cd2

Browse files
authored
Enhance progress reporting in inference (#2156)
* Enhance progress reporting in inference with elapsed time and MofN complete columns; add tracker final pass status messages * Improve final pass tracking with elapsed time reporting and enhance prediction queue handling to manage timeouts * Add some highly duplicated debugging logging and timeout logic * Fix inference dialog re-opening when no skeleton is present * Lint
1 parent ad7c563 commit 8228cd2

File tree

2 files changed

+147
-8
lines changed

2 files changed

+147
-8
lines changed

sleap/gui/app.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1569,7 +1569,10 @@ def _show_learning_dialog(self, mode: str):
15691569
# Update data in existing dialog widget.
15701570
self._child_windows[mode].labels = self.labels
15711571
self._child_windows[mode].labels_filename = self.state["filename"]
1572-
self._child_windows[mode].skeleton = self.labels.skeleton
1572+
try:
1573+
self._child_windows[mode].skeleton = self.labels.skeleton
1574+
except ValueError:
1575+
self._child_windows[mode].skeleton = None
15731576

15741577
self._child_windows[mode].update_file_lists()
15751578

sleap/nn/inference.py

+143-7
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import atexit
3434
import subprocess
3535
import rich.progress
36+
from queue import Empty
3637
import pandas as pd
3738
from rich.pretty import pprint
3839
from collections import deque
@@ -424,8 +425,11 @@ def process_batch(ex):
424425
"{task.description}",
425426
rich.progress.BarColumn(),
426427
"[progress.percentage]{task.percentage:>3.0f}%",
428+
rich.progress.MofNCompleteColumn(),
427429
"ETA:",
428430
rich.progress.TimeRemainingColumn(),
431+
"Elapsed:",
432+
rich.progress.TimeElapsedColumn(),
429433
RateColumn(),
430434
auto_refresh=False,
431435
refresh_per_second=self.report_rate,
@@ -1555,9 +1559,26 @@ def _make_labeled_frames_from_generator(
15551559
predicted_frames = []
15561560

15571561
def _object_builder():
1562+
n_timeouts = 0
15581563
while True:
1559-
ex = prediction_queue.get()
1564+
try:
1565+
# Get the next example from the queue.
1566+
ex = prediction_queue.get(timeout=10)
1567+
1568+
except Empty:
1569+
n_timeouts += 1
1570+
if n_timeouts >= 3:
1571+
# Too many timeouts, exit.
1572+
print(
1573+
"Timeout waiting for prediction queue, "
1574+
"exiting prediction loop."
1575+
)
1576+
break
1577+
continue
1578+
15601579
if ex is None:
1580+
# Poison pill, exit.
1581+
print("Got poison pill, exiting prediction loop.")
15611582
break
15621583

15631584
# Loop over frames.
@@ -1610,6 +1631,10 @@ def _object_builder():
16101631
prediction_queue.put(None)
16111632
object_builder.join()
16121633

1634+
print(
1635+
f"Finished building {len(predicted_frames):,} predicted frames.", flush=True
1636+
)
1637+
16131638
return predicted_frames
16141639

16151640
def export_model(
@@ -2617,9 +2642,26 @@ def _make_labeled_frames_from_generator(
26172642
predicted_frames = []
26182643

26192644
def _object_builder():
2645+
n_timeouts = 0
26202646
while True:
2621-
ex = prediction_queue.get()
2647+
try:
2648+
# Get the next example from the queue.
2649+
ex = prediction_queue.get(timeout=10)
2650+
2651+
except Empty:
2652+
n_timeouts += 1
2653+
if n_timeouts >= 3:
2654+
# Too many timeouts, exit.
2655+
print(
2656+
"Timeout waiting for prediction queue, "
2657+
"exiting prediction loop."
2658+
)
2659+
break
2660+
continue
2661+
26222662
if ex is None:
2663+
# Poison pill, exit.
2664+
print("Got poison pill, exiting prediction loop.")
26232665
break
26242666

26252667
if "n_valid" in ex:
@@ -2699,8 +2741,15 @@ def _object_builder():
26992741
prediction_queue.put(None)
27002742
object_builder.join()
27012743

2744+
print(
2745+
f"Finished building {len(predicted_frames):,} predicted frames.", flush=True
2746+
)
2747+
27022748
if self.tracker:
2749+
t0 = time()
2750+
print("Starting final pass of the tracker...", flush=True)
27032751
self.tracker.final_pass(predicted_frames)
2752+
print(f"Finished final pass of the tracker in {time() - t0:.2f} seconds.")
27042753

27052754
return predicted_frames
27062755

@@ -3253,9 +3302,26 @@ def _make_labeled_frames_from_generator(
32533302
predicted_frames = []
32543303

32553304
def _object_builder():
3305+
n_timeouts = 0
32563306
while True:
3257-
ex = prediction_queue.get()
3307+
try:
3308+
# Get the next example from the queue.
3309+
ex = prediction_queue.get(timeout=10)
3310+
3311+
except Empty:
3312+
n_timeouts += 1
3313+
if n_timeouts >= 3:
3314+
# Too many timeouts, exit.
3315+
print(
3316+
"Timeout waiting for prediction queue, "
3317+
"exiting prediction loop."
3318+
)
3319+
break
3320+
continue
3321+
32583322
if ex is None:
3323+
# Poison pill, exit.
3324+
print("Got poison pill, exiting prediction loop.")
32593325
break
32603326

32613327
if "n_valid" in ex:
@@ -3342,8 +3408,15 @@ def _object_builder():
33423408
prediction_queue.put(None)
33433409
object_builder.join()
33443410

3411+
print(
3412+
f"Finished building {len(predicted_frames):,} predicted frames.", flush=True
3413+
)
3414+
33453415
if self.tracker:
3416+
t0 = time()
3417+
print("Starting final pass of the tracker...", flush=True)
33463418
self.tracker.final_pass(predicted_frames)
3419+
print(f"Finished final pass of the tracker in {time() - t0:.2f} seconds.")
33473420

33483421
return predicted_frames
33493422

@@ -3792,10 +3865,27 @@ def _make_labeled_frames_from_generator(
37923865
predicted_frames = []
37933866

37943867
def _object_builder():
3868+
n_timeouts = 0
37953869
while True:
3796-
ex = prediction_queue.get()
3870+
try:
3871+
# Get the next example from the queue.
3872+
ex = prediction_queue.get(timeout=10)
3873+
3874+
except Empty:
3875+
n_timeouts += 1
3876+
if n_timeouts >= 3:
3877+
# Too many timeouts, exit.
3878+
print(
3879+
"Timeout waiting for prediction queue, "
3880+
"exiting prediction loop."
3881+
)
3882+
break
3883+
continue
3884+
37973885
if ex is None:
3798-
return
3886+
# Poison pill, exit.
3887+
print("Got poison pill, exiting prediction loop.")
3888+
break
37993889

38003890
# Loop over frames.
38013891
for image, video_ind, frame_ind, points, confidences, scores in zip(
@@ -3857,6 +3947,10 @@ def _object_builder():
38573947
prediction_queue.put(None)
38583948
object_builder.join()
38593949

3950+
print(
3951+
f"Finished building {len(predicted_frames):,} predicted frames.", flush=True
3952+
)
3953+
38603954
return predicted_frames
38613955

38623956

@@ -4499,9 +4593,26 @@ def _make_labeled_frames_from_generator(
44994593
predicted_frames = []
45004594

45014595
def _object_builder():
4596+
n_timeouts = 0
45024597
while True:
4503-
ex = prediction_queue.get()
4598+
try:
4599+
# Get the next example from the queue.
4600+
ex = prediction_queue.get(timeout=10)
4601+
4602+
except Empty:
4603+
n_timeouts += 1
4604+
if n_timeouts >= 3:
4605+
# Too many timeouts, exit.
4606+
print(
4607+
"Timeout waiting for prediction queue, "
4608+
"exiting prediction loop."
4609+
)
4610+
break
4611+
continue
4612+
45044613
if ex is None:
4614+
# Poison pill, exit.
4615+
print("Got poison pill, exiting prediction loop.")
45054616
break
45064617

45074618
# Loop over frames.
@@ -4573,6 +4684,10 @@ def _object_builder():
45734684
prediction_queue.put(None)
45744685
object_builder.join()
45754686

4687+
print(
4688+
f"Finished building {len(predicted_frames):,} predicted frames.", flush=True
4689+
)
4690+
45764691
return predicted_frames
45774692

45784693
def export_model(
@@ -4801,9 +4916,26 @@ def _make_labeled_frames_from_generator(
48014916
predicted_frames = []
48024917

48034918
def _object_builder():
4919+
n_timeouts = 0
48044920
while True:
4805-
ex = prediction_queue.get()
4921+
try:
4922+
# Get the next example from the queue.
4923+
ex = prediction_queue.get(timeout=10)
4924+
4925+
except Empty:
4926+
n_timeouts += 1
4927+
if n_timeouts >= 3:
4928+
# Too many timeouts, exit.
4929+
print(
4930+
"Timeout waiting for prediction queue, "
4931+
"exiting prediction loop."
4932+
)
4933+
break
4934+
continue
4935+
48064936
if ex is None:
4937+
# Poison pill, exit.
4938+
print("Got poison pill, exiting prediction loop.")
48074939
break
48084940

48094941
# Loop over frames.
@@ -4859,6 +4991,10 @@ def _object_builder():
48594991
prediction_queue.put(None)
48604992
object_builder.join()
48614993

4994+
print(
4995+
f"Finished building {len(predicted_frames):,} predicted frames.", flush=True
4996+
)
4997+
48624998
return predicted_frames
48634999

48645000

0 commit comments

Comments
 (0)