2121
2222from ....utils import logging
2323from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24- from ...utils .rbln_quantization import RBLNQuantizationConfig
2524from .configuration_lora import RBLNLoRAConfig
2625from .lora_architecture import LoRALinear
2726
@@ -622,7 +621,6 @@ def __init__(
622621 self .head_dim = self ._original_mod .head_dim
623622 self ._phase = "prefill"
624623 self .scale = torch .nn .Parameter (torch .tensor (self .get_attn_scale ()))
625- self .quantization = rbln_config .quantization
626624
627625 if hasattr (self ._original_mod , "num_key_value_heads" ):
628626 self .num_key_value_heads = self ._original_mod .num_key_value_heads
@@ -689,7 +687,6 @@ def create_attention_op(self):
689687 self .use_attention_mask ,
690688 self .num_key_value_heads ,
691689 self .kvcache_partition_len ,
692- self .quantization ,
693690 rbln_config = self .rbln_config ,
694691 )
695692 elif self .attn_impl == "eager" :
@@ -698,7 +695,6 @@ def create_attention_op(self):
698695 self .head_dim ,
699696 self .use_attention_mask ,
700697 self .num_key_value_heads ,
701- self .quantization ,
702698 rbln_config = self .rbln_config ,
703699 )
704700 else :
@@ -830,24 +826,27 @@ def __init__(
830826 head_dim : int ,
831827 use_attention_mask : bool ,
832828 num_key_value_heads : int ,
833- quantization : Optional [RBLNQuantizationConfig ] = None ,
834829 rbln_config : Optional ["RBLNDecoderOnlyModelConfig" ] = None ,
835830 ):
836831 super ().__init__ ()
837832 self .num_heads = num_heads
838833 self .head_dim = head_dim
839834 self .num_key_value_heads = num_key_value_heads
840835 self .phase = "prefill"
841- self .quantization = quantization
842836 self .rbln_config = rbln_config
843837 self .use_attention_mask = use_attention_mask
844838 self .attn_mask_type = rbln_config .attn_mask_type
845839 self .use_position_ids = rbln_config .use_position_ids
840+ self .quantization = rbln_config .quantization
846841
847842 def get_attn_op_name (self ):
848843 phase = "decode" if self .phase == "decode" else "prefill"
849- if self .use_attention_mask and not self .attn_mask_type == "2D" :
850- attn_op_name = "paged_attn_"
844+
845+ if self .use_attention_mask :
846+ if self .attn_mask_type == "2D" :
847+ attn_op_name = "paged_causal_attn_"
848+ else :
849+ attn_op_name = "paged_attn_"
851850 else :
852851 attn_op_name = "paged_causal_attn_"
853852
@@ -964,23 +963,25 @@ def __init__(
964963 use_attention_mask : bool ,
965964 num_key_value_heads : int ,
966965 kvcache_partition_len : int ,
967- quantization : Optional [RBLNQuantizationConfig ] = None ,
968966 rbln_config : Optional ["RBLNDecoderOnlyModelConfig" ] = None ,
969967 ):
970968 super ().__init__ (
971969 num_heads = num_heads ,
972970 head_dim = head_dim ,
973971 use_attention_mask = use_attention_mask ,
974972 num_key_value_heads = num_key_value_heads ,
975- quantization = quantization ,
976973 rbln_config = rbln_config ,
977974 )
978975 self .kvcache_partition_size = kvcache_partition_len
979976
980977 def get_attn_op_name (self ):
981978 phase = "decode" if self .phase == "decode" else "prefill"
982- if self .use_attention_mask and not self .attn_mask_type == "2D" :
983- attn_op_name = "paged_flash_attn_"
979+
980+ if self .use_attention_mask :
981+ if self .attn_mask_type == "2D" :
982+ attn_op_name = "paged_flash_causal_attn_"
983+ else :
984+ attn_op_name = "paged_flash_attn_"
984985 else :
985986 attn_op_name = "paged_flash_causal_attn_"
986987
@@ -1071,6 +1072,23 @@ def forward(
10711072
10721073
10731074class SlidingWindowAttentionOp (AttentionOp ):
1075+ def __init__ (
1076+ self ,
1077+ num_heads : int ,
1078+ head_dim : int ,
1079+ use_attention_mask : bool ,
1080+ num_key_value_heads : int ,
1081+ rbln_config : Optional ["RBLNDecoderOnlyModelConfig" ] = None ,
1082+ ):
1083+ super ().__init__ (
1084+ num_heads = num_heads ,
1085+ head_dim = head_dim ,
1086+ use_attention_mask = use_attention_mask ,
1087+ num_key_value_heads = num_key_value_heads ,
1088+ rbln_config = rbln_config ,
1089+ )
1090+ self .quantization = None # Sliding window attention does not support quantization
1091+
10741092 def get_attn_op_name (self ):
10751093 phase = "decode" if self .phase == "decode" else "prefill"
10761094 if not self .use_attention_mask :
0 commit comments