File tree Expand file tree Collapse file tree 6 files changed +19
-5
lines changed
Expand file tree Collapse file tree 6 files changed +19
-5
lines changed Original file line number Diff line number Diff line change 1919
2020VERSION_PATH = "../mambular/__version__.py"
2121with open (VERSION_PATH ) as f :
22- VERSION = f .readlines ()[- 1 ].split ()[- 1 ].strip ("\" '" )
23- release = VERSION
22+ lines = f .readlines ()
23+ for line in lines :
24+ if line .startswith ("__version__" ):
25+ version = line .split ("=" )[- 1 ].strip ().strip ('"' )
26+ release = version
2427
2528# -- General configuration ---------------------------------------------------
2629# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Original file line number Diff line number Diff line change 1717
1818# The following line *must* be the last in the module, exactly as formatted:
1919
20- __version__ = "1.3.1"
21-
20+ __version__ = "1.3.2"
Original file line number Diff line number Diff line change 44import torch
55from sklearn .metrics import accuracy_score , log_loss
66from .sklearn_parent import SklearnBase
7+ import numpy as np
78
89
910class SklearnBaseClassifier (SklearnBase ):
@@ -85,6 +86,8 @@ def build_model(
8586 The built classifier.
8687 """
8788
89+ num_classes = len (np .unique (y ))
90+
8891 return super ()._build_model (
8992 X ,
9093 y ,
@@ -94,6 +97,7 @@ def build_model(
9497 y_val = y_val ,
9598 embeddings = embeddings ,
9699 embeddings_val = embeddings_val ,
100+ num_classes = num_classes ,
97101 random_state = random_state ,
98102 batch_size = batch_size ,
99103 shuffle = shuffle ,
@@ -190,6 +194,7 @@ def fit(
190194 The fitted classifier.
191195 """
192196
197+ num_classes = len (np .unique (y ))
193198 return super ().fit (
194199 X = X ,
195200 y = y ,
@@ -215,6 +220,7 @@ def fit(
215220 train_metrics = train_metrics ,
216221 val_metrics = val_metrics ,
217222 rebuild = rebuild ,
223+ num_classes = num_classes ,
218224 ** trainer_kwargs ,
219225 )
220226
Original file line number Diff line number Diff line change @@ -93,6 +93,7 @@ def build_model(
9393 y_val = y_val ,
9494 embeddings = embeddings ,
9595 embeddings_val = embeddings_val ,
96+ num_classes = 1 ,
9697 random_state = random_state ,
9798 batch_size = batch_size ,
9899 shuffle = shuffle ,
@@ -198,6 +199,7 @@ def fit(
198199 y_val = y_val ,
199200 embeddings = embeddings ,
200201 embeddings_val = embeddings_val ,
202+ num_classes = 1 ,
201203 max_epochs = max_epochs ,
202204 random_state = random_state ,
203205 batch_size = batch_size ,
Original file line number Diff line number Diff line change @@ -120,6 +120,7 @@ def _build_model(
120120 y_val = None ,
121121 embeddings = None ,
122122 embeddings_val = None ,
123+ num_classes : int = None ,
123124 random_state : int = 101 ,
124125 batch_size : int = 128 ,
125126 shuffle : bool = True ,
@@ -223,6 +224,7 @@ def _build_model(
223224 weight_decay = (
224225 weight_decay if weight_decay is not None else self .config .weight_decay
225226 ),
227+ num_classes = num_classes ,
226228 train_metrics = train_metrics ,
227229 val_metrics = val_metrics ,
228230 optimizer_type = self .optimizer_type ,
@@ -273,6 +275,7 @@ def fit(
273275 y_val = None ,
274276 embeddings = None ,
275277 embeddings_val = None ,
278+ num_classes : int = None ,
276279 max_epochs : int = 100 ,
277280 random_state : int = 101 ,
278281 batch_size : int = 128 ,
@@ -357,6 +360,7 @@ def fit(
357360 y_val = y_val ,
358361 embeddings = embeddings ,
359362 embeddings_val = embeddings_val ,
363+ num_classes = num_classes ,
360364 random_state = random_state ,
361365 batch_size = batch_size ,
362366 shuffle = shuffle ,
Original file line number Diff line number Diff line change 11[tool .poetry ]
22name = " mambular"
33
4- version = " 1.3.1 "
4+ version = " 1.3.2 "
55
66description = " A python package for tabular deep learning with mamba blocks."
77authors = [" Anton Thielmann" , " Manish Kumar" , " Christoph Weisser" ]
You can’t perform that action at this time.
0 commit comments