Skip to content

Commit cf95e02

Browse files
committed
patch for dtype when jit trace
1 parent d6ca980 commit cf95e02

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,13 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
120120
hidden_states = hidden_states.repeat(num_experts, 1)
121121
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
122122

123-
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
123+
gate_up = torch.bmm(hidden_states, self.gate_up_proj.to(hidden_states.dtype)) + self.gate_up_proj_bias[..., None, :].to(hidden_states.dtype)
124124
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
125125
gate = gate.clamp(min=None, max=self.limit)
126126
up = up.clamp(min=-self.limit, max=self.limit)
127127
glu = gate * torch.sigmoid(gate * self.alpha)
128-
next_states = torch.bmm(((up + 1.0) * glu), self.down_proj)
129-
next_states = next_states + self.down_proj_bias[..., None, :]
128+
next_states = torch.bmm(((up + 1.0) * glu), self.down_proj.to(hidden_states.dtype))
129+
next_states = next_states + self.down_proj_bias[..., None, :].to(hidden_states.dtype)
130130
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
131131
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
132132
next_states = next_states.sum(dim=0)

0 commit comments

Comments
 (0)