From 5ead1560c4e01c8c9435a33ea34789aaccd85fd1 Mon Sep 17 00:00:00 2001 From: 13015517713 <2430278602@qq.com> Date: Thu, 9 Mar 2023 16:24:04 +0800 Subject: [PATCH] fix up the problem in model_profiling when fed 'Modulelist' --- utils/model_profiling.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/utils/model_profiling.py b/utils/model_profiling.py index 361ace3..aad2c00 100644 --- a/utils/model_profiling.py +++ b/utils/model_profiling.py @@ -103,11 +103,23 @@ def module_profiling(self, input, output, verbose): self.n_params = 0 self.n_seconds = 0 num_children = 0 - for m in self.children(): + + def get_children(m): + children_list = [] + for child in m.children(): + if isinstance(child, nn.ModuleList): + children_list.extend(get_children(child)) + else: + children_list.append(child) + return children_list + + all_children = get_children(self) + num_children += len(all_children) + for m in all_children: self.n_macs += getattr(m, 'n_macs', 0) self.n_params += getattr(m, 'n_params', 0) self.n_seconds += getattr(m, 'n_seconds', 0) - num_children += 1 + ignore_zeros_t = [ nn.BatchNorm2d, nn.Dropout2d, nn.Dropout, nn.Sequential, nn.ReLU6, nn.ReLU, nn.MaxPool2d,