Skip to content

Commit eb65fd6

Browse files
committed
cleaned code issues
1 parent 3bf0247 commit eb65fd6

File tree

4 files changed

+64
-60
lines changed

4 files changed

+64
-60
lines changed

docs/source/conf.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
# -- Project information -----------------------------------------------------
1919

20-
project = 'bitermplus'
21-
copyright = '2021, Maksim Terpilowski'
22-
author = 'Maksim Terpilowski'
20+
project = "bitermplus"
21+
author = "Maksim Terpilovskii"
2322

2423

2524
# -- General configuration ---------------------------------------------------
@@ -28,12 +27,12 @@
2827
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
2928
# ones.
3029
extensions = [
31-
'sphinx.ext.autosummary',
32-
'sphinx.ext.napoleon',
30+
"sphinx.ext.autosummary",
31+
"sphinx.ext.napoleon",
3332
]
3433

3534
# Add any paths that contain templates here, relative to this directory.
36-
templates_path = ['_templates']
35+
templates_path = ["_templates"]
3736

3837
# List of patterns, relative to source directory, that match files and
3938
# directories to ignore when looking for source files.
@@ -46,9 +45,9 @@
4645
# The theme to use for HTML and HTML Help pages. See the documentation for
4746
# a list of builtin themes.
4847
#
49-
html_theme = 'sphinx_rtd_theme'
48+
html_theme = "sphinx_rtd_theme"
5049

5150
# Add any paths that contain custom static files (such as style sheets) here,
5251
# relative to this directory. They are copied after the builtin static files,
5352
# so a file named "default.css" will overwrite the builtin "default.css".
54-
html_static_path = ['_static']
53+
html_static_path = ["_static"]

src/bitermplus/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__version__ = "0.8.0"
22

3-
from ._btm import BTM
4-
from ._util import *
5-
from ._metrics import *
6-
from ._api import BTMClassifier
3+
from ._btm import BTM # noqa: F401, F403
4+
from ._util import * # noqa: F401, F403
5+
from ._metrics import * # noqa: F401, F403
6+
from ._api import BTMClassifie # noqa: F401, F403r

