Skip to content

Commit be8f8d6

Browse files
committed
🎨 Consistently overwrite config values with calculated values when building Zebra models
1 parent 0d0ec38 commit be8f8d6

2 files changed

Lines changed: 7 additions & 16 deletions

File tree

ice_station_zebra/models/encode_process_decode.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,22 @@ def __init__(
2727
# Add one encoder per dataset
2828
self.encoders = [
2929
hydra.utils.instantiate(
30-
dict(
31-
{"input_space": input_space, "latent_space": latent_space_},
32-
**encoder,
33-
),
30+
dict(**encoder)
31+
| {"input_space": input_space, "latent_space": latent_space_}
3432
)
3533
for input_space in self.input_spaces
3634
]
3735

3836
# Add a processor
3937
self.processor = hydra.utils.instantiate(
40-
dict(
41-
{"n_latent_channels": latent_space_.channels * len(self.encoders)},
42-
**processor,
43-
),
38+
dict(**processor)
39+
| {"n_latent_channels": latent_space_.channels * len(self.encoders)}
4440
)
4541

4642
# Add a decoder
4743
self.decoder = hydra.utils.instantiate(
48-
dict(
49-
{"latent_space": latent_space_, "output_space": self.output_space},
50-
**decoder,
51-
),
44+
dict(**decoder)
45+
| {"latent_space": latent_space_, "output_space": self.output_space}
5246
)
5347

5448
# Register all modules that need to be trained

ice_station_zebra/models/zebra_model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ def forward(self, inputs: LightningBatch) -> torch.Tensor:
4646
def configure_optimizers(self) -> Optimizer:
4747
"""Construct the optimizer from the config"""
4848
return hydra.utils.instantiate(
49-
dict(
50-
{"params": self.model_list.parameters()},
51-
**self.optimizer_cfg,
52-
),
49+
dict(**self.optimizer_cfg) | {"params": self.model_list.parameters()}
5350
)
5451

5552
def loss(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)