|
15 | 15 | from vllm.utils import LazyDict |
16 | 16 |
|
17 | 17 |
|
| 18 | +@CustomOp.register("xielu") |
| 19 | +class XIELU(CustomOp): |
| 20 | + """ |
| 21 | + Applies the xIELU activation function |
| 22 | +
|
| 23 | + Shapes: |
| 24 | + x: (num_tokens, d) or (batch_size, seq_len, d) |
| 25 | + return: (num_tokens, d) or (batch_size, seq_len, d) |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__(self, alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5, eps=-1e-6): |
| 29 | + super().__init__() |
| 30 | + self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1.0).unsqueeze(0)) |
| 31 | + self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - beta)) - 1.0).unsqueeze(0)) |
| 32 | + self.beta = beta |
| 33 | + self.eps = torch.tensor(eps, dtype=torch.bfloat16, device='cuda') |
| 34 | + |
| 35 | + if current_platform.is_cuda_alike(): |
| 36 | + # TODO CUDA implementation under development, using forward_native for now |
| 37 | + self._forward_method = self.forward_native |
| 38 | + elif current_platform.is_cpu(): |
| 39 | + self._forward_method = self.forward_native |
| 40 | + |
| 41 | + def forward_native(self, x: torch.Tensor) -> torch.Tensor: |
| 42 | + # TODO optimize to precompute |
| 43 | + alpha_p = F.softplus(self.alpha_p) |
| 44 | + alpha_n = self.beta + F.softplus(self.alpha_n) |
| 45 | + return torch.where( |
| 46 | + x > 0, |
| 47 | + alpha_p * x * x + self.beta * x, |
| 48 | + alpha_n * torch.expm1(torch.min(x, self.eps)) - alpha_n * x + self.beta * x |
| 49 | + ) |
| 50 | + |
| 51 | + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: |
| 52 | + return |
| 53 | + |
| 54 | + |
18 | 55 | @CustomOp.register("fatrelu_and_mul") |
19 | 56 | class FatreluAndMul(CustomOp): |
20 | 57 | """An activation function for FATReLU. |
|
0 commit comments