From 6192186c7e74d7032d0d99d99f27e63ed09a778c Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 4 Jun 2024 17:41:36 +0530 Subject: [PATCH 1/2] add func --- optimum/amd/ryzenai/utils.py | 58 +++++++++++++++++++++++++++++ tests/brevitas/test_onnx_export.py | 2 +- tests/brevitas/test_quantization.py | 5 ++- 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/optimum/amd/ryzenai/utils.py b/optimum/amd/ryzenai/utils.py index 6c3bffae..bec2c418 100644 --- a/optimum/amd/ryzenai/utils.py +++ b/optimum/amd/ryzenai/utils.py @@ -2,11 +2,16 @@ # Licensed under the MIT License. +import logging import os +import random import onnxruntime as ort +from PIL import Image, ImageDraw, ImageFont +logger = logging.getLogger(__name__) + ONNX_WEIGHTS_NAME = "model.onnx" ONNX_WEIGHTS_NAME_STATIC = "model_static.onnx" @@ -25,3 +30,56 @@ def validate_provider_availability(provider: str): raise ValueError( f"Asked to use {provider} as an ONNX Runtime execution provider, but the available execution providers are {available_providers}." ) + + +def plot_bbox(image, detections, output_path="plot_bbox_output.png"): + """ + Plots labels and bounding boxes on an image. + + Args: + image_path (str): Path to the image. + detections (list): List of detections where each detection is a dictionary with keys 'label', 'bbox'. + The 'bbox' should be a list or tuple of the form [x_min, y_min, x_max, y_max]. + + Returns: + PIL.Image: Image with bounding boxes plotted. + """ + if isinstance(image, str): + image = Image.open(image) + + draw = ImageDraw.Draw(image) + font = ImageFont.load_default() + + # Generate a unique color for each label + colors = {} + + for detection in detections: + label = f"{detection['label']} {detection['score']:.2f}" + label_color_txt = f"{detection['label']}" + bbox = detection["box"] + + if label_color_txt not in colors: + colors[label_color_txt] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + + color = colors[label_color_txt] + + # Draw the bounding box + box = [bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"]] + draw.rectangle(box, outline=color, width=2) + + # Determine the text color (black or white) based on the brightness of the bounding box color + brightness = color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114 + text_color = (0, 0, 0) if brightness > 186 else (255, 255, 255) + + # Draw the label background + text_bbox = draw.textbbox((box[0], box[1]), label, font=font) + text_bg_bbox = [box[0], box[1] - (text_bbox[3] - text_bbox[1]), box[0] + (text_bbox[2] - text_bbox[0]), box[1]] + draw.rectangle(text_bg_bbox, fill=color) + + # Draw the label text + draw.text((box[0], box[1] - (text_bbox[3] - text_bbox[1])), label, fill=text_color, font=font) + + image.save(output_path) + logger.info(f"Image with bounding boxes saved to {output_path}") + + return image diff --git a/tests/brevitas/test_onnx_export.py b/tests/brevitas/test_onnx_export.py index 836d21b8..f92bdb32 100644 --- a/tests/brevitas/test_onnx_export.py +++ b/tests/brevitas/test_onnx_export.py @@ -9,11 +9,11 @@ import onnx import torch -from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from parameterized import parameterized from testing_utils import SUPPORTED_MODELS_TINY, VALIDATE_EXPORT_ON_SHAPES, get_quantized_model +from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from optimum.amd.brevitas.export import find_and_insert_matmulinteger from optimum.exporters import TasksManager from optimum.exporters.onnx import ( diff --git a/tests/brevitas/test_quantization.py b/tests/brevitas/test_quantization.py index e37992f1..91a81ec0 100644 --- a/tests/brevitas/test_quantization.py +++ b/tests/brevitas/test_quantization.py @@ -4,11 +4,12 @@ import unittest import torch -from brevitas.nn.quant_linear import QuantLinear -from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector, DynamicActQuantProxyFromInjector from parameterized import parameterized from testing_utils import SUPPORTED_MODELS_TINY, get_quantized_model +from brevitas.nn.quant_linear import QuantLinear +from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector, DynamicActQuantProxyFromInjector + def _get_all_model_ids(model_type: str): if isinstance(SUPPORTED_MODELS_TINY[model_type], str): From b0e2695f7f8ab0f3dededc12270d96787c731c6d Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 4 Jun 2024 17:59:23 +0530 Subject: [PATCH 2/2] fix style --- tests/brevitas/test_onnx_export.py | 2 +- tests/brevitas/test_quantization.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/brevitas/test_onnx_export.py b/tests/brevitas/test_onnx_export.py index f92bdb32..836d21b8 100644 --- a/tests/brevitas/test_onnx_export.py +++ b/tests/brevitas/test_onnx_export.py @@ -9,11 +9,11 @@ import onnx import torch +from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from parameterized import parameterized from testing_utils import SUPPORTED_MODELS_TINY, VALIDATE_EXPORT_ON_SHAPES, get_quantized_model -from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from optimum.amd.brevitas.export import find_and_insert_matmulinteger from optimum.exporters import TasksManager from optimum.exporters.onnx import ( diff --git a/tests/brevitas/test_quantization.py b/tests/brevitas/test_quantization.py index 91a81ec0..e37992f1 100644 --- a/tests/brevitas/test_quantization.py +++ b/tests/brevitas/test_quantization.py @@ -4,11 +4,10 @@ import unittest import torch -from parameterized import parameterized -from testing_utils import SUPPORTED_MODELS_TINY, get_quantized_model - from brevitas.nn.quant_linear import QuantLinear from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector, DynamicActQuantProxyFromInjector +from parameterized import parameterized +from testing_utils import SUPPORTED_MODELS_TINY, get_quantized_model def _get_all_model_ids(model_type: str):