Skip to content

Commit 991a9f0

Browse files
authored
Fix BiasDataTypeList
1 parent feef2aa commit 991a9f0

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

tensilelite/Tensile/Utilities/tensile_generator/tensile_config_generator.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -236,19 +236,13 @@ def trans_map(trans):
236236
else:
237237
return None
238238

239-
def bias_datatype_map(dtype):
240-
if dtype == "f16_r":
241-
return [datatype_map('f32_r'), datatype_map('f16_r')]
242-
elif dtype == "f32_r":
243-
return [datatype_map('f32_r')]
244-
elif dtype == "xf32_r":
245-
return [datatype_map('xf32_r')]
246-
elif dtype == "bf16_r":
247-
return [datatype_map('f32_r'), datatype_map('bf16_r')]
248-
elif dtype == "f8_r":
249-
return [datatype_map('f32_r'), datatype_map('f8_r')]
250-
else:
239+
def bias_datatype_map(bias_type, data_type, compute_type, dest_type):
240+
bias_list = [datatype_map(data_type), datatype_map(compute_type), datatype_map(dest_type)]
241+
bias_map = datatype_map(bias_type)
242+
if bias_map in bias_list:
251243
return []
244+
bias_list.append(bias_map)
245+
return bias_list
252246

253247
def get_high_precision_accumulate(DataType):
254248
if DataType in ["H", "B", "F8"]:
@@ -282,7 +276,8 @@ def extract_dtype(match):
282276
if bias_source:
283277
res["UseBias"] = 1
284278
res["BiasSrc"] = bias_source
285-
res["BiasDataTypeList"] = list(bias_datatype_map(gdict.get("BIAS_TYPE", '').strip()))
279+
bias_type = gdict.get("BIAS_TYPE", '').strip()
280+
res["BiasDataTypeList"] = bias_datatype_map(bias_type, DataType, ComputeDataType, DestDataType)
286281
if activation_type != "none":
287282
res["Activation"] = True
288283
res["ActivationType"] = "hipblaslt_all"

0 commit comments

Comments
 (0)