Skip to content

support farimot #3623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
368 changes: 368 additions & 0 deletions docs/module_usage/tutorials/cv_modules/joint_detection_embedding.md

Large diffs are not rendered by default.

898 changes: 898 additions & 0 deletions docs/pipeline_usage/tutorials/cv_pipelines/multiobject_tracking.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
Global:
model: FairMOT-DLA-34_1088x608
mode: check_dataset # check_dataset/train/evaluate/predict
dataset_dir: "/paddle/dataset/paddlex/det/mot_examples"
device: gpu:0,1
output: "output"

CheckDataset:
convert:
enable: False
src_dataset_type: null
split:
enable: False
train_percent: null
val_percent: null

Train:
num_classes: 1
epochs_iters: 30
batch_size: 6
learning_rate: 0.0001
pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/FairMOT-DLA-34_1088x608_pretrained.pdparams
warmup_steps: 0
resume_path: null
log_interval: 10
eval_interval: 1

Evaluate:
weight_path: "output/best_model/best_model.pdparams"
log_interval: 10

Predict:
batch_size: 1
model_dir: "output/best_model/inference"
input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"
kernel_option:
run_mode: paddle


Export:
weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/FairMOT-DLA-34_1088x608_pretrained.pdparams
20 changes: 20 additions & 0 deletions paddlex/configs/pipelines/multiobject_tracking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
pipeline_name: multiobject_tracking

SubModules:
Detector:
module_name: joint_detection_embedding
model_name: FairMOT-DLA-34_1088x608
model_dir: null
batch_size: 1
img_size: null
threshold: null
ReID: None
Tracker:
module_name: JDETracker
conf_thres: 0.4
tracked_thresh: 0.4
metric_type: cosine
min_box_area: 200
vertical_ratio: 1.6
MOT:
module_name: JDEMOT
2 changes: 2 additions & 0 deletions paddlex/inference/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from .video_classification import VideoClasPredictor
from .video_detection import VideoDetPredictor

from .joint_detection_embedding import JDEPredictor

module_3d_bev_detection = import_module(".3d_bev_detection", "paddlex.inference.models")
BEVDet3DPredictor = getattr(module_3d_bev_detection, "BEVDet3DPredictor")

Expand Down
15 changes: 15 additions & 0 deletions paddlex/inference/models/joint_detection_embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .predictor import JDEPredictor
80 changes: 80 additions & 0 deletions paddlex/inference/models/joint_detection_embedding/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List

from ....modules.joint_detection_embedding.model_list import MODELS
from ..object_detection import DetPredictor
from ..object_detection.processors import ToBatch

from .result import JDEResult


class JDEPredictor(DetPredictor):

entities = MODELS

def __init__(self, *args, **kwargs):
"""Initializes DetPredictor.
Args:
*args: Arbitrary positional arguments passed to the superclass.
**kwargs: Arbitrary keyword arguments passed to the superclass.
"""
if "batch_size" in kwargs:
assert kwargs["batch_size"] == 1, "JDEPredictor only supports batch_size=1"
super().__init__(*args, **kwargs)

def _get_result_class(self):
return JDEResult

def process(self, batch_data: List[Any]):
"""
Process a batch of data through the preprocessing, inference, and postprocessing.

Args:
batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
Returns:
dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
"""
datas = batch_data.instances
# preprocess
for pre_op in self.pre_ops[:-1]:
datas = pre_op(datas)

# use `ToBatch` format batch inputs
batch_inputs = self.pre_ops[-1](datas)

# do infer
pred_dets, pred_embs = self.infer(batch_inputs)

return {
"input_path": batch_data.input_paths,
"input_img": [data["ori_img"] for data in datas],
"pred_dets": [pred_dets],
"pred_embs": [pred_embs],
}

def build_to_batch(self):
models_required_imgsize = ["FairMOT"]
if any(name in self.model_name for name in models_required_imgsize):
ordered_required_keys = (
"img_size",
"img",
"scale_factors",
)
else:
ordered_required_keys = ("img", "scale_factors")

