@@ -69,6 +69,7 @@ def scaled_dot_product_attention(
6969 is_causal = True
7070 # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
7171 use_fp32_acc = kwargs .get ("use_fp32_acc" , False )
72+ use_fp8_quantize = kwargs .get ("use_fp8_quantize" , True )
7273 query_dtype = query .dtype
7374
7475 if scale is None :
@@ -97,6 +98,30 @@ def scaled_dot_product_attention(
9798 key ,
9899 scale ,
99100 )
101+ # fixed value for test
102+ amax = torch .tensor ([0.6562 ])
103+ if use_fp8_quantize :
104+ key = impl .quantize .quantize (
105+ ctx ,
106+ target ,
107+ SourceIR .ATEN ,
108+ name ,
109+ key ,
110+ amax ,
111+ 8 ,
112+ 4 ,
113+ )
114+
115+ query = impl .quantize .quantize (
116+ ctx ,
117+ target ,
118+ SourceIR .ATEN ,
119+ name ,
120+ query ,
121+ amax ,
122+ 8 ,
123+ 4 ,
124+ )
100125
101126 if use_fp32_acc and query_dtype == trt .float16 :
102127 query = cast_trt_tensor (
@@ -173,6 +198,29 @@ def scaled_dot_product_attention(
173198 softmax = impl .normalization .softmax (
174199 ctx , target , source_ir , name + "_softmax" , scaled_add_attn_bias , - 1 , False
175200 )
201+ if use_fp8_quantize :
202+ softmax = impl .quantize .quantize (
203+ ctx ,
204+ target ,
205+ SourceIR .ATEN ,
206+ name ,
207+ softmax ,
208+ amax ,
209+ 8 ,
210+ 4 ,
211+ )
212+
213+ value = impl .quantize .quantize (
214+ ctx ,
215+ target ,
216+ SourceIR .ATEN ,
217+ name ,
218+ value ,
219+ amax ,
220+ 8 ,
221+ 4 ,
222+ )
223+
176224 if use_fp32_acc :
177225 softmax = cast_trt_tensor (
178226 ctx , softmax , trt .float32 , name + "_softmax_cast_to_fp32" , target , source_ir
@@ -188,9 +236,21 @@ def scaled_dot_product_attention(
188236 softmax ,
189237 value ,
190238 )
239+
191240 if use_fp32_acc :
192241 out = cast_trt_tensor (
193242 ctx , out , query_dtype , name + "_out_cast_to_fp16" , target , source_ir
194243 )
244+ if use_fp8_quantize :
245+ out = impl .quantize .quantize (
246+ ctx ,
247+ target ,
248+ SourceIR .ATEN ,
249+ name ,
250+ out ,
251+ amax ,
252+ 8 ,
253+ 4 ,
254+ )
195255
196256 return out
0 commit comments