Skip to content

[BUG] MONAI Deploy App Stuck Due to Downstream Operator Not Receiving Data* #531

Open
@Omar-Faisal

Description

@Omar-Faisal

Issue: MONAI Deploy App Stuck Due to Downstream Operator Not Receiving Data

Description

I am working on deploying a Bone Age Prediction Model using MONAI Deploy App SDK, and I am encountering an issue where my MONAI pipeline is not passing data downstream correctly. Specifically, the application starts successfully, but execution is halted due to the "No receiver connected to transmitter" error. The execution graph fails to tick, leading to a deadlock.

Expected Behavior

The application should:

  1. Load a DICOM Study from the specified input path.
  2. Select the correct DICOM series using the DICOMSeriesSelectorOperator.
  3. Convert the series into a volume using DICOMSeriesToVolumeOperator.
  4. Run inference on the extracted image using DICOMBoneAgeOperator.
  5. Output the predicted bone age in months.

Actual Behavior

  • The pipeline initializes, and the operators are properly connected.

  • However, during execution, the error message appears:

    [error] [entity_executor.cpp:309] [E00025] No receiver connected to transmitter of DownstreamReceptiveSchedulingTerm 30 of entity "unnamed_operator_1". The entity will never tick.
    
  • The scheduler stops because the pipeline does not progress past the data loading stage:

    2025-03-12 08:01:17.596 INFO  gxf/std/greedy_scheduler.cpp@372: Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.
    
  • The DICOM loader outputs the correct ports, and the pipeline seems correctly connected, but no data flows through the pipeline.

Environment Details

  • Operating System: Ubuntu 22.04
  • Python Version: 3.8
  • MONAI Deploy SDK Version: Latest
  • Installed Dependencies:
    pip install gdcm pylibjpeg pylibjpeg-libjpeg pylibjpeg-openjpeg pylibjpeg-rle

Logs

🚀 Starting Bone Age Prediction App...
[info] [fragment.cpp:586] Loading extensions from configs...
[2025-03-12 08:01:17,396] [INFO] (root) - Parsed args: Namespace(argv=['app.py', '--input', '/home/omar/bone_age_deploy/monai_deploy_app/input', '--output', '/home/omar/bone_age_deploy/monai_deploy_app/output'], input=PosixPath('/home/omar/bone_age_deploy/monai_deploy_app/input'), log_level=None, model=None, output=PosixPath('/home/omar/bone_age_deploy/monai_deploy_app/output'), workdir=None)
[2025-03-12 08:01:17,397] [INFO] (BoneAgePredictionApp) - ✅ Output ports of dicom_loader: ['dicom_study_list']
[2025-03-12 08:01:17,397] [INFO] (BoneAgePredictionApp) - ✅ Output ports of series_selector: ['study_selected_series_list']
[2025-03-12 08:01:17,397] [INFO] (BoneAgePredictionApp) - ✅ Output ports of series_to_volume: ['image']
[2025-03-12 08:01:17,397] [INFO] (BoneAgePredictionApp) - ✅ Input ports of dicom_processor: ['image']
[2025-03-12 08:01:17,397] [INFO] (BoneAgePredictionApp) - ✅ Operators successfully connected.
[info] [gxf_executor.cpp:252] Creating context
[info] [gxf_executor.cpp:1974] Activating Graph...
[error] [entity_executor.cpp:309] [E00025] No receiver connected to transmitter of DownstreamReceptiveSchedulingTerm 30 of entity "unnamed_operator_1". The entity will never tick.

Full Code of app.py

import logging
from pathlib import Path
import sys
import os

# ✅ Ensure the inference model is accessible
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import torch
import numpy as np
import pydicom
from monai.deploy.conditions import CountCondition
from monai.deploy.core import Application, Operator
from monai.deploy.operators import (
    DICOMDataLoaderOperator,
    DICOMSeriesSelectorOperator,
    DICOMSeriesToVolumeOperator
)
from inference.inference import BoneAgeInference  # Import your inference model

# ------------------------------------------------------------------------------------------------------
# ✅ Configuration for the model
# ------------------------------------------------------------------------------------------------------
CONFIG = {
    "model_name": "microsoft/swin-large-patch4-window7-224",
    "feature_dim": 1024,
    "dropout": 0.3,
    "use_cuda": torch.cuda.is_available()
}

MODEL_WEIGHTS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "model", "best_model.pth"))

# Initialize the inference model
inference_model = BoneAgeInference(MODEL_WEIGHTS_PATH, CONFIG)

