Skip to content

Commit 3fd73a1

Browse files
committed
Update on "[executorch][core] NamedDataMap interface"
Add NamedDataMap interface to runtime. Differential Revision: [D66834552](https://our.internmc.facebook.com/intern/diff/D66834552/) [ghstack-poisoned]
2 parents 61d6a8b + b4af0df commit 3fd73a1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+4183
-213
lines changed

Diff for: .ci/scripts/gather_benchmark_configs.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
import re
1212
import sys
13-
from typing import Any, Dict, List
13+
from typing import Any, Dict, List, NamedTuple
1414

1515
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1616
from examples.models import MODEL_NAME_TO_MODEL
@@ -47,6 +47,46 @@
4747
}
4848

4949

50+
class DisabledConfig(NamedTuple):
51+
config_name: str
52+
github_issue: str # Link to the GitHub issue
53+
54+
55+
# Updated DISABLED_CONFIGS
56+
DISABLED_CONFIGS: Dict[str, List[DisabledConfig]] = {
57+
"resnet50": [
58+
DisabledConfig(
59+
config_name="qnn_q8",
60+
github_issue="https://github.com/pytorch/executorch/issues/7892",
61+
),
62+
],
63+
"w2l": [
64+
DisabledConfig(
65+
config_name="qnn_q8",
66+
github_issue="https://github.com/pytorch/executorch/issues/7634",
67+
),
68+
],
69+
"mobilebert": [
70+
DisabledConfig(
71+
config_name="mps",
72+
github_issue="https://github.com/pytorch/executorch/issues/7904",
73+
),
74+
],
75+
"edsr": [
76+
DisabledConfig(
77+
config_name="mps",
78+
github_issue="https://github.com/pytorch/executorch/issues/7905",
79+
),
80+
],
81+
"llama": [
82+
DisabledConfig(
83+
config_name="mps",
84+
github_issue="https://github.com/pytorch/executorch/issues/7907",
85+
),
86+
],
87+
}
88+
89+
5090
def extract_all_configs(data, target_os=None):
5191
if isinstance(data, dict):
5292
# If target_os is specified, include "xplat" and the specified branch
@@ -117,6 +157,14 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
117157
# Skip unknown models with a warning
118158
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119159

160+
# Remove disabled configs for the given model
161+
disabled_configs = DISABLED_CONFIGS.get(model_name, [])
162+
disabled_config_names = {disabled.config_name for disabled in disabled_configs}
163+
for disabled in disabled_configs:
164+
print(
165+
f"Excluding disabled config: '{disabled.config_name}' for model '{model_name}' on '{target_os}'. Linked GitHub issue: {disabled.github_issue}"
166+
)
167+
configs = [config for config in configs if config not in disabled_config_names]
120168
return configs
121169

122170

Diff for: .ci/scripts/tests/test_gather_benchmark_configs.py

+89-21
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,41 @@
11
import importlib.util
22
import os
3+
import re
34
import subprocess
45
import sys
56
import unittest
67
from unittest.mock import mock_open, patch
78

89
import pytest
910

10-
# Dynamically import the script
11-
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
12-
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
13-
gather_benchmark_configs = importlib.util.module_from_spec(spec)
14-
spec.loader.exec_module(gather_benchmark_configs)
15-
1611

1712
@pytest.mark.skipif(
1813
sys.platform != "linux", reason="The script under test runs on Linux runners only"
1914
)
2015
class TestGatehrBenchmarkConfigs(unittest.TestCase):
2116