src/bitermplus/_api.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Sklearn-style API for Biterm Topic Model."""
22

3-
__all__ = ['BTMClassifier']
3+
__all__ = ["BTMClassifier"]
44

55
from typing import List, Union, Optional, Dict, Any
66
import numpy as np
@@ -75,7 +75,7 @@ def __init__(
7575
window_size: int = 15,
7676
has_background: bool = False,
7777
coherence_window: int = 20,
78-
vectorizer_params: Optional[Dict[str, Any]] = None
78+
vectorizer_params: Optional[Dict[str, Any]] = None,
7979
):
8080
self.n_topics = n_topics
8181
self.beta = beta
@@ -110,11 +110,11 @@ def _validate_params(self):
110110
def _setup_vectorizer(self):
111111
"""Initialize the vectorizer with default parameters."""
112112
default_params = {
113-
'lowercase': True,
114-
'token_pattern': r'\b[a-zA-Z][a-zA-Z0-9]*\b',
115-
'min_df': 2,
116-
'max_df': 0.95,
117-
'stop_words': 'english'
113+
"lowercase": True,
114+
"token_pattern": r"\b[a-zA-Z][a-zA-Z0-9]*\b",
115+
"min_df": 2,
116+
"max_df": 0.95,
117+
"stop_words": "english",
118118
}
119119
default_params.update(self.vectorizer_params)
120120
return CountVectorizer(**default_params)
@@ -147,7 +147,7 @@ def fit(self, X: Union[List[str], pd.Series], y=None):
147147

148148
# Vectorize documents
149149
self.vectorizer_ = self._setup_vectorizer()
150-
doc_term_matrix, vocabulary, vocab_dict = get_words_freqs(X, **self.vectorizer_params)
150+
doc_term_matrix, vocabulary, _ = get_words_freqs(X, **self.vectorizer_params)
151151

152152
# Store vocabulary information
153153
self.vocabulary_ = vocabulary
@@ -171,14 +171,16 @@ def fit(self, X: Union[List[str], pd.Series], y=None):
171171
beta=self.beta,
172172
seed=self.random_state or 0,
173173
win=self.window_size,
174-
has_background=self.has_background
174+
has_background=self.has_background,
175175
)
176176

177177
self.model_.fit(biterms, iterations=self.max_iter, verbose=True)
178178

179179
return self
180180

181-
def transform(self, X: Union[List[str], pd.Series], infer_type: str = 'sum_b') -> np.ndarray:
181+
def transform(
182+
self, X: Union[List[str], pd.Series], infer_type: str = "sum_b"
183+
) -> np.ndarray:
182184
"""Transform documents to topic distribution.
183185
184186
Parameters
@@ -193,7 +195,7 @@ def transform(self, X: Union[List[str], pd.Series], infer_type: str = 'sum_b') -
193195
doc_topic_matrix : np.ndarray of shape (n_documents, n_topics)
194196
Document-topic probability matrix.
195197
"""
196-
check_is_fitted(self, 'model_')
198+
check_is_fitted(self, "model_")
197199

198200
# Convert input to list of strings
199201
if isinstance(X, pd.Series):
@@ -207,7 +209,9 @@ def transform(self, X: Union[List[str], pd.Series], infer_type: str = 'sum_b') -
207209
# Transform using BTM model
208210
return self.model_.transform(docs_vec, infer_type=infer_type, verbose=False)
209211

210-
def fit_transform(self, X: Union[List[str], pd.Series], y=None, infer_type: str = 'sum_b') -> np.ndarray:
212+
def fit_transform(
213+
self, X: Union[List[str], pd.Series], y=None, infer_type: str = "sum_b"
214+
) -> np.ndarray:
211215
"""Fit model and transform documents in one step.
212216
213217
Parameters
@@ -226,7 +230,9 @@ def fit_transform(self, X: Union[List[str], pd.Series], y=None, infer_type: str
226230
"""
227231
return self.fit(X).transform(X, infer_type=infer_type)
228232

229-
def get_topic_words(self, topic_id: Optional[int] = None, n_words: int = 10) -> Union[List[str], Dict[int, List[str]]]:
233+
def get_topic_words(
234+
self, topic_id: Optional[int] = None, n_words: int = 10
235+
) -> Union[List[str], Dict[int, List[str]]]:
230236
"""Get top words for topics.
231237
232238
Parameters
@@ -243,7 +249,7 @@ def get_topic_words(self, topic_id: Optional[int] = None, n_words: int = 10) ->
243249
If topic_id is provided, returns list of top words for that topic.
244250
Otherwise, returns dict mapping topic_id to list of words.
245251
"""
246-
check_is_fitted(self, 'model_')
252+
check_is_fitted(self, "model_")
247253

248254
topic_word_matrix = self.model_.matrix_topics_words_
249255

@@ -259,7 +265,9 @@ def get_topic_words(self, topic_id: Optional[int] = None, n_words: int = 10) ->
259265
result[t] = self.vocabulary_[word_indices].tolist()
260266
return result
261267

262-
def get_document_topics(self, X: Union[List[str], pd.Series], threshold: float = 0.1) -> List[List[int]]:
268+
def get_document_topics(
269+
self, X: Union[List[str], pd.Series], threshold: float = 0.1
270+
) -> List[List[int]]:
263271
"""Get dominant topics for documents.
264272
265273
Parameters
@@ -286,19 +294,19 @@ def get_document_topics(self, X: Union[List[str], pd.Series], threshold: float =
286294
@property
287295
def coherence_(self) -> np.ndarray:
288296
"""Topic coherence scores."""
289-
check_is_fitted(self, 'model_')
297+
check_is_fitted(self, "model_")
290298
return self.model_.coherence_
291299

292300
@property
293301
def perplexity_(self) -> float:
294302
"""Model perplexity."""
295-
check_is_fitted(self, 'model_')
303+
check_is_fitted(self, "model_")
296304
return self.model_.perplexity_
297305

298306
@property
299307
def topic_word_matrix_(self) -> np.ndarray:
300308
"""Topic-word probability matrix."""
301-
check_is_fitted(self, 'model_')
309+
check_is_fitted(self, "model_")
302310
return self.model_.matrix_topics_words_
303311

304312
def score(self, X: Union[List[str], pd.Series], y=None) -> float:
@@ -316,5 +324,6 @@ def score(self, X: Union[List[str], pd.Series], y=None) -> float:
316324
score : float
317325
Mean coherence score across topics.
318326
"""
319-
check_is_fitted(self, 'model_')
320-
return float(np.mean(self.coherence_))
327+
check_is_fitted(self, "model_")
328+
return float(np.mean(self.coherence_))
329+

tests/test_sklearn_api.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def setUp(self):
2828
"reinforcement learning agents learn through trial and error",
2929
"supervised learning uses labeled training data",
3030
"unsupervised learning finds hidden patterns in data",
31-
"feature engineering improves model performance significantly"
31+
"feature engineering improves model performance significantly",
3232
]
3333

3434
def test_init_default_params(self):
@@ -43,11 +43,7 @@ def test_init_default_params(self):
4343
def test_init_custom_params(self):
4444
"""Test initialization with custom parameters."""
4545
model = BTMClassifier(
46-
n_topics=5,
47-
alpha=0.1,
48-
beta=0.05,
49-
max_iter=100,
50-
random_state=42
46+
n_topics=5, alpha=0.1, beta=0.05, max_iter=100, random_state=42
5147
)
5248
self.assertEqual(model.n_topics, 5)
5349
self.assertEqual(model.alpha, 0.1)
@@ -74,9 +70,9 @@ def test_fit_basic(self):
7470
model = BTMClassifier(n_topics=3, random_state=42, max_iter=50)
7571
model.fit(self.sample_texts)
7672

77-
self.assertTrue(hasattr(model, 'model_'))
78-
self.assertTrue(hasattr(model, 'vocabulary_'))
79-
self.assertTrue(hasattr(model, 'n_features_in_'))
73+
self.assertTrue(hasattr(model, "model_"))
74+
self.assertTrue(hasattr(model, "vocabulary_"))
75+
self.assertTrue(hasattr(model, "n_features_in_"))
8076
self.assertGreater(model.n_features_in_, 0)
8177

8278
def test_fit_with_pandas_series(self):
@@ -85,8 +81,8 @@ def test_fit_with_pandas_series(self):
8581
model = BTMClassifier(n_topics=3, random_state=42, max_iter=50)
8682
model.fit(texts_series)
8783

88-
self.assertTrue(hasattr(model, 'model_'))
89-
self.assertTrue(hasattr(model, 'vocabulary_'))
84+
self.assertTrue(hasattr(model, "model_"))
85+
self.assertTrue(hasattr(model, "vocabulary_"))
9086

9187
def test_fit_empty_input(self):
9288
"""Test fitting with empty input."""
@@ -110,7 +106,7 @@ def test_transform_different_inference_types(self):
110106
model = BTMClassifier(n_topics=3, random_state=42, max_iter=50)
111107
model.fit(self.sample_texts)
112108

113-
for infer_type in ['sum_b', 'sum_w', 'mix']:
109+
for infer_type in ["sum_b", "sum_w", "mix"]:
114110
doc_topics = model.transform(self.sample_texts[:3], infer_type=infer_type)
115111
self.assertEqual(doc_topics.shape, (3, 3))
116112
self.assertTrue(np.all(doc_topics >= 0))
@@ -143,7 +139,7 @@ def test_get_topic_words_all_topics(self):
143139

144140
self.assertIsInstance(words_dict, dict)
145141
self.assertEqual(len(words_dict), 3)
146-
for topic_id, words in words_dict.items():
142+
for _, words in words_dict.items():
147143
self.assertIsInstance(words, list)
148144
self.assertEqual(len(words), 5)
149145

@@ -204,41 +200,40 @@ def test_sklearn_compatibility(self):
204200
# This tests that the estimator interface is correct
205201
scores = cross_val_score(model, self.sample_texts, cv=2, scoring=None)
206202
self.assertEqual(len(scores), 2)
207-
except Exception as e:
203+
except Exception:
208204
# Some sklearn versions might have issues, but the interface should be correct
209-
self.assertIn('BTMClassifier', str(type(model)))
205+
self.assertIn("BTMClassifier", str(type(model)))
210206

211207
def test_pipeline_integration(self):
212208
"""Test integration with sklearn Pipeline."""
209+
213210
# Simple preprocessing function
214211
def preprocess_texts(texts):
215212
return [text.lower() for text in texts]
216213

217-
pipeline = Pipeline([
218-
('preprocess', FunctionTransformer(preprocess_texts)),
219-
('btm', BTMClassifier(n_topics=3, random_state=42, max_iter=50))
220-
])
214+
pipeline = Pipeline(
215+
[
216+
("preprocess", FunctionTransformer(preprocess_texts)),
217+
("btm", BTMClassifier(n_topics=3, random_state=42, max_iter=50)),
218+
]
219+
)
221220

222221
doc_topics = pipeline.fit_transform(self.sample_texts)
223222
self.assertEqual(doc_topics.shape, (len(self.sample_texts), 3))
224223

225224
def test_vectorizer_params(self):
226225
"""Test custom vectorizer parameters."""
227-
vectorizer_params = {
228-
'min_df': 1,
229-
'max_df': 1.0,
230-
'stop_words': None
231-
}
226+
vectorizer_params = {"min_df": 1, "max_df": 1.0, "stop_words": None}
232227

233228
model = BTMClassifier(
234229
n_topics=3,
235230
random_state=42,
236231
max_iter=50,
237-
vectorizer_params=vectorizer_params
232+
vectorizer_params=vectorizer_params,
238233
)
239234
model.fit(self.sample_texts)
240235

241-
self.assertTrue(hasattr(model, 'model_'))
236+
self.assertTrue(hasattr(model, "model_"))
242237

243238
def test_transform_unseen_data(self):
244239
"""Test transform on unseen data."""
@@ -247,7 +242,7 @@ def test_transform_unseen_data(self):
247242

248243
new_texts = [
249244
"new machine learning algorithm",
250-
"innovative data processing technique"
245+
"innovative data processing technique",
251246
]
252247

253248
doc_topics = model.transform(new_texts)
@@ -256,4 +251,5 @@ def test_transform_unseen_data(self):
256251

257252

258253
if __name__ == "__main__":
259-
unittest.main()
254+
unittest.main()
255+

0 commit comments

Comments
 (0)