# ------------------------------------------------------------------------------------------------------
# ✅ Define Bone Age Inference Operator
# ------------------------------------------------------------------------------------------------------
class DICOMBoneAgeOperator(Operator):
    """Operator to process DICOM images and predict bone age."""

    def __init__(self, fragment):
        super().__init__(fragment)
        self.logger = logging.getLogger(type(self).__name__)

    def setup(self, spec):
        """Define input and output ports for MONAI Deploy."""
        spec.input("image")  # ✅ Ensuring correct input name
        spec.output("predicted_age")  # ✅ Output port for predicted bone age

    def compute(self, input_context, output_context, execution_context):
        """Process DICOM input and predict bone age."""
        self.logger.info("🔹 Receiving image for inference...")
        dicom_image = input_context.receive("image")  # ✅ Correct input

        if dicom_image is None:
            raise ValueError("❌ No image received for processing.")
        
        self.logger.info(f"✅ Received Image UID: {dicom_image.SOPInstanceUID}")

        # Extract gender from metadata
        gender = getattr(dicom_image, "PatientSex", "M")  # Default: Male
        self.logger.info(f"🔹 Extracted gender: {gender}")

        # Convert DICOM image to NumPy format
        pixel_array = dicom_image.pixel_array.astype(np.float32)
        image_np = np.uint8(255 * (pixel_array / np.max(pixel_array)))  # Normalize

        # Run inference
        self.logger.info("🔹 Running inference model...")
        predicted_age = inference_model.predict_from_numpy(image_np, gender)

        # ✅ Store prediction
        output_context.write("predicted_age", predicted_age)
        self.logger.info(f"✅ Predicted Bone Age: {predicted_age:.2f} months for Gender: {gender}")

# ------------------------------------------------------------------------------------------------------
# ✅ Define the MONAI Deploy Application
# ------------------------------------------------------------------------------------------------------
class BoneAgePredictionApp(Application):
    """MONAI Deploy App for bone age prediction from DICOM images."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = logging.getLogger(type(self).__name__)

    def run(self, *args, **kwargs):
        """Run the application with logging."""
        self.logger.info("🚀 Starting Bone Age Prediction App...")
        super().run(*args, **kwargs)
        self.logger.info("✅ App execution completed.")

    def compose(self):
        """Define MONAI Deploy pipeline."""
        # ✅ Initialize MONAI Deploy App Context
        app_context = Application.init_app_context(self.argv)

        # ✅ Set input/output paths
        app_input_path = Path(app_context.input_path)
        app_output_path = Path(app_context.output_path)

        self.logger.info(f"📂 Input Folder: {app_input_path}")
        self.logger.info(f"📂 Output Folder: {app_output_path}")

        # ✅ Define MONAI Deploy Operators
        dicom_loader = DICOMDataLoaderOperator(
            self, CountCondition(self, 1), input_folder=Path(app_input_path), force=True, name="dicom_loader_op"
        )
        series_selector = DICOMSeriesSelectorOperator(self, name="dicom_series_selector_op")
        series_to_volume = DICOMSeriesToVolumeOperator(self, name="dicom_series_to_volume_op")
        dicom_processor = DICOMBoneAgeOperator(self)

        # ✅ Add Operators
        self.add_operator(dicom_loader)
        self.add_operator(series_selector)
        self.add_operator(series_to_volume)
        self.add_operator(dicom_processor)

        # 🔹 DEBUGGING: Check if each operator is passing data correctly
        self.logger.info(f"✅ Output ports of dicom_loader: {list(dicom_loader.spec.outputs.keys())}")
        self.logger.info(f"✅ Output ports of series_selector: {list(series_selector.spec.outputs.keys())}")
        self.logger.info(f"✅ Output ports of series_to_volume: {list(series_to_volume.spec.outputs.keys())}")
        self.logger.info(f"✅ Input ports of dicom_processor: {list(dicom_processor.spec.inputs.keys())}")

        # ✅ Ensure correct connections between operators
        self.add_flow(dicom_loader, series_selector, {("dicom_study_list", "dicom_study_list")})
        self.add_flow(series_selector, series_to_volume, {("study_selected_series_list", "study_selected_series_list")})
        self.add_flow(series_to_volume, dicom_processor, {("image", "image")})  # ✅ Corrected port names

        self.logger.info("✅ Operators successfully connected.")


# ------------------------------------------------------------------------------------------------------
# ✅ Run the Application
# ------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(message)s")
    app = BoneAgePredictionApp()
    app.run()

Question

  • How can I resolve the deadlock error?
  • Could there be a missing output definition in one of the operators?

Any guidance would be greatly appreciated! 🚀

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions