2222
2323from __future__ import annotations
2424
25+ import warnings
2526from abc import ABC , abstractmethod
2627from collections import OrderedDict
2728
@@ -511,12 +512,10 @@ class StratifiedStandardize(Standardize):
511512 def __init__ (
512513 self ,
513514 stratification_idx : int ,
514- observed_task_values : Tensor ,
515515 all_task_values : Tensor ,
516516 batch_shape : torch .Size = torch .Size (), # noqa: B008
517517 min_stdv : float = 1e-8 ,
518518 dtype : torch .dtype = torch .double ,
519- default_task_value : int | None = None ,
520519 ) -> None :
521520 r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
522521
@@ -526,28 +525,22 @@ def __init__(
526525 Args:
527526 stratification_idx: The index of the stratification dimension in the
528527 input tensor X.
529- observed_task_values: ``t``-dim tensor of task values that were actually
530- observed in the training data.
531528 all_task_values: ``t``-dim tensor of all possible task values that could
532529 appear in the dataset.
533530 batch_shape: The batch_shape of the training targets.
534531 min_stdv: The minimum standard deviation for which to perform
535532 standardization (if lower, only de-mean the data).
536533 dtype: The data type for internal computations.
537- default_task_value: The default task value that unexpected tasks are
538- mapped to. This is used in ``get_task_value_remapping``.
539534 """
540535 OutcomeTransform .__init__ (self )
541536 self ._stratification_idx = stratification_idx
542- observed_task_values = observed_task_values .unique (sorted = True )
537+ all_task_values = all_task_values .unique (sorted = True )
543538 self .strata_mapping = get_task_value_remapping (
544- observed_task_values = observed_task_values ,
545- all_task_values = all_task_values .unique (sorted = True ),
539+ all_task_values = all_task_values ,
546540 dtype = dtype ,
547- default_task_value = default_task_value ,
548541 )
549542 if self .strata_mapping is None :
550- self .strata_mapping = observed_task_values
543+ self .strata_mapping = all_task_values
551544 n_strata = self .strata_mapping .shape [0 ]
552545 self ._min_stdv = min_stdv
553546 self .register_buffer ("means" , torch .zeros (* batch_shape , n_strata , 1 ))
@@ -629,7 +622,20 @@ def _get_per_input_means_stdvs(
629622 - The per-input stdvs squared.
630623 """
631624 strata = X [..., self ._stratification_idx ].long ()
632- mapped_strata = self .strata_mapping [strata ].unsqueeze (- 1 ).long ()
625+ mapped_strata_float = self .strata_mapping [strata ]
626+ # Check for unobserved tasks (mapped to NaN) and warn
627+ unobserved_mask = torch .isnan (mapped_strata_float )
628+ if unobserved_mask .any ():
629+ warnings .warn (
630+ "Predictions are being made for tasks that were not observed "
631+ "during training. These tasks will use an identity transform "
632+ "(mean=0, stdv=1)." ,
633+ stacklevel = 3 ,
634+ )
635+ # Map unobserved tasks to index 0 temporarily for gather operation
636+ mapped_strata_float = mapped_strata_float .clone ()
637+ mapped_strata_float [unobserved_mask ] = 0.0
638+ mapped_strata = mapped_strata_float .unsqueeze (- 1 ).long ()
633639 # get means and stdvs for each strata
634640 n_extra_batch_dims = mapped_strata .ndim - 2 - len (self ._batch_shape )
635641 expand_shape = mapped_strata .shape [:n_extra_batch_dims ] + self .means .shape
@@ -643,12 +649,22 @@ def _get_per_input_means_stdvs(
643649 dim = - 2 ,
644650 index = mapped_strata ,
645651 )
652+ # Apply identity transform (mean=0, stdv=1) for unobserved tasks
653+ if unobserved_mask .any ():
654+ unobserved_mask_expanded = unobserved_mask .unsqueeze (- 1 )
655+ means = means .clone ()
656+ stdvs = stdvs .clone ()
657+ means [unobserved_mask_expanded ] = 0.0
658+ stdvs [unobserved_mask_expanded ] = 1.0
646659 if include_stdvs_sq :
647660 stdvs_sq = torch .gather (
648661 input = self ._stdvs_sq .expand (expand_shape ),
649662 dim = - 2 ,
650663 index = mapped_strata ,
651664 )
665+ if unobserved_mask .any ():
666+ stdvs_sq = stdvs_sq .clone ()
667+ stdvs_sq [unobserved_mask_expanded ] = 1.0
652668 else :
653669 stdvs_sq = None
654670 return means , stdvs , stdvs_sq
0 commit comments