Skip to content

Commit 491c6db

Browse files
Merge branch 'main' into rajprinc/unit-test-new-api-changes
2 parents f06d70a + 71fa330 commit 491c6db

File tree

12 files changed

+677
-75
lines changed

12 files changed

+677
-75
lines changed

.github/copilot-instructions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ ninja
4141

4242
Build / Test / Lint summary
4343
---------------------------
44-
- **Bootstrap**: No special bootstrap required. Python dependencies in `pyproject.toml` (`networkx`, `numpy`, `pydot`, `scipy`, `treelib`) are needed for Python tests. Install with `pip install -e .` in project root.
44+
- **Bootstrap**: No special bootstrap required. Python dependencies in `pyproject.toml` (`networkx`, `numpy`, `pydot`, `scipy`, `treelib`, `ml_dtypes`) are needed for Python tests. Install with `pip install -e .` in project root.
4545
- **Build**: Use CMake 3.22+ and Ninja (see commands above). **ALWAYS** run from clean build directory to avoid stale state.
4646
- **C++ Unit Tests**: After build, run `cmake --build . --target test_unit` (runs all unit tests in `test/unit/`).
4747
- **C++ Examples**: `cmake --build . --target test_examples` (builds and validates examples in `examples/`).

examples/11_xe20_cutlass_library/xe20_cutlass_library_b16.cpp

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
***************************************************************************************************/
3131

3232

33-
3433
#include <exception>
3534
#include <iostream>
3635
#include <memory>
@@ -95,31 +94,29 @@ struct identity_op {
9594
T operator()(T val) const { return val; }
9695
};
9796

98-
99-
100-
using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_epilogue =
97+
using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_epilogue =
10198
typename cutlass::epilogue::collective::CollectiveBuilder<
10299
cutlass::arch::Xe20, cutlass::arch::OpClassTensorOp,
103100
cute::Shape<cute::_256, cute::_256, cute::_32>,
104101
cute::Shape<cute::_1, cute::_1, cute::_1>,
105102
cutlass::epilogue::collective::EpilogueTileAuto,
106103
float, float,
107-
float, cutlass::layout::RowMajor, 4,
108-
float, cutlass::layout::RowMajor, 4,
104+
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, // Bias
105+
cutlass::bfloat16_t, cutlass::layout::RowMajor, 8, // Output
109106
cutlass::epilogue::collective::EpilogueScheduleAuto,
110107
cutlass::epilogue::fusion::LinearCombination<
108+
cutlass::bfloat16_t,
111109
float,
112-
float,
113-
float,
110+
cutlass::bfloat16_t,
114111
float
115112
>
116113
>::CollectiveOp;
117114

118-
using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop =
115+
using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_mainloop =
119116
typename cutlass::gemm::collective::CollectiveBuilder<
120117
cutlass::arch::Xe20, cutlass::arch::OpClassTensorOp,
121-
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8,
122-
cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8,
118+
cutlass::bfloat16_t, cutlass::layout::RowMajor, 8, // A
119+
cutlass::bfloat16_t, cutlass::layout::RowMajor, 8, // B
123120
float,
124121
cute::Shape<cute::_256, cute::_256, cute::_32>,
125122
cute::Shape<cute::_1, cute::_1, cute::_1>,
@@ -128,34 +125,88 @@ using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop =
128125
>::CollectiveOp;
129126

130127
// Gemm operator cutlass3x_xe11_tensorop_gemm_bf16_128x256_16x0_tn_align2
131-
using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_base = cutlass::gemm::kernel::GemmUniversal<
128+
using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_base = cutlass::gemm::kernel::GemmUniversal<
132129
cute::Shape<int,int,int,int>,
133-
cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop,
134-
cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_epilogue,
130+
cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_mainloop,
131+
cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_epilogue,
135132
cutlass::gemm::PersistentScheduler>;
136133

137134
// Define named type
138-
struct cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8 :
139-
public cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_base { };
140-
135+
struct cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8 :
136+
public cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_base { };
141137

142-
using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8>;
138+
using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8>;
143139

