Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified ptx/lib/zluda_ptx_impl.bc
Binary file not shown.
288 changes: 221 additions & 67 deletions ptx/lib/zluda_ptx_impl.cpp

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions ptx/src/pass/replace_instructions_with_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,13 @@ fn run_instruction<'input>(
ast::MmaDetails {
alayout,
blayout,
dtype_scalar,
atype_scalar,
btype_scalar,
ctype_scalar,
cd_type_scalar,
ab_type_scalar,
},
..
} => {
let cd_type_name = scalar_to_ptx_name(cd_type_scalar);
let ab_type_name = scalar_to_ptx_name(ab_type_scalar);
let name = format!(
"mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}",
match alayout {
Expand All @@ -373,10 +373,10 @@ fn run_instruction<'input>(
ast::MatrixLayout::Row => "row",
ast::MatrixLayout::Col => "col",
},
scalar_to_ptx_name(dtype_scalar),
scalar_to_ptx_name(atype_scalar),
scalar_to_ptx_name(btype_scalar),
scalar_to_ptx_name(ctype_scalar),
cd_type_name,
ab_type_name,
ab_type_name,
cd_type_name,
);
to_call(resolver, fn_declarations, name.into(), i)?
}
Expand Down
24 changes: 24 additions & 0 deletions ptx/src/test/spirv_run/cvt_f16x2_f32.ptx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
.version 7.0
.target sm_80
.address_size 64

.visible .entry cvt_f16x2_f32(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b32 temp1;
.reg .b32 temp2;
.reg .f16x2 result;

ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];

ld.f32 temp1, [in_addr];
ld.f32 temp2, [in_addr+4];
cvt.rn.f16x2.f32 result, temp1, temp2;
st.global.b32 [out_addr], result;
ret;
}
55 changes: 55 additions & 0 deletions ptx/src/test/spirv_run/mma_m16n8k16_f32_f16_f16_f32.ptx
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
.version 7.0
.target sm_80
.address_size 64

.visible .entry mma_m16n8k16_f32_f16_f16_f32(
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u64 out_index;
.reg .u32 thread_id;

.reg .f32 in<8>;

.reg .b32 a0a1, a2a3, a4a5, a6a7;
.reg .b32 b0b1, b2b3;
.reg .f32 d<4>;

ld.param.u64 out_addr, [output];
mov.u32 thread_id, %tid.x;

cvt.rn.f32.u32 in0, thread_id;
mul.f32 in0, in0, 0f41000000; // 8.0
add.f32 in1, in0, 0f3f800000; // 1.0
add.f32 in2, in0, 0f40000000; // 2.0
add.f32 in3, in0, 0f40400000; // 3.0
add.f32 in4, in0, 0f40800000; // 4.0
add.f32 in5, in0, 0f40a00000; // 5.0
add.f32 in6, in0, 0f40c00000; // 6.0
add.f32 in7, in0, 0f40e00000; // 7.0

cvt.rn.f16x2.f32 a0a1, in0, in1;
cvt.rn.f16x2.f32 a2a3, in2, in3;
cvt.rn.f16x2.f32 a4a5, in4, in5;
cvt.rn.f16x2.f32 a6a7, in6, in7;

cvt.rn.f16x2.f32 b0b1, in0, in1;
cvt.rn.f16x2.f32 b2b3, in2, in3;

mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
{d0, d1, d2, d3},
{a0a1, a2a3, a4a5, a6a7},
{b0b1, b2b3},
{in0, in1, in2, in3};

cvt.u64.u32 out_index, thread_id;
mul.lo.u64 out_index, out_index, 16;
add.u64 out_addr, out_addr, out_index;
st.f32 [out_addr], d0;
st.f32 [out_addr+4], d1;
st.f32 [out_addr+8], d2;
st.f32 [out_addr+12], d3;

ret;
}
195 changes: 66 additions & 129 deletions ptx/src/test/spirv_run/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ test_ptx!(param_is_addressable, [0xDEAD], [0u64]);
// [0xce16728dead1ceb1u64, 0xe7728e3c390b7fb8]
//);
test_ptx!(copysign, [0x0BDA2A2Cu32, 0xe31a8fd7u32], [0x631A8FD7u32]);
test_ptx!(cvt_f16x2_f32, [1.0f32, 2.0f32], [0x3C004000u32]);

