Skip to content

Commit 2ed2b82

Browse files
committed
🐛 Use the implicit self._modules instead of a manual ModuleList to keep track of child modules. This fixes the incorrectly displayed summary table
1 parent f3e7616 commit 2ed2b82

2 files changed

Lines changed: 12 additions & 10 deletions

File tree

ice_station_zebra/models/encode_process_decode.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def __init__(
3232
)
3333
for input_space in self.input_spaces
3434
]
35+
# We have to explicitly register each encoder as list[Module] will not be
36+
# automatically picked up by PyTorch
37+
for idx, encoder in enumerate(self.encoders):
38+
self.add_module(f"encoder_{idx}", encoder)
3539

3640
# Add a processor
3741
self.processor = hydra.utils.instantiate(
@@ -45,11 +49,6 @@ def __init__(
4549
| {"latent_space": latent_space_, "output_space": self.output_space}
4650
)
4751

48-
# Register all modules that need to be trained
49-
self.model_list.extend(self.encoders)
50-
self.model_list.append(self.processor)
51-
self.model_list.append(self.decoder)
52-
5352
def forward(self, inputs: LightningBatch) -> torch.Tensor:
5453
"""Forward step of the model
5554

ice_station_zebra/models/zebra_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import itertools
12
from abc import ABC, abstractmethod
3+
from collections.abc import Iterator
24

35
import hydra
46
import torch
5-
import torch.nn as nn
67
from lightning import LightningModule
78
from omegaconf import DictConfig
89
from torch.optim import Optimizer
@@ -28,9 +29,6 @@ def __init__(
2829
self.input_spaces = [DataSpace.from_dict(space) for space in input_spaces]
2930
self.output_space = DataSpace.from_dict(output_space)
3031

31-
# Initialise an empty module list
32-
self.model_list = nn.ModuleList()
33-
3432
# Store the optimizer config
3533
self.optimizer_cfg = optimizer
3634

@@ -46,12 +44,17 @@ def forward(self, inputs: LightningBatch) -> torch.Tensor:
4644
def configure_optimizers(self) -> Optimizer:
4745
"""Construct the optimizer from the config"""
4846
return hydra.utils.instantiate(
49-
dict(**self.optimizer_cfg) | {"params": self.model_list.parameters()}
47+
dict(**self.optimizer_cfg) | {"params": self.parameters()}
5048
)
5149

5250
def loss(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
5351
return torch.nn.functional.l1_loss(output, target)
5452

53+
def parameters(self) -> Iterator[torch.nn.Parameter]:
54+
return itertools.chain(
55+
*[module.parameters() for module in self._modules.values()]
56+
)
57+
5558
def test_step(
5659
self, batch: LightningBatch, batch_idx: int
5760
) -> dict[str, torch.Tensor]:

0 commit comments

Comments
 (0)