144140
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
145141
// Otherwise, computes the Gemm kernel using the given workspace ptr.
146142
extern "C" {
147-
PT_EXPORT int sycl_tla_gemm_xe20_bf16(const uint16_t* X, const uint16_t* W, uint16_t* Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const int X_offset, const int W_offset, const int Y_offset, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, sycl::queue* stream) {
143+
PT_EXPORT int sycl_tla_gemm_xe20_bf16(const cutlass::bfloat16_t* X, const cutlass::bfloat16_t* W, const cutlass::bfloat16_t* Bias, cutlass::bfloat16_t* Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const int X_offset, const int W_offset, const int Bias_offset, const int Y_offset, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, sycl::queue* stream) {
148144
try {
149-
using ElementComputeEpilogue = cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type::ElementAccumulator;
145+
using ElementComputeEpilogue = cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_device_type::ElementAccumulator;
150146
using coord_t = cutlass::gemm::GemmCoord::Index;
151147
static cutlass::KernelHardwareInfo hw_info;
152148
if (hw_info.sm_count == 0) {
153149
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
154150
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
155151
}
156152

153+
cutlass::DeviceAllocation<cutlass::bfloat16_t> block_A;
154+
cutlass::DeviceAllocation<cutlass::bfloat16_t> block_B;
155+
cutlass::DeviceAllocation<cutlass::bfloat16_t> block_C;
156+
cutlass::DeviceAllocation<cutlass::bfloat16_t> block_D;
157+
158+
if (!workspace_size) {
159+
if (!X || !W) {
160+
std::cerr << "Input host pointers null!" << std::endl;
161+
return -1;
162+
}
163+
else {
164+
block_A.reset(static_cast<std::size_t>(M) * K * B);
165+
block_B.reset(static_cast<std::size_t>(K) * N * B);
166+
if (!block_A.get() || !block_B.get()) {
167+
std::cerr << "Device allocation of inputs failed!" << std::endl;
168+
return -1;
169+
}
170+
compat::wait();
171+
compat::memcpy(block_A.get(), (X + X_offset), (M * K * B) * sizeof(cutlass::bfloat16_t));
172+
compat::wait();
173+
compat::memcpy(block_B.get(), (W + W_offset), (K * N * B) * sizeof(cutlass::bfloat16_t));
174+
compat::wait();
175+
}
176+
177+
if (!Bias) {
178+
std::cerr << "Bias host pointer null!" << std::endl;
179+
return -1;
180+
}
181+
else {
182+
block_C.reset(static_cast<std::size_t>(M) * N * B);
183+
if (!block_C.get()) {
184+
std::cerr << "Device allocation of bias failed!" << std::endl;
185+
return -1;
186+
}
187+
compat::wait();
188+
compat::memcpy(block_C.get(), (Bias + Bias_offset), (M * N * B) * sizeof(cutlass::bfloat16_t));
189+
compat::wait();
190+
}
191+
192+
if (!Y) {
193+
std::cerr << "Output host pointer null!" << std::endl;
194+
return -1;
195+
}
196+
else {
197+
block_D.reset(static_cast<std::size_t>(M) * N * B);
198+
if (!block_D.get()) {
199+
std::cerr << "Device allocation of output failed!" << std::endl;
200+
return -1;
201+
}
202+
compat::wait();
203+
compat::memset(block_D.get(), 0, (M * N * B) * sizeof(cutlass::bfloat16_t));
204+
compat::wait();
205+
}
206+
}
207+
157208
// Initialize GemmUniversal3xInstance arguments using constructor
158-
cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type::Arguments arguments{
209+
cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_device_type::Arguments arguments{
159210
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
160211
{
161212
static_cast<coord_t>(M),
@@ -164,28 +215,30 @@ PT_EXPORT int sycl_tla_gemm_xe20_bf16(const uint16_t* X, const uint16_t* W, uint
164215
static_cast<coord_t>(B)
165216
}, // ProblemShape problem_shape
166217
{
167-
(cutlass::bfloat16_t*)(X + X_offset), // ElementA const* ptr_A
168-
cute::make_tuple(cute::Int<1>{}, int64_t(lda), int64_t(0)), // StrideA dA (column-major: stride_m=1, stride_n=lda, batch=0)
169-
(cutlass::bfloat16_t*)(W + W_offset), // ElementB const* ptr_B
170-
cute::make_tuple(int64_t(ldb), cute::Int<1>{}, int64_t(0)), // StrideB dB (column-major: stride_m=ldb, stride_n=1, batch=0)
218+
(cutlass::bfloat16_t*)(block_A.get()), // ElementA const* ptr_A
219+
{int64_t(lda), cute::Int<1>{}, int64_t(0)},
220+
(cutlass::bfloat16_t*)(block_B.get()), // ElementB const* ptr_B
221+
{cute::Int<1>{}, int64_t(ldb), int64_t(0)},
171222
}, // MainloopArguments mainloop
172223

173224
// see https://tinyurl.com/4rk89z48
174225
{
175-
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
176-
nullptr, // ElementC const* ptr_C
177-
cute::make_tuple(int64_t(0), cute::Int<1>{}, int64_t(0)), // StrideC dC (row-major: stride_m, stride_n=1, batch=0)
178-
(float*)(Y + Y_offset), // ElementD ptr_D (output is float, not bfloat16)
179-
cute::make_tuple(int64_t(ldd), cute::Int<1>{}, int64_t(0)), // StrideD dD (row-major: stride_m=ldd, stride_n=1, batch=0)
226+
{1.f, 1.f}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
227+
(cutlass::bfloat16_t*)(block_C.get()), // ElementC const* ptr_C
228+
{cute::Int<1>{}, int64_t(ldc), int64_t(0)},
229+
(cutlass::bfloat16_t*)(block_D.get()), // ElementD const* ptr_D
230+
{int64_t(ldd), cute::Int<1>{}, int64_t(0)},
180231
}, // EpilogueArguments epilogue,
181232
hw_info
182233
};
234+
183235
arguments.scheduler.max_swizzle_size = swizzle;
184-
cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type gemm_op;
236+
cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_device_type gemm_op;
185237
if (workspace_size) {
186238
*workspace_size = gemm_op.get_workspace_size(arguments);
187239
return 0;
188240
}
241+
189242
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
190243
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
191244
{
@@ -209,6 +262,10 @@ PT_EXPORT int sycl_tla_gemm_xe20_bf16(const uint16_t* X, const uint16_t* W, uint
209262
{
210263
auto status = gemm_op(stream);
211264
CUTLASS_CHECK(status);
265+
266+
compat::wait();
267+
compat::memcpy((Y + Y_offset), block_D.get(), (M * N * B) * sizeof(cutlass::bfloat16_t));
268+
compat::wait();
212269
}
213270
}
214271
catch (std::exception& e) {
@@ -222,4 +279,4 @@ PT_EXPORT int sycl_tla_gemm_xe20_bf16(const uint16_t* X, const uint16_t* W, uint
222279
}
223280
}
224281

225-
// configuration name: cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8
282+
// configuration name: cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8

0 commit comments

Comments
 (0)