-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathflops.py
More file actions
executable file
·31 lines (24 loc) · 1002 Bytes
/
flops.py
File metadata and controls
executable file
·31 lines (24 loc) · 1002 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#!/usr/bin/env python3
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
import timm
from deepspeed.accelerator import get_accelerator
from deepspeed.profiling.flops_profiler import get_model_profile
models = ['resnet18', 'resnet50', 'vit_small_patch16_224']
num_classes = 14
in_channels = 11
batch_size = 64
patch_size = 224
input_shape = (batch_size, in_channels, patch_size, patch_size)
for model in models:
print(f'Model: {model}')
m = timm.create_model(model, num_classes=num_classes, in_chans=in_channels)
# Calculate memory requirements of model
mem_params = sum([p.nelement() * p.element_size() for p in m.parameters()])
mem_bufs = sum([b.nelement() * b.element_size() for b in m.buffers()])
mem = (mem_params + mem_bufs) / 1000000
print(f'Memory: {mem:.2f} MB')
with get_accelerator().device(0):
get_model_profile(
model=m, input_shape=input_shape, detailed=False, module_depth=0
)