diff --git a/tools/accuracy_checker/accuracy_checker/adapters/README.md b/tools/accuracy_checker/accuracy_checker/adapters/README.md index affd3350834..5bf7b631581 100644 --- a/tools/accuracy_checker/accuracy_checker/adapters/README.md +++ b/tools/accuracy_checker/accuracy_checker/adapters/README.md @@ -89,6 +89,11 @@ AccuracyChecker supports following set of adapters: * `output_name` - name of output layer (Optional). * `threshold` - minimal objectness score value for valid detections (Optional, default 0.001). * `num` - num parameter from DarkNet configuration file (Optional, default 5). +* `yoloxs` - converting output of YOLOX model to `DetectionPrediction` representation. + * `score_threshold` - minimal accepted score for valid detections (Optional, default 0.001). + * `box_format_xywh` - enabling additional preprocessing when box format is xywh (default `False`). + * `boxes_out` - name of output layer with boxes(Optional, default `boxes`). + * `labels_out` - name of output layer with labels(Optional, default `labels`). * `yolo_v8_detection` - converting output of YOLO v8 family pretrained for object detection to `DetectionPrediction`. * `conf_threshold` - minimal confidence for filtering valid detections (Optional, default 0.25). * `multi_label` - allow to use multiple labels for the same box coordinates (Optional, default True). diff --git a/tools/accuracy_checker/accuracy_checker/adapters/__init__.py b/tools/accuracy_checker/accuracy_checker/adapters/__init__.py index d1814aeac5a..3526e3945f1 100644 --- a/tools/accuracy_checker/accuracy_checker/adapters/__init__.py +++ b/tools/accuracy_checker/accuracy_checker/adapters/__init__.py @@ -85,6 +85,7 @@ YoloV5Adapter, YolorAdapter, YoloxAdapter, + YoloxsAdapter, YolofAdapter, # for adapter registration, it should be imported and added to __all__ list YoloV8DetectionAdapter @@ -184,6 +185,7 @@ 'YoloV5Adapter', 'YolorAdapter', 'YoloxAdapter', + 'YoloxsAdapter', 'YolofAdapter', 'YoloV8DetectionAdapter', diff --git a/tools/accuracy_checker/accuracy_checker/adapters/yolo.py b/tools/accuracy_checker/accuracy_checker/adapters/yolo.py index 0b7392e723b..93df767d938 100644 --- a/tools/accuracy_checker/accuracy_checker/adapters/yolo.py +++ b/tools/accuracy_checker/accuracy_checker/adapters/yolo.py @@ -798,6 +798,74 @@ def xywh2xyxy(x): return y +class YoloxsAdapter(Adapter): + __provider__ = 'yoloxs' + prediction_types = (DetectionPrediction, ) + + @classmethod + def parameters(cls): + parameters = super().parameters() + parameters.update({ + 'score_threshold': NumberField(value_type=float, optional=True, min_value=0, default=0.001, + description="Minimal accepted score value for valid detections."), + 'box_format_xywh': BoolField(optional=True, default=False, + description="Indicates that box output format is xywh."), + 'boxes_out': StringField(optional=True, default='boxes', description="Boxes output layer name."), + 'labels_out': StringField(optional=True, default='labels', description="Labels output layer name."), + }) + return parameters + + def configure(self): + self.score_threshold = self.get_value_from_config('score_threshold') + self.box_format_xywh = self.get_value_from_config('box_format_xywh') + self.boxes_out = self.get_value_from_config('boxes_out') + self.labels_out = self.get_value_from_config('labels_out') + + def process(self, raw, identifiers, frame_meta): + result = [] + raw_outputs = self._extract_predictions(raw, frame_meta) + + num_classes = 0 + x_mins, y_mins, x_maxs, y_maxs = [], [], [], [] + + for identifier, meta in zip(identifiers, frame_meta): + if len(self.additional_output_mapping) > 0: + boxes = np.array(raw_outputs[self.additional_output_mapping[self.boxes_out]]).squeeze() + labels = np.array(raw_outputs[self.additional_output_mapping[self.labels_out]]).squeeze() + if not labels.shape: + result.append(DetectionPrediction(identifier, [], [], [], [], [], [], meta)) + continue + scores = boxes[:, 4] + boxes = boxes[:, :4] + else: + output = np.array(raw_outputs[self.output_blob]) + num_classes = output.shape[1] - 5 + labels = np.argmax(output[:, 5: 5 + num_classes], axis=1) + scores = output[:, 4] + boxes = output[:, :4] + + if num_classes > 0: + class_max_confidences = np.max(output[:, 5: 5 + num_classes], axis=1) + scores *= class_max_confidences + + if self.box_format_xywh: + boxes = xywh2xyxy(boxes) + + mask = scores > self.score_threshold + scores = scores[mask] + labels = labels[mask] + boxes = boxes[mask] + + image_resize_ratio = meta['scale_x'] + if boxes.size > 0 and image_resize_ratio > 0: + x_mins, y_mins, x_maxs, y_maxs = boxes.T / image_resize_ratio + + result.append(DetectionPrediction( + identifier, labels, scores, x_mins, y_mins, x_maxs, y_maxs, meta + )) + return result + + class YoloV8DetectionAdapter(Adapter): """ class adapter for yolov8, yolov8 support multiple tasks, this class for object detection. diff --git a/tools/accuracy_checker/accuracy_checker/annotation_converters/ms_coco.py b/tools/accuracy_checker/accuracy_checker/annotation_converters/ms_coco.py index b6f9ad6feef..ac244e9bd82 100644 --- a/tools/accuracy_checker/accuracy_checker/annotation_converters/ms_coco.py +++ b/tools/accuracy_checker/accuracy_checker/annotation_converters/ms_coco.py @@ -206,9 +206,10 @@ def _create_representations( image_full_path = self.images_dir / image[1] if not check_file_existence(image_full_path): content_errors.append('{}: does not exist'.format(image_full_path)) - detection_annotation = DetectionAnnotation(image[1], image_labels, xmins, ymins, xmaxs, ymaxs) - detection_annotation.metadata['iscrowd'] = is_crowd - detection_annotations.append(detection_annotation) + if image_labels != []: + detection_annotation = DetectionAnnotation(image[1], image_labels, xmins, ymins, xmaxs, ymaxs) + detection_annotation.metadata['iscrowd'] = is_crowd + detection_annotations.append(detection_annotation) progress_reporter.update(image_id, 1) progress_reporter.finish() diff --git a/tools/accuracy_checker/accuracy_checker/config/config_validator.py b/tools/accuracy_checker/accuracy_checker/config/config_validator.py index 576287de60f..e7c7df167a4 100644 --- a/tools/accuracy_checker/accuracy_checker/config/config_validator.py +++ b/tools/accuracy_checker/accuracy_checker/config/config_validator.py @@ -84,7 +84,7 @@ class ConfigValidator(BaseValidator): WARN_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.WARN ERROR_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.ERROR IGNORE_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.IGNORE - acceptable_unknown_options = ['connector', '_command_line_mapping'] + acceptable_unknown_options = ['connector', '_command_line_mapping', 'model'] def __init__(self, config_uri, on_extra_argument=WARN_ON_EXTRA_ARGUMENT, fields=None, **kwargs): super().__init__(**kwargs) diff --git a/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py b/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py index fb413a323fe..d8e6865bdd7 100644 --- a/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py +++ b/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py @@ -17,6 +17,8 @@ from contextlib import contextmanager import sys import importlib +import urllib +import re from collections import OrderedDict import numpy as np @@ -25,6 +27,7 @@ MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)' DEVICE_REGEX = r'(?Pcpu$|cuda)?' +CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$' class PyTorchLauncher(Launcher): @@ -38,6 +41,10 @@ def parameters(cls): 'checkpoint': PathField( check_exists=True, is_directory=False, optional=True, description='pre-trained model checkpoint' ), + 'checkpoint_url': StringField( + optional=True, regex=CHECKPOINT_URL_REGEX, description='Url link to pre-trained model checkpoint.' + ), + 'state_key': StringField(optional=True, regex=r'\w+', description='pre-trained model checkpoint state key'), 'python_path': PathField( check_exists=True, is_directory=True, optional=True, description='appendix for PYTHONPATH for making network module visible in current python environment' @@ -47,6 +54,9 @@ def parameters(cls): key_type=str, validate_values=False, optional=True, default={}, description='keyword arguments for network module' ), + 'init_method': StringField( + optional=True, regex=r'\w+', description='Method name to be called for module initialization.' + ), 'device': StringField(default='cpu', regex=DEVICE_REGEX), 'batch': NumberField(value_type=int, min_value=1, optional=True, description="Batch size.", default=1), 'output_names': ListField( @@ -79,13 +89,17 @@ def __init__(self, config_entry: dict, *args, **kwargs): module_kwargs = config_entry.get("module_kwargs", {}) self.device = self.get_value_from_config('device') self.cuda = 'cuda' in self.device + checkpoint = config_entry.get('checkpoint') + if checkpoint is None: + checkpoint = config_entry.get('checkpoint_url') self.module = self.load_module( config_entry['module'], module_args, module_kwargs, - config_entry.get('checkpoint'), + checkpoint, config_entry.get('state_key'), - config_entry.get("python_path") + config_entry.get("python_path"), + config_entry.get("init_method") ) self._batch = self.get_value_from_config('batch') @@ -115,14 +129,25 @@ def batch(self): def output_blob(self): return next(iter(self.output_names)) - def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, state_key=None, python_path=None): + def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, state_key=None, python_path=None, + init_method=None + ): module_parts = model_cls.split(".") model_cls = module_parts[-1] model_path = ".".join(module_parts[:-1]) with append_to_path(python_path): model_cls = importlib.import_module(model_path).__getattribute__(model_cls) module = model_cls(*module_args, **module_kwargs) + if init_method is not None: + if hasattr(model_cls, init_method): + init_method = getattr(module, init_method) + module = init_method() + else: + raise ValueError(f'Could not call the method {init_method} in the module {model_cls}.') + if checkpoint: + if isinstance(checkpoint, str) and re.match(CHECKPOINT_URL_REGEX, checkpoint): + checkpoint = urllib.request.urlretrieve(checkpoint)[0] checkpoint = self._torch.load( checkpoint, map_location=None if self.cuda else self._torch.device('cpu') ) diff --git a/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md b/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md index 232398a0066..2bafc19f5a9 100644 --- a/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md +++ b/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md @@ -7,9 +7,12 @@ For enabling PyTorch launcher you need to add `framework: pytorch` in launchers * `device` - specifies which device will be used for infer (`cpu`, `cuda` and so on). * `module`- PyTorch network module for loading. * `checkpoint` - pre-trained model checkpoint (Optional). +* `checkpoint_url` - url link to pre-trained model checkpoint (Optional). +* `state_key` - pre-trained model checkpoint state key (Optional). * `python_path` - appendix for PYTHONPATH for making network module visible in current python environment (Optional). * `module_args` - list of positional arguments for network module (Optional). * `module_kwargs` - dictionary (`key`: `value` where `key` is argument name, `value` is argument value) which represent network module keyword arguments. +* `init_method` - method name to be called for module initialization (Optional). * `adapter` - approach how raw output will be converted to representation of dataset problem, some adapters can be specific to framework. You can find detailed instruction how to use adapters [here](../adapters/README.md). * `batch` - batch size for running model (Optional, default 1). * `use_openvino_backend` - use torch.compile feature with `openvino` backend (Optional, default `False`)