Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mambular/data_utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(
labels=None,
regression=True,
):
assert cat_features_list or num_features_list

self.cat_features_list = cat_features_list # Categorical features tensors
self.num_features_list = num_features_list # Numerical features tensors
self.embeddings_list = embeddings_list # Embeddings tensors (optional)
Expand All @@ -44,7 +46,8 @@ def __init__(
self.labels = None # No labels in prediction mode

def __len__(self):
return len(self.num_features_list[0]) # Use numerical features length
_feats = self.num_features_list if self.num_features_list else self.cat_features_list
return len(_feats[0])

def __getitem__(self, idx):
"""Retrieves the features and label for a given index.
Expand Down