@@ -6,8 +6,9 @@ map_ptx_to_jl_frag = Dict(
6
6
" u32" => UInt32 (42 ),
7
7
" s32" => Int32 (42 ),
8
8
" f16" => ntuple (i -> VecElement {Float16} (42 ), 2 ),
9
- " f32" => Float32 (42 )
10
- )
9
+ " f32" => Float32 (42 ),
10
+ " tf32" => Float32 (42 )
11
+ )
11
12
# Return specific matrix shape given operation configuration
12
13
function get_array_shape (mat, mnk, layout)
13
14
if ! (mat in [" a" ," b" ," c" ," d" ])
46
47
# Type-dependent variables
47
48
array_ty = CUDA. WMMA. map_ptx_to_jl_array[elem_type]
48
49
expected = map_ptx_to_jl_frag[elem_type]
49
-
50
+
50
51
# Address-space dependent variables
51
52
do_shared_test = (addr_space == " _shared" )
52
53
53
54
# Get the function name
54
55
func = Symbol (" llvm_wmma_load_$(mat) _$(layout) _$(shape)$(addr_space) _stride_$(elem_type) " )
55
-
56
+
56
57
input_shape = get_array_shape (mat, mnk, layout)
57
58
input = array_ty (42 ) * ones (array_ty, input_shape)
58
59
input_dev = CuArray (input)
96
97
elem_type in ops[3 ],
97
98
addr_space in [" " , " _global" , " _shared" ],
98
99
stride in [" stride" ]
99
-
100
+
100
101
# Skip all but d matrices
101
102
if mat != " d"
102
103
continue
169
170
ldc_func = getfield (Main, Symbol (" llvm_wmma_load_c_col_$(shape) _global_stride_$(c_elem_type) " ))
170
171
# Account for half and int/subint mma different naming conventions
171
172
# Int/subint mma functions are distinguished by the a/b element type
172
- mma_sym = d_ty == Int32 ? Symbol (" llvm_wmma_mma_$(a_layout) _$(b_layout) _$(shape) _$(ab_elem_type) " ) :
173
+ mma_sym = ( d_ty == Int32 || ab_elem_type == " tf32 " ) ? Symbol (" llvm_wmma_mma_$(a_layout) _$(b_layout) _$(shape) _$(ab_elem_type) " ) :
173
174
Symbol (" llvm_wmma_mma_$(a_layout) _$(b_layout) _$(shape) _$(d_elem_type) _$(c_elem_type) " )
174
- mma_func = getfield (Main, mma_sym)
175
+ mma_func = getfield (Main, mma_sym)
175
176
std_func = getfield (Main, Symbol (" llvm_wmma_store_d_col_$(shape) _global_stride_$(d_elem_type) " ))
176
177
177
178
a_shape = get_array_shape (" a" , mnk, a_layout)
205
206
new_a = (a_layout == " col" ? a : transpose (a))
206
207
new_b = (b_layout == " col" ? b : transpose (b))
207
208
# Alter test depending on a/b element Type
208
- if ab_ty == Float16
209
+ if ab_ty == Float16 || ab_elem_type == " tf32 "
209
210
@test new_a * new_b + c ≈ Array (d_dev) rtol= Base. rtoldefault (Float16)
210
- else # Cast a and b to prevent UInt8 rollover of resultant data
211
+ else # Cast a and b to prevent UInt8 rollover of resultant data
211
212
@test Int32 .(new_a) * Int32 .(new_b) + c == Array (d_dev)
212
213
end
213
214
end
344
345
@test ! occursin (r" wmma.store.d.sync(.aligned)?.col.m16n16k16.f32" , ptx)
345
346
@test occursin (r" wmma.store.d.sync(.aligned)?.col.m16n16k16.shared.f32" , ptx)
346
347
end
347
- end
348
+ end
0 commit comments