|
1 |
| -from typing import Optional, Sequence, Union |
| 1 | +from typing import Any, Sequence, Union |
2 | 2 |
|
3 | 3 | import mmcv
|
4 | 4 | import numpy as np
|
5 | 5 | import torch
|
6 | 6 |
|
7 |
| -from mmdeploy.utils import (Backend, get_backend, get_codebase, |
8 |
| - get_input_shape, get_task_type, load_config) |
9 |
| -from .utils import (create_input, init_backend_model, init_pytorch_model, |
10 |
| - run_inference, visualize) |
| 7 | +from mmdeploy.utils import get_input_shape, load_config |
11 | 8 |
|
12 | 9 |
|
13 | 10 | def inference_model(model_cfg: Union[str, mmcv.Config],
|
14 | 11 | deploy_cfg: Union[str, mmcv.Config],
|
15 |
| - model: Union[str, Sequence[str], torch.nn.Module], |
16 |
| - img: Union[str, np.ndarray], |
17 |
| - device: str, |
18 |
| - backend: Optional[Backend] = None, |
19 |
| - output_file: Optional[str] = None, |
20 |
| - show_result: bool = False): |
| 12 | + backend_files: Sequence[str], img: Union[str, np.ndarray], |
| 13 | + device: str) -> Any: |
21 | 14 | """Run inference with PyTorch or backend model and show results.
|
22 | 15 |
|
23 | 16 | Args:
|
24 | 17 | model_cfg (str | mmcv.Config): Model config file or Config object.
|
25 | 18 | deploy_cfg (str | mmcv.Config): Deployment config file or Config
|
26 | 19 | object.
|
27 |
| - model (str | list[str], torch.nn.Module): Input model or file(s). |
| 20 | + backend_files (Sequence[str]): Input backend model file(s). |
28 | 21 | img (str | np.ndarray): Input image file or numpy array for inference.
|
29 | 22 | device (str): A string specifying device type.
|
30 |
| - backend (Backend): Specifying backend type, defaults to `None`. |
31 |
| - output_file (str): Output file to save visualized image, defaults to |
32 |
| - `None`. Only valid if `show_result` is set to `False`. |
33 |
| - show_result (bool): Whether to show plotted image in windows, defaults |
34 |
| - to `False`. |
| 23 | +
|
| 24 | + Returns: |
| 25 | + Any: The inference results |
35 | 26 | """
|
36 | 27 | deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
37 | 28 |
|
38 |
| - codebase = get_codebase(deploy_cfg) |
39 |
| - task = get_task_type(deploy_cfg) |
40 |
| - input_shape = get_input_shape(deploy_cfg) |
41 |
| - if backend is None: |
42 |
| - backend = get_backend(deploy_cfg) |
43 |
| - |
44 |
| - if isinstance(model, str): |
45 |
| - model = [model] |
| 29 | + from mmdeploy.apis.utils import build_task_processor |
| 30 | + task_processor = build_task_processor(model_cfg, deploy_cfg, device) |
46 | 31 |
|
47 |
| - if isinstance(model, (list, tuple)): |
48 |
| - assert len(model) > 0, 'Model should have at least one element.' |
49 |
| - assert all([isinstance(m, str) for m in model]), 'All elements in the \ |
50 |
| - list should be str' |
| 32 | + model = task_processor.init_backend_model(backend_files) |
51 | 33 |
|
52 |
| - if backend == Backend.PYTORCH: |
53 |
| - model = init_pytorch_model(codebase, model_cfg, model[0], device) |
54 |
| - else: |
55 |
| - device_id = -1 if device == 'cpu' else 0 |
56 |
| - model = init_backend_model( |
57 |
| - model, |
58 |
| - model_cfg=model_cfg, |
59 |
| - deploy_cfg=deploy_cfg, |
60 |
| - device_id=device_id) |
61 |
| - |
62 |
| - model_inputs, _ = create_input(codebase, task, model_cfg, img, input_shape, |
63 |
| - device) |
| 34 | + input_shape = get_input_shape(deploy_cfg) |
| 35 | + model_inputs, _ = task_processor.create_input(img, input_shape) |
64 | 36 |
|
65 | 37 | with torch.no_grad():
|
66 |
| - result = run_inference(codebase, model_inputs, model) |
| 38 | + result = task_processor.run_inference(model, model_inputs) |
67 | 39 |
|
68 |
| - visualize( |
69 |
| - codebase, |
70 |
| - img, |
71 |
| - result=result, |
72 |
| - model=model, |
73 |
| - output_file=output_file, |
74 |
| - backend=backend, |
75 |
| - show_result=show_result) |
| 40 | + return result |
0 commit comments