Skip to content

Commit 2f0ddb6

Browse files
committed
reviews & refactor
1 parent 1018a3a commit 2f0ddb6

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from ....utils import logging
2323
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24-
from ...utils.rbln_quantization import RBLNQuantizationConfig
2524
from .configuration_lora import RBLNLoRAConfig
2625
from .lora_architecture import LoRALinear
2726

@@ -622,7 +621,6 @@ def __init__(
622621
self.head_dim = self._original_mod.head_dim
623622
self._phase = "prefill"
624623
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
625-
self.quantization = rbln_config.quantization
626624

627625
if hasattr(self._original_mod, "num_key_value_heads"):
628626
self.num_key_value_heads = self._original_mod.num_key_value_heads
@@ -689,7 +687,6 @@ def create_attention_op(self):
689687
self.use_attention_mask,
690688
self.num_key_value_heads,
691689
self.kvcache_partition_len,
692-
self.quantization,
693690
rbln_config=self.rbln_config,
694691
)
695692
elif self.attn_impl == "eager":
@@ -698,7 +695,6 @@ def create_attention_op(self):
698695
self.head_dim,
699696
self.use_attention_mask,
700697
self.num_key_value_heads,
701-
self.quantization,
702698
rbln_config=self.rbln_config,
703699
)
704700
else:
@@ -830,24 +826,27 @@ def __init__(
830826
head_dim: int,
831827
use_attention_mask: bool,
832828
num_key_value_heads: int,
833-
quantization: Optional[RBLNQuantizationConfig] = None,
834829
rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
835830
):
836831
super().__init__()
837832
self.num_heads = num_heads
838833
self.head_dim = head_dim
839834
self.num_key_value_heads = num_key_value_heads
840835
self.phase = "prefill"
841-
self.quantization = quantization
842836
self.rbln_config = rbln_config
843837
self.use_attention_mask = use_attention_mask
844838
self.attn_mask_type = rbln_config.attn_mask_type
845839
self.use_position_ids = rbln_config.use_position_ids
840+
self.quantization = rbln_config.quantization
846841

847842
def get_attn_op_name(self):
848843
phase = "decode" if self.phase == "decode" else "prefill"
849-
if self.use_attention_mask and not self.attn_mask_type == "2D":
850-
attn_op_name = "paged_attn_"
844+
845+
if self.use_attention_mask:
846+
if self.attn_mask_type == "2D":
847+
attn_op_name = "paged_causal_attn_"
848+
else:
849+
attn_op_name = "paged_attn_"
851850
else:
852851
attn_op_name = "paged_causal_attn_"
853852

@@ -964,23 +963,25 @@ def __init__(
964963
use_attention_mask: bool,
965964
num_key_value_heads: int,
966965
kvcache_partition_len: int,
967-
quantization: Optional[RBLNQuantizationConfig] = None,
968966
rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
969967
):
970968
super().__init__(
971969
num_heads=num_heads,
972970
head_dim=head_dim,
973971
use_attention_mask=use_attention_mask,
974972
num_key_value_heads=num_key_value_heads,
975-
quantization=quantization,
976973
rbln_config=rbln_config,
977974
)
978975
self.kvcache_partition_size = kvcache_partition_len
979976

980977
def get_attn_op_name(self):
981978
phase = "decode" if self.phase == "decode" else "prefill"
982-
if self.use_attention_mask and not self.attn_mask_type == "2D":
983-
attn_op_name = "paged_flash_attn_"
979+
980+
if self.use_attention_mask:
981+
if self.attn_mask_type == "2D":
982+
attn_op_name = "paged_flash_causal_attn_"
983+
else:
984+
attn_op_name = "paged_flash_attn_"
984985
else:
985986
attn_op_name = "paged_flash_causal_attn_"
986987

@@ -1071,6 +1072,23 @@ def forward(
10711072

10721073

10731074
class SlidingWindowAttentionOp(AttentionOp):
1075+
def __init__(
1076+
self,
1077+
num_heads: int,
1078+
head_dim: int,
1079+
use_attention_mask: bool,
1080+
num_key_value_heads: int,
1081+
rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
1082+
):
1083+
super().__init__(
1084+
num_heads=num_heads,
1085+
head_dim=head_dim,
1086+
use_attention_mask=use_attention_mask,
1087+
num_key_value_heads=num_key_value_heads,
1088+
rbln_config=rbln_config,
1089+
)
1090+
self.quantization = None # Sliding window attention does not support quantization
1091+
10741092
def get_attn_op_name(self):
10751093
phase = "decode" if self.phase == "decode" else "prefill"
10761094
if not self.use_attention_mask:

src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,10 +445,22 @@ def _update_sliding_window_config(
445445
# Returns:
446446
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
447447

448-
raise NotImplementedError(
449-
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
450-
"See method docstring for required configuration details."
448+
rbln_config.sliding_window = model_config.sliding_window
449+
sliding_window_layers = []
450+
451+
for i in range(model_config.num_hidden_layers):
452+
if hasattr(model_config, "layer_types"):
453+
if model_config.layer_types[i] == "sliding_attention":
454+
sliding_window_layers.append(i)
455+
else:
456+
sliding_window_layers.append(i)
457+
458+
rbln_config.sliding_window_layers = sliding_window_layers
459+
460+
rbln_config.cache_impl = (
461+
"sliding_window" if len(sliding_window_layers) == model_config.num_hidden_layers else "hybrid"
451462
)
463+
return rbln_config
452464

453465
@classmethod
454466
def _update_attention_config(

0 commit comments

Comments
 (0)