Skip to content

Commit e65701a

Browse files
authored
[Feature]: Meta OrdinalClassifier estimator (#611)
* implementation of main methods * docstrings * demo notebook * tests,docs,api change * unit tests * docs * change calibration api * feedback adjustments
1 parent 25bd5c9 commit e65701a

File tree

10 files changed

+417
-13
lines changed

10 files changed

+417
-13
lines changed

docs/_scripts/meta-models.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@
6767
plt.clf()
6868

6969
# --8<-- [start:cross-validation-no-refit]
70-
# %%time
70+
# %%time
7171

72-
# Train an original model
72+
# Train an original model
7373
orig_model = LogisticRegression(solver="lbfgs")
7474
orig_model.fit(X, y)
7575

@@ -111,7 +111,7 @@
111111

112112
def plot_model(model):
113113
df = load_chicken(as_frame=True)
114-
114+
115115
_ = model.fit(df[["diet", "time"]], df["weight"])
116116
metric_df = (df[["diet", "time", "weight"]]
117117
.assign(pred=lambda d: model.predict(d[["diet", "time"]]))
@@ -280,7 +280,7 @@ def plot_model(model):
280280

281281

282282
# --8<-- [start:decay-functions]
283-
from sklego.meta._decay_utils import exponential_decay, linear_decay, sigmoid_decay, stepwise_decay
283+
from sklego.meta._decay_utils import exponential_decay, linear_decay, sigmoid_decay, stepwise_decay
284284

285285
fig = plt.figure(figsize=(12, 6))
286286

@@ -312,13 +312,13 @@ def plot_model(model):
312312
np.random.seed(42)
313313

314314
n1, n2, n3 = 100, 500, 50
315-
X = np.concatenate([np.random.normal(0, 1, (n1, 2)),
315+
X = np.concatenate([np.random.normal(0, 1, (n1, 2)),
316316
np.random.normal(2, 1, (n2, 2)),
317-
np.random.normal(3, 1, (n3, 2))],
317+
np.random.normal(3, 1, (n3, 2))],
318318
axis=0)
319-
y = np.concatenate([np.zeros((n1, 1)),
319+
y = np.concatenate([np.zeros((n1, 1)),
320320
np.ones((n2, 1)),
321-
np.zeros((n3, 1))],
321+
np.zeros((n3, 1))],
322322
axis=0).reshape(-1)
323323
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap);
324324
# --8<-- [end:make-blobs]
@@ -360,7 +360,7 @@ def false_negatives(mod, x, y):
360360
cf_mod = ConfusionBalancer(LogisticRegression(solver="lbfgs", max_iter=1000), alpha=1.0)
361361

362362
grid = GridSearchCV(
363-
cf_mod,
363+
cf_mod,
364364
param_grid={"alpha": np.linspace(-1.0, 3.0, 31)},
365365
scoring={
366366
"accuracy": make_scorer(accuracy_score),
@@ -464,4 +464,49 @@ def false_negatives(mod, x, y):
464464

465465
from sklearn.utils import estimator_html_repr
466466
with open(_static_path / "outlier-classifier-stacking.html", "w") as f:
467-
f.write(estimator_html_repr(stacker))
467+
f.write(estimator_html_repr(stacker))
468+
469+
# --8<-- [start:ordinal-classifier-data]
470+
import pandas as pd
471+
472+
url = "https://stats.idre.ucla.edu/stat/data/ologit.dta"
473+
df = pd.read_stata(url).assign(apply_codes = lambda t: t["apply"].cat.codes)
474+
475+
target = "apply_codes"
476+
features = [c for c in df.columns if c not in {target, "apply"}]
477+
478+
X, y = df[features].to_numpy(), df[target].to_numpy()
479+
df.head()
480+
# --8<-- [end:ordinal-classifier-data]
481+
482+
with open(_static_path / "ordinal_data.md", "w") as f:
483+
f.write(df.head().to_markdown(index=False))
484+
485+
# --8<-- [start:ordinal-classifier]
486+
from sklearn.linear_model import LogisticRegression
487+
from sklego.meta import OrdinalClassifier
488+
489+
ord_clf = OrdinalClassifier(LogisticRegression(), n_jobs=-1, use_calibration=False)
490+
_ = ord_clf.fit(X, y)
491+
ord_clf.predict_proba(X[0])
492+
# --8<-- [end:ordinal-classifier]
493+
494+
print(ord_clf.predict_proba(X[0]))
495+
496+
# --8<-- [start:ordinal-classifier-with-calibration]
497+
from sklearn.calibration import CalibratedClassifierCV
498+
from sklearn.linear_model import LogisticRegression
499+
from sklego.meta import OrdinalClassifier
500+
501+
calibration_kwargs = {...}
502+
503+
ord_clf = OrdinalClassifier(
504+
estimator=LogisticRegression(),
505+
use_calibration=True,
506+
calibration_kwargs=calibration_kwargs
507+
)
508+
509+
# This is equivalent to:
510+
estimator = CalibratedClassifierCV(LogisticRegression(), **calibration_kwargs)
511+
ord_clf = OrdinalClassifier(estimator)
512+
# --8<-- [end:ordinal-classifier-with-calibration]
68.4 KB
Loading
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
| apply | pared | public | gpa | apply_codes |
2+
|:----------------|--------:|---------:|------:|--------------:|
3+
| very likely | 0 | 0 | 3.26 | 2 |
4+
| somewhat likely | 1 | 0 | 3.21 | 1 |
5+
| unlikely | 1 | 1 | 3.94 | 0 |
6+
| somewhat likely | 0 | 0 | 2.81 | 1 |
7+
| somewhat likely | 0 | 0 | 2.53 | 1 |

docs/api/meta.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
show_root_full_path: true
2626
show_root_heading: true
2727

28+
::: sklego.meta.ordinal_classification.OrdinalClassifier
29+
options:
30+
show_root_full_path: true
31+
show_root_heading: true
32+
2833
::: sklego.meta.outlier_classifier.OutlierClassifier
2934
options:
3035
show_root_full_path: true

docs/user-guide/meta-models.md

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ The image below demonstrates what will happen.
100100

101101
![grouped](../_static/meta-models/grouped-df.png)
102102

103-
104103
We train 5 models in total because the model will also train a fallback automatically (you can turn this off via `use_fallback=False`).
105104

106105
The idea behind the fallback is that we can predict something if there is a group at prediction time which is unseen during training.
@@ -291,6 +290,7 @@ We'll perform an optimistic demonstration below.
291290
```py
292291
--8<-- "docs/_scripts/meta-models.py:confusion-balancer-results"
293292
```
293+
294294
It seems that we can pick a value for $\alpha$ such that the confusion matrix is balanced. there's also a modest increase in accuracy for this balancing moment.
295295

296296
It should be emphasized though that this feature is **experimental**. There have been dataset/model combinations where this effect seems to work very well while there have also been situations where this trick does not work at all.
@@ -327,7 +327,7 @@ ZIR (RFC+RFR) r²: 0.8992404366385873
327327
RFR r²: 0.8516522752031502
328328
```
329329

330-
## OutlierClassifier
330+
## Outlier Classifier
331331

332332
Outlier models are unsupervised so they don't have `predict_proba` or `score` methods.
333333

@@ -381,6 +381,66 @@ The `OutlierClassifier` can be combined with any classification model in the `St
381381

382382
--8<-- "docs/_static/meta-models/outlier-classifier-stacking.html"
383383

384+
## Ordinal Classification
385+
386+
Ordinal classification (sometimes also referred to as Ordinal Regression) involves predicting an ordinal target variable, where the classes have a meaningful order.
387+
Examples of this kind of problem are: predicting customer satisfaction on a scale from 1 to 5, predicting the severity of a disease, predicting the quality of a product, etc.
388+
389+
The [`OrdinalClassifier`][ordinal-classifier-api] is a meta-model that can be used to transform any classifier into an ordinal classifier by fitting N-1 binary classifiers, each handling a specific class boundary, namely: $P(y <= 1), P(y <= 2), ..., P(y <= N-1)$.
390+
391+
This implementation is based on the paper [A simple approach to ordinal classification][ordinal-classification-paper] and it allows to predict the ordinal probabilities of each sample belonging to a particular class.
392+
393+
??? tip "Graphical representation"
394+
An image (from the paper itself) is worth a thousand words:
395+
![ordinal-classification](../_static/meta-models/ordinal-classification.png)
396+
397+
!!! note "mord library"
398+
If you are looking for a library that implements other ordinal classification algorithms, you can have a look at the [mord][mord] library.
399+
400+
```py title="Ordinal Data"
401+
--8<-- "docs/_scripts/meta-models.py:ordinal-classifier-data"
402+
```
403+
404+
--8<-- "docs/_static/meta-models/ordinal_data.md"
405+
406+
Description of the dataset from [statsmodels tutorial][statsmodels-ordinal-regression]:
407+
408+
> This dataset is about the probability for undergraduate students to apply to graduate school given three exogenous variables:
409+
>
410+
> - their grade point average (`gpa`), a float between 0 and 4.
411+
> - `pared`, a binary that indicates if at least one parent went to graduate school.
412+
> - `public`, a binary that indicates if the current undergraduate institution of the student is > public or private.
413+
>
414+
> `apply`, the target variable is categorical with ordered categories: "unlikely" < "somewhat likely" < "very likely".
415+
>
416+
> [...]
417+
>
418+
> For more details see the the Documentation of OrderedModel, [the UCLA webpage][ucla-webpage].
419+
420+
The only transformation we are applying to the data is to convert the target variable to an ordinal categorical variable by mapping the ordered categories to integers using their (pandas) category codes.
421+
422+
We are now ready to train a [`OrdinalClassifier`][ordinal-classifier-api] on this dataset:
423+
424+
```py title="OrdinalClassifier"
425+
--8<-- "docs/_scripts/meta-models.py:ordinal-classifier"
426+
```
427+
428+
> [[0.54883853 0.36225347 0.088908]]
429+
430+
### Probability Calibration
431+
432+
The `OrdinalClassifier` emphasizes the importance of proper probability estimates for its functionality. It is recommended to use the [`CalibratedClassifierCV`][calibrated-classifier-api] class from scikit-learn to calibrate the probabilities of the binary classifiers.
433+
434+
Probability calibration is _not_ enabled by default, but we provide a convenient keyword argument `use_calibration` to enable it as follows:
435+
436+
```py title="OrdinalClassifier with probability calibration"
437+
--8<-- "docs/_scripts/meta-models.py:ordinal-classifier-with-calibration"
438+
```
439+
440+
### Computation Time
441+
442+
As a meta-estimator, the `OrdinalClassifier` fits N-1 binary classifiers, which may be computationally expensive, especially with a large number of samples, features, or a complex classifier.
443+
384444
[thresholder-api]: ../../api/meta#sklego.meta.thresholder.Thresholder
385445
[grouped-predictor-api]: ../../api/meta#sklego.meta.grouped_predictor.GroupedPredictor
386446
[grouped-transformer-api]: ../../api/meta#sklego.meta.grouped_transformer.GroupedTransformer
@@ -389,8 +449,14 @@ The `OutlierClassifier` can be combined with any classification model in the `St
389449
[confusion-balancer-api]: ../../api/meta#sklego.meta.confusion_balancer.ConfusionBalancer
390450
[zero-inflated-api]: ../../api/meta#sklego.meta.zero_inflated_regressor.ZeroInflatedRegressor
391451
[outlier-classifier-api]: ../../api/meta#sklego.meta.outlier_classifier.OutlierClassifier
452+
[ordinal-classifier-api]: ../../api/meta#sklego.meta.ordinal_classification.OrdinalClassifier
392453

393454
[standard-scaler-api]: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html
394455
[stacking-classifier-api]: https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.StackingClassifier.html#sklearn.ensemble.StackingClassifier
395456
[dummy-regressor-api]: https://scikit-learn.org/stable/modules/generated/sklearn.dummy.DummyRegressor.html
396457
[imb-learn]: https://imbalanced-learn.org/stable/
458+
[ordinal-classification-paper]: https://www.cs.waikato.ac.nz/~eibe/pubs/ordinal_tech_report.pdf
459+
[mord]: https://pythonhosted.org/mord/
460+
[statsmodels-ordinal-regression]: https://www.statsmodels.org/dev/examples/notebooks/generated/ordinal_regression.html
461+
[ucla-webpage]: https://stats.oarc.ucla.edu/r/dae/ordinal-logistic-regression/
462+
[calibrated-classifier-api]: https://scikit-learn.org/stable/modules/generated/sklearn.calibration.CalibratedClassifierCV.html

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ sklego = ["data/*.zip"]
9595
[tool.ruff]
9696
line-length = 120
9797
extend-select = ["I"]
98+
exclude = ["docs"]
9899

99100
[tool.pytest.ini_options]
100101
markers = [

sklego/meta/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"GroupedEstimator",
66
"GroupedPredictor",
77
"GroupedTransformer",
8+
"OrdinalClassifier",
89
"OutlierRemover",
910
"SubjectiveClassifier",
1011
"Thresholder",
@@ -19,6 +20,7 @@
1920
from sklego.meta.grouped_estimator import GroupedEstimator
2021
from sklego.meta.grouped_predictor import GroupedPredictor
2122
from sklego.meta.grouped_transformer import GroupedTransformer
23+
from sklego.meta.ordinal_classification import OrdinalClassifier
2224
from sklego.meta.outlier_classifier import OutlierClassifier
2325
from sklego.meta.outlier_remover import OutlierRemover
2426
from sklego.meta.regression_outlier_detector import RegressionOutlierDetector

0 commit comments

Comments
 (0)