Skip to content

Commit b93511f

Browse files
committed
wmma tf32 tests
1 parent bb03e1e commit b93511f

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

Diff for: test/device/intrinsics/wmma.jl

+11-10
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ map_ptx_to_jl_frag = Dict(
66
"u32" => UInt32(42),
77
"s32" => Int32(42),
88
"f16" => ntuple(i -> VecElement{Float16}(42), 2),
9-
"f32" => Float32(42)
10-
)
9+
"f32" => Float32(42),
10+
"tf32" => Float32(42)
11+
)
1112
# Return specific matrix shape given operation configuration
1213
function get_array_shape(mat, mnk, layout)
1314
if !(mat in ["a","b","c","d"])
@@ -46,13 +47,13 @@ end
4647
# Type-dependent variables
4748
array_ty = CUDA.WMMA.map_ptx_to_jl_array[elem_type]
4849
expected = map_ptx_to_jl_frag[elem_type]
49-
50+
5051
# Address-space dependent variables
5152
do_shared_test = (addr_space == "_shared")
5253

5354
# Get the function name
5455
func = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)")
55-
56+
5657
input_shape = get_array_shape(mat, mnk, layout)
5758
input = array_ty(42) * ones(array_ty, input_shape)
5859
input_dev = CuArray(input)
@@ -96,7 +97,7 @@ end
9697
elem_type in ops[3],
9798
addr_space in ["", "_global", "_shared"],
9899
stride in ["stride"]
99-
100+
100101
# Skip all but d matrices
101102
if mat != "d"
102103
continue
@@ -169,9 +170,9 @@ end
169170
ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_$(shape)_global_stride_$(c_elem_type)"))
170171
# Account for half and int/subint mma different naming conventions
171172
# Int/subint mma functions are distinguished by the a/b element type
172-
mma_sym = d_ty == Int32 ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
173+
mma_sym = (d_ty == Int32 || ab_elem_type == "tf32") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
173174
Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)")
174-
mma_func = getfield(Main, mma_sym)
175+
mma_func = getfield(Main, mma_sym)
175176
std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)"))
176177

177178
a_shape = get_array_shape("a", mnk, a_layout)
@@ -205,9 +206,9 @@ end
205206
new_a = (a_layout == "col" ? a : transpose(a))
206207
new_b = (b_layout == "col" ? b : transpose(b))
207208
# Alter test depending on a/b element Type
208-
if ab_ty == Float16
209+
if ab_ty == Float16 || ab_elem_type == "tf32"
209210
@test new_a * new_b + c Array(d_dev) rtol=Base.rtoldefault(Float16)
210-
else # Cast a and b to prevent UInt8 rollover of resultant data
211+
else # Cast a and b to prevent UInt8 rollover of resultant data
211212
@test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev)
212213
end
213214
end
@@ -344,4 +345,4 @@ end
344345
@test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx)
345346
@test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.shared.f32", ptx)
346347
end
347-
end
348+
end

0 commit comments

Comments
 (0)