-
Notifications
You must be signed in to change notification settings - Fork 54
Expand file tree
/
Copy pathcatboost_reranker.py
More file actions
112 lines (91 loc) · 4.24 KB
/
catboost_reranker.py
File metadata and controls
112 lines (91 loc) · 4.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2026 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
import pandas as pd
from catboost import CatBoostClassifier, CatBoostRanker, Pool
from rectools import Columns
from .candidate_ranking import Reranker
class CatBoostReranker(Reranker):
"""
A reranker using CatBoost models for classification or ranking tasks.
This class supports both `CatBoostClassifier` and `CatBoostRanker` models to rerank candidates
based on their features and optionally provided additional parameters for fitting and pool creation.
"""
def __init__(
self,
model: tp.Union[CatBoostClassifier, CatBoostRanker],
fit_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
pool_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
):
"""
Initialize the CatBoostReranker with `model`, `fit_kwargs` and `pool_kwargs`.
Parameters
----------
model : CatBoostClassifier | CatBoostRanker
A CatBoost model instance used for reranking. Can be either a classifier or a ranker.
fit_kwargs : dict(str -> any), optional, default ``None``
Additional keyword arguments to be passed to the `fit` method of the CatBoost model.
pool_kwargs : dict(str -> any), optional, default ``None``
Additional keyword arguments to be used when creating the CatBoost `Pool`.
"""
super().__init__(model)
self.is_classifier = isinstance(model, CatBoostClassifier) # CatBoostRanker otherwise
self.fit_kwargs = fit_kwargs
self.pool_kwargs = pool_kwargs
def prepare_training_pool(self, candidates_with_target: pd.DataFrame) -> Pool:
"""
Prepare a CatBoost `Pool` for training from the given candidates with target.
Depending on whether the model is a classifier or a ranker, the pool is prepared differently.
For classifiers, only data and label are used. For rankers, group information is also included.
Parameters
----------
candidates_with_target : pd.DataFrame
DataFrame containing candidate features and target values, along with user and item identifiers.
Returns
-------
Pool
A CatBoost Pool object ready for training.
"""
if self.is_classifier:
pool_kwargs = {
"data": candidates_with_target.drop(columns=Columns.UserItem + [Columns.Target]),
"label": candidates_with_target[Columns.Target],
}
else:
candidates_with_target = candidates_with_target.sort_values(by=[Columns.User])
pool_kwargs = {
"data": candidates_with_target.drop(columns=Columns.UserItem + [Columns.Target]),
"label": candidates_with_target[Columns.Target],
"group_id": candidates_with_target[Columns.User].values,
}
if self.pool_kwargs is not None:
pool_kwargs.update(self.pool_kwargs)
return Pool(**pool_kwargs)
def fit(self, candidates_with_target: pd.DataFrame) -> None:
"""
Fit the CatBoost model using the given candidates with target data.
This method prepares the training pool and fits the model using the specified fit parameters.
Parameters
----------
candidates_with_target : pd.DataFrame
DataFrame containing candidate features and target values, along with user and item identifiers.
Returns
-------
None
"""
training_pool = self.prepare_training_pool(candidates_with_target)
fit_kwargs = {"X": training_pool}
if self.fit_kwargs is not None:
fit_kwargs.update(self.fit_kwargs)
self.model.fit(**fit_kwargs)