17+
@classmethod
18+
def setUpClass(cls):
19+
# Dynamically import the script
20+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
21+
spec = importlib.util.spec_from_file_location(
22+
"gather_benchmark_configs", script_path
23+
)
24+
cls.gather_benchmark_configs = importlib.util.module_from_spec(spec)
25+
spec.loader.exec_module(cls.gather_benchmark_configs)
26+
2227
def test_extract_all_configs_android(self):
23-
android_configs = gather_benchmark_configs.extract_all_configs(
24-
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
28+
android_configs = self.gather_benchmark_configs.extract_all_configs(
29+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
2530
)
2631
self.assertIn("xnnpack_q8", android_configs)
2732
self.assertIn("qnn_q8", android_configs)
2833
self.assertIn("llama3_spinquant", android_configs)
2934
self.assertIn("llama3_qlora", android_configs)
3035

3136
def test_extract_all_configs_ios(self):
32-
ios_configs = gather_benchmark_configs.extract_all_configs(
33-
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
37+
ios_configs = self.gather_benchmark_configs.extract_all_configs(
38+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
3439
)
3540

3641
self.assertIn("xnnpack_q8", ios_configs)
@@ -40,51 +45,114 @@ def test_extract_all_configs_ios(self):
4045
self.assertIn("llama3_spinquant", ios_configs)
4146
self.assertIn("llama3_qlora", ios_configs)
4247

48+
def test_skip_disabled_configs(self):
49+
# Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
50+
with patch.dict(
51+
self.gather_benchmark_configs.DISABLED_CONFIGS,
52+
{
53+
"mv3": [
54+
self.gather_benchmark_configs.DisabledConfig(
55+
config_name="disabled_config1",
56+
github_issue="https://github.com/org/repo/issues/123",
57+
),
58+
self.gather_benchmark_configs.DisabledConfig(
59+
config_name="disabled_config2",
60+
github_issue="https://github.com/org/repo/issues/124",
61+
),
62+
]
63+
},
64+
), patch.dict(
65+
self.gather_benchmark_configs.BENCHMARK_CONFIGS,
66+
{
67+
"ios": [
68+
"disabled_config1",
69+
"disabled_config2",
70+
"enabled_config1",
71+
"enabled_config2",
72+
]
73+
},
74+
):
75+
result = self.gather_benchmark_configs.generate_compatible_configs(
76+
"mv3", target_os="ios"
77+
)
78+
79+
# Assert that disabled configs are excluded
80+
self.assertNotIn("disabled_config1", result)
81+
self.assertNotIn("disabled_config2", result)
82+
# Assert enabled configs are included
83+
self.assertIn("enabled_config1", result)
84+
self.assertIn("enabled_config2", result)
85+
86+
def test_disabled_configs_have_github_links(self):
87+
github_issue_regex = re.compile(r"https://github\.com/.+/.+/issues/\d+")
88+
89+
for (
90+
model_name,
91+
disabled_configs,
92+
) in self.gather_benchmark_configs.DISABLED_CONFIGS.items():
93+
for disabled in disabled_configs:
94+
with self.subTest(model_name=model_name, config=disabled.config_name):
95+
# Assert that disabled is an instance of DisabledConfig
96+
self.assertIsInstance(
97+
disabled, self.gather_benchmark_configs.DisabledConfig
98+
)
99+
100+
# Assert that github_issue is provided and matches the expected pattern
101+
self.assertTrue(
102+
disabled.github_issue
103+
and github_issue_regex.match(disabled.github_issue),
104+
f"Invalid or missing GitHub issue link for '{disabled.config_name}' in model '{model_name}'.",
105+
)
106+
43107
def test_generate_compatible_configs_llama_model(self):
44108
model_name = "meta-llama/Llama-3.2-1B"
45109
target_os = "ios"
46-
result = gather_benchmark_configs.generate_compatible_configs(
110+
result = self.gather_benchmark_configs.generate_compatible_configs(
47111
model_name, target_os
48112
)
49113
expected = ["llama3_fb16", "llama3_coreml_ane"]
50114
self.assertEqual(result, expected)
51115

52116
target_os = "android"
53-
result = gather_benchmark_configs.generate_compatible_configs(
117+
result = self.gather_benchmark_configs.generate_compatible_configs(
54118
model_name, target_os
55119
)
56120
expected = ["llama3_fb16"]
57121
self.assertEqual(result, expected)
58122

59123
def test_generate_compatible_configs_quantized_llama_model(self):
60124
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
125+
result = self.gather_benchmark_configs.generate_compatible_configs(
126+
model_name, None
127+
)
62128
expected = ["llama3_spinquant"]
63129
self.assertEqual(result, expected)
64130

65131
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
132+
result = self.gather_benchmark_configs.generate_compatible_configs(
133+
model_name, None
134+
)
67135
expected = ["llama3_qlora"]
68136
self.assertEqual(result, expected)
69137

70138
def test_generate_compatible_configs_non_genai_model(self):
71139
model_name = "mv2"
72140
target_os = "xplat"
73-
result = gather_benchmark_configs.generate_compatible_configs(
141+
result = self.gather_benchmark_configs.generate_compatible_configs(
74142
model_name, target_os
75143
)
76144
expected = ["xnnpack_q8"]
77145
self.assertEqual(result, expected)
78146

79147
target_os = "android"
80-
result = gather_benchmark_configs.generate_compatible_configs(
148+
result = self.gather_benchmark_configs.generate_compatible_configs(
81149
model_name, target_os
82150
)
83151
expected = ["xnnpack_q8", "qnn_q8"]
84152
self.assertEqual(result, expected)
85153

86154
target_os = "ios"
87-
result = gather_benchmark_configs.generate_compatible_configs(
155+
result = self.gather_benchmark_configs.generate_compatible_configs(
88156
model_name, target_os
89157
)
90158
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
@@ -93,22 +161,22 @@ def test_generate_compatible_configs_non_genai_model(self):
93161
def test_generate_compatible_configs_unknown_model(self):
94162
model_name = "unknown_model"
95163
target_os = "ios"
96-
result = gather_benchmark_configs.generate_compatible_configs(
164+
result = self.gather_benchmark_configs.generate_compatible_configs(
97165
model_name, target_os
98166
)
99167
self.assertEqual(result, [])
100168

101169
def test_is_valid_huggingface_model_id_valid(self):
102170
valid_model = "meta-llama/Llama-3.2-1B"
103171
self.assertTrue(
104-
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
172+
self.gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
105173
)
106174

107175
@patch("builtins.open", new_callable=mock_open)
108176
@patch("os.getenv", return_value=None)
109177
def test_set_output_no_github_env(self, mock_getenv, mock_file):
110178
with patch("builtins.print") as mock_print:
111-
gather_benchmark_configs.set_output("test_name", "test_value")
179+
self.gather_benchmark_configs.set_output("test_name", "test_value")
112180
mock_print.assert_called_with("::set-output name=test_name::test_value")
113181

114182
def test_device_pools_contains_all_devices(self):
@@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self):
120188
"google_pixel_8_pro",
121189
]
122190
for device in expected_devices:
123-
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
191+
self.assertIn(device, self.gather_benchmark_configs.DEVICE_POOLS)
124192

125193
def test_gather_benchmark_configs_cli(self):
126194
args = {

Diff for: CMakeLists.txt

-2
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ else()
164164
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O2")
165165
endif()
166166

167-
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g")
168-
169167
option(EXECUTORCH_BUILD_ANDROID_JNI "Build Android JNI" OFF)
170168

171169
option(EXECUTORCH_BUILD_ARM_BAREMETAL

Diff for: backends/cadence/aot/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ python_library(
6565
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
6666
"//executorch/backends/cadence/runtime:runtime",
6767
"//executorch/backends/cadence/aot/quantizer:quantizer",
68+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
6869
"//executorch/backends/transforms:decompose_sdpa",
6970
"//executorch/backends/transforms:remove_clone_ops",
7071
"//executorch/exir:lib",

Diff for: backends/cadence/aot/compiler.py

+3-22
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from executorch.backends.cadence.aot.utils import (
2525
get_default_memory_config,
2626
MemoryConfig,
27-
model_is_quantized,
2827
)
2928
from executorch.devtools import generate_etrecord
3029
from executorch.exir import (
@@ -38,7 +37,6 @@
3837
from executorch.exir.passes import ToOutVarPass
3938
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
4039
from torch._inductor.decomposition import remove_decompositions
41-
from torch.ao.quantization.pt2e.export_utils import model_is_exported
4240
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4341

4442
from torch.export import export
@@ -158,26 +156,10 @@ def export_program(
158156
) -> ExportedProgram:
159157
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
160158

161-
# We don't support training mode. Make the model inference mode by
162-
# calling model.eval() or an equivalent call for quantized models.
163-
# GraphModules cannot call eval(), so we skip them.
164-
if not isinstance(model, torch.fx.GraphModule):
165-
if hasattr(model, "eval"):
166-
model.eval()
167-
else:
168-
# If the model is quantized, call the suggested torch.ao.quantization API
169-
# which only does dropout and batchnorm.
170-
if model_is_quantized(model):
171-
torch.ao.quantization.move_exported_model_to_eval(model)
172-
else:
173-
# If we get a GraphModule which is _not_ quantized, then it should already
174-
# have been exported.
175-
assert model_is_exported(model), "model should be from an ExportedProgram"
176-
177159
# Prevent mkldnn decompositions
178160
torch._C._set_mkldnn_enabled(False)
179161

180-
# else: capture the model and return it.
162+
# Export the model and return it.
181163
expo_program = export(model, inputs, strict=True)
182164

183165
if dump_graphs:
@@ -206,8 +188,8 @@ def export_to_edge(
206188
_skip_dim_order=True,
207189
# Allow specific non-core aten ops in the IR.
208190
_core_aten_ops_exception_list=[
191+
torch.ops.aten._native_batch_norm_legit_functional.default,
209192
torch.ops.aten.linear.default,
210-
torch.ops.aten.native_batch_norm.default,
211193
torch.ops.aten.linalg_vector_norm.default,
212194
torch.ops.aten.unfold.default,
213195
torch.ops.aten.angle.default,
@@ -226,10 +208,9 @@ def export_to_cadence(
226208
model: torch.nn.Module,
227209
inputs: tuple[object, ...],
228210
dump_graphs: bool = False,
229-
output_dir: Optional[str] = None,
230211
opt_level: int = 1,
231212
) -> EdgeProgramManager:
232-
edge_prog_manager = export_to_edge(model, inputs)
213+
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
233214
cadence_passes = get_cadence_passes(opt_level)
234215

235216
# Run a couple required passes for quant/dequant ops

Diff for: backends/cadence/aot/export_example.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
2424
from executorch.backends.cadence.runtime import runtime
2525
from executorch.backends.cadence.runtime.executor import BundledProgramManager
26-
from executorch.exir import ExecutorchProgramManager
27-
from torch import nn
28-
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
29-
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
26+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
3027
QuantizationConfig,
3128
QuantizationSpec,
3229
)
30+
from executorch.exir import ExecutorchProgramManager
31+
from torch import nn
32+
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
3333

3434
from .utils import save_bpte_program, save_pte_program
3535

Diff for: backends/cadence/aot/quantizer/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ python_library(
3434
":patterns",
3535
":utils",
3636
"//caffe2:torch",
37+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer_utils",
3738
],
3839
)
3940

0 commit comments

Comments
 (0)