Skip to content

Commit 78d1f22

Browse files
committed
fix: enhance numerical stability and robustness
- Add configurable epsilon parameter for division-by-zero protection - Fix vectorizer parameter consistency in sklearn API - Improve normalization stability with edge case handling - Add robust input validation for empty documents/biterms - Unify random seed handling to prevent timing issues - Enhanced error messages for better user experience
1 parent 5585cda commit 78d1f22

File tree

4 files changed

+80
-17
lines changed

4 files changed

+80
-17
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[build-system]
2-
requires = ["setuptools>=64", "wheel", "cython>=0.29.0", "numpy>=1.19.0"]
2+
requires = ["setuptools>=77", "wheel", "cython>=0.29.0", "numpy>=1.19.0"]
33
build-backend = "setuptools.build_meta"
44

55
[project]
@@ -8,7 +8,7 @@ dynamic = ["version"]
88
description = "Biterm Topic Model with sklearn-compatible API"
99
readme = "README.md"
1010
requires-python = ">=3.8"
11-
license.file = "LICENSE"
11+
license-files = ["LICENSE"]
1212
authors = [
1313
{ name = "Maksim Terpilovskii", email = "[email protected]" },
1414
]

src/bitermplus/_api.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class BTMClassifier(BaseEstimator, TransformerMixin):
4141
Number of top words for coherence calculation.
4242
vectorizer_params : dict, default=None
4343
Parameters to pass to CountVectorizer for preprocessing.
44+
epsilon : float, default=1e-10
45+
Small constant to prevent numerical issues (division by zero, etc.).
4446
4547
Attributes
4648
----------
@@ -76,6 +78,7 @@ def __init__(
7678
has_background: bool = False,
7779
coherence_window: int = 20,
7880
vectorizer_params: Optional[Dict[str, Any]] = None,
81+
epsilon: float = 1e-10,
7982
):
8083
self.n_topics = n_topics
8184
self.beta = beta
@@ -85,6 +88,7 @@ def __init__(
8588
self.has_background = has_background
8689
self.coherence_window = coherence_window
8790
self.vectorizer_params = vectorizer_params or {}
91+
self.epsilon = epsilon
8892

8993
# Validate parameters before calculating alpha
9094
self._validate_params()
@@ -106,13 +110,15 @@ def _validate_params(self):
106110
raise ValueError("window_size must be positive")
107111
if self.coherence_window <= 0:
108112
raise ValueError("coherence_window must be positive")
113+
if self.epsilon <= 0:
114+
raise ValueError("epsilon must be positive")
109115

110116
def _setup_vectorizer(self):
111117
"""Initialize the vectorizer with default parameters."""
112118
default_params = {
113119
"lowercase": True,
114120
"token_pattern": r"\b[a-zA-Z][a-zA-Z0-9]*\b",
115-
"min_df": 2,
121+
"min_df": 1, # Changed from 2 to work with small datasets
116122
"max_df": 0.95,
117123
"stop_words": "english",
118124
}
@@ -145,9 +151,11 @@ def fit(self, X: Union[List[str], pd.Series], y=None):
145151
if len(X) == 0:
146152
raise ValueError("Input documents cannot be empty")
147153

148-
# Vectorize documents
154+
# Vectorize documents using the configured vectorizer
149155
self.vectorizer_ = self._setup_vectorizer()
150-
doc_term_matrix, vocabulary, _ = get_words_freqs(X, **self.vectorizer_params)
156+
doc_term_matrix = self.vectorizer_.fit_transform(X)
157+
vocabulary = np.array(self.vectorizer_.get_feature_names_out())
158+
vocab_dict = self.vectorizer_.vocabulary_
151159

152160
# Store vocabulary information
153161
self.vocabulary_ = vocabulary
@@ -172,6 +180,7 @@ def fit(self, X: Union[List[str], pd.Series], y=None):
172180
seed=self.random_state or 0,
173181
win=self.window_size,
174182
has_background=self.has_background,
183+
epsilon=self.epsilon,
175184
)
176185

177186
self.model_.fit(biterms, iterations=self.max_iter, verbose=True)

