Skip to content

Commit c2424f6

Browse files
committed
Fix exp function
1 parent fadb547 commit c2424f6

3 files changed

Lines changed: 5 additions & 5 deletions

File tree

legateboost/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@ def metric(self, y: cn.ndarray, pred: cn.ndarray, w: cn.ndarray) -> cn.ndarray:
343343
f = cn.log(pred) * (K - self.one) # undo softmax
344344
y_k = cn.full((y.size, K), -self.one / (K - self.one))
345345

346-
set_col_by_idx(y_k, y.astype(cn.int32), self.one)
347-
# y_k[cn.arange(y.size), y.astype(cn.int32)] = 1.0
346+
# y_k[cn.arange(y.size), y.astype(cn.int32)] = self.one
347+
y_k = set_col_by_idx(y_k, y.astype(cn.int32), self.one)
348348

349349
exp = cn.exp(-1 / K * cn.sum(y_k * f, axis=1))
350350
return (exp * w).sum() / w.sum()

legateboost/objectives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def gradient(self, y: cn.ndarray, pred: cn.ndarray) -> GradPair:
598598
f = cn.log(pred) * (K - 1) # undo softmax
599599
y_k = cn.full((y.size, K), -self.one / (K - self.one))
600600
labels = y.astype(cn.int32).squeeze()
601-
set_col_by_idx(y_k, labels, self.one)
601+
y_k = set_col_by_idx(y_k, labels, self.one)
602602
# y_k[cn.arange(y.size), labels] = 1.0
603603
exp = cn.exp(-1 / K * cn.sum(y_k * f, axis=1))
604604

legateboost/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def pick_col_by_idx(a: cn.ndarray, b: cn.ndarray) -> cn.ndarray:
105105
return result.sum(axis=1)
106106

107107

108-
def set_col_by_idx(a: cn.ndarray, b: cn.ndarray, delta: float) -> None:
108+
def set_col_by_idx(a: cn.ndarray, b: cn.ndarray, delta: float) -> cn.ndarray:
109109
"""Alternative implementation for a[cn.arange(b.size), b] = delta."""
110110

111111
assert a.ndim == 2
@@ -116,7 +116,7 @@ def set_col_by_idx(a: cn.ndarray, b: cn.ndarray, delta: float) -> None:
116116
bools = b[:, cn.newaxis] == range[cn.newaxis, :]
117117
a -= a * bools
118118
a += delta * bools
119-
return
119+
return a
120120

121121

122122
def mod_col_by_idx(a: cn.ndarray, b: cn.ndarray, delta: float) -> None:

0 commit comments

Comments
 (0)