return ToBatch(ordered_required_keys=ordered_required_keys)
103 changes: 103 additions & 0 deletions paddlex/inference/models/joint_detection_embedding/processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple, Union, Optional

import cv2
import numpy as np
from numpy import ndarray


Boxes = List[dict]
Number = Union[int, float]


class JDEPostProcess:
"""Save Result Transform

This class is responsible for post-processing detection results, including
thresholding, non-maximum suppression (NMS), and restructuring the boxes
based on the input type (normal or rotated object detection).
"""

def __init__(self, labels: Optional[List[str]] = None) -> None:
"""Initialize the DetPostProcess class.

Args:
threshold (float, optional): The threshold to apply to the detection scores. Defaults to 0.5.
labels (Optional[List[str]], optional): The list of labels for the detection categories. Defaults to None.
layout_postprocess (bool, optional): Whether to apply layout post-processing. Defaults to False.
"""
super().__init__()
self.labels = labels

def apply(
self,
boxes: ndarray,
img_size: Tuple[int, int],
threshold: Union[float, dict],
) -> Boxes:
"""Apply post-processing to the detection boxes.

Args:
boxes (ndarray): The input detection boxes with scores.
img_size (tuple): The original image size.

Returns:
Boxes: The post-processed detection boxes.
"""
if isinstance(threshold, float):
expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
boxes = boxes[expect_boxes, :]
elif isinstance(threshold, dict):
category_filtered_boxes = []
for cat_id in np.unique(boxes[:, 0]):
category_boxes = boxes[boxes[:, 0] == cat_id]
category_threshold = threshold.get(int(cat_id), 0.5)
selected_indices = (category_boxes[:, 1] > category_threshold) & (
category_boxes[:, 0] > -1
)
category_filtered_boxes.append(category_boxes[selected_indices])
boxes = (
np.vstack(category_filtered_boxes)
if category_filtered_boxes
else np.array([])
)

return boxes

def __call__(
self,
batch_outputs: List[dict],
datas: List[dict],
threshold: Optional[Union[float, dict]] = None,
) -> List[Boxes]:
"""Apply the post-processing to a batch of outputs.

Args:
batch_outputs (List[dict]): The list of detection outputs.
datas (List[dict]): The list of input data.

Returns:
List[Boxes]: The list of post-processed detection boxes.
"""
outputs = []
for data, output in zip(datas, batch_outputs):
boxes = self.apply(
output["boxes"],
data["ori_img_size"],
threshold,
)
outputs.append(boxes)
return outputs
33 changes: 33 additions & 0 deletions paddlex/inference/models/joint_detection_embedding/result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from ...common.result import BaseResult, JsonMixin


class JDEResult(BaseResult):

def __init__(self, data: dict) -> None:
super().__init__(data)

def _to_str(self, *args, **kwargs):
data = copy.deepcopy(self)
data.pop("input_img")
return JsonMixin._to_str(data, *args, **kwargs)

def _to_json(self, *args, **kwargs):
data = copy.deepcopy(self)
data.pop("input_img")
return JsonMixin._to_json(data, *args, **kwargs)
8 changes: 8 additions & 0 deletions paddlex/inference/models/object_detection/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PadStride,
ReadImage,
Resize,
LetterBoxResize,
ToBatch,
ToCHWImage,
WarpAffine,
Expand Down Expand Up @@ -273,6 +274,12 @@ def build_resize(self, target_size, keep_ratio=False, interp=2):
op = Resize(target_size=target_size[::-1], keep_ratio=keep_ratio, interp=interp)
return op

@register("LetterBoxResize")
def build_letterbox_resize(self, target_size):
assert target_size
op = LetterBoxResize(target_size=target_size)
return op

@register("NormalizeImage")
def build_normalize(
self,
Expand Down Expand Up @@ -345,4 +352,5 @@ def build_postprocess(self):
self.layout_merge_bboxes_mode = self.config.get(
"layout_merge_bboxes_mode", None
)
self.labels = self.config["label_list"]
return DetPostProcess(labels=self.config["label_list"])
Loading