Skip to content

Commit 9ba0628

Browse files
authored
Merge pull request #337 from ego-thales/api-mode-none
* api: summary keeps current mode (#331) * fix: forgot to rm 'Mode' from error msg * dev: revert Mode Enum removal -> now 'same' in place of None * readme: Update summary doc ('same' mode)
2 parents 1281f91 + 48ee0a7 commit 9ba0628

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

README.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def summary(
104104
depth: int = 3,
105105
device: Optional[torch.device] = None,
106106
dtypes: Optional[List[torch.dtype]] = None,
107-
mode: str | None = None,
107+
mode: str = "same",
108108
row_settings: Optional[Iterable[str]] = None,
109109
verbose: int = 1,
110110
**kwargs: Any,
@@ -198,9 +198,10 @@ Args:
198198
Default: None
199199
200200
mode (str)
201-
Either "train" or "eval", which determines whether we call
202-
model.train() or model.eval() before calling summary().
203-
Default: "eval".
201+
Either "train", "eval" or "same", which determines whether we call
202+
model.train() or model.eval() before calling summary(). In any case,
203+
original model mode is restored at the end.
204+
Default: "same".
204205
205206
row_settings (Iterable[str]):
206207
Specify which features to show in a row. Currently supported: (

tests/torchinfo_xl_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_eval_order_doesnt_matter() -> None:
5757
model2 = torchvision.models.resnet18(
5858
weights=torchvision.models.ResNet18_Weights.DEFAULT
5959
)
60-
summary(model2, input_size=input_size)
60+
summary(model2, input_size=input_size, mode="eval")
6161
model2.eval()
6262
with torch.inference_mode():
6363
output2 = model2(input_tensor)
@@ -144,7 +144,7 @@ def test_tmva_net_column_totals() -> None:
144144
def test_google() -> None:
145145
google_net = torchvision.models.googlenet(init_weights=False)
146146

147-
summary(google_net, (1, 3, 112, 112), depth=7)
147+
summary(google_net, (1, 3, 112, 112), depth=7, mode="eval")
148148

149149
# Check googlenet in training mode since InceptionAux layers are used in
150150
# forward-prop in train mode but not in eval mode.

torchinfo/enums.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class Mode(str, Enum):
1111

1212
TRAIN = "train"
1313
EVAL = "eval"
14+
SAME = "same"
1415

1516

1617
@unique

torchinfo/torchinfo.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def summary(
6262
depth: int = 3,
6363
device: torch.device | str | None = None,
6464
dtypes: list[torch.dtype] | None = None,
65-
mode: str | None = None,
65+
mode: str = "same",
6666
row_settings: Iterable[str] | None = None,
6767
verbose: int | None = None,
6868
**kwargs: Any,
@@ -156,9 +156,10 @@ class name as the key. If the forward pass is an expensive operation,
156156
Default: None
157157
158158
mode (str)
159-
Either "train" or "eval", which determines whether we call
160-
model.train() or model.eval() before calling summary().
161-
Default: "eval".
159+
Either "train", "eval" or "same", which determines whether we call
160+
model.train() or model.eval() before calling summary(). In any case,
161+
original model mode is restored at the end.
162+
Default: "same".
162163
163164
row_settings (Iterable[str]):
164165
Specify which features to show in a row. Currently supported: (
@@ -198,10 +199,7 @@ class name as the key. If the forward pass is an expensive operation,
198199
else:
199200
rows = {RowSettings(name) for name in row_settings}
200201

201-
if mode is None:
202-
model_mode = Mode.EVAL
203-
else:
204-
model_mode = Mode(mode)
202+
model_mode = Mode(mode)
205203

206204
if verbose is None:
207205
verbose = 0 if hasattr(sys, "ps1") and sys.ps1 else 1
@@ -286,7 +284,7 @@ def forward_pass(
286284
model.train()
287285
elif mode == Mode.EVAL:
288286
model.eval()
289-
else:
287+
elif mode != Mode.SAME:
290288
raise RuntimeError(
291289
f"Specified model mode ({list(Mode)}) not recognized: {mode}"
292290
)

0 commit comments

Comments
 (0)