-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_existing_model.py
More file actions
465 lines (372 loc) · 17.7 KB
/
test_existing_model.py
File metadata and controls
465 lines (372 loc) · 17.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
#!/usr/bin/env python3
"""
Simple test script to compare existing ONNX models with and without simplification.
This script takes an existing ONNX model, simplifies it, and compares:
- File sizes
- Node counts
- Inference performance
- Output accuracy
Usage:
python test_existing_model.py --model path/to/model.onnx
"""
import argparse
import os
import time
import numpy as np
import onnx
import onnxruntime as ort
import onnxsim
from pathlib import Path
from collections import defaultdict
def analyze_model_info(model_path):
"""Analyze ONNX model and return detailed information."""
model = onnx.load(model_path)
# Get file size
file_size = os.path.getsize(model_path) / (1024 * 1024) # MB
# Count nodes
num_nodes = len(model.graph.node)
# Detailed parameter analysis
param_info = {}
total_params = 0
param_by_type = defaultdict(int)
param_by_shape = defaultdict(int)
for init in model.graph.initializer:
param_size = np.prod(init.dims) if init.dims else 1
total_params += param_size
# Get data type
dtype_map = {
1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16',
6: 'int32', 7: 'int64', 8: 'string', 9: 'bool', 10: 'float16',
11: 'double', 12: 'uint32', 13: 'uint64'
}
dtype = dtype_map.get(init.data_type, f'unknown({init.data_type})')
param_info[init.name] = {
'shape': list(init.dims) if init.dims else [],
'size': param_size,
'dtype': dtype
}
param_by_type[dtype] += param_size
shape_str = 'x'.join(map(str, init.dims)) if init.dims else 'scalar'
param_by_shape[shape_str] += 1
# Get unique operator types
op_types = set(node.op_type for node in model.graph.node)
op_counts = defaultdict(int)
for node in model.graph.node:
op_counts[node.op_type] += 1
# Get input/output info
inputs = [(inp.name, [d.dim_value for d in inp.type.tensor_type.shape.dim])
for inp in model.graph.input]
outputs = [(out.name, [d.dim_value for d in out.type.tensor_type.shape.dim])
for out in model.graph.output]
return {
'file_size_mb': file_size,
'num_nodes': num_nodes,
'num_params': total_params,
'num_param_tensors': len(model.graph.initializer),
'num_op_types': len(op_types),
'op_types': sorted(op_types),
'op_counts': dict(op_counts),
'param_info': param_info,
'param_by_type': dict(param_by_type),
'param_by_shape': dict(param_by_shape),
'inputs': inputs,
'outputs': outputs
}
def analyze_parameter_changes(original_info, simplified_info):
"""Analyze why parameter count changed."""
orig_params = original_info['param_info']
simp_params = simplified_info['param_info']
# Find parameter changes
removed_params = set(orig_params.keys()) - set(simp_params.keys())
added_params = set(simp_params.keys()) - set(orig_params.keys())
common_params = set(orig_params.keys()) & set(simp_params.keys())
print(f"\n🔍 DETAILED PARAMETER ANALYSIS:")
print(f"{'='*60}")
print(f"Original parameters: {original_info['num_param_tensors']} tensors, {original_info['num_params']:,} elements")
print(f"Simplified parameters: {simplified_info['num_param_tensors']} tensors, {simplified_info['num_params']:,} elements")
param_change = simplified_info['num_params'] - original_info['num_params']
tensor_change = simplified_info['num_param_tensors'] - original_info['num_param_tensors']
print(f"Parameter change: {param_change:+,} elements ({param_change/original_info['num_params']*100:+.2f}%)")
print(f"Tensor change: {tensor_change:+d} tensors")
# Analyze removed parameters
if removed_params:
print(f"\n❌ Removed Parameters ({len(removed_params)} tensors):")
removed_size = 0
for param_name in sorted(removed_params)[:5]:
param = orig_params[param_name]
removed_size += param['size']
print(f" - {param_name}: {param['shape']} ({param['size']:,} elements, {param['dtype']})")
if len(removed_params) > 5:
print(f" ... and {len(removed_params) - 5} more")
print(f" Total removed: {removed_size:,} elements")
# Analyze added parameters
if added_params:
print(f"\n➕ Added Parameters ({len(added_params)} tensors):")
added_size = 0
for param_name in sorted(added_params)[:5]:
param = simp_params[param_name]
added_size += param['size']
print(f" + {param_name}: {param['shape']} ({param['size']:,} elements, {param['dtype']})")
if len(added_params) > 5:
print(f" ... and {len(added_params) - 5} more")
print(f" Total added: {added_size:,} elements")
# Analyze changed parameters (same name, different size)
changed_params = []
for param_name in common_params:
orig_size = orig_params[param_name]['size']
simp_size = simp_params[param_name]['size']
if orig_size != simp_size:
changed_params.append((param_name, orig_size, simp_size))
if changed_params:
print(f"\n📝 Modified Parameters ({len(changed_params)} tensors):")
for param_name, orig_size, simp_size in changed_params[:5]:
size_change = simp_size - orig_size
orig_shape = orig_params[param_name]['shape']
simp_shape = simp_params[param_name]['shape']
print(f" ~ {param_name}: {orig_shape} → {simp_shape} ({size_change:+,} elements)")
if len(changed_params) > 5:
print(f" ... and {len(changed_params) - 5} more")
# Analyze data type changes
print(f"\n📊 Parameter Distribution by Data Type:")
print(f"{'Type':<12} {'Original':<15} {'Simplified':<15} {'Change':<10}")
print("-" * 55)
all_types = set(original_info['param_by_type'].keys()) | set(simplified_info['param_by_type'].keys())
for dtype in sorted(all_types):
orig_count = original_info['param_by_type'].get(dtype, 0)
simp_count = simplified_info['param_by_type'].get(dtype, 0)
change = simp_count - orig_count
print(f"{dtype:<12} {orig_count:<15,} {simp_count:<15,} {change:+10,}")
# Analyze operator changes
print(f"\n🔧 Operator Changes:")
orig_ops = original_info['op_counts']
simp_ops = simplified_info['op_counts']
all_ops = set(orig_ops.keys()) | set(simp_ops.keys())
for op_type in sorted(all_ops):
orig_count = orig_ops.get(op_type, 0)
simp_count = simp_ops.get(op_type, 0)
if orig_count != simp_count:
change = simp_count - orig_count
print(f" {op_type}: {orig_count} → {simp_count} ({change:+d})")
return {
'removed_params': removed_params,
'added_params': added_params,
'changed_params': changed_params,
'param_change': param_change,
'tensor_change': tensor_change
}
def simplify_onnx_model(input_path, output_path):
"""Simplify ONNX model using onnxsim."""
print(f"Simplifying ONNX model: {input_path} -> {output_path}...")
try:
# Load the model
model = onnx.load(input_path)
# Simplify the model
simplified_model, check = onnxsim.simplify(
model,
check_n=3, # Check with 3 different inputs
perform_optimization=True
)
if check:
onnx.save(simplified_model, output_path)
print(f"✓ Simplified model saved to {output_path}")
return True
else:
print("✗ Model simplification validation failed!")
return False
except Exception as e:
print(f"✗ Error during simplification: {e}")
return False
# ...existing code...
def print_comparison_table(original_info, simplified_info, original_time, simplified_time):
"""Print a comparison table of model statistics."""
print("\n" + "="*80)
print(" MODEL COMPARISON")
print("="*80)
print(f"{'Metric':<30} {'Original':<20} {'Simplified':<20} {'Change':<10}")
print("-"*80)
# File size
size_change = (simplified_info['file_size_mb'] / original_info['file_size_mb'] - 1) * 100
print(f"{'File Size (MB)':<30} {original_info['file_size_mb']:<20.2f} "
f"{simplified_info['file_size_mb']:<20.2f} {size_change:<+10.1f}%")
# Number of nodes
node_change = (simplified_info['num_nodes'] / original_info['num_nodes'] - 1) * 100
print(f"{'Number of Nodes':<30} {original_info['num_nodes']:<20} "
f"{simplified_info['num_nodes']:<20} {node_change:<+10.1f}%")
# Number of parameters
param_change = (simplified_info['num_params'] / max(original_info['num_params'], 1) - 1) * 100
print(f"{'Number of Parameters':<30} {original_info['num_params']:<20,} "
f"{simplified_info['num_params']:<20,} {param_change:<+10.1f}%")
# Number of parameter tensors
tensor_change = (simplified_info['num_param_tensors'] / max(original_info['num_param_tensors'], 1) - 1) * 100
print(f"{'Parameter Tensors':<30} {original_info['num_param_tensors']:<20} "
f"{simplified_info['num_param_tensors']:<20} {tensor_change:<+10.1f}%")
# Number of operator types
op_change = (simplified_info['num_op_types'] / original_info['num_op_types'] - 1) * 100
print(f"{'Operator Types':<30} {original_info['num_op_types']:<20} "
f"{simplified_info['num_op_types']:<20} {op_change:<+10.1f}%")
# Inference time
time_change = (simplified_time / original_time - 1) * 100
print(f"{'Inference Time (ms)':<30} {original_time:<20.2f} "
f"{simplified_time:<20.2f} {time_change:<+10.1f}%")
print("-"*80)
def test_model_simplification(model_path):
"""Test simplification of a given ONNX model."""
model_path = Path(model_path)
if not model_path.exists():
print(f"Error: Model file {model_path} does not exist!")
return
print(f"Testing ONNX model: {model_path.name}")
print("=" * 60)
# Create output paths
output_dir = model_path.parent / "simplification_test"
output_dir.mkdir(exist_ok=True)
simplified_path = output_dir / f"{model_path.stem}_simplified.onnx"
# 1. Analyze original model
print("\n1. Analyzing original model...")
original_info = analyze_model_info(model_path)
print(f"✓ Original model: {original_info['num_nodes']} nodes, "
f"{original_info['file_size_mb']:.2f} MB, {original_info['num_params']:,} parameters")
print(f" Inputs: {original_info['inputs']}")
print(f" Outputs: {original_info['outputs']}")
# 2. Simplify model
print("\n2. Simplifying model...")
success = simplify_onnx_model(str(model_path), str(simplified_path))
if not success:
print("Simplification failed!")
return
# 3. Analyze simplified model
print("\n3. Analyzing simplified model...")
simplified_info = analyze_model_info(simplified_path)
print(f"✓ Simplified model: {simplified_info['num_nodes']} nodes, "
f"{simplified_info['file_size_mb']:.2f} MB, {simplified_info['num_params']:,} parameters")
# 4. Analyze parameter changes
analyze_parameter_changes(original_info, simplified_info)
# ...existing code continues with steps 4-8...
# 5. Create test inputs
print("\n5. Creating test inputs...")
try:
test_inputs = create_dummy_inputs(model_path)
print(f"✓ Created {len(test_inputs)} test inputs")
for name, arr in test_inputs.items():
print(f" {name}: {arr.shape}")
except Exception as e:
print(f"✗ Error creating test inputs: {e}")
return
# 6. Create ONNX Runtime sessions
print("\n6. Creating ONNXRuntime sessions...")
try:
original_session = ort.InferenceSession(str(model_path),
providers=['CPUExecutionProvider'])
simplified_session = ort.InferenceSession(str(simplified_path),
providers=['CPUExecutionProvider'])
print("✓ Sessions created successfully")
except Exception as e:
print(f"✗ Error creating sessions: {e}")
return
# 7. Benchmark and compare
print("\n7. Running inference tests...")
try:
original_time, original_outputs = benchmark_model(original_session, test_inputs)
simplified_time, simplified_outputs = benchmark_model(simplified_session, test_inputs)
print(f"✓ Original model: {original_time:.2f} ms/inference")
print(f"✓ Simplified model: {simplified_time:.2f} ms/inference")
# Compare outputs
outputs_match, diff_stats = compare_outputs(original_outputs, simplified_outputs)
if outputs_match:
print("✓ Model outputs match within tolerance!")
else:
print("⚠ Model outputs differ:")
print(f" Max absolute difference: {diff_stats['max_absolute_diff']:.2e}")
print(f" Mean absolute difference: {diff_stats['mean_absolute_diff']:.2e}")
except Exception as e:
print(f"✗ Error during inference testing: {e}")
return
# 8. Print comparison
print_comparison_table(original_info, simplified_info, original_time, simplified_time)
# 9. Summary
print("\n" + "="*80)
print(" SUMMARY")
print("="*80)
size_change = (simplified_info['file_size_mb'] / original_info['file_size_mb'] - 1) * 100
node_change = (simplified_info['num_nodes'] / original_info['num_nodes'] - 1) * 100
param_change = (simplified_info['num_params'] / original_info['num_params'] - 1) * 100
time_change = (simplified_time / original_time - 1) * 100
print(f"File size change: {size_change:+.1f}%")
print(f"Node count change: {node_change:+.1f}%")
print(f"Parameter count change: {param_change:+.1f}%")
print(f"Inference time change: {time_change:+.1f}%")
print(f"Output accuracy: {'Preserved' if outputs_match else 'May differ'}")
if param_change > 0:
print(f"\n💡 Parameter count increased - this can happen due to:")
print(f" • Constant propagation creating new explicit parameters")
print(f" • Broadcasting elimination (expanding implicit broadcasts)")
print(f" • Operator decomposition (complex ops → simpler ops + params)")
print(f" • Precision normalization (mixed types → consistent types)")
print(f"\nFiles:")
print(f" Original: {model_path}")
print(f" Simplified: {simplified_path}")
def create_dummy_inputs(model_path):
"""Create dummy inputs based on model input shapes."""
model = onnx.load(model_path)
inputs = {}
for inp in model.graph.input:
# Skip if it's an initializer (constant)
if any(init.name == inp.name for init in model.graph.initializer):
continue
# Get shape
shape = []
for dim in inp.type.tensor_type.shape.dim:
if dim.dim_value > 0:
shape.append(dim.dim_value)
else:
# Dynamic dimension, use a reasonable default
shape.append(1)
# Create random input
if shape:
inputs[inp.name] = np.random.randn(*shape).astype(np.float32)
else:
# Scalar input
inputs[inp.name] = np.array([1.0], dtype=np.float32)
return inputs
def benchmark_model(session, input_dict, num_runs=5, warmup_runs=2):
"""Benchmark ONNX model inference time."""
# Warmup runs
for _ in range(warmup_runs):
session.run(None, input_dict)
# Benchmark runs
start_time = time.time()
for _ in range(num_runs):
outputs = session.run(None, input_dict)
end_time = time.time()
avg_time_ms = (end_time - start_time) / num_runs * 1000
return avg_time_ms, outputs
def compare_outputs(outputs1, outputs2, tolerance=1e-5):
"""Compare two sets of model outputs."""
if len(outputs1) != len(outputs2):
return False, "Different number of outputs"
all_match = True
max_diff = 0
mean_diff = 0
total_elements = 0
for i, (out1, out2) in enumerate(zip(outputs1, outputs2)):
if out1.shape != out2.shape:
return False, f"Output {i} shape mismatch: {out1.shape} vs {out2.shape}"
abs_diff = np.abs(out1 - out2)
max_diff = max(max_diff, np.max(abs_diff))
mean_diff += np.sum(abs_diff)
total_elements += out1.size
if not np.allclose(out1, out2, atol=tolerance, rtol=tolerance):
all_match = False
mean_diff /= total_elements
return all_match, {
'max_absolute_diff': max_diff,
'mean_absolute_diff': mean_diff,
'num_outputs': len(outputs1)
}
def main():
parser = argparse.ArgumentParser(description="Test ONNX model simplification")
parser.add_argument("--model", required=True, help="Path to ONNX model file")
args = parser.parse_args()
test_model_simplification(args.model)
if __name__ == "__main__":
main()