Skip to content

Commit 6ec9859

Browse files
authored
Merge pull request #264 from basf/master
Release v1.5.0
2 parents 6386483 + 28b5456 commit 6ec9859

29 files changed

+686
-1892
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,4 @@ docs/_build/doctrees/*
173173
docs/_build/html/*
174174

175175

176-
dev/*
176+
dev/*

README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Mambular is a Python library for tabular deep learning. It includes models that
2323

2424
<h3>⚡ What's New ⚡</h3>
2525
<ul>
26-
<li>New Models: `Tangos`, `AutoInt`, `Trompt`</li>
26+
<li>New Models: `Tangos`, `AutoInt`, `Trompt`, `ModernNCA`</li>
2727
<li>Pretraining optionality for suitable models.</li>
2828
<li>Individual preprocessing: preprocess each feature differently, use pre-trained models for categorical encoding</li>
2929
<li>Extract latent representations of tables</li>
@@ -82,6 +82,8 @@ Mambular is a Python package that brings the power of advanced deep learning arc
8282
| `AutoInt` | Automatic Feature Interaction Learning via Self-Attentive Neural Networks introduced [here](https://arxiv.org/abs/1810.11921). |
8383
| `Trompt` | Trompt: Towards a Better Deep Neural Network for Tabular Data introduced [here](https://arxiv.org/abs/2305.18446). |
8484
| `Tangos` | Tangos: Regularizing Tabular Neural Networks through Gradient Orthogonalization and Specialization introduced [here](https://openreview.net/pdf?id=n6H86gW8u0d). |
85+
| `ModernNCA` | Revisiting Nearest Neighbor for Tabular Data: A Deep Tabular Baseline Two Decades Later introduced [here](https://arxiv.org/abs/2407.03257). |
86+
| `TabR` | TabR: Tabular Deep Learning Meets Nearest Neighbors in 2023 [here](https://arxiv.org/abs/2307.14338) |
8587

8688

8789

@@ -118,8 +120,11 @@ pip install mamba-ssm
118120

119121
<h2> Preprocessing </h2>
120122

121-
Mambular simplifies data preprocessing with a range of tools designed for easy transformation of tabular data.
122-
Specify a default method, or a dictionary defining individual preprocessing methods for each feature.
123+
Mambular uses pretab preprocessing: https://github.com/OpenTabular/PreTab
124+
125+
Hence, datatypes etc. are detected automatically and all preprocessing methods from pretab as well as from Sklearn.preprocessing are available.
126+
Additionally, you can specify that each feature is preprocessed differently, according to your requirements, by setting the `feature_preprocessing={}`argument during model initialization.
127+
For an overview over all available methods: [pretab](https://github.com/OpenTabular/PreTab)
123128

124129
<h3> Data Type Detection and Transformation </h3>
125130

docs/api/preprocessing/Preprocessor.rst

Lines changed: 0 additions & 5 deletions
This file was deleted.

docs/api/preprocessing/index.rst

Lines changed: 0 additions & 20 deletions
This file was deleted.

docs/api/utils/Preprocessor.rst

Lines changed: 0 additions & 5 deletions
This file was deleted.

docs/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
api/models/index
3333
api/base_models/index
34-
api/preprocessing/index
3534
api/data_utils/index
3635
api/configs/index
3736

mamba_tabular_summary.pdf

-79.3 KB
Binary file not shown.

mambular/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from . import base_models, data_utils, models, preprocessing, utils
1+
from . import base_models, data_utils, models, utils
22
from .__version__ import __version__
33

44
__all__ = [
55
"__version__",
66
"base_models",
77
"data_utils",
88
"models",
9-
"preprocessing",
109
"utils",
1110
]

mambular/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
1919

20-
__version__ = "1.4.0"
20+
__version__ = "1.5.0"
2121

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def __init__(self, num_feature_info, cat_feature_info, emb_feature_info, config)
125125
if self.layer_norm_after_embedding:
126126
self.embedding_norm = nn.LayerNorm(self.d_model)
127127

128+
self.feature_info = (num_feature_info, cat_feature_info, emb_feature_info)
129+
128130
def forward(self, num_features, cat_features, emb_features):
129131
"""Defines the forward pass of the model.
130132
@@ -171,6 +173,8 @@ def forward(self, num_features, cat_features, emb_features):
171173

172174
# Process numerical embeddings based on embedding_type
173175
if self.embedding_type == "plr":
176+
# check pre-processing type compatibility with plr
177+
self.check_plr_embedding_compatibility(self.feature_info)
174178
# For PLR, pass all numerical features together
175179
if num_features is not None:
176180
num_features = torch.stack(num_features, dim=1).squeeze(
@@ -226,6 +230,21 @@ def forward(self, num_features, cat_features, emb_features):
226230
x = self.embedding_dropout(x)
227231

228232
return x
233+
234+
def check_plr_embedding_compatibility(self, feature_info:tuple):
235+
# List of incompatible preprocessing terms for PLR embedding
236+
incompatible_terms = ['ple', 'one-hot', 'polynomial', 'splines', 'sigmoid', 'rbf']
237+
238+
# Iterate through each dictionary in the tuple (data)
239+
for sub_dict in feature_info:
240+
# Iterate through each feature in the current dictionary
241+
for feature, properties in sub_dict.items():
242+
preprocessing = properties.get('preprocessing', '')
243+
244+
# Check for incompatible terms in the preprocessing string
245+
for term in incompatible_terms:
246+
if term in preprocessing:
247+
raise ValueError(f"PLR embedding type doesn't work with the '{term}' pre-processing method.\n")
229248

230249

231250
class OneHotEncoding(nn.Module):

mambular/base_models/modern_nca.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
self.save_hyperparameters(ignore=["feature_information"])
2323

2424
self.returns_ensemble = False
25-
self.uses_nca_candidates = True
25+
self.uses_candidates = True
2626

2727
self.T = config.temperature
2828
self.sample_rate = config.sample_rate
@@ -31,6 +31,7 @@ def __init__(
3131
*feature_information,
3232
config=config,
3333
)
34+
3435
input_dim = np.sum(
3536
[len(info) * self.hparams.d_model for info in feature_information]
3637
)
@@ -75,7 +76,7 @@ def forward(self, *data):
7576
x = self.post_encoder(x)
7677
return self.tabular_head(x)
7778

78-
def nca_train(self, *data, targets, candidate_x, candidate_y):
79+
def train_with_candidates(self, *data, targets, candidate_x, candidate_y):
7980
"""NCA-style training forward pass selecting candidates."""
8081
if self.hparams.use_embeddings:
8182
x = self.embedding_layer(*data)
@@ -85,6 +86,7 @@ def nca_train(self, *data, targets, candidate_x, candidate_y):
8586
B, S, D = candidate_x.shape
8687
candidate_x = candidate_x.reshape(B, S * D)
8788
else:
89+
8890
x = torch.cat([t for tensors in data for t in tensors], dim=1)
8991
candidate_x = torch.cat(
9092
[t for tensors in candidate_x for t in tensors], dim=1
@@ -129,7 +131,7 @@ def nca_train(self, *data, targets, candidate_x, candidate_y):
129131

130132
return logits
131133

132-
def nca_validate(self, *data, candidate_x, candidate_y):
134+
def validate_with_candidates(self, *data, candidate_x, candidate_y):
133135
"""Validation forward pass with NCA-style candidate selection."""
134136
if self.hparams.use_embeddings:
135137
x = self.embedding_layer(*data)
@@ -172,7 +174,7 @@ def nca_validate(self, *data, candidate_x, candidate_y):
172174

173175
return logits
174176

175-
def nca_predict(self, *data, candidate_x, candidate_y):
177+
def predict_with_candidates(self, *data, candidate_x, candidate_y):
176178
"""Prediction forward pass with candidate selection."""
177179
if self.hparams.use_embeddings:
178180
x = self.embedding_layer(*data)

0 commit comments

Comments
 (0)