Skip to content

Draft: Ordinal model V2 cleaning #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions docs/leaspy.models.utils.ordinal.rst

This file was deleted.

46 changes: 31 additions & 15 deletions leaspy/io/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def get_times_patient(self, i: int) -> torch.FloatTensor:
"""
return self.timepoints[i, :self.n_visits_per_individual[i]]

def get_event_patient(self, idx_patient: int) -> Tuple[torch.Tensor, torch.Tensor]:
def get_event_patient(self, idx_patient: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Get ages at event for patient number ``idx_patient``

Expand Down Expand Up @@ -297,7 +297,27 @@ def move_to_device(self, device: torch.device) -> None:
if self._one_hot_encoding is not None:
self._one_hot_encoding = {k: t.to(device) for k, t in self._one_hot_encoding.items()}

def get_one_hot_encoding(self, *, sf: bool, ordinal_infos: KwargsType):
def get_max_levels(self) -> dict[str, int]:
df = self.to_pandas().dropna(how="all").sort_index()[self.headers]
return {feature: int(s.max()) for feature, s in df.items()}

def get_mask(self) -> torch.Tensor:
max_levels = self.get_max_levels()
max_level = max(max_levels.values())
return torch.stack(
[
torch.cat(
[
torch.ones(ft_max_level),
torch.zeros(max_level - ft_max_level),
],
dim=-1,
)
for ft_max_level in max_levels.values()
],
)

def get_one_hot_encoding(self, *, sf: bool):
"""
Builds the one-hot encoding of ordinal data once and for all and returns it.

Expand All @@ -316,32 +336,28 @@ def get_one_hot_encoding(self, *, sf: bool, ordinal_infos: KwargsType):
"""
if self._one_hot_encoding is not None:
return self._one_hot_encoding[sf]

## Check the data & construct the one-hot encodings once for all for fast look-up afterwards

max_levels = self.get_max_levels()
max_level = max(max_levels.values())
# Check for values different than non-negative integers
if (self.values != self.values.round()).any() or (self.values < 0).any():
raise LeaspyInputError(
"Please make sure your data contains only integers >= 0 when using ordinal noise modelling.")

"Please make sure your data contains only integers >= 0 when using ordinal noise modelling."
)
# First of all check consistency of features given in ordinal_infos compared to the ones in the dataset (names & order!)
ordinal_feat_names = list(ordinal_infos['max_levels'])
if ordinal_feat_names != self.headers:
if list(max_levels.keys()) != self.headers:
raise LeaspyInputError(
f"Features stored in ordinal model ({ordinal_feat_names}) are not consistent with features in data ({self.headers})")

f"Features stored in ordinal model ({max_levels}) are not consistent with features in data ({self.headers})"
)
# Now check that integers are within the expected range, per feature [0, max_level_ft]
# (masked values are encoded by 0 at this point)
vals = self.values.long()
vals_issues = {
'unexpected': [],
'missing': [],
}
for ft_i, (ft, max_level_ft) in enumerate(ordinal_infos['max_levels'].items()):
for ft_i, (ft, max_level_ft) in enumerate(max_levels.items()):
expected_codes = set(range(0, max_level_ft + 1)) # max level is included

vals_ft = vals[:, :, ft_i]

if not self.no_warning:
# replacing masked values by -1 (which was guaranteed not to be part of input from first check, all >= 0)
actual_vals_ft = vals_ft.where(self.mask[:, :, ft_i].bool(), torch.tensor(-1))
Expand All @@ -366,7 +382,7 @@ def get_one_hot_encoding(self, *, sf: bool, ordinal_infos: KwargsType):
+ '\n'.join(vals_issues['missing']))

# one-hot encode all the values after the checks & clipping
vals_pdf = torch.nn.functional.one_hot(vals, num_classes=ordinal_infos['max_level'] + 1)
vals_pdf = torch.nn.functional.one_hot(vals, num_classes=max_level + 1)
# build the survival function by simple (1 - cumsum) and remove the useless P(X >= 0) = 1
vals_sf = discrete_sf_from_pdf(vals_pdf)
# cache the values to retrieve them fast afterwards
Expand Down
1 change: 0 additions & 1 deletion leaspy/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def parameters(self) -> DictParamsTorch:
for p in self.hyperparameters_names + self.parameters_names
}

@abstractmethod
def to_dict(self) -> KwargsType:
"""
Export model as a dictionary ready for export.
Expand Down
6 changes: 1 addition & 5 deletions leaspy/models/abstract_multivariate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from leaspy.models.abstract_model import AbstractModel, InitializationMethod
from leaspy.models.obs_models import observation_model_factory

# WIP
# from leaspy.models.utils.initialization.model_initialization import initialize_parameters
# from leaspy.models.utils.ordinal import OrdinalModelMixin
from leaspy.io.data.dataset import Dataset
from leaspy.variables.specs import (
NamedVariables,
Expand All @@ -17,7 +14,6 @@
PopulationLatentVariable,
IndividualLatentVariable,
LinkedVariable,
VariablesValuesRO,
)
from leaspy.variables.distributions import Normal
from leaspy.utils.functional import (
Expand All @@ -33,7 +29,7 @@


@doc_with_super()
class AbstractMultivariateModel(AbstractModel): # OrdinalModelMixin,
class AbstractMultivariateModel(AbstractModel):
"""
Contains the common attributes & methods of the multivariate models.

Expand Down
1 change: 1 addition & 0 deletions leaspy/models/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from leaspy.utils.typing import KwargsType
from leaspy.models.obs_models import observation_model_factory
import pandas as pd
import warnings
from leaspy.utils.typing import DictParams, Optional
from leaspy.exceptions import LeaspyInputError

Expand Down
29 changes: 7 additions & 22 deletions leaspy/models/obs_models/_ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,9 @@


class OrdinalObservationModel(ObservationModel):
max_level: int
max_levels: dict[str, int]
_mask: torch.Tensor
string_for_json = 'ordinal'
string_for_json: str = "ordinal"

@property
def ordinal_infos(self) -> dict:
"""Property to return the ordinal info dictionary."""
# Maybe not put all ordinal infos in the ObservationModel but in the model itself
return dict(
max_level= self.max_level,
max_levels= self.max_levels,
mask= self._mask,
)

def __init__(
self,
**extra_vars: VariableInterface,
):
def __init__(self, **extra_vars: VariableInterface):
super().__init__(
name="y",
getter=self.y_getter,
Expand All @@ -41,7 +25,8 @@ def y_getter(self, dataset: Dataset) -> WeightedTensor:
"Provided dataset is not valid. "
"Both values and mask should be not None."
)
pdf = dataset.get_one_hot_encoding(sf=False, ordinal_infos=self.ordinal_infos)
mask = torch.ones_like(pdf)
mask[..., 1:] = self._mask #Add +1 on last dimension for level 0
return WeightedTensor(pdf, weight=dataset.mask.to(torch.bool).unsqueeze(-1) * mask)
pdf = dataset.get_one_hot_encoding(sf=False)
mask_ = torch.ones_like(pdf)
mask_[..., 1:] = dataset.get_mask() # Add +1 on last dimension for level 0
return WeightedTensor(pdf, weight=dataset.mask.to(torch.bool).unsqueeze(-1) * mask_)

Loading