11"""Sklearn-style API for Biterm Topic Model."""
22
3- __all__ = [' BTMClassifier' ]
3+ __all__ = [" BTMClassifier" ]
44
55from typing import List , Union , Optional , Dict , Any
66import numpy as np
@@ -75,7 +75,7 @@ def __init__(
7575 window_size : int = 15 ,
7676 has_background : bool = False ,
7777 coherence_window : int = 20 ,
78- vectorizer_params : Optional [Dict [str , Any ]] = None
78+ vectorizer_params : Optional [Dict [str , Any ]] = None ,
7979 ):
8080 self .n_topics = n_topics
8181 self .beta = beta
@@ -110,11 +110,11 @@ def _validate_params(self):
110110 def _setup_vectorizer (self ):
111111 """Initialize the vectorizer with default parameters."""
112112 default_params = {
113- ' lowercase' : True ,
114- ' token_pattern' : r' \b[a-zA-Z][a-zA-Z0-9]*\b' ,
115- ' min_df' : 2 ,
116- ' max_df' : 0.95 ,
117- ' stop_words' : ' english'
113+ " lowercase" : True ,
114+ " token_pattern" : r" \b[a-zA-Z][a-zA-Z0-9]*\b" ,
115+ " min_df" : 2 ,
116+ " max_df" : 0.95 ,
117+ " stop_words" : " english" ,
118118 }
119119 default_params .update (self .vectorizer_params )
120120 return CountVectorizer (** default_params )
@@ -147,7 +147,7 @@ def fit(self, X: Union[List[str], pd.Series], y=None):
147147
148148 # Vectorize documents
149149 self .vectorizer_ = self ._setup_vectorizer ()
150- doc_term_matrix , vocabulary , vocab_dict = get_words_freqs (X , ** self .vectorizer_params )
150+ doc_term_matrix , vocabulary , _ = get_words_freqs (X , ** self .vectorizer_params )
151151
152152 # Store vocabulary information
153153 self .vocabulary_ = vocabulary
@@ -171,14 +171,16 @@ def fit(self, X: Union[List[str], pd.Series], y=None):
171171 beta = self .beta ,
172172 seed = self .random_state or 0 ,
173173 win = self .window_size ,
174- has_background = self .has_background
174+ has_background = self .has_background ,
175175 )
176176
177177 self .model_ .fit (biterms , iterations = self .max_iter , verbose = True )
178178
179179 return self
180180
181- def transform (self , X : Union [List [str ], pd .Series ], infer_type : str = 'sum_b' ) -> np .ndarray :
181+ def transform (
182+ self , X : Union [List [str ], pd .Series ], infer_type : str = "sum_b"
183+ ) -> np .ndarray :
182184 """Transform documents to topic distribution.
183185
184186 Parameters
@@ -193,7 +195,7 @@ def transform(self, X: Union[List[str], pd.Series], infer_type: str = 'sum_b') -
193195 doc_topic_matrix : np.ndarray of shape (n_documents, n_topics)
194196 Document-topic probability matrix.
195197 """
196- check_is_fitted (self , ' model_' )
198+ check_is_fitted (self , " model_" )
197199
198200 # Convert input to list of strings
199201 if isinstance (X , pd .Series ):
@@ -207,7 +209,9 @@ def transform(self, X: Union[List[str], pd.Series], infer_type: str = 'sum_b') -
207209 # Transform using BTM model
208210 return self .model_ .transform (docs_vec , infer_type = infer_type , verbose = False )
209211
210- def fit_transform (self , X : Union [List [str ], pd .Series ], y = None , infer_type : str = 'sum_b' ) -> np .ndarray :
212+ def fit_transform (
213+ self , X : Union [List [str ], pd .Series ], y = None , infer_type : str = "sum_b"
214+ ) -> np .ndarray :
211215 """Fit model and transform documents in one step.
212216
213217 Parameters
@@ -226,7 +230,9 @@ def fit_transform(self, X: Union[List[str], pd.Series], y=None, infer_type: str
226230 """
227231 return self .fit (X ).transform (X , infer_type = infer_type )
228232
229- def get_topic_words (self , topic_id : Optional [int ] = None , n_words : int = 10 ) -> Union [List [str ], Dict [int , List [str ]]]:
233+ def get_topic_words (
234+ self , topic_id : Optional [int ] = None , n_words : int = 10
235+ ) -> Union [List [str ], Dict [int , List [str ]]]:
230236 """Get top words for topics.
231237
232238 Parameters
@@ -243,7 +249,7 @@ def get_topic_words(self, topic_id: Optional[int] = None, n_words: int = 10) ->
243249 If topic_id is provided, returns list of top words for that topic.
244250 Otherwise, returns dict mapping topic_id to list of words.
245251 """
246- check_is_fitted (self , ' model_' )
252+ check_is_fitted (self , " model_" )
247253
248254 topic_word_matrix = self .model_ .matrix_topics_words_
249255
@@ -259,7 +265,9 @@ def get_topic_words(self, topic_id: Optional[int] = None, n_words: int = 10) ->
259265 result [t ] = self .vocabulary_ [word_indices ].tolist ()
260266 return result
261267
262- def get_document_topics (self , X : Union [List [str ], pd .Series ], threshold : float = 0.1 ) -> List [List [int ]]:
268+ def get_document_topics (
269+ self , X : Union [List [str ], pd .Series ], threshold : float = 0.1
270+ ) -> List [List [int ]]:
263271 """Get dominant topics for documents.
264272
265273 Parameters
@@ -286,19 +294,19 @@ def get_document_topics(self, X: Union[List[str], pd.Series], threshold: float =
286294 @property
287295 def coherence_ (self ) -> np .ndarray :
288296 """Topic coherence scores."""
289- check_is_fitted (self , ' model_' )
297+ check_is_fitted (self , " model_" )
290298 return self .model_ .coherence_
291299
292300 @property
293301 def perplexity_ (self ) -> float :
294302 """Model perplexity."""
295- check_is_fitted (self , ' model_' )
303+ check_is_fitted (self , " model_" )
296304 return self .model_ .perplexity_
297305
298306 @property
299307 def topic_word_matrix_ (self ) -> np .ndarray :
300308 """Topic-word probability matrix."""
301- check_is_fitted (self , ' model_' )
309+ check_is_fitted (self , " model_" )
302310 return self .model_ .matrix_topics_words_
303311
304312 def score (self , X : Union [List [str ], pd .Series ], y = None ) -> float :
@@ -316,5 +324,6 @@ def score(self, X: Union[List[str], pd.Series], y=None) -> float:
316324 score : float
317325 Mean coherence score across topics.
318326 """
319- check_is_fitted (self , 'model_' )
320- return float (np .mean (self .coherence_ ))
327+ check_is_fitted (self , "model_" )
328+ return float (np .mean (self .coherence_ ))
329+
0 commit comments