Skip to content

Commit a4018a1

Browse files
committed
make DenormMixin a lightning module
1 parent 70f6bef commit a4018a1

1 file changed

Lines changed: 49 additions & 12 deletions

File tree

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,55 @@
1+
import lightning as L
12
from the_well.data.normalization import ZScoreNormalization
23

34
from autocast.types.batch import Batch
45
from autocast.types.types import Tensor
56

67

7-
class DenormMixin:
8+
class DenormMixin(L.LightningModule):
89
"""
910
Mixin class to provide denormalization functionality for models.
1011
11-
Taken from The Well Trainer.denormalize(), see:
12+
Based on The Well Trainer.denormalize(), see:
1213
https://github.com/PolymathicAI/the_well/blob/6cd3c44ef832855a5abae87d555bf0f0f52b1fa7/the_well/benchmark/trainer/training.py#L190
1314
"""
1415

16+
norm: ZScoreNormalization | None = None
17+
denormalize_predictions: bool = True
18+
19+
def on_fit_start(self):
20+
"""Automatically connect to datamodule's normalizer at training start."""
21+
self._connect_normalizer()
22+
# Call parent hook if it exists (for multiple inheritance)
23+
if hasattr(super(), "on_fit_start"):
24+
super().on_fit_start()
25+
26+
def on_predict_start(self):
27+
"""Automatically connect to datamodule's normalizer at prediction start."""
28+
self._connect_normalizer()
29+
# Call parent hook if it exists (for multiple inheritance)
30+
if hasattr(super(), "on_predict_start"):
31+
super().on_predict_start()
32+
33+
def _connect_normalizer(self):
34+
"""
35+
Helper to connect to datamodule's normalizer.
36+
37+
Looks for the normalizer in trainer.datamodule.train_dataset.norm
38+
and sets self.normalizer if found.
39+
"""
40+
if not hasattr(self, "trainer"):
41+
return
42+
43+
if hasattr(self.trainer, "datamodule"):
44+
datamodule = self.trainer.datamodule
45+
if hasattr(datamodule, "train_dataset") and hasattr(
46+
datamodule.train_dataset, "norm"
47+
):
48+
self.normalizer = datamodule.train_dataset.norm
49+
1550
def denormalize_batch(
1651
self,
1752
batch: Batch,
18-
norm: ZScoreNormalization,
1953
) -> Batch:
2054
"""
2155
Denormalize the input batch.
@@ -24,20 +58,23 @@ def denormalize_batch(
2458
----------
2559
batch : Batch
2660
The input batch containing normalized data.
27-
norm : type[ZScoreNormalization]
28-
The normalization class used for denormalization.
2961
3062
Returns
3163
-------
3264
Batch
3365
The denormalized batch.
3466
"""
67+
if self.norm is None:
68+
return batch
69+
3570
return Batch(
36-
input_fields=norm.denormalize_flattened(batch.input_fields, "variable"),
71+
input_fields=self.norm.denormalize_flattened(
72+
batch.input_fields, "variable"
73+
),
3774
output_fields=batch.output_fields,
3875
constant_scalars=batch.constant_scalars,
3976
constant_fields=(
40-
norm.denormalize_flattened(batch.constant_fields, "constant")
77+
self.norm.denormalize_flattened(batch.constant_fields, "constant")
4178
if batch.constant_fields
4279
else None
4380
),
@@ -46,7 +83,6 @@ def denormalize_batch(
4683
def denormalize_tensor(
4784
self,
4885
tensor: Tensor,
49-
norm: ZScoreNormalization,
5086
delta=False,
5187
) -> Tensor:
5288
"""
@@ -56,8 +92,6 @@ def denormalize_tensor(
5692
----------
5793
tensor : Tensor
5894
The normalized tensor to be denormalized.
59-
norm : type[ZScoreNormalization]
60-
The normalization class used for denormalization.
6195
delta : bool, optional
6296
Whether to apply delta denormalization. Default is False.
6397
@@ -66,9 +100,12 @@ def denormalize_tensor(
66100
Tensor
67101
The denormalized tensor.
68102
"""
103+
if self.norm is None:
104+
return tensor
105+
69106
if delta:
70-
denorm_tensor = norm.delta_denormalize_flattened(tensor, "variable")
107+
denorm_tensor = self.norm.delta_denormalize_flattened(tensor, "variable")
71108
else:
72-
denorm_tensor = norm.denormalize_flattened(tensor, "variable")
109+
denorm_tensor = self.norm.denormalize_flattened(tensor, "variable")
73110

74111
return denorm_tensor

0 commit comments

Comments
 (0)