Skip to content

Commit 9f60f5b

Browse files
erwei-xilinxclaude
andcommitted
Address review comments on fused SwiGLU example
- Add missing <cstring> include for std::memset in test.cpp - Update test.cpp to 4-arg signature (x, w_gate, w_up, out) matching the Python module's separate weight arguments - Pass PEANO_INSTALL_DIR and OUTPUT_FORMAT=elf in LIT test RUN line, matching the convention used by decode/ and prefill/ examples - Fix CHECK pattern to match "PASS!" (with exclamation mark) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b57dd97 commit 9f60f5b

2 files changed

Lines changed: 27 additions & 16 deletions

File tree

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: make -f %S/Makefile clean
2-
// RUN: make -f %S/Makefile run4x4 COMPILE_MODE=compile-and-run 2>&1 | FileCheck %s
3-
// CHECK: PASS
1+
// RUN: make -f %S/Makefile clean PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR
2+
// RUN: make -f %S/Makefile run4x4 PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR COMPILE_MODE=compile-and-run OUTPUT_FORMAT=elf 2>&1 | FileCheck %s
3+
// CHECK: PASS!
44

55
// REQUIRES: ryzen_ai, peano

programming_examples/ffn_swiglu/fused/test.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <chrono>
1313
#include <cstdint>
1414
#include <cstdlib>
15+
#include <cstring>
1516
#include <ctime>
1617
#include <fstream>
1718
#include <iomanip>
@@ -59,9 +60,10 @@ int main(int argc, const char *argv[]) {
5960
int K = vm["size_k"].as<int>();
6061
int N = vm["size_n"].as<int>();
6162

62-
// x: [M, K], w_combined: [K, 2*N], out: [M, N]
63+
// x: [M, K], w_gate: [K, N], w_up: [K, N], out: [M, N]
6364
size_t X_SIZE = (size_t)M * K * sizeof(DATATYPE);
64-
size_t W_COMBINED_SIZE = (size_t)K * 2 * N * sizeof(DATATYPE);
65+
size_t WGATE_SIZE = (size_t)K * N * sizeof(DATATYPE);
66+
size_t WUP_SIZE = (size_t)K * N * sizeof(DATATYPE);
6567
size_t OUT_SIZE = (size_t)M * N * sizeof(DATATYPE);
6668

6769
srand(time(NULL));
@@ -82,23 +84,29 @@ int main(int argc, const char *argv[]) {
8284

8385
// Use xrt::ext::bo (no group_id needed for ELF)
8486
xrt::bo bo_x = xrt::ext::bo(device, X_SIZE);
85-
xrt::bo bo_w = xrt::ext::bo(device, W_COMBINED_SIZE);
87+
xrt::bo bo_wgate = xrt::ext::bo(device, WGATE_SIZE);
88+
xrt::bo bo_wup = xrt::ext::bo(device, WUP_SIZE);
8689
xrt::bo bo_out = xrt::ext::bo(device, OUT_SIZE);
8790

8891
// Fill inputs with random data
8992
DATATYPE *bufX = bo_x.map<DATATYPE *>();
9093
for (size_t i = 0; i < (size_t)M * K; i++)
9194
bufX[i] = random_bfloat16_t();
9295

93-
DATATYPE *bufW = bo_w.map<DATATYPE *>();
94-
for (size_t i = 0; i < (size_t)K * 2 * N; i++)
95-
bufW[i] = random_bfloat16_t();
96+
DATATYPE *bufWgate = bo_wgate.map<DATATYPE *>();
97+
for (size_t i = 0; i < (size_t)K * N; i++)
98+
bufWgate[i] = random_bfloat16_t();
99+
100+
DATATYPE *bufWup = bo_wup.map<DATATYPE *>();
101+
for (size_t i = 0; i < (size_t)K * N; i++)
102+
bufWup[i] = random_bfloat16_t();
96103

97104
DATATYPE *bufOut = bo_out.map<DATATYPE *>();
98-
memset(bufOut, 0, OUT_SIZE);
105+
std::memset(bufOut, 0, OUT_SIZE);
99106

100107
bo_x.sync(XCL_BO_SYNC_BO_TO_DEVICE);
101-
bo_w.sync(XCL_BO_SYNC_BO_TO_DEVICE);
108+
bo_wgate.sync(XCL_BO_SYNC_BO_TO_DEVICE);
109+
bo_wup.sync(XCL_BO_SYNC_BO_TO_DEVICE);
102110
bo_out.sync(XCL_BO_SYNC_BO_TO_DEVICE);
103111

104112
unsigned n_iterations = vm["iterations"].as<int>();
@@ -117,8 +125,10 @@ int main(int argc, const char *argv[]) {
117125
std::cout << " M=" << M << ", K=" << K << ", N=" << N << std::endl;
118126
std::cout << " x: [" << M << "x" << K << "] (" << X_SIZE << " bytes)"
119127
<< std::endl;
120-
std::cout << " w_combined: [" << K << "x" << 2 * N << "] ("
121-
<< W_COMBINED_SIZE << " bytes)" << std::endl;
128+
std::cout << " w_gate: [" << K << "x" << N << "] (" << WGATE_SIZE
129+
<< " bytes)" << std::endl;
130+
std::cout << " w_up: [" << K << "x" << N << "] (" << WUP_SIZE << " bytes)"
131+
<< std::endl;
122132
std::cout << " output: [" << M << "x" << N << "] (" << OUT_SIZE << " bytes)"
123133
<< std::endl;
124134
std::cout << " warmup=" << n_warmup_iterations
@@ -129,11 +139,12 @@ int main(int argc, const char *argv[]) {
129139
std::cout << "Running Kernel (iteration " << iter << ").\n";
130140

131141
auto start = std::chrono::high_resolution_clock::now();
132-
// ELF path: use xrt::run with set_arg
142+
// ELF path: use xrt::run with set_arg (4 args: x, w_gate, w_up, out)
133143
auto run = xrt::run(kernel);
134144
run.set_arg(0, bo_x);
135-
run.set_arg(1, bo_w);
136-
run.set_arg(2, bo_out);
145+
run.set_arg(1, bo_wgate);
146+
run.set_arg(2, bo_wup);
147+
run.set_arg(3, bo_out);
137148
run.start();
138149
run.wait2();
139150
auto stop = std::chrono::high_resolution_clock::now();

0 commit comments

Comments
 (0)