Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 8d190e4

Browse files
Laurentdeltheil
authored andcommitted
(fluxion/layers/activations) replace ApproximateGeLU by GeLUApproximation
1 parent 2bdb42e commit 8d190e4

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

src/refiners/fluxion/layers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from refiners.fluxion.layers.activations import (
22
GLU,
33
Activation,
4-
ApproximateGeLU,
54
GeLU,
5+
GeLUApproximation,
66
ReLU,
77
Sigmoid,
88
SiLU,
@@ -64,10 +64,10 @@
6464
"InstanceNorm2d",
6565
"Activation",
6666
"GeLU",
67+
"GeLUApproximation",
6768
"GLU",
6869
"SiLU",
6970
"ReLU",
70-
"ApproximateGeLU",
7171
"Sigmoid",
7272
"Attention",
7373
"ScaledDotProductAttention",

src/refiners/fluxion/layers/activations.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,24 +97,21 @@ class GeLU(Activation):
9797
```
9898
"""
9999

100-
def __init__(self) -> None:
101-
super().__init__()
102-
103-
def forward(self, x: Tensor) -> Tensor:
104-
return gelu(x) # type: ignore
105-
106-
107-
class ApproximateGeLU(Activation):
108-
"""
109-
The approximate form of Gaussian Error Linear Unit (GELU)
110-
For more details, see section 2: https://arxiv.org/abs/1606.08415
111-
"""
112-
113-
def __init__(self) -> None:
100+
def __init__(
101+
self,
102+
approximation: GeLUApproximation = GeLUApproximation.NONE,
103+
) -> None:
114104
super().__init__()
105+
self.approximation = approximation
115106

116107
def forward(self, x: Tensor) -> Tensor:
117-
return x * sigmoid(1.702 * x)
108+
match self.approximation:
109+
case GeLUApproximation.NONE:
110+
return gelu(x, approximate="none")
111+
case GeLUApproximation.TANH:
112+
return gelu(x, approximate="tanh")
113+
case GeLUApproximation.SIGMOID:
114+
return x * sigmoid(1.702 * x)
118115

119116

120117
class Sigmoid(Activation):

src/refiners/foundationals/clip/text_encoder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ def __init__(
146146
)
147147
if use_quick_gelu:
148148
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
149-
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
149+
parent.replace(
150+
old_module=gelu,
151+
new_module=fl.GeLU(approximation=fl.GeLUApproximation.SIGMOID),
152+
)
150153

151154

152155
class CLIPTextEncoderL(CLIPTextEncoder):

0 commit comments

Comments
 (0)