src/bitermplus/_btm.pyx

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ cdef class BTM:
5555
Biterms generation window.
5656
has_background : bool = False
5757
Use a background topic to accumulate highly frequent words.
58+
epsilon : double = 1e-10
59+
Small constant to prevent numerical issues (division by zero, etc.).
5860
"""
5961
cdef:
6062
n_dw
@@ -75,13 +77,15 @@ cdef class BTM:
7577
int[:, :] B
7678
int iters
7779
unsigned int seed
80+
object rng # Numpy random generator
81+
double epsilon # Small constant to prevent numerical issues
7882

7983
# cdef dict __dict__
8084

8185
def __init__(
8286
self, n_dw, vocabulary, int T, int M=20,
8387
double alpha=1., double beta=0.01, unsigned int seed=0,
84-
int win=15, bint has_background=False):
88+
int win=15, bint has_background=False, double epsilon=1e-10):
8589
self.n_dw = n_dw
8690
self.vocabulary = vocabulary
8791
self.T = T
@@ -91,6 +95,9 @@ cdef class BTM:
9195
self.beta = beta
9296
self.win = win
9397
self.seed = seed
98+
self.epsilon = epsilon
99+
# Initialize RNG once to avoid time-based seed issues
100+
self.rng = np.random.default_rng(self.seed if self.seed else time(NULL))
94101
self.p_wb = np.asarray(n_dw.sum(axis=0) / n_dw.sum())[0]
95102
self.p_z = array(
96103
shape=(self.T, ), itemsize=sizeof(double), format="d",
@@ -133,7 +140,9 @@ cdef class BTM:
133140
'p_zd': np.asarray(self.p_zd),
134141
'p_wz': np.asarray(self.p_wz),
135142
'p_wb': np.asarray(self.p_wb),
136-
'p_z': np.asarray(self.p_z)
143+
'p_z': np.asarray(self.p_z),
144+
'seed': self.seed,
145+
'epsilon': self.epsilon
137146
}
138147

139148
def __setstate__(self, state):
@@ -154,11 +163,14 @@ cdef class BTM:
154163
self.p_wz = state.get('p_wz')
155164
self.p_wb = state.get('p_wb')
156165
self.p_z = state.get('p_z')
166+
self.seed = state.get('seed', 0)
167+
self.epsilon = state.get('epsilon', 1e-10)
168+
# Reinitialize RNG after unpickling
169+
self.rng = np.random.default_rng(self.seed if self.seed else time(NULL))
157170

158171
cdef int[:, :] _biterms_to_array(self, list B):
159-
rng = np.random.default_rng(self.seed if self.seed else time(NULL))
160172
arr = np.asarray(list(chain(*B)), dtype=np.int32)
161-
random_topics = rng.integers(
173+
random_topics = self.rng.integers(
162174
low=0, high=self.T, size=(arr.shape[0], 1), dtype=np.int32)
163175
arr = np.append(arr, random_topics, axis=1)
164176
return arr
@@ -172,7 +184,7 @@ cdef class BTM:
172184
for k in range(self.T):
173185
for w in range(self.W):
174186
self.p_wz[k][w] = (self.n_wz[k][w] + self.beta) / \
175-
(self.n_bz[k] * 2. + self.W * self.beta)
187+
max(self.n_bz[k] * 2. + self.W * self.beta, self.epsilon)
176188

177189
@boundscheck(False)
178190
@cdivision(True)
@@ -190,11 +202,11 @@ cdef class BTM:
190202
pw2k = self.p_wb[w2]
191203
else:
192204
pw1k = (self.n_wz[k][w1] + self.beta) / \
193-
(2. * self.n_bz[k] + self.W * self.beta)
205+
max(2. * self.n_bz[k] + self.W * self.beta, self.epsilon)
194206
pw2k = (self.n_wz[k][w2] + self.beta) / \
195-
(2. * self.n_bz[k] + 1. + self.W * self.beta)
207+
max(2. * self.n_bz[k] + 1. + self.W * self.beta, self.epsilon)
196208
pk = (self.n_bz[k] + self.alpha) / \
197-
(self.B.shape[0] + self.T * self.alpha)
209+
max(self.B.shape[0] + self.T * self.alpha, self.epsilon)
198210
p_z[k] = pk * pw1k * pw2k
199211

200212
# return p_z # self._normalize(p_z)
@@ -213,8 +225,19 @@ cdef class BTM:
213225
for i in range(num):
214226
p_sum += p[i]
215227

228+
# Handle edge cases where sum is zero or very small
229+
# Uniform distribution if all probabilities are zero/tiny
230+
if p_sum <= self.epsilon:
231+
for i in range(num):
232+
p[i] = 1.0 / num
233+
return
234+
235+
cdef double denominator = p_sum + num * smoother
236+
if denominator <= self.epsilon:
237+
denominator = self.epsilon
238+
216239
for i in range(num):
217-
p[i] = (p[i] + smoother) / (p_sum + num * smoother)
240+
p[i] = (p[i] + smoother) / denominator
218241

219242
@initializedcheck(False)
220243
@boundscheck(False)
@@ -231,6 +254,22 @@ cdef class BTM:
231254
verbose : bool = True
232255
Show progress bar.
233256
"""
257+
# Validate that we have biterms to work with
258+
if not Bs:
259+
raise ValueError("Cannot fit model: no biterms available. "
260+
"Check that documents have sufficient vocabulary overlap and length.")
261+
262+
# Check if all biterm lists are empty
263+
cdef bint has_biterms = False
264+
for doc_biterms in Bs:
265+
if len(doc_biterms) > 0:
266+
has_biterms = True
267+
break
268+
269+
if not has_biterms:
270+
raise ValueError("Cannot fit model: no biterms available. "
271+
"Check that documents have sufficient vocabulary overlap and length.")
272+
234273
self.B = self._biterms_to_array(Bs)
235274
# rng = np.random.default_rng(self.seed if self.seed else time(NULL))
236275
# random_factors = rng.random(
@@ -247,7 +286,6 @@ cdef class BTM:
247286
shape=(B_len, ), itemsize=sizeof(double), format="d",
248287
allocate_buffer=True)
249288

