Skip to content

Commit 85fa4a9

Browse files
committed
examples: graph: add int8 sdpa example
1 parent 8c8eb43 commit 85fa4a9

File tree

1 file changed

+379
-0
lines changed

1 file changed

+379
-0
lines changed

examples/graph/int8_sdpa.cpp

+379
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
/*******************************************************************************
2+
* Copyright 2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include <cassert>
18+
#include <chrono>
19+
#include <iomanip>
20+
#include <iostream>
21+
#include <memory>
22+
#include <random>
23+
#include <string>
24+
#include <vector>
25+
26+
#include "oneapi/dnnl/dnnl.hpp"
27+
#include "oneapi/dnnl/dnnl_graph.hpp"
28+
29+
#include "graph_example_utils.hpp"
30+
31+
using namespace dnnl;
32+
33+
using namespace dnnl::graph;
34+
using layout_type = logical_tensor::layout_type;
35+
using data_type = logical_tensor::data_type;
36+
using dim = logical_tensor::dim;
37+
using dims = logical_tensor::dims;
38+
39+
struct sdpa_dims_t {
40+
dim mb;
41+
dim seq_len;
42+
dim head_num;
43+
dim head_size;
44+
};
45+
46+
static const int min_runs = 4;
47+
48+
// this is changed from the fill_random() function in matmul_perf.cpp.
49+
void fill_random(std::vector<float> &out) {
50+
static std::vector<float> random_data_f;
51+
constexpr size_t nrand = 1037;
52+
53+
if (random_data_f.empty()) {
54+
std::mt19937 generator;
55+
std::uniform_real_distribution<float> dist_f(-1.0f, 1.0f);
56+
57+
random_data_f.resize(nrand);
58+
for (auto &d : random_data_f)
59+
d = dist_f(generator);
60+
}
61+
62+
for (size_t i = 0; i < out.size(); i += nrand) {
63+
size_t chunk = std::min(nrand, out.size() - i);
64+
std::memcpy(&out[i], random_data_f.data(), chunk * sizeof(float));
65+
}
66+
}
67+
68+
// initialize the mask with first 3/4 elements with 0s and the last 1/4 elements
69+
// with -inf.
70+
void fill_mask(std::vector<float> &mask, size_t seq_len) {
71+
const size_t pos = seq_len * 3 / 4;
72+
for (size_t i = 0; i < mask.size(); ++i) {
73+
if (i % seq_len < pos)
74+
mask[i] = 0.f;
75+
else
76+
mask[i] = -1 * std::numeric_limits<float>::infinity();
77+
}
78+
}
79+
80+
const char *get_type_string(logical_tensor::data_type dt) {
81+
const char *type_string = "unknown";
82+
83+
#define TYPE_CASE(T) \
84+
if (dt == logical_tensor::data_type::T) type_string = #T;
85+
TYPE_CASE(f16);
86+
TYPE_CASE(f32);
87+
TYPE_CASE(bf16);
88+
TYPE_CASE(u8);
89+
TYPE_CASE(s8);
90+
#undef TYPE_CASE
91+
92+
return type_string;
93+
}
94+
95+
size_t size_of(logical_tensor::data_type dt) {
96+
// This example only supports f32, bf16, and f16.
97+
switch (dt) {
98+
case logical_tensor::data_type::f32: return 4;
99+
case logical_tensor::data_type::bf16:
100+
case logical_tensor::data_type::f16: return 2;
101+
default: assert(!"unknown data_type");
102+
}
103+
104+
return (size_t)-1; /* not supposed to be reachable */
105+
}
106+
107+
void print_test_case(logical_tensor::data_type dt, const sdpa_dims_t &p) {
108+
std::cout << '[' << std::setw(4) << get_type_string(dt);
109+
std::cout << " mb = " << p.mb << ", seq_len = " << p.seq_len
110+
<< ", head_num = " << p.head_num
111+
<< ", head_size = " << p.head_size;
112+
std::cout << "] " << std::flush;
113+
}
114+
115+
void bench_int8_sdpa(
116+
engine::kind ekind, const sdpa_dims_t &p, double time_limit = 0.) {
117+
const bool quick_test = (time_limit == 0.);
118+
print_test_case(data_type::u8, p);
119+
120+
allocator alloc = create_allocator(ekind);
121+
122+
// Create execution dnnl::engine.
123+
dnnl::engine eng = make_engine_with_allocator(ekind, 0, alloc);
124+
// Create dnnl::stream.
125+
dnnl::stream strm(eng);
126+
127+
// Prepare input and output shapes to construct the sdpa graph.
128+
const dims qkv_sz = {p.mb, p.head_num, p.seq_len, p.head_size};
129+
const dims score_sz = {p.mb, p.head_num, p.seq_len, p.seq_len};
130+
const dims scale_sz = {1};
131+
const dims mask_sz = {p.mb, 1, 1, p.seq_len};
132+
133+
// Incremental IDs used to create logical tensors and operations.
134+
size_t id = 0;
135+
136+
// insert the dequant for u8 query to f32 query
137+
auto q_u8
138+
= logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided);
139+
auto q_f32 = logical_tensor(
140+
id++, data_type::f32, qkv_sz, layout_type::strided);
141+
auto q_deq = op(id++, op::kind::Dequantize, "q_deq");
142+
q_deq.set_attr<std::string>(op::attr::qtype, "per_tensor");
143+
q_deq.set_attr<float>(op::attr::scales, 0.25f);
144+
q_deq.set_attr<int64_t>(op::attr::zps, 128);
145+
q_deq.add_input(q_u8);
146+
q_deq.add_output(q_f32);
147+
148+
// insert the dequant for u8 key to f32 key
149+
auto k_u8
150+
= logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided);
151+
auto k_f32 = logical_tensor(
152+
id++, data_type::f32, qkv_sz, layout_type::strided);
153+
auto k_deq = op(id++, op::kind::Dequantize, "k_deq");
154+
k_deq.set_attr<std::string>(op::attr::qtype, "per_tensor");
155+
k_deq.set_attr<float>(op::attr::scales, 0.25f);
156+
k_deq.set_attr<int64_t>(op::attr::zps, 128);
157+
k_deq.add_input(k_u8);
158+
k_deq.add_output(k_f32);
159+
160+
// score = query x key.T.
161+
auto score = logical_tensor(
162+
id++, data_type::f32, score_sz, layout_type::strided);
163+
auto bmm1 = op(id++, op::kind::MatMul, "bmm1");
164+
bmm1.set_attr<bool>(op::attr::transpose_b, true);
165+
bmm1.add_inputs({q_f32, k_f32});
166+
bmm1.add_output(score);
167+
168+
// scaled_score = score / scale
169+
auto scale = logical_tensor(
170+
id++, data_type::f32, scale_sz, layout_type::strided);
171+
auto scaled_score = logical_tensor(
172+
id++, data_type::f32, score_sz, layout_type::strided);
173+
auto scale_div = op(id++, op::kind::Divide, "scale_div");
174+
scale_div.add_inputs({score, scale});
175+
scale_div.add_outputs({scaled_score});
176+
177+
// masked_score = scaled_score + mask
178+
auto mask = logical_tensor(
179+
id++, data_type::f32, mask_sz, layout_type::strided);
180+
auto masked_score = logical_tensor(
181+
id++, data_type::f32, score_sz, layout_type::strided);
182+
auto mask_add = op(id++, op::kind::Add, "mask_add");
183+
mask_add.add_inputs({scaled_score, mask});
184+
mask_add.add_outputs({masked_score});
185+
186+
// attention_probs = softmax(masked_score)
187+
auto probs = logical_tensor(
188+
id++, data_type::f32, score_sz, layout_type::strided);
189+
auto softmax = op(id++, op::kind::SoftMax, "softmax");
190+
softmax.set_attr<int64_t>(op::attr::axis, -1);
191+
softmax.add_inputs({masked_score});
192+
softmax.add_outputs({probs});
193+
194+
// quantize the probs from f32 to u8
195+
auto probs_u8 = logical_tensor(
196+
id++, data_type::u8, score_sz, layout_type::strided);
197+
auto p_quant = op(id++, op::kind::Quantize, "p_quant");
198+
p_quant.set_attr<std::string>(op::attr::qtype, "per_tensor");
199+
p_quant.set_attr<float>(op::attr::scales, 0.25f);
200+
p_quant.set_attr<int64_t>(op::attr::zps, 128);
201+
p_quant.add_input(probs);
202+
p_quant.add_output(probs_u8);
203+
204+
// dequant the probs from u8 to f32
205+
auto probs_f32 = logical_tensor(
206+
id++, data_type::f32, score_sz, layout_type::strided);
207+
auto p_deq = op(id++, op::kind::Dequantize, "p_deq");
208+
p_deq.set_attr<std::string>(op::attr::qtype, "per_tensor");
209+
p_deq.set_attr<float>(op::attr::scales, 0.25f);
210+
p_deq.set_attr<int64_t>(op::attr::zps, 128);
211+
p_deq.add_input(probs_u8);
212+
p_deq.add_output(probs_f32);
213+
214+
// dequant the value from u8 to f32
215+
auto v_u8
216+
= logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided);
217+
auto v_f32 = logical_tensor(
218+
id++, data_type::f32, qkv_sz, layout_type::strided);
219+
auto v_deq = op(id++, op::kind::Dequantize, "v_deq");
220+
v_deq.set_attr<std::string>(op::attr::qtype, "per_tensor");
221+
v_deq.set_attr<float>(op::attr::scales, 0.25f);
222+
v_deq.set_attr<int64_t>(op::attr::zps, 128);
223+
v_deq.add_input(v_u8);
224+
v_deq.add_output(v_f32);
225+
226+
// attention_output = attention_probs x value.
227+
auto output = logical_tensor(
228+
id++, data_type::f32, qkv_sz, layout_type::strided);
229+
auto bmm2 = op(id++, op::kind::MatMul, "bmm2");
230+
bmm2.add_inputs({probs_f32, v_f32});
231+
bmm2.add_outputs({output});
232+
233+
// quantize the output from f32 to u8
234+
auto output_u8
235+
= logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided);
236+
auto o_quant = op(id++, op::kind::Quantize, "o_quant");
237+
o_quant.set_attr<std::string>(op::attr::qtype, "per_tensor");
238+
o_quant.set_attr<float>(op::attr::scales, 0.25f);
239+
o_quant.set_attr<int64_t>(op::attr::zps, 128);
240+
o_quant.add_input(output);
241+
o_quant.add_output(output_u8);
242+
243+
// Construct a sdpa graph with engine kind and operations.
244+
dnnl::graph::graph sdpa(ekind);
245+
sdpa.add_op(q_deq);
246+
sdpa.add_op(k_deq);
247+
sdpa.add_op(bmm1);
248+
sdpa.add_op(scale_div);
249+
sdpa.add_op(mask_add);
250+
sdpa.add_op(softmax);
251+
sdpa.add_op(p_quant);
252+
sdpa.add_op(p_deq);
253+
sdpa.add_op(v_deq);
254+
sdpa.add_op(bmm2);
255+
sdpa.add_op(o_quant);
256+
sdpa.finalize();
257+
258+
// Get partitions from the sdpa graph.
259+
std::vector<partition> partitions = sdpa.get_partitions();
260+
// This is just for oneDNN testing purpose.
261+
if (partitions.size() != 1) {
262+
std::cout << "unsupported sdpa" << std::endl;
263+
return;
264+
}
265+
266+
// Compile the partition with inputs, outputs, and an engine.
267+
compiled_partition cp = partitions[0].compile(
268+
{q_u8, k_u8, scale, mask, v_u8}, {output_u8}, eng);
269+
270+
// Create tensor objects
271+
auto ts_query = tensor(q_u8, eng);
272+
auto ts_key = tensor(k_u8, eng);
273+
auto ts_scale = tensor(scale, eng);
274+
auto ts_mask = tensor(mask, eng);
275+
auto ts_value = tensor(v_u8, eng);
276+
auto ts_output = tensor(output_u8, eng);
277+
278+
// Allocate user data.
279+
std::vector<float> query_data(product(qkv_sz));
280+
std::vector<float> key_data(product(qkv_sz));
281+
std::vector<float> scale_data(product(scale_sz), std::sqrt(p.head_size));
282+
std::vector<float> mask_data(product(mask_sz));
283+
std::vector<float> value_data(product(qkv_sz));
284+
std::vector<float> output_data(product(qkv_sz));
285+
286+
fill_random(query_data);
287+
fill_random(key_data);
288+
fill_random(value_data);
289+
fill_mask(mask_data, static_cast<size_t>(p.seq_len));
290+
291+
// Write data to tensor object's handle.
292+
write_to_dnnl_tensor(query_data.data(), ts_query);
293+
write_to_dnnl_tensor(key_data.data(), ts_key);
294+
write_to_dnnl_tensor(scale_data.data(), ts_scale);
295+
write_to_dnnl_tensor(mask_data.data(), ts_mask);
296+
write_to_dnnl_tensor(value_data.data(), ts_value);
297+
298+
// Warmup run.
299+
// Execute the compiled partition of sdpa.
300+
cp.execute(
301+
strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output});
302+
303+
// Wait for the computation to finish.
304+
strm.wait();
305+
306+
// First run.
307+
auto start_first = std::chrono::steady_clock::now();
308+
cp.execute(
309+
strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output});
310+
strm.wait();
311+
auto end_first = std::chrono::steady_clock::now();
312+
std::chrono::duration<double, std::milli> dur_first
313+
= end_first - start_first;
314+
315+
if (quick_test) return;
316+
317+
// Timing runs.
318+
const int runs = std::max(min_runs, int(time_limit / dur_first.count()));
319+
auto start = std::chrono::steady_clock::now();
320+
for (int i = 0; i <= runs; i++)
321+
cp.execute(strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value},
322+
{ts_output});
323+
strm.wait();
324+
auto end = std::chrono::steady_clock::now();
325+
std::chrono::duration<double, std::milli> duration = end - start;
326+
327+
// Display the results.
328+
double avg_time = (duration.count() - dur_first.count()) / runs;
329+
std::cout << "graph runs: " << runs + 1 << "; ";
330+
std::cout << "avg_time: " << avg_time << " ms" << std::endl;
331+
}
332+
333+
void bad_args() {
334+
std::cerr << "Usage: graph-int8-sdpa-cpp [cpu|gpu]\n"
335+
" graph-int8-sdpa-cpp [cpu|gpu] <mb> <seq_len> "
336+
"<head_num> <head_size>\n\n";
337+
throw std::invalid_argument("Incorrect input arguments.");
338+
}
339+
340+
void bench(engine::kind ekind, const sdpa_dims_t &p, double time_limit = 0.) {
341+
try {
342+
bench_int8_sdpa(ekind, p, time_limit);
343+
get_mem_pool().clear();
344+
} catch (dnnl::error &e) {
345+
// Catch and report unimplemented cases.
346+
if (e.status == dnnl_unimplemented) {
347+
std::cout << "unsupported sdpa: " << std::endl;
348+
} else
349+
throw;
350+
}
351+
}
352+
353+
void sdpa_perf(engine::kind ekind, int argc, char **argv) {
354+
// default testing parameters
355+
sdpa_dims_t params = {32, 384, 16, 64};
356+
357+
if (argc > 2) {
358+
if (argc == 6) {
359+
params.mb = std::atoi(argv[2]);
360+
params.seq_len = std::atoi(argv[3]);
361+
params.head_num = std::atoi(argv[4]);
362+
params.head_size = std::atoi(argv[5]);
363+
} else {
364+
bad_args();
365+
}
366+
367+
if (params.mb <= 0 || params.seq_len <= 0 || params.head_num <= 0
368+
|| params.head_size <= 0) {
369+
bad_args();
370+
}
371+
}
372+
373+
bench(ekind, params, 2000.0 /*ms*/);
374+
}
375+
376+
int main(int argc, char **argv) {
377+
return handle_example_errors(
378+
sdpa_perf, parse_engine_kind(argc, argv, 4), argc, argv);
379+
}

0 commit comments

Comments
 (0)