Skip to content

Commit 63a2129

Browse files
committed
fixed errors in saliency calculation method
1 parent 50ae7b1 commit 63a2129

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/tmplot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
from ._vis import * # noqa: F401, F403
66
from ._metrics import * # noqa: F401, F403
77

8-
__version__ = '0.1.3'
8+
__version__ = '0.2.0'

src/tmplot/_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,10 @@ def get_salient_terms(phi: ndarray, theta: ndarray) -> ndarray:
319319
p_w = calc_terms_marg_probs(phi, p_t)
320320

321321
def _p_tw(phi, w, t):
322-
return phi[w, t] * p_t[t] / p_w[w]
322+
return array(phi)[w, t] * p_t[t] / p_w[w]
323323

324324
saliency = array(
325-
(
325+
[
326326
p_w[w]
327327
* sum(
328328
(
@@ -331,7 +331,7 @@ def _p_tw(phi, w, t):
331331
)
332332
)
333333
for w in range(phi.shape[0])
334-
)
334+
]
335335
)
336336
# saliency(term w) = frequency(w)
337337
# * [sum_t p(t | w) * log(p(t | w)/p(t))] for topics t

0 commit comments

Comments
 (0)