Skip to content

Commit ee28c4f

Browse files
committed
Update
1 parent 09567ec commit ee28c4f

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pip install git+https://github.com/r3v1/tf-SmeLU
2222

2323
````python
2424
import tensorflow as tf
25-
from smelu import smelu
25+
from tf_smelu import smelu
2626

2727
x = tf.range(-6, 6, 1, dtype=float) # <tf.Tensor: numpy=array([-6., -5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.], dtype=float32)>
2828

examples/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def main():
2525
axs[1].set_title("SmeLU gradients")
2626
axs[1].grid()
2727
plt.tight_layout()
28-
plt.savefig("example.jpg")
28+
# plt.savefig("example.jpg")
2929
plt.show()
3030

3131

tf_smelu/smelu.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ def smelu(x: Union[list, tf.Tensor], beta: float = 1.):
1414
beta: float
1515
Half-width of a symmetric transition region around x = 0
1616
17+
Examples
18+
--------
19+
>>> import tensorflow as tf
20+
>>> from tf_smelu import smelu
21+
22+
>>> x = tf.range(-6, 6, 1, dtype=float)
23+
# <tf.Tensor: numpy=array([-6., -5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.], dtype=float32)>
24+
>>> smelu(x, beta=0.1)
25+
# <tf.Tensor: numpy=array([0., 0., 0., 0., 0., 0., 0.025, 1., 2., 3., 4., 5.], dtype=float32)>
26+
1727
See Also
1828
--------
1929
- https://arxiv.org/pdf/2202.06499.pdf

0 commit comments

Comments
 (0)