Skip to content

Commit 5788815

Browse files
authored
Merge pull request #641 from czbiohub-sf/multi-head-classification-ssaf
Multi head classification ssaf
2 parents 513dc71 + 8c50f09 commit 5788815

File tree

15 files changed

+5558
-49
lines changed

15 files changed

+5558
-49
lines changed

ulc_mm_package/QtGUI/oracle.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
66
"""
77

8-
import os
9-
import sys
10-
import traceback
11-
import socket
128
import enum
139
import logging
10+
import os
1411
import subprocess
12+
import socket
13+
import sys
14+
import traceback
15+
from typing import Optional
1516

1617
from os import (
1718
listdir,
@@ -70,11 +71,13 @@
7071
from ulc_mm_package.neural_nets.neural_network_constants import (
7172
AUTOFOCUS_MODEL_DIR,
7273
YOGO_MODEL_DIR,
74+
QC_STATUS,
7375
)
7476

7577
from ulc_mm_package.QtGUI.scope_op import ScopeOp
7678
from ulc_mm_package.QtGUI.form_gui import FormGUI
7779
from ulc_mm_package.QtGUI.liveview_gui import LiveviewGUI
80+
from PyQt5.QtWidgets import QPushButton
7881

7982

8083
QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True)
@@ -762,11 +765,86 @@ def _start_liveview(self, *args):
762765
def _end_liveview(self, *args):
763766
self.liveview_window.close()
764767

765-
def _start_intermission(self, msg=None, parasitemia_vis_path=""):
768+
def _start_intermission(
769+
self,
770+
msg=None,
771+
parasitemia_vis_path="",
772+
run_qc_status: Optional[int] = None,
773+
):
766774
if msg is None:
767775
# Retriggered intermission due to race condition
768776
return
769777

778+
# Display QC results if available
779+
# An Enum would be great but `pyqtsignal` on PyQt5 does not support Enums
780+
# and a workaround would be uglier
781+
if run_qc_status is not None:
782+
subsample_dir = self.scopeop.mscope.data_storage.get_subsample_folder_path()
783+
if not os.path.exists(subsample_dir):
784+
self.logger.error(
785+
f"Subsample images directory does not exist: {subsample_dir}"
786+
)
787+
subsample_dir = None
788+
if run_qc_status == QC_STATUS.GOOD.value:
789+
# Add a custom button labeled
790+
investigate_btn = None
791+
792+
msg_box = NoCloseMessageBox()
793+
msg_box.setWindowIcon(QIcon(ICON_PATH))
794+
msg_box.setIcon(QMessageBox.Icon.Information)
795+
msg_box.setWindowTitle("Run Quality: GOOD")
796+
msg_box.setText("✅ The run quality is GOOD.\n\n")
797+
798+
# Ok button
799+
msg_box.addButton(QMessageBox.Ok)
800+
801+
# Add button to open subsample images
802+
if subsample_dir is not None:
803+
investigate_btn = QPushButton("View subsample images")
804+
msg_box.addButton(investigate_btn, QMessageBox.ActionRole)
805+
msg_box.exec()
806+
if msg_box.clickedButton() == investigate_btn:
807+
# Path to the subsample images directory
808+
subsample_dir = (
809+
self.scopeop.mscope.data_storage.get_subsample_folder_path()
810+
)
811+
if os.name == "posix":
812+
subprocess.call(["xdg-open", subsample_dir])
813+
814+
elif run_qc_status == QC_STATUS.POOR.value:
815+
investigate_btn = None
816+
817+
msg_box = NoCloseMessageBox()
818+
msg_box.setWindowIcon(QIcon(ICON_PATH))
819+
msg_box.setIcon(QMessageBox.Icon.Critical)
820+
msg_box.setWindowTitle("Run Quality: POOR")
821+
msg_box.setText(
822+
"❌ The run quality is POOR.\n\nPlease RE-RUN THE SAMPLE to ensure valid results."
823+
)
824+
msg_box.setDetailedText(
825+
(
826+
"Open the run's subsample images to investigate its quality. The QC result is usually indicative of image quality, however "
827+
"in cases where the images are unusual for another reason (anemia, SCD, other hemoglobinopathies), the QC result may not be accurate."
828+
"If the images appear in-focus (i.e NOT blurry), and the cells are NOT coagulated, then you do not need to re-run the sample. If you are unsure, please re-run the sample."
829+
)
830+
)
831+
832+
# Ok button
833+
msg_box.addButton(QMessageBox.Ok)
834+
835+
# Add button to open subsample images
836+
if subsample_dir is not None:
837+
investigate_btn = QPushButton("View subsample images")
838+
msg_box.addButton(investigate_btn, QMessageBox.ActionRole)
839+
msg_box.exec()
840+
if msg_box.clickedButton() == investigate_btn:
841+
if os.name == "posix":
842+
subprocess.call(["xdg-open", subsample_dir])
843+
else:
844+
raise ValueError(
845+
f"Invalid run_qc_status: {run_qc_status}. Expected 'good' or 'poor'."
846+
)
847+
770848
self.display_message(
771849
QMessageBox.Icon.Information,
772850
"Run status",

ulc_mm_package/QtGUI/scope_op.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
66
"""
77

