Skip to content

Commit 1bc36c7

Browse files
haowhsu-quicCheng-Hsin Weng
and
Cheng-Hsin Weng
authored
Qualcomm AI Engine Direct - GA Model Enablement (deit) (pytorch#11065)
On behalf of @chenweng-quic ### Summary - support e2e script / test case for GA deit model - perf: 8a8w 2.7ms/inf (SM8750) - acc: top1/5 ~= 80%/90% ### Test plan ```bash python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleOssScript.test_deit --device <device_id> --host <host> --model <soc_model> --build_folder build-android --executorch_root . --image_dataset imagenet-mini/val --artifact deit_artifact ``` --------- Co-authored-by: Cheng-Hsin Weng <[email protected]>
1 parent 311489f commit 1bc36c7

File tree

2 files changed

+183
-0
lines changed

2 files changed

+183
-0
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3880,6 +3880,41 @@ def test_conv_former(self):
38803880
self.assertGreaterEqual(msg["top_1"], 60)
38813881
self.assertGreaterEqual(msg["top_5"], 80)
38823882

3883+
def test_deit(self):
3884+
if not self.required_envs([self.image_dataset]):
3885+
self.skipTest("missing required envs")
3886+
cmds = [
3887+
"python",
3888+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/deit.py",
3889+
"--dataset",
3890+
self.image_dataset,
3891+
"--artifact",
3892+
self.artifact_dir,
3893+
"--build_folder",
3894+
self.build_folder,
3895+
"--device",
3896+
self.device,
3897+
"--model",
3898+
self.model,
3899+
"--ip",
3900+
self.ip,
3901+
"--port",
3902+
str(self.port),
3903+
]
3904+
if self.host:
3905+
cmds.extend(["--host", self.host])
3906+
3907+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3908+
with Listener((self.ip, self.port)) as listener:
3909+
conn = listener.accept()
3910+
p.communicate()
3911+
msg = json.loads(conn.recv())
3912+
if "Error" in msg:
3913+
self.fail(msg["Error"])
3914+
else:
3915+
self.assertGreaterEqual(msg["top_1"], 75)
3916+
self.assertGreaterEqual(msg["top_5"], 90)
3917+
38833918
def test_dino_v2(self):
38843919
if not self.required_envs([self.image_dataset]):
38853920
self.skipTest("missing required envs")

examples/qualcomm/oss_scripts/deit.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import getpass
8+
import json
9+
import os
10+
from multiprocessing.connection import Client
11+
12+
import numpy as np
13+
from executorch.backends.qualcomm._passes.qnn_pass_manager import (
14+
get_capture_program_passes,
15+
)
16+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
17+
from executorch.examples.qualcomm.utils import (
18+
build_executorch_binary,
19+
get_imagenet_dataset,
20+
make_output_dir,
21+
parse_skip_delegation_node,
22+
setup_common_args_and_variables,
23+
SimpleADB,
24+
topk_accuracy,
25+
)
26+
from transformers import AutoConfig, AutoModelForImageClassification
27+
28+
29+
def get_instance():
30+
module = (
31+
AutoModelForImageClassification.from_pretrained(
32+
"facebook/deit-base-distilled-patch16-224"
33+
)
34+
.eval()
35+
.to("cpu")
36+
)
37+
38+
return module
39+
40+
41+
def main(args):
42+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
43+
44+
os.makedirs(args.artifact, exist_ok=True)
45+
config = AutoConfig.from_pretrained("facebook/deit-base-distilled-patch16-224")
46+
data_num = 100
47+
height = config.image_size
48+
width = config.image_size
49+
inputs, targets, input_list = get_imagenet_dataset(
50+
dataset_path=f"{args.dataset}",
51+
data_size=data_num,
52+
image_shape=(height, width),
53+
crop_size=(height, width),
54+
)
55+
56+
# Get the Deit model.
57+
model = get_instance()
58+
pte_filename = "deit_qnn"
59+
60+
# lower to QNN
61+
passes_job = get_capture_program_passes()
62+
build_executorch_binary(
63+
model,
64+
inputs[0],
65+
args.model,
66+
f"{args.artifact}/{pte_filename}",
67+
dataset=inputs,
68+
skip_node_id_set=skip_node_id_set,
69+
skip_node_op_set=skip_node_op_set,
70+
quant_dtype=QuantDtype.use_8a8w,
71+
passes_job=passes_job,
72+
shared_buffer=args.shared_buffer,
73+
)
74+
75+
if args.compile_only:
76+
return
77+
78+
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}"
79+
pte_path = f"{args.artifact}/{pte_filename}.pte"
80+
81+
adb = SimpleADB(
82+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
83+
build_path=f"{args.build_folder}",
84+
pte_path=pte_path,
85+
workspace=workspace,
86+
device_id=args.device,
87+
host_id=args.host,
88+
soc_model=args.model,
89+
)
90+
adb.push(inputs=inputs, input_list=input_list)
91+
adb.execute()
92+
93+
# collect output data
94+
output_data_folder = f"{args.artifact}/outputs"
95+
make_output_dir(output_data_folder)
96+
97+
adb.pull(output_path=args.artifact)
98+
99+
# top-k analysis
100+
predictions = []
101+
for i in range(data_num):
102+
predictions.append(
103+
np.fromfile(
104+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
105+
)
106+
)
107+
108+
k_val = [1, 5]
109+
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
110+
if args.ip and args.port != -1:
111+
with Client((args.ip, args.port)) as conn:
112+
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
113+
else:
114+
for i, k in enumerate(k_val):
115+
print(f"top_{k}->{topk[i]}%")
116+
117+
118+
if __name__ == "__main__":
119+
parser = setup_common_args_and_variables()
120+
parser.add_argument(
121+
"-a",
122+
"--artifact",
123+
help="path for storing generated artifacts and output by this example. Default ./deit_qnn",
124+
default="./deit_qnn",
125+
type=str,
126+
)
127+
128+
parser.add_argument(
129+
"-d",
130+
"--dataset",
131+
help=(
132+
"path to the validation folder of ImageNet dataset. "
133+
"e.g. --dataset imagenet-mini/val "
134+
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
135+
),
136+
type=str,
137+
required=True,
138+
)
139+
140+
args = parser.parse_args()
141+
try:
142+
main(args)
143+
except Exception as e:
144+
if args.ip and args.port != -1:
145+
with Client((args.ip, args.port)) as conn:
146+
conn.send(json.dumps({"Error": str(e)}))
147+
else:
148+
raise Exception(e)

0 commit comments

Comments
 (0)