@@ -236,19 +236,13 @@ def trans_map(trans):
236
236
else :
237
237
return None
238
238
239
- def bias_datatype_map (dtype ):
240
- if dtype == "f16_r" :
241
- return [datatype_map ('f32_r' ), datatype_map ('f16_r' )]
242
- elif dtype == "f32_r" :
243
- return [datatype_map ('f32_r' )]
244
- elif dtype == "xf32_r" :
245
- return [datatype_map ('xf32_r' )]
246
- elif dtype == "bf16_r" :
247
- return [datatype_map ('f32_r' ), datatype_map ('bf16_r' )]
248
- elif dtype == "f8_r" :
249
- return [datatype_map ('f32_r' ), datatype_map ('f8_r' )]
250
- else :
239
+ def bias_datatype_map (bias_type , data_type , compute_type , dest_type ):
240
+ bias_list = [datatype_map (data_type ), datatype_map (compute_type ), datatype_map (dest_type )]
241
+ bias_map = datatype_map (bias_type )
242
+ if bias_map in bias_list :
251
243
return []
244
+ bias_list .append (bias_map )
245
+ return bias_list
252
246
253
247
def get_high_precision_accumulate (DataType ):
254
248
if DataType in ["H" , "B" , "F8" ]:
@@ -282,7 +276,8 @@ def extract_dtype(match):
282
276
if bias_source :
283
277
res ["UseBias" ] = 1
284
278
res ["BiasSrc" ] = bias_source
285
- res ["BiasDataTypeList" ] = list (bias_datatype_map (gdict .get ("BIAS_TYPE" , '' ).strip ()))
279
+ bias_type = gdict .get ("BIAS_TYPE" , '' ).strip ()
280
+ res ["BiasDataTypeList" ] = bias_datatype_map (bias_type , DataType , ComputeDataType , DestDataType )
286
281
if activation_type != "none" :
287
282
res ["Activation" ] = True
288
283
res ["ActivationType" ] = "hipblaslt_all"
0 commit comments