Skip to content

Commit 0816102

Browse files
authored
Delete Extra Softplus (Iainmon#63)
This looks good.
2 parents 9f1585f + c52ef00 commit 0816102

2 files changed

Lines changed: 14 additions & 14 deletions

File tree

lib/NDArray.chpl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -986,19 +986,19 @@ inline proc ndarray.threshold(threshold: eltType, value: eltType) { // PyTorch h
986986
return rl;
987987
}
988988

989-
inline proc ndarray.softplus(beta: eltType=1.0, threshold: eltType=20.0) {
990-
const ref thisData = data;
991-
const dom = this.domain;
992-
var rl = new ndarray(dom, eltType);
993-
ref rld = rl.data;
994-
forall i in dom.every() {
995-
const x = thisData[i];
996-
const floatMax: eltType = Types.max(eltType);
997-
const xgbt: eltType = Math.ceil((x - threshold / beta) / floatMax); // x greater than beta * threshold: 1 if true, 0 otherwise
998-
rld[i] = x * xgbt + 1.0 / beta * Math.log(1 + Math.exp(beta * x)) * (1 - xgbt);
999-
}
1000-
return rl;
1001-
}
989+
// inline proc ndarray.softplus(beta: eltType=1.0, threshold: eltType=20.0) {
990+
// const ref thisData = data;
991+
// const dom = this.domain;
992+
// var rl = new ndarray(dom, eltType);
993+
// ref rld = rl.data;
994+
// forall i in dom.every() {
995+
// const x = thisData[i];
996+
// const floatMax: eltType = Types.max(eltType);
997+
// const xgbt: eltType = Math.ceil((x - threshold / beta) / floatMax); // x greater than beta * threshold: 1 if true, 0 otherwise
998+
// rld[i] = x * xgbt + 1.0 / beta * Math.log(1 + Math.exp(beta * x)) * (1 - xgbt);
999+
// }
1000+
// return rl;
1001+
// }
10021002

10031003
inline proc ndarray.celu(alpha: eltType=1.0) {
10041004
const ref thisData = data;

test/correspondence/activation/softplus/softplus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def test(imports):
99
b = torch.nn.Softplus()(torch.zeros(2,3,4) - 60.0)
1010
print(b)
1111

12-
c = (torch.nn.Softplus()(torch.zeros(10).to(torch.float32) + 40.0)).to(torch.float32)
12+
c = torch.nn.Softplus()(torch.zeros(10) + 40.0)
1313
print(c)
1414

1515
# same values with alpha = -0.001

0 commit comments

Comments
 (0)