Skip to content

Commit 7c585ff

Browse files
committed
fix: torch.nn.functional.gumbel_softmax frontend
1 parent 6ab5aa1 commit 7c585ff

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,20 @@ def glu(input, dim=-1):
6767

6868
@to_ivy_arrays_and_back
6969
@with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch")
70-
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
71-
gumbels = -ivy.empty_like(logits).exponential().log()
72-
gumbels = (logits + gumbels) / tau
73-
y_soft = ivy.softmax(gumbels, axis=dim)
74-
70+
def gumbel_softmax(logits, tau=1., hard=False, eps=1e-10, dim=-1):
71+
if logits.ndim == 0:
72+
return ivy.ones_like(logits)
73+
gumbel_noise = -ivy.log(
74+
-ivy.log(ivy.random_uniform(low=0, high=1, shape=logits.shape) + eps) + eps
75+
)
76+
y = (logits + gumbel_noise) / tau
77+
y_soft = ivy.softmax(y, axis=dim)
7578
if hard:
76-
indices = y_soft.max(axis=dim, keepdims=True)[1]
77-
y_hard = ivy.zeros_like(logits)
78-
updates = ivy.ones_like(indices)
79-
y_hard = ivy.scatter_nd(indices, updates, reduction="replace", out=y_hard)
80-
81-
ret = y_hard - y_soft.stop_gradient(preserve_type=True) + y_soft
82-
else:
83-
ret = y_soft
84-
85-
return ret
79+
index = ivy.argmax(y_soft, axis=dim)
80+
y_hard = ivy.one_hot(index, logits.shape[dim], axis=dim).astype(y_soft.dtype)
81+
ret = y_hard - ivy.stop_gradient(y_soft) + y_soft
82+
return ret.astype(logits.dtype)
83+
return y_soft.astype(logits.dtype)
8684

8785

8886
@to_ivy_arrays_and_back

ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,11 @@ def test_torch_glu(
307307
dtype_and_x=helpers.dtype_and_values(
308308
available_dtypes=helpers.get_dtypes("float"),
309309
),
310-
tau=st.floats(min_value=0),
310+
tau=st.floats(min_value=1e-6, max_value=10.0),
311311
hard=st.booleans(),
312312
eps=st.floats(min_value=0, max_value=1),
313-
dim=st.integers(),
313+
dim=st.integers(min_value=-1, max_value=0),
314314
test_with_out=st.just(False),
315-
test_inplace=st.booleans(),
316315
)
317316
def test_torch_gumbel_softmax(
318317
*,

0 commit comments

Comments
 (0)