test_ptx!(assertfail);
// TODO: not yet supported
Expand Down Expand Up @@ -544,134 +545,71 @@ test_ptx_warp!(
test_ptx_warp!(
mma_m16n8k16_f32_bf16_bf16_f32,
[
4448.0f32,
11873.0f32,
4882.0f32,
13331.0f32,
19304.0f32,
26729.0f32,
21786.0f32,
30235.0f32,
34160.0f32,
41585.0f32,
38690.0f32,
47139.0f32,
49016.0f32,
56441.0f32,
55594.0f32,
64043.0f32,
11392.0f32,
35201.0f32,
11826.0f32,
36659.0f32,
59016.0f32,
82825.0f32,
61498.0f32,
86331.0f32,
106640.0f32,
130449.0f32,
111170.0f32,
136003.0f32,
154264.0f32,
178073.0f32,
160842.0f32,
185675.0f32,
18336.0f32,
58529.0f32,
18770.0f32,
59987.0f32,
98728.0f32,
138921.0f32,
101210.0f32,
142427.0f32,
179120.0f32,
219313.0f32,
183650.0f32,
224867.0f32,
259512.0f32,
299705.0f32,
266090.0f32,
307307.0f32,
25280.0f32,
81857.0f32,
25714.0f32,
83315.0f32,
138440.0f32,
195017.0f32,
140922.0f32,
198523.0f32,
251600.0f32,
308177.0f32,
256130.0f32,
313731.0f32,
364760.0f32,
421337.0f32,
371338.0f32,
428939.0f32,
32224.0f32,
105185.0f32,
32658.0f32,
106643.0f32,
178152.0f32,
251113.0f32,
180634.0f32,
254619.0f32,
324080.0f32,
397041.0f32,
328610.0f32,
402595.0f32,
470008.0f32,
542969.0f32,
476586.0f32,
550571.0f32,
39168.0f32,
128513.0f32,
39602.0f32,
129971.0f32,
217864.0f32,
307209.0f32,
220346.0f32,
310715.0f32,
396560.0f32,
485905.0f32,
401090.0f32,
491459.0f32,
575256.0f32,
664601.0f32,
581834.0f32,
672203.0f32,
46112.0f32,
151841.0f32,
46546.0f32,
153299.0f32,
257576.0f32,
363305.0f32,
260058.0f32,
366811.0f32,
469040.0f32,
574769.0f32,
473570.0f32,
580323.0f32,
680504.0f32,
786233.0f32,
687082.0f32,
793835.0f32,
53056.0f32,
175169.0f32,
53490.0f32,
176627.0f32,
297288.0f32,
419401.0f32,
299770.0f32,
422907.0f32,
541520.0f32,
663633.0f32,
546050.0f32,
669187.0f32,
785752.0f32,
907865.0f32,
792330.0f32,
915467.0f32
4448.0f32, 11873.0, 4882.0, 13331.0, 19304.0, 26729.0, 21786.0, 30235.0, 34160.0, 41585.0,
38690.0, 47139.0, 49016.0, 56441.0, 55594.0, 64043.0, 11392.0, 35201.0, 11826.0, 36659.0,
59016.0, 82825.0, 61498.0, 86331.0, 106640.0, 130449.0, 111170.0, 136003.0, 154264.0,
178073.0, 160842.0, 185675.0, 18336.0, 58529.0, 18770.0, 59987.0, 98728.0, 138921.0,
101210.0, 142427.0, 179120.0, 219313.0, 183650.0, 224867.0, 259512.0, 299705.0, 266090.0,
307307.0, 25280.0, 81857.0, 25714.0, 83315.0, 138440.0, 195017.0, 140922.0, 198523.0,
251600.0, 308177.0, 256130.0, 313731.0, 364760.0, 421337.0, 371338.0, 428939.0, 32224.0,
105185.0, 32658.0, 106643.0, 178152.0, 251113.0, 180634.0, 254619.0, 324080.0, 397041.0,
328610.0, 402595.0, 470008.0, 542969.0, 476586.0, 550571.0, 39168.0, 128513.0, 39602.0,
129971.0, 217864.0, 307209.0, 220346.0, 310715.0, 396560.0, 485905.0, 401090.0, 491459.0,
575256.0, 664601.0, 581834.0, 672203.0, 46112.0, 151841.0, 46546.0, 153299.0, 257576.0,
363305.0, 260058.0, 366811.0, 469040.0, 574769.0, 473570.0, 580323.0, 680504.0, 786233.0,
687082.0, 793835.0, 53056.0, 175169.0, 53490.0, 176627.0, 297288.0, 419401.0, 299770.0,
422907.0, 541520.0, 663633.0, 546050.0, 669187.0, 785752.0, 907865.0, 792330.0, 915467.0,
1165824.0, 1304065.0, 1178770.0, 1318547.0, 1442312.0, 1580553.0, 1458330.0, 1598107.0,
1718800.0, 1857041.0, 1737890.0, 1877667.0, 1995288.0, 2133529.0, 2017450.0, 2157227.0,
1303840.0, 1458465.0, 1316786.0, 1472947.0, 1613096.0, 1767721.0, 1629114.0, 1785275.0,
1922352.0, 2076977.0, 1941442.0, 2097603.0, 2231608.0, 2386233.0, 2253770.0, 2409931.0,
1441856.0, 1612865.0, 1454802.0, 1627347.0, 1783880.0, 1954889.0, 1799898.0, 1972443.0,
2125904.0, 2296913.0, 2144994.0, 2317539.0, 2467928.0, 2638937.0, 2490090.0, 2662635.0,
1579872.0, 1767265.0, 1592818.0, 1781747.0, 1954664.0, 2142057.0, 1970682.0, 2159611.0,
2329456.0, 2516849.0, 2348546.0, 2537475.0, 2704248.0, 2891641.0, 2726410.0, 2915339.0,
1717888.0, 1921665.0, 1730834.0, 1936147.0, 2125448.0, 2329225.0, 2141466.0, 2346779.0,
2533008.0, 2736785.0, 2552098.0, 2757411.0, 2940568.0, 3144345.0, 2962730.0, 3168043.0,
1855904.0, 2076065.0, 1868850.0, 2090547.0, 2296232.0, 2516393.0, 2312250.0, 2533947.0,
2736560.0, 2956721.0, 2755650.0, 2977347.0, 3176888.0, 3397049.0, 3199050.0, 3420747.0,
1993920.0, 2230465.0, 2006866.0, 2244947.0, 2467016.0, 2703561.0, 2483034.0, 2721115.0,
2940112.0, 3176657.0, 2959202.0, 3197283.0, 3413208.0, 3649753.0, 3435370.0, 3673451.0,
2131936.0, 2384865.0, 2144882.0, 2399347.0, 2637800.0, 2890729.0, 2653818.0, 2908283.0,
3143664.0, 3396593.0, 3162754.0, 3417219.0, 3649528.0, 3902457.0, 3671690.0, 3926155.0
]
);
test_ptx_warp!(
mma_m16n8k16_f32_f16_f16_f32,
[
4448.0f32, 11873.0, 4882.0, 13331.0, 19304.0, 26729.0, 21786.0, 30235.0, 34160.0, 41585.0,
38690.0, 47139.0, 49016.0, 56441.0, 55594.0, 64043.0, 11392.0, 35201.0, 11826.0, 36659.0,
59016.0, 82825.0, 61498.0, 86331.0, 106640.0, 130449.0, 111170.0, 136003.0, 154264.0,
178073.0, 160842.0, 185675.0, 18336.0, 58529.0, 18770.0, 59987.0, 98728.0, 138921.0,
101210.0, 142427.0, 179120.0, 219313.0, 183650.0, 224867.0, 259512.0, 299705.0, 266090.0,
307307.0, 25280.0, 81857.0, 25714.0, 83315.0, 138440.0, 195017.0, 140922.0, 198523.0,
251600.0, 308177.0, 256130.0, 313731.0, 364760.0, 421337.0, 371338.0, 428939.0, 32224.0,
105185.0, 32658.0, 106643.0, 178152.0, 251113.0, 180634.0, 254619.0, 324080.0, 397041.0,
328610.0, 402595.0, 470008.0, 542969.0, 476586.0, 550571.0, 39168.0, 128513.0, 39602.0,
129971.0, 217864.0, 307209.0, 220346.0, 310715.0, 396560.0, 485905.0, 401090.0, 491459.0,
575256.0, 664601.0, 581834.0, 672203.0, 46112.0, 151841.0, 46546.0, 153299.0, 257576.0,
363305.0, 260058.0, 366811.0, 469040.0, 574769.0, 473570.0, 580323.0, 680504.0, 786233.0,
687082.0, 793835.0, 53056.0, 175169.0, 53490.0, 176627.0, 297288.0, 419401.0, 299770.0,
422907.0, 541520.0, 663633.0, 546050.0, 669187.0, 785752.0, 907865.0, 792330.0, 915467.0,
1167968.0, 1306465.0, 1176594.0, 1316115.0, 1444968.0, 1583465.0, 1455642.0, 1595163.0,
1721968.0, 1860465.0, 1734690.0, 1874211.0, 1998968.0, 2137465.0, 2013738.0, 2153259.0,
1305984.0, 1460865.0, 1314610.0, 1470515.0, 1615752.0, 1770633.0, 1626426.0, 1782331.0,
1925520.0, 2080401.0, 1938242.0, 2094147.0, 2235288.0, 2390169.0, 2250058.0, 2405963.0,
1444000.0, 1615265.0, 1452626.0, 1624915.0, 1786536.0, 1957801.0, 1797210.0, 1969499.0,
2129072.0, 2300337.0, 2141794.0, 2314083.0, 2471608.0, 2642873.0, 2486378.0, 2658667.0,
1582016.0, 1769665.0, 1590642.0, 1779315.0, 1957320.0, 2144969.0, 1967994.0, 2156667.0,
2332624.0, 2520273.0, 2345346.0, 2534019.0, 2707928.0, 2895577.0, 2722698.0, 2911371.0,
1720032.0, 1924065.0, 1728658.0, 1933715.0, 2128104.0, 2332137.0, 2138778.0, 2343835.0,
2536176.0, 2740209.0, 2548898.0, 2753955.0, 2944248.0, 3148281.0, 2959018.0, 3164075.0,
1858048.0, 2078465.0, 1866674.0, 2088115.0, 2298888.0, 2519305.0, 2309562.0, 2531003.0,
2739728.0, 2960145.0, 2752450.0, 2973891.0, 3180568.0, 3400985.0, 3195338.0, 3416779.0,
1996064.0, 2232865.0, 2004690.0, 2242515.0, 2469672.0, 2706473.0, 2480346.0, 2718171.0,
2943280.0, 3180081.0, 2956002.0, 3193827.0, 3416888.0, 3653689.0, 3431658.0, 3669483.0,
2134080.0, 2387265.0, 2142706.0, 2396915.0, 2640456.0, 2893641.0, 2651130.0, 2905339.0,
3146832.0, 3400017.0, 3159554.0, 3413763.0, 3653208.0, 3906393.0, 3667978.0, 3922187.0
]
);
test_ptx_warp!(
Expand Down Expand Up @@ -935,7 +873,6 @@ test_ptx_warp!(
930863.0f32,
]
);

struct DisplayError<T: Debug> {
err: T,
}
Expand Down
6 changes: 2 additions & 4 deletions ptx_parser/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2418,10 +2418,8 @@ pub struct ReduxSyncData {
pub struct MmaDetails {
pub alayout: MatrixLayout,
pub blayout: MatrixLayout,
pub dtype_scalar: ScalarType,
pub atype_scalar: ScalarType,
pub btype_scalar: ScalarType,
pub ctype_scalar: ScalarType,
pub cd_type_scalar: ScalarType,
pub ab_type_scalar: ScalarType,
}

impl MmaDetails {
Expand Down
Loading
Loading