Skip to content

Commit 854fb5b

Browse files
Pranav ChoudharyPranav Choudhary
authored andcommitted
Fix scikit-learn API compliance and MaskedDataset logic
1 parent d2b2777 commit 854fb5b

3 files changed

Lines changed: 34 additions & 12 deletions

File tree

pyaptamer/benchmarking/_base.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,13 @@ class Benchmarking:
7373
>>> summary = bench.run() # doctest: +SKIP
7474
"""
7575

76-
def __init__(self, estimators, metrics, X, y, cv=None):
76+
def __init__(self, estimators, metrics, X, y, cv=None, labels=None):
7777
self.estimators = estimators if isinstance(estimators, list) else [estimators]
7878
self.metrics = metrics if isinstance(metrics, list) else [metrics]
7979
self.X = X
8080
self.y = y
8181
self.cv = cv
82+
self.labels = labels
8283
self.results = None
8384

8485
def _to_scorers(self, metrics):
@@ -128,8 +129,27 @@ def run(self):
128129
self.scorers_ = self._to_scorers(self.metrics)
129130
results = {}
130131

131-
for estimator in self.estimators:
132-
est_name = estimator.__class__.__name__
132+
if self.labels is not None:
133+
if len(self.labels) != len(self.estimators):
134+
raise ValueError("Length of labels must match length of estimators.")
135+
names = self.labels
136+
else:
137+
counts = {}
138+
for est in self.estimators:
139+
name = est.__class__.__name__
140+
counts[name] = counts.get(name, 0) + 1
141+
142+
names = []
143+
seen = {}
144+
for est in self.estimators:
145+
name = est.__class__.__name__
146+
if counts[name] > 1:
147+
seen[name] = seen.get(name, 0) + 1
148+
names.append(f"{name}_{seen[name]}")
149+
else:
150+
names.append(name)
151+
152+
for estimator, est_name in zip(self.estimators, names):
133153

134154
cv_results = cross_validate(
135155
estimator,

pyaptamer/datasets/dataclasses/_masked.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
self.box = np.array(list(range(max_len)))
8686
self.len = len(self.x)
8787

88-
def _mask_rna(self, x_masked: Tensor, mask_positions: list[int]) -> Tensor:
88+
def _mask_rna(self, x_masked: Tensor, mask_positions: list[int]) -> tuple[Tensor, list[int]]:
8989
"""Mask adjacent nucleotides for RNA sequences.
9090
9191
Parameters
@@ -97,20 +97,20 @@ def _mask_rna(self, x_masked: Tensor, mask_positions: list[int]) -> Tensor:
9797
9898
Returns
9999
-------
100-
Tensor
101-
The tensor with adjacent nucleotides masked.
100+
tuple[Tensor, list[int]]
101+
The tensor with adjacent nucleotides masked and the list of adjacent positions.
102102
"""
103103
adjacent_positions = []
104104
for pos in mask_positions:
105105
# mask position + 1 (if within bounds)
106-
if pos < self.max_len - 1:
106+
if pos < self.max_len - 1 and x_masked[pos + 1] > 0:
107107
adjacent_positions.append(pos + 1)
108108
# mask position - 1 (if within bounds)
109-
if pos > 0:
109+
if pos > 0 and x_masked[pos - 1] > 0:
110110
adjacent_positions.append(pos - 1)
111111
x_masked[adjacent_positions] = self.mask_idx
112112

113-
return x_masked
113+
return x_masked, adjacent_positions
114114

115115
def __len__(self) -> int:
116116
"""
@@ -151,7 +151,7 @@ def __getitem__(self, index: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
151151
y = torch.tensor(self.y[index], dtype=torch.int64)
152152

153153
x_masked = x.clone().detach()
154-
y_masked = x.clone().detach()
154+
y_masked = y.clone().detach()
155155

156156
# non-padding positions (0 is padding)
157157
seq_len = torch.sum(x_masked > 0)
@@ -173,7 +173,8 @@ def __getitem__(self, index: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
173173

174174
# for RNA, also mask adjacent nucleotides for base pairing
175175
if self.is_rna:
176-
x_masked = self._mask_rna(x_masked, actual_mask_positions)
176+
x_masked, adjacent_positions = self._mask_rna(x_masked, actual_mask_positions)
177+
no_mask_positions = [pos for pos in no_mask_positions if pos not in adjacent_positions]
177178

178179
# zero out non-masked positions in target
179180
y_masked[no_mask_positions] = 0

pyaptamer/trafos/encode/_greedy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def _transform(self, X):
132132

133133
return result_df
134134

135-
def get_test_params(self):
135+
@classmethod
136+
def get_test_params(cls):
136137
"""Get test parameters for GreedyEncoder.
137138
138139
Returns

0 commit comments

Comments
 (0)