10
10
from packaging import version as pkg_version
11
11
from deepspeed .utils .logging import log_dist
12
12
from deepspeed .accelerator import get_accelerator
13
- from deepspeed .ops .op_builder import InferenceBuilder
13
+ from deepspeed .ops .transformer .inference .op_binding .workspace import WorkspaceOp
14
+ from deepspeed .ops .transformer .inference .op_binding .softmax_context import SoftmaxContextOp
15
+ from deepspeed .ops .transformer .inference .op_binding import LinearOp
16
+ from deepspeed .ops .transformer .inference .op_binding .pad_transform import PadTransformOp
14
17
15
- # Cuda modules will be imported if needed
16
- inference_module = None
17
18
minus_inf = - 10000.0
18
19
triton_flash_attn = None
19
20
@@ -36,7 +37,8 @@ class DeepSpeedDiffusersAttentionFunction(Function):
36
37
@staticmethod
37
38
def forward (ctx , input , context , input_mask , config , attn_qkvw , attn_qw , attn_kw , attn_vw , attn_qkvb ,
38
39
num_attention_heads_per_partition , norm_factor , hidden_size_per_partition , attn_ow , attn_ob ,
39
- do_out_bias , score_context_func , linear_func , triton_flash_attn_kernel , rope_theta ):
40
+ do_out_bias , score_context_func , linear_func , pad_transform_func , triton_flash_attn_kernel ,
41
+ rope_theta ):
40
42
41
43
def _transpose_for_context (x ):
42
44
x = x .permute (0 , 2 , 1 , 3 )
@@ -77,7 +79,7 @@ def selfAttention_fp(input, context, input_mask):
77
79
query = query .contiguous ()
78
80
key = key .contiguous ()
79
81
value = value .contiguous ()
80
- query , key , value = inference_module . pad_transform_fp16 (query , key , value , config .heads , do_flash_attn )
82
+ query , key , value = pad_transform_func (query , key , value , config .heads , do_flash_attn )
81
83
attention_scores = (torch .matmul (query , key .transpose (- 1 , - 2 )) * scale ).softmax (dim = - 1 )
82
84
context_layer = _transpose_for_context (torch .matmul (attention_scores , value ))
83
85
@@ -117,10 +119,6 @@ def __init__(
117
119
118
120
data_type = self .config .dtype
119
121
data_type_fp = torch .half if self .config .dtype == torch .int8 else self .config .dtype
120
- global inference_module
121
- if inference_module is None :
122
- builder = InferenceBuilder ()
123
- inference_module = builder .load ()
124
122
125
123
if DeepSpeedDiffusersAttention .layer_id == 1 :
126
124
log_dist (f"DeepSpeed-Attention config: { self .config .__dict__ } " , [0 ])
@@ -171,26 +169,24 @@ def __init__(
171
169
self .norm_factor *= math .sqrt (self .config .layer_id + 1 )
172
170
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
173
171
174
- if self .config .dtype in [torch .float16 , torch .int8 ]:
175
- self .score_context_func = inference_module .softmax_context_fp16
176
- self .linear_func = inference_module .linear_layer_fp16
177
- self .allocate_workspace = inference_module .allocate_workspace_fp16
178
- else :
179
- self .score_context_func = inference_module .softmax_context_fp32
180
- self .linear_func = inference_module .linear_layer_fp32
181
- self .allocate_workspace = inference_module .allocate_workspace_fp32
172
+ self .workspace = WorkspaceOp (self .config )
173
+ self .score_context_func = SoftmaxContextOp (self .config )
174
+ self .linear_func = LinearOp (self .config )
175
+ self .pad_transform_func = PadTransformOp (self .config )
182
176
183
- def forward (self , input , context = None , input_mask = None ):
177
+ def allocate_workspace (self , size ):
178
+ # Allocate memory only on first layer forward
184
179
if self .config .layer_id == 0 :
185
- self .allocate_workspace (self .config .hidden_size , self .config .heads ,
186
- input .size ()[1 ],
187
- input .size ()[0 ], DeepSpeedDiffusersAttention .layer_id , self .config .mp_size , False ,
188
- 0 , self .config .max_out_tokens , self .config .min_out_tokens )
189
- output = DeepSpeedDiffusersAttentionFunction .apply (input , context , input_mask , self .config , self .attn_qkvw ,
190
- self .attn_qw , self .attn_kw , self .attn_vw , self .attn_qkvb ,
191
- self .num_attention_heads_per_partition , self .norm_factor ,
192
- self .hidden_size_per_partition , self .attn_ow , self .attn_ob ,
193
- self .do_out_bias , self .score_context_func , self .linear_func ,
194
- self .triton_flash_attn_kernel , self .config .rope_theta )
180
+ self .workspace .allocate_workspace (self .config .hidden_size , self .config .heads , size [1 ], size [0 ],
181
+ DeepSpeedDiffusersAttention .layer_id , self .config .mp_size , False , 0 ,
182
+ self .config .max_out_tokens , self .config .min_out_tokens )
183
+
184
+ def forward (self , input , context = None , input_mask = None ):
185
+ self .allocate_workspace (input .size ())
186
+ output = DeepSpeedDiffusersAttentionFunction .apply (
187
+ input , context , input_mask , self .config , self .attn_qkvw , self .attn_qw , self .attn_kw , self .attn_vw ,
188
+ self .attn_qkvb , self .num_attention_heads_per_partition , self .norm_factor , self .hidden_size_per_partition ,
189
+ self .attn_ow , self .attn_ob , self .do_out_bias , self .score_context_func , self .linear_func ,
190
+ self .pad_transform_func , self .triton_flash_attn_kernel , self .config .rope_theta )
195
191
196
192
return output
0 commit comments