Skip to content

Commit 836f92e

Browse files
committed
topk casting
1 parent 6ef2004 commit 836f92e

File tree

5 files changed

+17
-10
lines changed

5 files changed

+17
-10
lines changed

examples/text2text-generation/run_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main(
6666
rbln_tensor_parallel_size=tensor_parallel_size,
6767
rbln_kvcache_partition_len=kvcache_partition_len,
6868
config=target_config,
69-
dtype=torch.float32,
69+
# dtype=torch.float32,
7070
)
7171
model.save_pretrained(os.path.basename(model_id))
7272
else:

src/optimum/rbln/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,9 @@ def redirect(func):
9797

9898
def can_generate(self):
9999
return True
100-
101-
100+
102101
@classmethod
103102
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
104-
105103
with no_init_weights():
106104
model_cls_name = model.model.language_model.__class__.__name__
107105
causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")

src/optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,17 @@ def __init__(self, model):
6868
self.weight = model.weight
6969
self.bias = model.bias
7070

71+
def casted_top_K(self, router_logits, hidden_states):
72+
logits = router_logits.to(torch.float32)
73+
router_top_value, router_indices = torch.topk(logits, self.top_k, dim=-1)
74+
75+
return router_top_value.to(hidden_states.dtype), router_indices
76+
7177
def forward(self, hidden_states):
7278
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
7379
router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
74-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
75-
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
80+
router_top_value, router_indices = self.casted_top_K(router_logits, hidden_states)
81+
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=hidden_states.dtype)
7682
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
7783

7884
return router_scores, router_indices
@@ -120,7 +126,9 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
120126
hidden_states = hidden_states.repeat(num_experts, 1)
121127
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
122128

123-
gate_up = torch.bmm(hidden_states, self.gate_up_proj.to(hidden_states.dtype)) + self.gate_up_proj_bias[..., None, :].to(hidden_states.dtype)
129+
gate_up = torch.bmm(hidden_states, self.gate_up_proj.to(hidden_states.dtype)) + self.gate_up_proj_bias[
130+
..., None, :
131+
].to(hidden_states.dtype)
124132
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
125133
gate = gate.clamp(min=None, max=self.limit)
126134
up = up.clamp(min=-self.limit, max=self.limit)

src/optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,7 @@ def save_torch_artifacts(
204204
save_dict["bbox_embed"] = model.bbox_embed.state_dict()
205205

206206
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
207-
208-
207+
209208
@classmethod
210209
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
211210
model.encoder = model.model.encoder

src/optimum/rbln/transformers/models/siglip/modeling_siglip.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ class RBLNSiglipVisionModel(RBLNModel):
6666
_tp_support = False
6767

6868
@classmethod
69-
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
69+
def _wrap_model_if_needed(
70+
cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig
71+
) -> torch.nn.Module:
7072
wrapper_cfg = {
7173
"interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
7274
"output_hidden_states": rbln_config.output_hidden_states,

0 commit comments

Comments
 (0)