1212#include < chrono>
1313#include < cstdint>
1414#include < cstdlib>
15+ #include < cstring>
1516#include < ctime>
1617#include < fstream>
1718#include < iomanip>
@@ -59,9 +60,10 @@ int main(int argc, const char *argv[]) {
5960 int K = vm[" size_k" ].as <int >();
6061 int N = vm[" size_n" ].as <int >();
6162
62- // x: [M, K], w_combined : [K, 2* N], out: [M, N]
63+ // x: [M, K], w_gate : [K, N], w_up: [K, N], out: [M, N]
6364 size_t X_SIZE = (size_t )M * K * sizeof (DATATYPE);
64- size_t W_COMBINED_SIZE = (size_t )K * 2 * N * sizeof (DATATYPE);
65+ size_t WGATE_SIZE = (size_t )K * N * sizeof (DATATYPE);
66+ size_t WUP_SIZE = (size_t )K * N * sizeof (DATATYPE);
6567 size_t OUT_SIZE = (size_t )M * N * sizeof (DATATYPE);
6668
6769 srand (time (NULL ));
@@ -82,23 +84,29 @@ int main(int argc, const char *argv[]) {
8284
8385 // Use xrt::ext::bo (no group_id needed for ELF)
8486 xrt::bo bo_x = xrt::ext::bo (device, X_SIZE);
85- xrt::bo bo_w = xrt::ext::bo (device, W_COMBINED_SIZE);
87+ xrt::bo bo_wgate = xrt::ext::bo (device, WGATE_SIZE);
88+ xrt::bo bo_wup = xrt::ext::bo (device, WUP_SIZE);
8689 xrt::bo bo_out = xrt::ext::bo (device, OUT_SIZE);
8790
8891 // Fill inputs with random data
8992 DATATYPE *bufX = bo_x.map <DATATYPE *>();
9093 for (size_t i = 0 ; i < (size_t )M * K; i++)
9194 bufX[i] = random_bfloat16_t ();
9295
93- DATATYPE *bufW = bo_w.map <DATATYPE *>();
94- for (size_t i = 0 ; i < (size_t )K * 2 * N; i++)
95- bufW[i] = random_bfloat16_t ();
96+ DATATYPE *bufWgate = bo_wgate.map <DATATYPE *>();
97+ for (size_t i = 0 ; i < (size_t )K * N; i++)
98+ bufWgate[i] = random_bfloat16_t ();
99+
100+ DATATYPE *bufWup = bo_wup.map <DATATYPE *>();
101+ for (size_t i = 0 ; i < (size_t )K * N; i++)
102+ bufWup[i] = random_bfloat16_t ();
96103
97104 DATATYPE *bufOut = bo_out.map <DATATYPE *>();
98- memset (bufOut, 0 , OUT_SIZE);
105+ std:: memset (bufOut, 0 , OUT_SIZE);
99106
100107 bo_x.sync (XCL_BO_SYNC_BO_TO_DEVICE);
101- bo_w.sync (XCL_BO_SYNC_BO_TO_DEVICE);
108+ bo_wgate.sync (XCL_BO_SYNC_BO_TO_DEVICE);
109+ bo_wup.sync (XCL_BO_SYNC_BO_TO_DEVICE);
102110 bo_out.sync (XCL_BO_SYNC_BO_TO_DEVICE);
103111
104112 unsigned n_iterations = vm[" iterations" ].as <int >();
@@ -117,8 +125,10 @@ int main(int argc, const char *argv[]) {
117125 std::cout << " M=" << M << " , K=" << K << " , N=" << N << std::endl;
118126 std::cout << " x: [" << M << " x" << K << " ] (" << X_SIZE << " bytes)"
119127 << std::endl;
120- std::cout << " w_combined: [" << K << " x" << 2 * N << " ] ("
121- << W_COMBINED_SIZE << " bytes)" << std::endl;
128+ std::cout << " w_gate: [" << K << " x" << N << " ] (" << WGATE_SIZE
129+ << " bytes)" << std::endl;
130+ std::cout << " w_up: [" << K << " x" << N << " ] (" << WUP_SIZE << " bytes)"
131+ << std::endl;
122132 std::cout << " output: [" << M << " x" << N << " ] (" << OUT_SIZE << " bytes)"
123133 << std::endl;
124134 std::cout << " warmup=" << n_warmup_iterations
@@ -129,11 +139,12 @@ int main(int argc, const char *argv[]) {
129139 std::cout << " Running Kernel (iteration " << iter << " ).\n " ;
130140
131141 auto start = std::chrono::high_resolution_clock::now ();
132- // ELF path: use xrt::run with set_arg
142+ // ELF path: use xrt::run with set_arg (4 args: x, w_gate, w_up, out)
133143 auto run = xrt::run (kernel);
134144 run.set_arg (0 , bo_x);
135- run.set_arg (1 , bo_w);
136- run.set_arg (2 , bo_out);
145+ run.set_arg (1 , bo_wgate);
146+ run.set_arg (2 , bo_wup);
147+ run.set_arg (3 , bo_out);
137148 run.start ();
138149 run.wait2 ();
139150 auto stop = std::chrono::high_resolution_clock::now ();
0 commit comments