1
1
import importlib .util
2
2
import os
3
+ import re
3
4
import subprocess
4
5
import sys
5
6
import unittest
6
7
from unittest .mock import mock_open , patch
7
8
8
9
import pytest
9
10
10
- # Dynamically import the script
11
- script_path = os .path .join (".ci" , "scripts" , "gather_benchmark_configs.py" )
12
- spec = importlib .util .spec_from_file_location ("gather_benchmark_configs" , script_path )
13
- gather_benchmark_configs = importlib .util .module_from_spec (spec )
14
- spec .loader .exec_module (gather_benchmark_configs )
15
-
16
11
17
12
@pytest .mark .skipif (
18
13
sys .platform != "linux" , reason = "The script under test runs on Linux runners only"
19
14
)
20
15
class TestGatehrBenchmarkConfigs (unittest .TestCase ):
21
16
17
+ @classmethod
18
+ def setUpClass (cls ):
19
+ # Dynamically import the script
20
+ script_path = os .path .join (".ci" , "scripts" , "gather_benchmark_configs.py" )
21
+ spec = importlib .util .spec_from_file_location (
22
+ "gather_benchmark_configs" , script_path
23
+ )
24
+ cls .gather_benchmark_configs = importlib .util .module_from_spec (spec )
25
+ spec .loader .exec_module (cls .gather_benchmark_configs )
26
+
22
27
def test_extract_all_configs_android (self ):
23
- android_configs = gather_benchmark_configs .extract_all_configs (
24
- gather_benchmark_configs .BENCHMARK_CONFIGS , "android"
28
+ android_configs = self . gather_benchmark_configs .extract_all_configs (
29
+ self . gather_benchmark_configs .BENCHMARK_CONFIGS , "android"
25
30
)
26
31
self .assertIn ("xnnpack_q8" , android_configs )
27
32
self .assertIn ("qnn_q8" , android_configs )
28
33
self .assertIn ("llama3_spinquant" , android_configs )
29
34
self .assertIn ("llama3_qlora" , android_configs )
30
35
31
36
def test_extract_all_configs_ios (self ):
32
- ios_configs = gather_benchmark_configs .extract_all_configs (
33
- gather_benchmark_configs .BENCHMARK_CONFIGS , "ios"
37
+ ios_configs = self . gather_benchmark_configs .extract_all_configs (
38
+ self . gather_benchmark_configs .BENCHMARK_CONFIGS , "ios"
34
39
)
35
40
36
41
self .assertIn ("xnnpack_q8" , ios_configs )
@@ -40,51 +45,114 @@ def test_extract_all_configs_ios(self):
40
45
self .assertIn ("llama3_spinquant" , ios_configs )
41
46
self .assertIn ("llama3_qlora" , ios_configs )
42
47
48
+ def test_skip_disabled_configs (self ):
49
+ # Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
50
+ with patch .dict (
51
+ self .gather_benchmark_configs .DISABLED_CONFIGS ,
52
+ {
53
+ "mv3" : [
54
+ self .gather_benchmark_configs .DisabledConfig (
55
+ config_name = "disabled_config1" ,
56
+ github_issue = "https://github.com/org/repo/issues/123" ,
57
+ ),
58
+ self .gather_benchmark_configs .DisabledConfig (
59
+ config_name = "disabled_config2" ,
60
+ github_issue = "https://github.com/org/repo/issues/124" ,
61
+ ),
62
+ ]
63
+ },
64
+ ), patch .dict (
65
+ self .gather_benchmark_configs .BENCHMARK_CONFIGS ,
66
+ {
67
+ "ios" : [
68
+ "disabled_config1" ,
69
+ "disabled_config2" ,
70
+ "enabled_config1" ,
71
+ "enabled_config2" ,
72
+ ]
73
+ },
74
+ ):
75
+ result = self .gather_benchmark_configs .generate_compatible_configs (
76
+ "mv3" , target_os = "ios"
77
+ )
78
+
79
+ # Assert that disabled configs are excluded
80
+ self .assertNotIn ("disabled_config1" , result )
81
+ self .assertNotIn ("disabled_config2" , result )
82
+ # Assert enabled configs are included
83
+ self .assertIn ("enabled_config1" , result )
84
+ self .assertIn ("enabled_config2" , result )
85
+
86
+ def test_disabled_configs_have_github_links (self ):
87
+ github_issue_regex = re .compile (r"https://github\.com/.+/.+/issues/\d+" )
88
+
89
+ for (
90
+ model_name ,
91
+ disabled_configs ,
92
+ ) in self .gather_benchmark_configs .DISABLED_CONFIGS .items ():
93
+ for disabled in disabled_configs :
94
+ with self .subTest (model_name = model_name , config = disabled .config_name ):
95
+ # Assert that disabled is an instance of DisabledConfig
96
+ self .assertIsInstance (
97
+ disabled , self .gather_benchmark_configs .DisabledConfig
98
+ )
99
+
100
+ # Assert that github_issue is provided and matches the expected pattern
101
+ self .assertTrue (
102
+ disabled .github_issue
103
+ and github_issue_regex .match (disabled .github_issue ),
104
+ f"Invalid or missing GitHub issue link for '{ disabled .config_name } ' in model '{ model_name } '." ,
105
+ )
106
+
43
107
def test_generate_compatible_configs_llama_model (self ):
44
108
model_name = "meta-llama/Llama-3.2-1B"
45
109
target_os = "ios"
46
- result = gather_benchmark_configs .generate_compatible_configs (
110
+ result = self . gather_benchmark_configs .generate_compatible_configs (
47
111
model_name , target_os
48
112
)
49
113
expected = ["llama3_fb16" , "llama3_coreml_ane" ]
50
114
self .assertEqual (result , expected )
51
115
52
116
target_os = "android"
53
- result = gather_benchmark_configs .generate_compatible_configs (
117
+ result = self . gather_benchmark_configs .generate_compatible_configs (
54
118
model_name , target_os
55
119
)
56
120
expected = ["llama3_fb16" ]
57
121
self .assertEqual (result , expected )
58
122
59
123
def test_generate_compatible_configs_quantized_llama_model (self ):
60
124
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61
- result = gather_benchmark_configs .generate_compatible_configs (model_name , None )
125
+ result = self .gather_benchmark_configs .generate_compatible_configs (
126
+ model_name , None
127
+ )
62
128
expected = ["llama3_spinquant" ]
63
129
self .assertEqual (result , expected )
64
130
65
131
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66
- result = gather_benchmark_configs .generate_compatible_configs (model_name , None )
132
+ result = self .gather_benchmark_configs .generate_compatible_configs (
133
+ model_name , None
134
+ )
67
135
expected = ["llama3_qlora" ]
68
136
self .assertEqual (result , expected )
69
137
70
138
def test_generate_compatible_configs_non_genai_model (self ):
71
139
model_name = "mv2"
72
140
target_os = "xplat"
73
- result = gather_benchmark_configs .generate_compatible_configs (
141
+ result = self . gather_benchmark_configs .generate_compatible_configs (
74
142
model_name , target_os
75
143
)
76
144
expected = ["xnnpack_q8" ]
77
145
self .assertEqual (result , expected )
78
146
79
147
target_os = "android"
80
- result = gather_benchmark_configs .generate_compatible_configs (
148
+ result = self . gather_benchmark_configs .generate_compatible_configs (
81
149
model_name , target_os
82
150
)
83
151
expected = ["xnnpack_q8" , "qnn_q8" ]
84
152
self .assertEqual (result , expected )
85
153
86
154
target_os = "ios"
87
- result = gather_benchmark_configs .generate_compatible_configs (
155
+ result = self . gather_benchmark_configs .generate_compatible_configs (
88
156
model_name , target_os
89
157
)
90
158
expected = ["xnnpack_q8" , "coreml_fp16" , "mps" ]
@@ -93,22 +161,22 @@ def test_generate_compatible_configs_non_genai_model(self):
93
161
def test_generate_compatible_configs_unknown_model (self ):
94
162
model_name = "unknown_model"
95
163
target_os = "ios"
96
- result = gather_benchmark_configs .generate_compatible_configs (
164
+ result = self . gather_benchmark_configs .generate_compatible_configs (
97
165
model_name , target_os
98
166
)
99
167
self .assertEqual (result , [])
100
168
101
169
def test_is_valid_huggingface_model_id_valid (self ):
102
170
valid_model = "meta-llama/Llama-3.2-1B"
103
171
self .assertTrue (
104
- gather_benchmark_configs .is_valid_huggingface_model_id (valid_model )
172
+ self . gather_benchmark_configs .is_valid_huggingface_model_id (valid_model )
105
173
)
106
174
107
175
@patch ("builtins.open" , new_callable = mock_open )
108
176
@patch ("os.getenv" , return_value = None )
109
177
def test_set_output_no_github_env (self , mock_getenv , mock_file ):
110
178
with patch ("builtins.print" ) as mock_print :
111
- gather_benchmark_configs .set_output ("test_name" , "test_value" )
179
+ self . gather_benchmark_configs .set_output ("test_name" , "test_value" )
112
180
mock_print .assert_called_with ("::set-output name=test_name::test_value" )
113
181
114
182
def test_device_pools_contains_all_devices (self ):
@@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self):
120
188
"google_pixel_8_pro" ,
121
189
]
122
190
for device in expected_devices :
123
- self .assertIn (device , gather_benchmark_configs .DEVICE_POOLS )
191
+ self .assertIn (device , self . gather_benchmark_configs .DEVICE_POOLS )
124
192
125
193
def test_gather_benchmark_configs_cli (self ):
126
194
args = {
0 commit comments