8-
import cv2
98
import logging
10-
import numpy as np
119

1210
from typing import Any
1311
from time import sleep, perf_counter
14-
from transitions import Machine, State
1512

13+
import cv2
14+
import numpy as np
1615
from PyQt5.QtCore import QObject, pyqtSignal, pyqtSlot
16+
from transitions import Machine, State
1717

1818
from ulc_mm_package.hardware.scope import MalariaScope, GPIOEdge
1919
from ulc_mm_package.hardware.scope_routines import Routines
@@ -38,7 +38,12 @@
3838
LEDNoPower,
3939
)
4040

41-
from ulc_mm_package.neural_nets.neural_network_constants import IMG_RESIZED_DIMS
41+
from ulc_mm_package.neural_nets.neural_network_constants import (
42+
IMG_RESIZED_DIMS,
43+
QC_GOODNESS_THRESHOLD,
44+
QC_STATUS,
45+
PERC_OF_IMAGES_GOOD,
46+
)
4247
from ulc_mm_package.neural_nets.YOGOInference import YOGO, ClassCountResult
4348
from ulc_mm_package.neural_nets.neural_network_constants import (
4449
YOGO_CLASS_LIST,
@@ -84,7 +89,7 @@ class NamedMachine(Machine):
8489

8590
class ScopeOp(QObject, NamedMachine):
8691
setup_done = pyqtSignal()
87-
experiment_done = pyqtSignal(str, str)
92+
experiment_done = pyqtSignal(str, str, int)
8893
reset_done = pyqtSignal()
8994

9095
yield_mscope = pyqtSignal(MalariaScope)
@@ -301,6 +306,34 @@ def update_thumbnails(self):
301306
)
302307
)
303308

309+
def run_status_from_qc_results(self, qc_results: np.ndarray) -> QC_STATUS:
310+
"""Logic for determining whether a run was decent based on the QC results at the end of a run.
311+
312+
Parameters
313+
----------
314+
qc_results : np.ndarray
315+
Array of QC results from the QC model.
316+
317+
Returns
318+
-------
319+
QC_STATUS
320+
Status of the run based on the QC results.
321+
- "good" if X% of results are below the QC_GOODNESS_THRESHOLD (i.e considered 'good')
322+
- "poor" if Y% of results are above the threshold
323+
"""
324+
325+
num_good = (qc_results <= QC_GOODNESS_THRESHOLD).sum()
326+
num_total = len(qc_results)
327+
328+
if num_total == 0:
329+
self.logger.warning("No QC results available. Cannot determine run status.")
330+
raise ValueError("Run status cannot be determined without QC results.")
331+
332+
if num_good / num_total >= PERC_OF_IMAGES_GOOD:
333+
return QC_STATUS.GOOD
334+
else:
335+
return QC_STATUS.POOR
336+
304337
def setup(self):
305338
self.create_timers.emit()
306339

