Skip to content

Commit 1337cdf

Browse files
authored
Move resnet example to models/ (#47)
Better code organization, instead of having two seperate related directories. Also moves llama requirements to it's own directory.
1 parent 6b10945 commit 1337cdf

9 files changed

Lines changed: 311 additions & 659 deletions

File tree

examples/requirements.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/resnet50.ipynb

Lines changed: 0 additions & 658 deletions
This file was deleted.

models/resnet/README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# ResNet-50 with IREE ONNX Runtime EP
2+
3+
Image classification using ResNet-50 via the IREE Execution Provider for ONNX Runtime.
4+
5+
## Setup
6+
7+
### 1. Download the model and labels
8+
9+
Download the ONNX model from the ONNX Model Zoo mirror on Hugging Face and the ImageNet labels file:
10+
11+
```bash
12+
mkdir -p resnet50-assets
13+
14+
curl -L \
15+
https://huggingface.co/onnxmodelzoo/resnet50_Opset18_torch_hub/resolve/main/resnet50_Opset18_torch_hub.onnx \
16+
-o resnet50-assets/model.onnx
17+
18+
curl -L \
19+
https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json \
20+
-o resnet50-assets/imagenet-simple-labels.json
21+
```
22+
23+
If you already have the checked-in `examples/model.onnx` and `examples/imagenet-simple-labels.json`, the script uses those paths by default.
24+
25+
### 2. Run
26+
27+
From the `models/resnet` directory:
28+
29+
```bash
30+
cd models/resnet
31+
```
32+
33+
For CPU execution:
34+
35+
```bash
36+
python run.py \
37+
--image images/dog.jpg \
38+
--image images/plane.jpg \
39+
--driver local-task \
40+
--target none
41+
```
42+
43+
To use separately downloaded assets instead of the checked-in `examples/` copies:
44+
45+
```bash
46+
python run.py \
47+
--model resnet50-assets/model.onnx \
48+
--labels resnet50-assets/imagenet-simple-labels.json \
49+
--image images/dog.jpg \
50+
--driver local-task \
51+
--target none
52+
```
53+
54+
For GPU execution, pass the appropriate driver and target architecture, for example:
55+
56+
```bash
57+
python run.py \
58+
--image images/dog.jpg \
59+
--driver hip \
60+
--target gfx1201
61+
```
62+
63+
Use `--top-k N` to control how many predictions are printed and `--verbose` for detailed ONNX Runtime logging.

models/resnet/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pillow

models/resnet/run.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
"""ResNet-50 image classification with ONNX Runtime and the IREE EP."""
2+
3+
from __future__ import annotations
4+
5+
import argparse
6+
import json
7+
import logging
8+
import time
9+
from pathlib import Path
10+
11+
import numpy as np
12+
import onnxruntime as ort
13+
import onnxruntime_ep_iree as iree_ep
14+
from PIL import Image
15+
16+
LOGGER = logging.getLogger(__name__)
17+
REPO_ROOT = Path(__file__).resolve().parents[2]
18+
DEFAULT_MODEL_PATH = REPO_ROOT / "examples" / "model.onnx"
19+
DEFAULT_LABELS_PATH = REPO_ROOT / "examples" / "imagenet-simple-labels.json"
20+
MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
21+
STDDEV = np.array([0.229, 0.224, 0.225], dtype=np.float32)
22+
23+
24+
def parse_args() -> argparse.Namespace:
25+
parser = argparse.ArgumentParser(
26+
description="Run ResNet-50 image classification through the IREE ONNX Runtime EP."
27+
)
28+
parser.add_argument(
29+
"--model",
30+
type=Path,
31+
default=DEFAULT_MODEL_PATH,
32+
help=f"Path to the ONNX model (default: {DEFAULT_MODEL_PATH}).",
33+
)
34+
parser.add_argument(
35+
"--labels",
36+
type=Path,
37+
default=DEFAULT_LABELS_PATH,
38+
help=f"Path to the ImageNet labels JSON (default: {DEFAULT_LABELS_PATH}).",
39+
)
40+
parser.add_argument(
41+
"--image",
42+
type=Path,
43+
action="append",
44+
required=True,
45+
help="Input image to classify. Repeat for multiple images.",
46+
)
47+
parser.add_argument(
48+
"--driver",
49+
default="local-task",
50+
help="IREE driver to use, for example local-task or hip.",
51+
)
52+
parser.add_argument(
53+
"--target",
54+
default="none",
55+
help="IREE target arch, for example none on CPU or gfx1201 on RDNA4.",
56+
)
57+
parser.add_argument(
58+
"--top-k",
59+
type=int,
60+
default=5,
61+
help="Number of predictions to print per image.",
62+
)
63+
parser.add_argument(
64+
"--verbose",
65+
action="store_true",
66+
help="Enable verbose ONNX Runtime and script logging.",
67+
)
68+
return parser.parse_args()
69+
70+
71+
def configure_logging(verbose: bool) -> None:
72+
logging.basicConfig(
73+
level=logging.DEBUG if verbose else logging.INFO,
74+
format="%(levelname)s %(message)s",
75+
)
76+
ort.set_default_logger_severity(0 if verbose else 2)
77+
78+
79+
def validate_path(path: Path, description: str) -> Path:
80+
resolved = path.expanduser().resolve()
81+
if not resolved.exists():
82+
raise FileNotFoundError(f"{description} not found: {resolved}")
83+
return resolved
84+
85+
86+
def load_labels(path: Path) -> list[str]:
87+
with path.open() as f:
88+
labels = json.load(f)
89+
if not isinstance(labels, list):
90+
raise ValueError(f"Expected a JSON list of labels in {path}")
91+
return labels
92+
93+
94+
def register_iree_ep() -> None:
95+
ep_name = iree_ep.get_ep_name()
96+
ep_library = iree_ep.get_library_path()
97+
LOGGER.debug("Registering execution provider %s from %s", ep_name, ep_library)
98+
ort.register_execution_provider_library(ep_name, ep_library)
99+
100+
101+
def get_iree_device(driver: str):
102+
ep_devices = ort.get_ep_devices()
103+
for dev in ep_devices:
104+
if dev.device.metadata.get("iree.driver") == driver:
105+
LOGGER.debug("Selected IREE device metadata: %s", dev.device.metadata)
106+
return dev
107+
108+
available = sorted(
109+
{
110+
dev.device.metadata.get("iree.driver")
111+
for dev in ep_devices
112+
if dev.device.metadata.get("iree.driver")
113+
}
114+
)
115+
raise RuntimeError(
116+
f"IREE device with driver '{driver}' not found. Available drivers: {available}"
117+
)
118+
119+
120+
def create_session(model_path: Path, target: str, driver: str):
121+
register_iree_ep()
122+
iree_device = get_iree_device(driver)
123+
124+
sess_options = ort.SessionOptions()
125+
sess_options.add_provider_for_devices(
126+
[iree_device],
127+
{
128+
"target_arch": target,
129+
"opt_level": "O3",
130+
},
131+
)
132+
session = ort.InferenceSession(
133+
str(model_path),
134+
sess_options=sess_options,
135+
enable_fallback=False,
136+
)
137+
return session, iree_device
138+
139+
140+
def get_model_io(session: ort.InferenceSession) -> tuple[str, str, int, int]:
141+
inputs = session.get_inputs()
142+
outputs = session.get_outputs()
143+
if len(inputs) != 1:
144+
raise ValueError(f"Expected a single model input, found {len(inputs)}")
145+
if len(outputs) != 1:
146+
raise ValueError(f"Expected a single model output, found {len(outputs)}")
147+
148+
model_input = inputs[0]
149+
if len(model_input.shape) != 4:
150+
raise ValueError(
151+
f"Expected a 4D NCHW input tensor, got shape {model_input.shape}"
152+
)
153+
154+
_, channels, height, width = model_input.shape
155+
if channels != 3:
156+
raise ValueError(f"Expected 3 input channels, got {channels}")
157+
if not isinstance(height, int) or not isinstance(width, int):
158+
raise ValueError(f"Expected static image size, got shape {model_input.shape}")
159+
160+
return model_input.name, outputs[0].name, height, width
161+
162+
163+
def preprocess_image(image_path: Path, height: int, width: int) -> np.ndarray:
164+
image = Image.open(image_path).convert("RGB")
165+
resampling = getattr(Image, "Resampling", Image)
166+
image = image.resize((width, height), resample=resampling.BILINEAR)
167+
168+
image_data = np.asarray(image, dtype=np.float32).transpose(2, 0, 1)
169+
image_data = image_data / 255.0
170+
image_data = (image_data - MEAN[:, None, None]) / STDDEV[:, None, None]
171+
return image_data.reshape(1, 3, height, width).astype(np.float32)
172+
173+
174+
def softmax(values: np.ndarray) -> np.ndarray:
175+
values = values.reshape(-1)
176+
shifted = values - np.max(values)
177+
exp_values = np.exp(shifted)
178+
return exp_values / np.sum(exp_values)
179+
180+
181+
def run_inference(
182+
session: ort.InferenceSession,
183+
input_name: str,
184+
output_name: str,
185+
image_tensor: np.ndarray,
186+
) -> tuple[np.ndarray, float]:
187+
start = time.perf_counter()
188+
output = session.run([output_name], {input_name: image_tensor})[0]
189+
elapsed_ms = (time.perf_counter() - start) * 1000.0
190+
return softmax(np.asarray(output)), elapsed_ms
191+
192+
193+
def print_predictions(
194+
image_path: Path,
195+
probabilities: np.ndarray,
196+
labels: list[str],
197+
elapsed_ms: float,
198+
top_k: int,
199+
) -> None:
200+
top_k = min(top_k, len(labels))
201+
top_indices = np.argsort(probabilities)[::-1][:top_k]
202+
best_index = int(top_indices[0])
203+
204+
print(f"Image: {image_path}")
205+
print(f"Inference time: {elapsed_ms:.2f} ms")
206+
print(
207+
"Top prediction: "
208+
f"{labels[best_index]} ({probabilities[best_index] * 100.0:.2f}%)"
209+
)
210+
print(f"Top {top_k} predictions:")
211+
for rank, index in enumerate(top_indices, start=1):
212+
print(f" {rank}. {labels[index]} ({probabilities[index] * 100.0:.2f}%)")
213+
print()
214+
215+
216+
def main() -> None:
217+
args = parse_args()
218+
configure_logging(args.verbose)
219+
220+
model_path = validate_path(args.model, "Model")
221+
labels_path = validate_path(args.labels, "Labels")
222+
image_paths = [validate_path(path, "Image") for path in args.image]
223+
224+
labels = load_labels(labels_path)
225+
session, iree_device = create_session(model_path, args.target, args.driver)
226+
input_name, output_name, height, width = get_model_io(session)
227+
228+
LOGGER.info(
229+
"Running ResNet-50 on IREE driver=%s target=%s input=%s output=%s size=%dx%d",
230+
iree_device.device.metadata.get("iree.driver"),
231+
args.target,
232+
input_name,
233+
output_name,
234+
width,
235+
height,
236+
)
237+
238+
for image_path in image_paths:
239+
image_tensor = preprocess_image(image_path, height, width)
240+
probabilities, elapsed_ms = run_inference(
241+
session, input_name, output_name, image_tensor
242+
)
243+
print_predictions(image_path, probabilities, labels, elapsed_ms, args.top_k)
244+
245+
246+
if __name__ == "__main__":
247+
main()

0 commit comments

Comments
 (0)