1+ import lightning as L
12from the_well .data .normalization import ZScoreNormalization
23
34from autocast .types .batch import Batch
45from autocast .types .types import Tensor
56
67
7- class DenormMixin :
8+ class DenormMixin ( L . LightningModule ) :
89 """
910 Mixin class to provide denormalization functionality for models.
1011
11- Taken from The Well Trainer.denormalize(), see:
12+ Based on The Well Trainer.denormalize(), see:
1213 https://github.com/PolymathicAI/the_well/blob/6cd3c44ef832855a5abae87d555bf0f0f52b1fa7/the_well/benchmark/trainer/training.py#L190
1314 """
1415
16+ norm : ZScoreNormalization | None = None
17+ denormalize_predictions : bool = True
18+
19+ def on_fit_start (self ):
20+ """Automatically connect to datamodule's normalizer at training start."""
21+ self ._connect_normalizer ()
22+ # Call parent hook if it exists (for multiple inheritance)
23+ if hasattr (super (), "on_fit_start" ):
24+ super ().on_fit_start ()
25+
26+ def on_predict_start (self ):
27+ """Automatically connect to datamodule's normalizer at prediction start."""
28+ self ._connect_normalizer ()
29+ # Call parent hook if it exists (for multiple inheritance)
30+ if hasattr (super (), "on_predict_start" ):
31+ super ().on_predict_start ()
32+
33+ def _connect_normalizer (self ):
34+ """
35+ Helper to connect to datamodule's normalizer.
36+
37+ Looks for the normalizer in trainer.datamodule.train_dataset.norm
38+ and sets self.normalizer if found.
39+ """
40+ if not hasattr (self , "trainer" ):
41+ return
42+
43+ if hasattr (self .trainer , "datamodule" ):
44+ datamodule = self .trainer .datamodule
45+ if hasattr (datamodule , "train_dataset" ) and hasattr (
46+ datamodule .train_dataset , "norm"
47+ ):
48+ self .normalizer = datamodule .train_dataset .norm
49+
1550 def denormalize_batch (
1651 self ,
1752 batch : Batch ,
18- norm : ZScoreNormalization ,
1953 ) -> Batch :
2054 """
2155 Denormalize the input batch.
@@ -24,20 +58,23 @@ def denormalize_batch(
2458 ----------
2559 batch : Batch
2660 The input batch containing normalized data.
27- norm : type[ZScoreNormalization]
28- The normalization class used for denormalization.
2961
3062 Returns
3163 -------
3264 Batch
3365 The denormalized batch.
3466 """
67+ if self .norm is None :
68+ return batch
69+
3570 return Batch (
36- input_fields = norm .denormalize_flattened (batch .input_fields , "variable" ),
71+ input_fields = self .norm .denormalize_flattened (
72+ batch .input_fields , "variable"
73+ ),
3774 output_fields = batch .output_fields ,
3875 constant_scalars = batch .constant_scalars ,
3976 constant_fields = (
40- norm .denormalize_flattened (batch .constant_fields , "constant" )
77+ self . norm .denormalize_flattened (batch .constant_fields , "constant" )
4178 if batch .constant_fields
4279 else None
4380 ),
@@ -46,7 +83,6 @@ def denormalize_batch(
4683 def denormalize_tensor (
4784 self ,
4885 tensor : Tensor ,
49- norm : ZScoreNormalization ,
5086 delta = False ,
5187 ) -> Tensor :
5288 """
@@ -56,8 +92,6 @@ def denormalize_tensor(
5692 ----------
5793 tensor : Tensor
5894 The normalized tensor to be denormalized.
59- norm : type[ZScoreNormalization]
60- The normalization class used for denormalization.
6195 delta : bool, optional
6296 Whether to apply delta denormalization. Default is False.
6397
@@ -66,9 +100,12 @@ def denormalize_tensor(
66100 Tensor
67101 The denormalized tensor.
68102 """
103+ if self .norm is None :
104+ return tensor
105+
69106 if delta :
70- denorm_tensor = norm .delta_denormalize_flattened (tensor , "variable" )
107+ denorm_tensor = self . norm .delta_denormalize_flattened (tensor , "variable" )
71108 else :
72- denorm_tensor = norm .denormalize_flattened (tensor , "variable" )
109+ denorm_tensor = self . norm .denormalize_flattened (tensor , "variable" )
73110
74111 return denorm_tensor
0 commit comments