@@ -559,6 +592,48 @@ def _end_experiment(self, *args):
559592

560593
self.mscope.reset_for_end_experiment()
561594

595+
# Run the QC model a small partition of the data
596+
self.logger.info("Running QC on images.")
597+
qc_results = None
598+
try:
599+
zf = self.mscope.data_storage.get_read_only_zarr()
600+
601+
try:
602+
img_indices = np.linspace(0, zf.initialized - 1, 50).astype(int)
603+
for idx in img_indices:
604+
img = zf[:, :, idx]
605+
self.mscope.qc.asyn(img)
606+
qc_results = self.mscope.qc.get_asyn_results(timeout=None)
607+
qc_results = [self.mscope.qc._sigmoid(x.result) for x in qc_results]
608+
except Exception as e:
609+
self.logger.error(
610+
f"Unexpected error while submitting images to QC model: {e}. Skipping QC..."
611+
)
612+
except Exception as e:
613+
self.logger.error(f"Failed to get zarr data for QC: {e}.\nSkipping QC...")
614+
615+
# Log QC results
616+
if qc_results:
617+
self.did_run_pass_qc = None
618+
qc_results_np = np.array([x[0][0] for x in qc_results])
619+
num_qc_results_good = (qc_results_np <= QC_GOODNESS_THRESHOLD).sum()
620+
num_imgs_qc = len(qc_results_np)
621+
self.logger.info(
622+
f"QC all results: {qc_results_np}\n"
623+
f"QC mean: {qc_results_np.mean():.3f}, "
624+
f"QC stdev: {qc_results_np.std():.3f}, "
625+
f"QC best score: {qc_results_np.min():.3f}, "
626+
f"QC worst score: {qc_results_np.max():.3f}, "
627+
f"QC number of images good: {num_qc_results_good}/{num_imgs_qc} ({num_qc_results_good/num_imgs_qc:.2%})%"
628+
)
629+
self.did_run_pass_qc = self.run_status_from_qc_results(qc_results_np).value
630+
else:
631+
self.logger.warning("No QC results available. Skipping QC...")
632+
633+
# Save qc results
634+
self.mscope.data_storage.save_qc_data(img_indices, qc_results_np)
635+
self.finishing_experiment.emit(80)
636+
562637
# Turn camera back on
563638
self.mscope.camera.startAcquisition()
564639

@@ -577,12 +652,15 @@ def _end_experiment(self, *args):
577652
def _start_intermission(self, msg):
578653
parasitemia_vis_path = self.mscope.data_storage.get_parasitemia_vis_filename()
579654

655+
# Display parasitemia visualization if it exists
580656
if parasitemia_vis_path.exists():
581657
self.experiment_done.emit(
582-
msg + PARASITEMIA_VIS_MSG, str(parasitemia_vis_path)
658+
msg + PARASITEMIA_VIS_MSG,
659+
str(parasitemia_vis_path),
660+
self.did_run_pass_qc,
583661
)
584662
else:
585-
self.experiment_done.emit(msg, "")
663+
self.experiment_done.emit(msg, "", self.did_run_pass_qc)
586664

587665
@pyqtSlot(np.ndarray, float)
588666
def run_autobrightness(self, img, _timestamp):
@@ -701,10 +779,7 @@ def run_autofocus(self, img, _timestamp):
701779

702780
if not self.autofocus_done:
703781
if len(self.autofocus_batch) < AF_BATCH_SIZE:
704-
resized_img = cv2.resize(
705-
img, IMG_RESIZED_DIMS, interpolation=cv2.INTER_CUBIC
706-
)
707-
self.autofocus_batch.append(resized_img)
782+
self.autofocus_batch.append(img)
708783

