Skip to content

how to use yolov11s-seg supervision onnx runtime? #1787

@pranta-barua007

Description

@pranta-barua007

dear @onuralpszr i saw similar case on #1626 and tried some customization with my own usecase for segmentation but doesn't seem to properly working

here is how I am exporting my model with ultralytics

ft_loaded_best_model.export(
format="onnx",
nms=True,
data="/content/disease__instance_segmented/data.yaml",
) # creates 'best.onnx'

which outputs in console

Ultralytics 8.3.75 🚀 Python-3.11.11 torch-2.5.1+cu124 CPU (Intel Xeon 2.00GHz)
YOLO11s-seg summary (fused): 265 layers, 10,068,364 parameters, 0 gradients, 35.3 GFLOPs

PyTorch: starting from '/content/drive/MyDrive/ML/DENTAL_THESIS/fine_tuned/segment/train/weights/best.pt' with input shape (1, 3, 640, 640) BCHW and output shape(s) ((1, 300, 38), (1, 32, 160, 160)) (19.6 MB)

ONNX: starting export with onnx 1.17.0 opset 19...
ONNX: slimming with onnxslim 0.1.48...
ONNX: export success ✅ 4.2s, saved as '/content/drive/MyDrive/ML/DENTAL_THESIS/fine_tuned/segment/train/weights/best.onnx' (38.7 MB)

Export complete (5.5s)
Results saved to /content/drive/MyDrive/ML/DENTAL_THESIS/fine_tuned/segment/train/weights
Predict:         yolo predict task=segment model=/content/drive/MyDrive/ML/DENTAL_THESIS/fine_tuned/segment/train/weights/best.onnx imgsz=640  
Validate:        yolo val task=segment model=/content/drive/MyDrive/ML/DENTAL_THESIS/fine_tuned/segment/train/weights/best.onnx imgsz=640 data=/content/dental_disease__instance_segmented-7/data.yaml  
Visualize:       https://netron.app/
/content/drive/MyDrive/ML/DENTAL_THESIS/fine_tuned/segment/train/weights/best.onnx

I have 4 classes in my model
as I applied nms my output0 is already transposed I think
where first 4 indices are bbox. 5 is prob, 6 is class id 7 and rest 32 are mask and the 300 is for the model will detect up to 300 results, educate if my interpretation is wrong ?

here is my implementation

def xywh2xyxy(x):
    y = np.copy(x)
    y[..., 0] = x[..., 0] - x[..., 2] / 2
    y[..., 1] = x[..., 1] - x[..., 3] / 2
    y[..., 2] = x[..., 0] + x[..., 2] / 2
    y[..., 3] = x[..., 1] + x[..., 3] / 2
    return y


class YOLOv11:
    def __init__(self, path, conf_thres=0.7, iou_thres=0.5):
        self.conf_threshold = conf_thres
        self.iou_threshold = iou_thres
        # Initialize the ONNX model
        self.initialize_model(path)

    def __call__(self, image):
        return self.detect_objects(image)

    def initialize_model(self, path):
        self.session = onnxruntime.InferenceSession(
            path, providers=onnxruntime.get_available_providers()
        )
        self.get_input_details()
        self.get_output_details()

    def detect_objects(self, image):
        input_tensor = self.prepare_input(image)
        outputs = self.inference(input_tensor)
        self.boxes, self.scores, self.class_ids, self.masks = self.process_output(outputs)
        return self.boxes, self.scores, self.class_ids, self.masks

    def prepare_input(self, image):
        self.img_height, self.img_width = image.shape[:2]
        input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        input_img = cv2.resize(input_img, (self.input_width, self.input_height))
        input_img = input_img / 255.0
        input_img = input_img.transpose(2, 0, 1)
        input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
        return input_tensor

    def inference(self, input_tensor):
        outputs = self.session.run(self.output_names, {self.input_names[0]: input_tensor})
        return outputs

    def process_output(self, outputs):
        """
        Process model outputs:
          - outputs[0]: shape (1, 300, 38)
            * 0-3: bounding box (xywh)
            * 4: confidence score
            * 5: class id
            * 6-37: segmentation coefficients (32 values)
          - outputs[1]: shape (1, 38, 160, 160) mask prototypes
        """
        # Remove batch dimension from detections
        predictions = np.squeeze(outputs[0])  # shape (300, 38)
        mask_protos = outputs[1]               # shape (1, 38, 160, 160)

        # Filter predictions using the confidence score (index 4)
        conf_scores = predictions[:, 4]
        valid = conf_scores > self.conf_threshold
        predictions = predictions[valid]
        scores = conf_scores[valid]

        if len(scores) == 0:
            return [], [], [], []

        # Extract bounding boxes (indices 0-3)
        boxes = self.extract_boxes(predictions)

        # Extract class ids (index 5) and cast them to int
        class_ids = predictions[:, 5].astype(np.int32)

        # Extract segmentation masks using segmentation coefficients (indices 6-37)
        masks = self.extract_masks(predictions, mask_protos)

        return boxes, scores, class_ids, masks

    def extract_boxes(self, predictions):
        boxes = predictions[:, :4]  # xywh format
        boxes = self.rescale_boxes(boxes)
        boxes = xywh2xyxy(boxes)
        return boxes

    def rescale_boxes(self, boxes):
        # Scale boxes from network input dimensions to original image dimensions
        input_shape = np.array([self.input_width, self.input_height, self.input_width, self.input_height])
        boxes = np.divide(boxes, input_shape, dtype=np.float32)
        boxes *= np.array([self.img_width, self.img_height, self.img_width, self.img_height])
        return boxes

    def extract_masks(self, predictions, mask_protos):
        """
        Compute segmentation masks:
          - predictions: (num_detections, 38) with segmentation coefficients at indices 6-37
          - mask_protos: (1, 38, 160, 160); we use the first 32 channels to match coefficients.
        """
        # Get segmentation coefficients from predictions (32 coefficients)
        seg_coeffs = predictions[:, 6:38]  # shape: (num_detections, 32)

        # Use the first 32 channels from mask prototypes
        mask_protos = mask_protos[0, :32, :, :]  # shape: (32, 160, 160)

        # Compute per-detection masks as a weighted sum over mask prototypes
        masks = np.einsum('nc,chw->nhw', seg_coeffs, mask_protos)

        # Apply sigmoid to get values between 0 and 1
        masks = 1 / (1 + np.exp(-masks))

        # Threshold masks to produce binary masks
        masks = masks > 0.5

        # Resize each mask to the original image dimensions
        final_masks = []
        for mask in masks:
            mask_uint8 = (mask.astype(np.uint8)) * 255
            mask_resized = cv2.resize(mask_uint8, (self.img_width, self.img_height), interpolation=cv2.INTER_NEAREST)
            final_masks.append(mask_resized)
        final_masks = np.array(final_masks)

        return final_masks

    def get_input_details(self):
        model_inputs = self.session.get_inputs()
        self.input_names = [inp.name for inp in model_inputs]
        self.input_shape = model_inputs[0].shape
        self.input_height = self.input_shape[2]
        self.input_width = self.input_shape[3]

    def get_output_details(self):
        model_outputs = self.session.get_outputs()
        self.output_names = [out.name for out in model_outputs]

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions