Skip to content

Qualcomm AI Engine Direct - GA mobilevit #11208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions backends/qualcomm/aot/wrappers/OpWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,13 @@ class OpWrapper final {
std::unique_ptr<QuantizeParamsWrapper> quantize_param_wrapper =
std::make_unique<UndefinedQuantizeParamsWrapper>();
constexpr std::uint32_t kBytes = 0;
std::vector<uint8_t> dynamic_dims(rank, 0);
std::shared_ptr<TensorWrapper> tensor_wrapper = CreateTensorWrapper(
QNN_TENSOR_TYPE_STATIC,
data_type,
std::move(quantize_param_wrapper),
rank,
dims,
dynamic_dims.data(),
nullptr,
kBytes,
data,
copy_data);
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def get_dynamic_dimension(self, dims):
nominal_dims.append(dim)
dynamic_dims.append(0)

return dynamic_dims, nominal_dims
return dynamic_dims if any(dynamic_dims) else [], nominal_dims

def define_custom_tensor_wrapper(
self,
Expand Down
36 changes: 36 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4180,6 +4180,42 @@ def test_pvt(self):
self.assertGreaterEqual(msg["top_1"], 65)
self.assertGreaterEqual(msg["top_5"], 85)

def test_mobilevit1(self):
if not self.required_envs([self.image_dataset]):
self.skipTest("missing required envs")

cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit1.py"
"--dataset",
self.image_dataset,
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--device",
self.device,
"--model",
self.model,
"--ip",
self.ip,
"--port",
str(self.port),
]
if self.host:
cmds.extend(["--host", self.host])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
else:
self.assertGreaterEqual(msg["top_1"], 70)
self.assertGreaterEqual(msg["top_5"], 85)

def test_regnet(self):
if not self.required_envs([self.image_dataset]):
self.skipTest("missing required envs")
Expand Down
173 changes: 173 additions & 0 deletions examples/qualcomm/oss_scripts/mobilevit1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
from multiprocessing.connection import Client

import numpy as np

import torch
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.examples.qualcomm.utils import (
build_executorch_binary,
make_output_dir,
parse_skip_delegation_node,
setup_common_args_and_variables,
SimpleADB,
topk_accuracy,
)
from PIL import Image
from torchvision import datasets
from transformers import AutoModelForImageClassification, MobileViTFeatureExtractor


def get_imagenet_dataset(dataset_path, data_size, shuffle=True):

def get_data_loader():
imagenet_data = datasets.ImageFolder(dataset_path)
return torch.utils.data.DataLoader(
imagenet_data,
shuffle=shuffle,
)

# prepare input data
inputs, targets, input_list = [], [], ""
data_loader = get_data_loader()
feature_extractor = MobileViTFeatureExtractor.from_pretrained(
"apple/mobilevit-xx-small"
)
for index, data in enumerate(data_loader.dataset.imgs):
if index >= data_size:
break
data_path, target = data
image = Image.open(data_path).convert("RGB")
feature = feature_extractor(images=image, return_tensors="pt")
inputs.append((feature["pixel_values"],))
targets.append(torch.tensor(target))
input_list += f"input_{index}_0.raw\n"

return inputs, targets, input_list


def main(args):
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)

# ensure the working directory exist.
os.makedirs(args.artifact, exist_ok=True)

if not args.compile_only and args.device is None:
raise RuntimeError(
"device serial is required if not compile only. "
"Please specify a device serial by -s/--device argument."
)

data_num = 100
if args.ci:
inputs = [(torch.rand(1, 3, 224, 224),)]
logging.warning(
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
inputs, targets, input_list = get_imagenet_dataset(
dataset_path=f"{args.dataset}",
data_size=data_num,
)

module = (
AutoModelForImageClassification.from_pretrained("apple/mobilevit-xx-small")
.eval()
.to("cpu")
)

pte_filename = "mobilevit1_qnn_q16"
build_executorch_binary(
module.eval(),
inputs[0],
args.model,
f"{args.artifact}/{pte_filename}",
inputs,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
quant_dtype=QuantDtype.use_16a16w,
shared_buffer=args.shared_buffer,
)

if args.compile_only:
return

adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=f"{args.build_folder}",
pte_path=f"{args.artifact}/{pte_filename}.pte",
workspace=f"/data/local/tmp/executorch/{pte_filename}",
device_id=args.device,
host_id=args.host,
soc_model=args.model,
shared_buffer=args.shared_buffer,
)
adb.push(inputs=inputs, input_list=input_list)
adb.execute()

# collect output data
output_data_folder = f"{args.artifact}/outputs"
make_output_dir(output_data_folder)

adb.pull(output_path=args.artifact)

# top-k analysis
predictions = []
for i in range(data_num):
predictions.append(
np.fromfile(
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
)
)

k_val = [1, 5]
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
else:
for i, k in enumerate(k_val):
print(f"top_{k}->{topk[i]}%")


if __name__ == "__main__":
parser = setup_common_args_and_variables()

parser.add_argument(
"-d",
"--dataset",
help=(
"path to the validation folder of ImageNet dataset. "
"e.g. --dataset imagenet-mini/val "
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
),
type=str,
required=False,
)

parser.add_argument(
"-a",
"--artifact",
help="path for storing generated artifacts by this example. "
"Default ./mobilevit1",
default="./mobilevit1",
type=str,
)

args = parser.parse_args()
try:
main(args)
except Exception as e:
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
conn.send(json.dumps({"Error": str(e)}))
else:
raise Exception(e)
Loading