709784
if self.running:
710785
self.img_signal.connect(self.run_autofocus)
@@ -931,7 +1006,7 @@ def run_experiment(self, img, timestamp) -> None:
9311006
raw_focus_err,
9321007
filtered_focus_err,
9331008
focus_adjustment,
934-
) = self.PSSAF_routine.send(resized_img)
1009+
) = self.PSSAF_routine.send(img)
9351010
except MotorControllerError as e:
9361011
if not SIMULATION:
9371012
self.logger.error(

ulc_mm_package/hardware/scope.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
The malaria scope object, containing all the different hardware
3-
periperhals which make up the malaria scope.
3+
peripherals which make up the malaria scope.
44
55
Components
66
- Motorcontroller (DRV8825 Nema)
@@ -50,6 +50,7 @@
5050
from ulc_mm_package.image_processing.flow_control import FlowController
5151
from ulc_mm_package.neural_nets.YOGOInference import YOGO
5252
from ulc_mm_package.neural_nets.AutofocusInference import AutoFocus
53+
from ulc_mm_package.neural_nets.QCInference import QC
5354
from ulc_mm_package.neural_nets.NCSModel import GPUError
5455
from ulc_mm_package.neural_nets.predictions_handler import PredictionsHandler
5556

@@ -129,7 +130,7 @@ def reset_pneumatic_and_led_and_flow_control(self) -> None:
129130

130131
def reset_for_end_experiment(self) -> None:
131132
"""
132-
Reset syringe, turn LED off, reset flow control, reset YOGO / Autofoucs,
133+
Reset syringe, turn LED off, reset flow control, reset YOGO / Autofocus,
133134
and close data storage.
134135
"""
135136

@@ -322,6 +323,7 @@ def _init_GPU(self):
322323
self.logger.info("Initializing GPU...")
323324
self.autofocus_model = AutoFocus()
324325
self.cell_diagnosis_model = YOGO()
326+
self.qc = QC()
325327
self.gpu_enabled = True
326328
except GPUError as e:
327329
self.logger.error(f"GPU initialization failed. {e}")

ulc_mm_package/hardware/scope_routines.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,16 @@ def singleShotAutofocusRoutine(
7676
The number of steps that the motor was moved.
7777
"""
7878

79-
ssaf_steps_from_focus = mscope.autofocus_model(img_arr)
80-
steps_to_move = -round(np.mean(ssaf_steps_from_focus))
79+
res = mscope.autofocus_model(img_arr)
80+
81+
# The indices of the first output correspond to below/in/above focus.
82+
# We subtract one so that it maps to -1, 0, +1, and then subsequently take the element-wise product with the
83+
# magnitude vector to get the number of steps away from focus.
84+
above_or_below_vec = np.array([np.argmax(r[0][0]) - 1 for r in res])
85+
magnitude_vec = np.array([np.argmax(r[1][0]) for r in res])
86+
steps_vec = above_or_below_vec * magnitude_vec
87+
88+
steps_to_move = -round(np.mean(steps_vec))
8189

8290
try:
8391
dir = Direction.CW if steps_to_move > 0 else Direction.CCW
@@ -147,11 +155,21 @@ def periodicAutofocusWrapper(
147155
mscope.autofocus_model.asyn(img, img_counter)
148156
results = mscope.autofocus_model.get_asyn_results(timeout=0.005) or []
149157

150-
for res in sorted(results, key=lambda res: res.id):
158+
for res in sorted(results, key=lambda res: res[0].id):
151159
move_counter += 1
152-
153-
steps_from_focus = res.result.item()
154-
filtered_error = ssaf_filter.update_and_get_val(steps_from_focus)
160+
direction_asyn_result = res[0]
161+
mag_asyn_result = res[1]
162+
163+
direction = (
164+
np.argmax(direction_asyn_result.result) - 1
165+
) # Subtract 1 so that it maps to -1, 0, +1
166+
magnitude = np.argmax(mag_asyn_result.result)
167+
mag_conf = mag_asyn_result.result[0][magnitude]
168+
if mag_conf >= nn_constants.MAG_CONF_THRESH:
169+
steps_from_focus = direction * magnitude
170+
filtered_error = ssaf_filter.update_and_get_val(
171+
steps_from_focus
172+
)
155173

156174
throttle_counter = 0
157175

0 commit comments

Comments
 (0)