-
Notifications
You must be signed in to change notification settings - Fork 163
Expand file tree
/
Copy pathdemo.py
More file actions
110 lines (97 loc) · 3.88 KB
/
demo.py
File metadata and controls
110 lines (97 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# ---------------------------------------------------------------------
# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
# ---------------------------------------------------------------------
from __future__ import annotations
import torch
from PIL import Image
from qai_hub_models.models.cvt.app import CVTApp
from qai_hub_models.models.cvt.model import CVT, MODEL_ASSET_VERSION, MODEL_ID
from qai_hub_models.utils.args import (
add_output_dir_arg,
demo_model_from_cli_args,
get_model_cli_parser,
get_on_device_demo_parser,
validate_on_device_demo_args,
)
from qai_hub_models.utils.asset_loaders import (
CachedWebModelAsset,
load_image,
load_json,
)
from qai_hub_models.utils.display import display_or_save_image
# Sample 6-camera images for demo
CAMERAS = {
"CAM_FRONT_LEFT": CachedWebModelAsset.from_asset_store(
MODEL_ID, MODEL_ASSET_VERSION, "CAM_FRONT_LEFT.jpg"
),
"CAM_FRONT": CachedWebModelAsset.from_asset_store(
MODEL_ID, MODEL_ASSET_VERSION, "CAM_FRONT.jpg"
),
"CAM_FRONT_RIGHT": CachedWebModelAsset.from_asset_store(
MODEL_ID, MODEL_ASSET_VERSION, "CAM_FRONT_RIGHT.jpg"
),
"CAM_BACK_LEFT": CachedWebModelAsset.from_asset_store(
MODEL_ID, MODEL_ASSET_VERSION, "CAM_BACK_LEFT.jpg"
),
"CAM_BACK": CachedWebModelAsset.from_asset_store(
MODEL_ID, MODEL_ASSET_VERSION, "CAM_BACK.jpg"
),
"CAM_BACK_RIGHT": CachedWebModelAsset.from_asset_store(
MODEL_ID, MODEL_ASSET_VERSION, "CAM_BACK_RIGHT.jpg"
),
}
CAM_METADATA = CachedWebModelAsset.from_asset_store(
MODEL_ID, MODEL_ASSET_VERSION, "inputs.json"
)
def cvt_demo(
model_type: type[CVT],
model_id: str,
cameras: dict[str, CachedWebModelAsset],
cam_metadata: CachedWebModelAsset,
is_test: bool = False,
) -> list[Image.Image] | torch.Tensor:
parser = get_model_cli_parser(model_type)
parser = get_on_device_demo_parser(parser)
add_output_dir_arg(parser)
parser.add_argument(
"--ckpt", type=str, default="both", choices=["both", "vehicles", "road"]
)
args = parser.parse_args([] if is_test else None)
# Load images and inputs
cam_paths = dict(cameras.items())
images = [load_image(str(img.fetch())) for img in cam_paths.values()]
camera_metadata = load_json(cam_metadata.fetch())
# Process model inference
def process_model(ckpt_name: str) -> list[Image.Image] | torch.Tensor:
args.ckpt_name = ckpt_name
model = demo_model_from_cli_args(model_type, model_id, args)
validate_on_device_demo_args(args, model_id)
h, w = CVT.get_input_spec()["image"][0][3:]
app = CVTApp(
model, # type: ignore[arg-type]
ckpt_name=ckpt_name,
target_height=h,
target_width=w,
)
return app.predict_from_images(images, camera_metadata, raw_output=is_test)
if args.ckpt == "both":
vehicle_maps = process_model("vehicles_50k")
road_maps = process_model("road_75k")
if not is_test:
for v_img, r_img in zip(vehicle_maps, road_maps, strict=False):
combined = Image.blend(v_img, r_img, alpha=0.5)
display_or_save_image(combined, args.output_dir, "cvt_bev_combined.png")
display_or_save_image(v_img, args.output_dir, "cvt_bev_vehicle.png")
display_or_save_image(r_img, args.output_dir, "cvt_bev_road.png")
return vehicle_maps
ckpt_name = "vehicles_50k" if args.ckpt == "vehicles" else "road_75k"
maps = process_model(ckpt_name)
if not is_test:
for i, img in enumerate(iterable=maps):
display_or_save_image(img, args.output_dir, f"cvt_bev_{args.ckpt}_{i}.png")
return maps
def main(is_test: bool = False) -> None:
cvt_demo(CVT, MODEL_ID, CAMERAS, CAM_METADATA, is_test)
if __name__ == "__main__":
main()