250-
rng = np.random.default_rng(self.seed if self.seed else time(NULL))
251289
trange = tqdm.trange if verbose else range
252290

253291
for i in range(B_len):
@@ -259,7 +297,7 @@ cdef class BTM:
259297
self.n_wz[topic][w2] += 1
260298

261299
for j in trange(iterations):
262-
rnd_uniform = rng.uniform(0, 1, B_len)
300+
rnd_uniform = self.rng.uniform(0, 1, B_len)
263301
for i in range(B_len):
264302
w1 = self.B[i, 0]
265303
w2 = self.B[i, 1]
@@ -616,3 +654,8 @@ cdef class BTM:
616654
def labels_(self) -> np.ndarray:
617655
"""Model document labels (most probable topic for each document)."""
618656
return np.asarray(self.p_zd).argmax(axis=1)
657+
658+
@property
659+
def epsilon_(self) -> float:
660+
"""Numerical stability constant (epsilon) used to prevent division by zero."""
661+
return self.epsilon

src/bitermplus/_util.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ def _parse_words(w):
9090

9191
result = []
9292
for doc in docs:
93-
word_ids = [vocab_idx[word] for word in doc.split() if word in vocab_idx]
93+
# Handle potential None/empty doc and filter out empty strings
94+
if doc is None:
95+
doc = ""
96+
words = [word.strip() for word in doc.split() if word.strip()]
97+
word_ids = [vocab_idx[word] for word in words if word in vocab_idx]
9498
result.append(np.array(word_ids, dtype=np.int32))
9599
return result
96100

@@ -139,6 +143,13 @@ def get_biterms(
139143
wj = max(doc[i], doc[j])
140144
doc_biterms.append([wi, wj])
141145
biterms.append(doc_biterms)
146+
147+
# Check if we have any biterms at all
148+
total_biterms = sum(len(doc_biterms) for doc_biterms in biterms)
149+
if total_biterms == 0:
150+
raise ValueError("No biterms could be generated from the documents. "
151+
"Documents may be too short or have insufficient vocabulary overlap.")
152+
142153
return biterms
143154

144155

0 commit comments

Comments
 (0)