Skip to content

Commit 54043b2

Browse files
committed
attn_mask for sliding_window_decode
1 parent 8adfbdc commit 54043b2

File tree

8 files changed

+25
-16
lines changed

8 files changed

+25
-16
lines changed

src/optimum/rbln/ops/sliding_window_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414

1515

16+
from typing import Optional
17+
1618
import torch
1719
from torch import Tensor
18-
from typing import Optional
1920

2021

2122
@torch.library.custom_op(
@@ -95,6 +96,7 @@ def paged_sliding_window_attn_decode(
9596
scale: Tensor,
9697
block_table: Tensor,
9798
block_size: int,
99+
attn_mask: Tensor,
98100
s_aux: Optional[Tensor] = None,
99101
) -> Tensor:
100102
return torch.empty_like(q)
@@ -112,6 +114,7 @@ def paged_sliding_window_attn_decode_fake(
112114
scale: Tensor,
113115
block_table: Tensor,
114116
block_size: int,
117+
attn_mask: Tensor,
115118
s_aux: Optional[Tensor] = None,
116119
) -> Tensor:
117120
return torch.empty_like(q)

src/optimum/rbln/transformers/models/colpali/modeling_colpali.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
import bisect
1616
from pathlib import Path
17-
from tempfile import TemporaryDirectory
18-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
17+
from typing import TYPE_CHECKING, Optional, Tuple, Union
1918

2019
import torch
2120
from transformers import PretrainedConfig, PreTrainedModel

src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,11 @@ def get_local_cache_positions(self, position_ids, query_position):
372372
torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
373373
) # cache offset for next steps
374374

375-
return cache_seq_len, cache_offset
375+
# Causal mask for sliding window attention
376+
attn_mask = torch.arange(max_cache_len)[None, :] - cache_seq_len
377+
attn_mask = torch.where(attn_mask > 0, 0.0, 1.0)[:, None, :, None]
378+
379+
return cache_seq_len, cache_offset, attn_mask
376380

377381
def get_last_layernorm(self) -> nn.LayerNorm:
378382
return self._original_mod.norm
@@ -458,7 +462,7 @@ def forward(
458462

459463
# Get local cache positions for sliding window layers
460464
if len(self.sliding_window_layers) > 0:
461-
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
465+
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position, hidden_states)
462466

463467
for layer_idx, layer in enumerate(self.layers):
464468
is_sliding = True if layer_idx in self.sliding_window_layers else False
@@ -1128,6 +1132,9 @@ def forward(
11281132
if self.phase == "prefill" or self.phase == "image_prefill":
11291133
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
11301134

1135+
if self.phase == "decode":
1136+
op_args["attn_mask"] = attn_mask
1137+
11311138
if s_aux is not None:
11321139
op_args["s_aux"] = s_aux
11331140

src/optimum/rbln/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,9 @@ def redirect(func):
9797

9898
def can_generate(self):
9999
return True
100-
101-
100+
102101
@classmethod
103102
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
104-
105103
with no_init_weights():
106104
model_cls_name = model.model.language_model.__class__.__name__
107105
causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")

src/optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
DecoderOnlyAttention,
2525
DecoderOnlyLayer,
2626
DecoderOnlyWrapper,
27-
DecoderOnlyAttention,
2827
)
2928

3029

@@ -120,7 +119,9 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
120119
hidden_states = hidden_states.repeat(num_experts, 1)
121120
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
122121

123-
gate_up = torch.bmm(hidden_states, self.gate_up_proj.to(hidden_states.dtype)) + self.gate_up_proj_bias[..., None, :].to(hidden_states.dtype)
122+
gate_up = torch.bmm(hidden_states, self.gate_up_proj.to(hidden_states.dtype)) + self.gate_up_proj_bias[
123+
..., None, :
124+
].to(hidden_states.dtype)
124125
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
125126
gate = gate.clamp(min=None, max=self.limit)
126127
up = up.clamp(min=-self.limit, max=self.limit)

src/optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Union, TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Optional, Union
1616

1717
from transformers import PretrainedConfig
1818

1919
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyModelForCausalLMConfig
2020
from .gpt_oss_architecture import RBLNGptOssWrapper
2121

22+
2223
if TYPE_CHECKING:
23-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
24-
from transformers import PreTrainedModel
24+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
2525

2626

2727
class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):

src/optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,7 @@ def save_torch_artifacts(
204204
save_dict["bbox_embed"] = model.bbox_embed.state_dict()
205205

206206
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
207-
208-
207+
209208
@classmethod
210209
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
211210
model.encoder = model.model.encoder

src/optimum/rbln/transformers/models/siglip/modeling_siglip.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ class RBLNSiglipVisionModel(RBLNModel):
6666
_tp_support = False
6767

6868
@classmethod
69-
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
69+
def _wrap_model_if_needed(
70+
cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig
71+
) -> torch.nn.Module:
7072
wrapper_cfg = {
7173
"interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
7274
"output_hidden_states": rbln_config.output_hidden_states,

0 commit comments

Comments
 (0)