Skip to content

Commit 9d9fba3

Browse files
Fix division by zero issues in divergence calculations
1 parent 2fb9f93 commit 9d9fba3

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

src/sdialog/evaluation/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def _cs_divergence(p1, p2, resolution=100, bw_method=1):
109109
p2_vals = p2_kernel(r)
110110
numerator = np.sum(p1_vals * p2_vals)
111111
denominator = sqrt(np.sum(p1_vals ** 2) * np.sum(p2_vals ** 2))
112-
# Avoid log(0) by ensuring numerator has minimum value
113-
return -log(max(numerator, 1e-12) / denominator)
112+
# Avoid log(0) and division by zero by ensuring minimum values
113+
return -log(max(numerator, 1e-12) / max(denominator, 1e-12))
114114

115115

116116
def _kl_divergence(p1, p2, resolution=100, bw_method=1e-1):
@@ -157,7 +157,11 @@ def _kl_divergence(p1, p2, resolution=100, bw_method=1e-1):
157157
p1_vals = np.clip(p1_vals, eps, None)
158158
p2_vals = np.clip(p2_vals, eps, None)
159159

160-
return float(np.sum(p1_vals * np.log(p1_vals / p2_vals)) / np.sum(p1_vals))
160+
sum_p1_vals = np.sum(p1_vals)
161+
# Protect against division by zero
162+
if sum_p1_vals == 0:
163+
return 0.0
164+
return float(np.sum(p1_vals * np.log(p1_vals / p2_vals)) / sum_p1_vals)
161165

162166

163167
class ConversationalFeatures(BaseDialogScore):

0 commit comments

Comments
 (0)