|
| 1 | +/******************************************************************************* |
| 2 | +* Copyright 2025 Intel Corporation |
| 3 | +* |
| 4 | +* Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +* you may not use this file except in compliance with the License. |
| 6 | +* You may obtain a copy of the License at |
| 7 | +* |
| 8 | +* http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +* |
| 10 | +* Unless required by applicable law or agreed to in writing, software |
| 11 | +* distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +* See the License for the specific language governing permissions and |
| 14 | +* limitations under the License. |
| 15 | +*******************************************************************************/ |
| 16 | + |
| 17 | +#include <cassert> |
| 18 | +#include <chrono> |
| 19 | +#include <iomanip> |
| 20 | +#include <iostream> |
| 21 | +#include <memory> |
| 22 | +#include <random> |
| 23 | +#include <string> |
| 24 | +#include <vector> |
| 25 | + |
| 26 | +#include "oneapi/dnnl/dnnl.hpp" |
| 27 | +#include "oneapi/dnnl/dnnl_graph.hpp" |
| 28 | + |
| 29 | +#include "graph_example_utils.hpp" |
| 30 | + |
| 31 | +using namespace dnnl; |
| 32 | + |
| 33 | +using namespace dnnl::graph; |
| 34 | +using layout_type = logical_tensor::layout_type; |
| 35 | +using data_type = logical_tensor::data_type; |
| 36 | +using dim = logical_tensor::dim; |
| 37 | +using dims = logical_tensor::dims; |
| 38 | + |
| 39 | +struct sdpa_dims_t { |
| 40 | + dim mb; |
| 41 | + dim seq_len; |
| 42 | + dim head_num; |
| 43 | + dim head_size; |
| 44 | +}; |
| 45 | + |
| 46 | +static const int min_runs = 4; |
| 47 | + |
| 48 | +// this is changed from the fill_random() function in matmul_perf.cpp. |
| 49 | +void fill_random(std::vector<float> &out) { |
| 50 | + static std::vector<float> random_data_f; |
| 51 | + constexpr size_t nrand = 1037; |
| 52 | + |
| 53 | + if (random_data_f.empty()) { |
| 54 | + std::mt19937 generator; |
| 55 | + std::uniform_real_distribution<float> dist_f(-1.0f, 1.0f); |
| 56 | + |
| 57 | + random_data_f.resize(nrand); |
| 58 | + for (auto &d : random_data_f) |
| 59 | + d = dist_f(generator); |
| 60 | + } |
| 61 | + |
| 62 | + for (size_t i = 0; i < out.size(); i += nrand) { |
| 63 | + size_t chunk = std::min(nrand, out.size() - i); |
| 64 | + std::memcpy(&out[i], random_data_f.data(), chunk * sizeof(float)); |
| 65 | + } |
| 66 | +} |
| 67 | + |
| 68 | +// initialize the mask with first 3/4 elements with 0s and the last 1/4 elements |
| 69 | +// with -inf. |
| 70 | +void fill_mask(std::vector<float> &mask, size_t seq_len) { |
| 71 | + const size_t pos = seq_len * 3 / 4; |
| 72 | + for (size_t i = 0; i < mask.size(); ++i) { |
| 73 | + if (i % seq_len < pos) |
| 74 | + mask[i] = 0.f; |
| 75 | + else |
| 76 | + mask[i] = -1 * std::numeric_limits<float>::infinity(); |
| 77 | + } |
| 78 | +} |
| 79 | + |
| 80 | +const char *get_type_string(logical_tensor::data_type dt) { |
| 81 | + const char *type_string = "unknown"; |
| 82 | + |
| 83 | +#define TYPE_CASE(T) \ |
| 84 | + if (dt == logical_tensor::data_type::T) type_string = #T; |
| 85 | + TYPE_CASE(f16); |
| 86 | + TYPE_CASE(f32); |
| 87 | + TYPE_CASE(bf16); |
| 88 | + TYPE_CASE(u8); |
| 89 | + TYPE_CASE(s8); |
| 90 | +#undef TYPE_CASE |
| 91 | + |
| 92 | + return type_string; |
| 93 | +} |
| 94 | + |
| 95 | +size_t size_of(logical_tensor::data_type dt) { |
| 96 | + // This example only supports f32, bf16, and f16. |
| 97 | + switch (dt) { |
| 98 | + case logical_tensor::data_type::f32: return 4; |
| 99 | + case logical_tensor::data_type::bf16: |
| 100 | + case logical_tensor::data_type::f16: return 2; |
| 101 | + default: assert(!"unknown data_type"); |
| 102 | + } |
| 103 | + |
| 104 | + return (size_t)-1; /* not supposed to be reachable */ |
| 105 | +} |
| 106 | + |
| 107 | +void print_test_case(logical_tensor::data_type dt, const sdpa_dims_t &p) { |
| 108 | + std::cout << '[' << std::setw(4) << get_type_string(dt); |
| 109 | + std::cout << " mb = " << p.mb << ", seq_len = " << p.seq_len |
| 110 | + << ", head_num = " << p.head_num |
| 111 | + << ", head_size = " << p.head_size; |
| 112 | + std::cout << "] " << std::flush; |
| 113 | +} |
| 114 | + |
| 115 | +void bench_int8_sdpa( |
| 116 | + engine::kind ekind, const sdpa_dims_t &p, double time_limit = 0.) { |
| 117 | + const bool quick_test = (time_limit == 0.); |
| 118 | + print_test_case(data_type::u8, p); |
| 119 | + |
| 120 | + allocator alloc = create_allocator(ekind); |
| 121 | + |
| 122 | + // Create execution dnnl::engine. |
| 123 | + dnnl::engine eng = make_engine_with_allocator(ekind, 0, alloc); |
| 124 | + // Create dnnl::stream. |
| 125 | + dnnl::stream strm(eng); |
| 126 | + |
| 127 | + // Prepare input and output shapes to construct the sdpa graph. |
| 128 | + const dims qkv_sz = {p.mb, p.head_num, p.seq_len, p.head_size}; |
| 129 | + const dims score_sz = {p.mb, p.head_num, p.seq_len, p.seq_len}; |
| 130 | + const dims scale_sz = {1}; |
| 131 | + const dims mask_sz = {p.mb, 1, 1, p.seq_len}; |
| 132 | + |
| 133 | + // Incremental IDs used to create logical tensors and operations. |
| 134 | + size_t id = 0; |
| 135 | + |
| 136 | + // insert the dequant for u8 query to f32 query |
| 137 | + auto q_u8 |
| 138 | + = logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided); |
| 139 | + auto q_f32 = logical_tensor( |
| 140 | + id++, data_type::f32, qkv_sz, layout_type::strided); |
| 141 | + auto q_deq = op(id++, op::kind::Dequantize, "q_deq"); |
| 142 | + q_deq.set_attr<std::string>(op::attr::qtype, "per_tensor"); |
| 143 | + q_deq.set_attr<float>(op::attr::scales, 0.25f); |
| 144 | + q_deq.set_attr<int64_t>(op::attr::zps, 128); |
| 145 | + q_deq.add_input(q_u8); |
| 146 | + q_deq.add_output(q_f32); |
| 147 | + |
| 148 | + // insert the dequant for u8 key to f32 key |
| 149 | + auto k_u8 |
| 150 | + = logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided); |
| 151 | + auto k_f32 = logical_tensor( |
| 152 | + id++, data_type::f32, qkv_sz, layout_type::strided); |
| 153 | + auto k_deq = op(id++, op::kind::Dequantize, "k_deq"); |
| 154 | + k_deq.set_attr<std::string>(op::attr::qtype, "per_tensor"); |
| 155 | + k_deq.set_attr<float>(op::attr::scales, 0.25f); |
| 156 | + k_deq.set_attr<int64_t>(op::attr::zps, 128); |
| 157 | + k_deq.add_input(k_u8); |
| 158 | + k_deq.add_output(k_f32); |
| 159 | + |
| 160 | + // score = query x key.T. |
| 161 | + auto score = logical_tensor( |
| 162 | + id++, data_type::f32, score_sz, layout_type::strided); |
| 163 | + auto bmm1 = op(id++, op::kind::MatMul, "bmm1"); |
| 164 | + bmm1.set_attr<bool>(op::attr::transpose_b, true); |
| 165 | + bmm1.add_inputs({q_f32, k_f32}); |
| 166 | + bmm1.add_output(score); |
| 167 | + |
| 168 | + // scaled_score = score / scale |
| 169 | + auto scale = logical_tensor( |
| 170 | + id++, data_type::f32, scale_sz, layout_type::strided); |
| 171 | + auto scaled_score = logical_tensor( |
| 172 | + id++, data_type::f32, score_sz, layout_type::strided); |
| 173 | + auto scale_div = op(id++, op::kind::Divide, "scale_div"); |
| 174 | + scale_div.add_inputs({score, scale}); |
| 175 | + scale_div.add_outputs({scaled_score}); |
| 176 | + |
| 177 | + // masked_score = scaled_score + mask |
| 178 | + auto mask = logical_tensor( |
| 179 | + id++, data_type::f32, mask_sz, layout_type::strided); |
| 180 | + auto masked_score = logical_tensor( |
| 181 | + id++, data_type::f32, score_sz, layout_type::strided); |
| 182 | + auto mask_add = op(id++, op::kind::Add, "mask_add"); |
| 183 | + mask_add.add_inputs({scaled_score, mask}); |
| 184 | + mask_add.add_outputs({masked_score}); |
| 185 | + |
| 186 | + // attention_probs = softmax(masked_score) |
| 187 | + auto probs = logical_tensor( |
| 188 | + id++, data_type::f32, score_sz, layout_type::strided); |
| 189 | + auto softmax = op(id++, op::kind::SoftMax, "softmax"); |
| 190 | + softmax.set_attr<int64_t>(op::attr::axis, -1); |
| 191 | + softmax.add_inputs({masked_score}); |
| 192 | + softmax.add_outputs({probs}); |
| 193 | + |
| 194 | + // quantize the probs from f32 to u8 |
| 195 | + auto probs_u8 = logical_tensor( |
| 196 | + id++, data_type::u8, score_sz, layout_type::strided); |
| 197 | + auto p_quant = op(id++, op::kind::Quantize, "p_quant"); |
| 198 | + p_quant.set_attr<std::string>(op::attr::qtype, "per_tensor"); |
| 199 | + p_quant.set_attr<float>(op::attr::scales, 0.25f); |
| 200 | + p_quant.set_attr<int64_t>(op::attr::zps, 128); |
| 201 | + p_quant.add_input(probs); |
| 202 | + p_quant.add_output(probs_u8); |
| 203 | + |
| 204 | + // dequant the probs from u8 to f32 |
| 205 | + auto probs_f32 = logical_tensor( |
| 206 | + id++, data_type::f32, score_sz, layout_type::strided); |
| 207 | + auto p_deq = op(id++, op::kind::Dequantize, "p_deq"); |
| 208 | + p_deq.set_attr<std::string>(op::attr::qtype, "per_tensor"); |
| 209 | + p_deq.set_attr<float>(op::attr::scales, 0.25f); |
| 210 | + p_deq.set_attr<int64_t>(op::attr::zps, 128); |
| 211 | + p_deq.add_input(probs_u8); |
| 212 | + p_deq.add_output(probs_f32); |
| 213 | + |
| 214 | + // dequant the value from u8 to f32 |
| 215 | + auto v_u8 |
| 216 | + = logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided); |
| 217 | + auto v_f32 = logical_tensor( |
| 218 | + id++, data_type::f32, qkv_sz, layout_type::strided); |
| 219 | + auto v_deq = op(id++, op::kind::Dequantize, "v_deq"); |
| 220 | + v_deq.set_attr<std::string>(op::attr::qtype, "per_tensor"); |
| 221 | + v_deq.set_attr<float>(op::attr::scales, 0.25f); |
| 222 | + v_deq.set_attr<int64_t>(op::attr::zps, 128); |
| 223 | + v_deq.add_input(v_u8); |
| 224 | + v_deq.add_output(v_f32); |
| 225 | + |
| 226 | + // attention_output = attention_probs x value. |
| 227 | + auto output = logical_tensor( |
| 228 | + id++, data_type::f32, qkv_sz, layout_type::strided); |
| 229 | + auto bmm2 = op(id++, op::kind::MatMul, "bmm2"); |
| 230 | + bmm2.add_inputs({probs_f32, v_f32}); |
| 231 | + bmm2.add_outputs({output}); |
| 232 | + |
| 233 | + // quantize the output from f32 to u8 |
| 234 | + auto output_u8 |
| 235 | + = logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided); |
| 236 | + auto o_quant = op(id++, op::kind::Quantize, "o_quant"); |
| 237 | + o_quant.set_attr<std::string>(op::attr::qtype, "per_tensor"); |
| 238 | + o_quant.set_attr<float>(op::attr::scales, 0.25f); |
| 239 | + o_quant.set_attr<int64_t>(op::attr::zps, 128); |
| 240 | + o_quant.add_input(output); |
| 241 | + o_quant.add_output(output_u8); |
| 242 | + |
| 243 | + // Construct a sdpa graph with engine kind and operations. |
| 244 | + dnnl::graph::graph sdpa(ekind); |
| 245 | + sdpa.add_op(q_deq); |
| 246 | + sdpa.add_op(k_deq); |
| 247 | + sdpa.add_op(bmm1); |
| 248 | + sdpa.add_op(scale_div); |
| 249 | + sdpa.add_op(mask_add); |
| 250 | + sdpa.add_op(softmax); |
| 251 | + sdpa.add_op(p_quant); |
| 252 | + sdpa.add_op(p_deq); |
| 253 | + sdpa.add_op(v_deq); |
| 254 | + sdpa.add_op(bmm2); |
| 255 | + sdpa.add_op(o_quant); |
| 256 | + sdpa.finalize(); |
| 257 | + |
| 258 | + // Get partitions from the sdpa graph. |
| 259 | + std::vector<partition> partitions = sdpa.get_partitions(); |
| 260 | + // This is just for oneDNN testing purpose. |
| 261 | + if (partitions.size() != 1) { |
| 262 | + std::cout << "unsupported sdpa" << std::endl; |
| 263 | + return; |
| 264 | + } |
| 265 | + |
| 266 | + // Compile the partition with inputs, outputs, and an engine. |
| 267 | + compiled_partition cp = partitions[0].compile( |
| 268 | + {q_u8, k_u8, scale, mask, v_u8}, {output_u8}, eng); |
| 269 | + |
| 270 | + // Create tensor objects |
| 271 | + auto ts_query = tensor(q_u8, eng); |
| 272 | + auto ts_key = tensor(k_u8, eng); |
| 273 | + auto ts_scale = tensor(scale, eng); |
| 274 | + auto ts_mask = tensor(mask, eng); |
| 275 | + auto ts_value = tensor(v_u8, eng); |
| 276 | + auto ts_output = tensor(output_u8, eng); |
| 277 | + |
| 278 | + // Allocate user data. |
| 279 | + std::vector<float> query_data(product(qkv_sz)); |
| 280 | + std::vector<float> key_data(product(qkv_sz)); |
| 281 | + std::vector<float> scale_data(product(scale_sz), std::sqrt(p.head_size)); |
| 282 | + std::vector<float> mask_data(product(mask_sz)); |
| 283 | + std::vector<float> value_data(product(qkv_sz)); |
| 284 | + std::vector<float> output_data(product(qkv_sz)); |
| 285 | + |
| 286 | + fill_random(query_data); |
| 287 | + fill_random(key_data); |
| 288 | + fill_random(value_data); |
| 289 | + fill_mask(mask_data, static_cast<size_t>(p.seq_len)); |
| 290 | + |
| 291 | + // Write data to tensor object's handle. |
| 292 | + write_to_dnnl_tensor(query_data.data(), ts_query); |
| 293 | + write_to_dnnl_tensor(key_data.data(), ts_key); |
| 294 | + write_to_dnnl_tensor(scale_data.data(), ts_scale); |
| 295 | + write_to_dnnl_tensor(mask_data.data(), ts_mask); |
| 296 | + write_to_dnnl_tensor(value_data.data(), ts_value); |
| 297 | + |
| 298 | + // Warmup run. |
| 299 | + // Execute the compiled partition of sdpa. |
| 300 | + cp.execute( |
| 301 | + strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output}); |
| 302 | + |
| 303 | + // Wait for the computation to finish. |
| 304 | + strm.wait(); |
| 305 | + |
| 306 | + // First run. |
| 307 | + auto start_first = std::chrono::steady_clock::now(); |
| 308 | + cp.execute( |
| 309 | + strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output}); |
| 310 | + strm.wait(); |
| 311 | + auto end_first = std::chrono::steady_clock::now(); |
| 312 | + std::chrono::duration<double, std::milli> dur_first |
| 313 | + = end_first - start_first; |
| 314 | + |
| 315 | + if (quick_test) return; |
| 316 | + |
| 317 | + // Timing runs. |
| 318 | + const int runs = std::max(min_runs, int(time_limit / dur_first.count())); |
| 319 | + auto start = std::chrono::steady_clock::now(); |
| 320 | + for (int i = 0; i <= runs; i++) |
| 321 | + cp.execute(strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, |
| 322 | + {ts_output}); |
| 323 | + strm.wait(); |
| 324 | + auto end = std::chrono::steady_clock::now(); |
| 325 | + std::chrono::duration<double, std::milli> duration = end - start; |
| 326 | + |
| 327 | + // Display the results. |
| 328 | + double avg_time = (duration.count() - dur_first.count()) / runs; |
| 329 | + std::cout << "graph runs: " << runs + 1 << "; "; |
| 330 | + std::cout << "avg_time: " << avg_time << " ms" << std::endl; |
| 331 | +} |
| 332 | + |
| 333 | +void bad_args() { |
| 334 | + std::cerr << "Usage: graph-int8-sdpa-cpp [cpu|gpu]\n" |
| 335 | + " graph-int8-sdpa-cpp [cpu|gpu] <mb> <seq_len> " |
| 336 | + "<head_num> <head_size>\n\n"; |
| 337 | + throw std::invalid_argument("Incorrect input arguments."); |
| 338 | +} |
| 339 | + |
| 340 | +void bench(engine::kind ekind, const sdpa_dims_t &p, double time_limit = 0.) { |
| 341 | + try { |
| 342 | + bench_int8_sdpa(ekind, p, time_limit); |
| 343 | + get_mem_pool().clear(); |
| 344 | + } catch (dnnl::error &e) { |
| 345 | + // Catch and report unimplemented cases. |
| 346 | + if (e.status == dnnl_unimplemented) { |
| 347 | + std::cout << "unsupported sdpa: " << std::endl; |
| 348 | + } else |
| 349 | + throw; |
| 350 | + } |
| 351 | +} |
| 352 | + |
| 353 | +void sdpa_perf(engine::kind ekind, int argc, char **argv) { |
| 354 | + // default testing parameters |
| 355 | + sdpa_dims_t params = {32, 384, 16, 64}; |
| 356 | + |
| 357 | + if (argc > 2) { |
| 358 | + if (argc == 6) { |
| 359 | + params.mb = std::atoi(argv[2]); |
| 360 | + params.seq_len = std::atoi(argv[3]); |
| 361 | + params.head_num = std::atoi(argv[4]); |
| 362 | + params.head_size = std::atoi(argv[5]); |
| 363 | + } else { |
| 364 | + bad_args(); |
| 365 | + } |
| 366 | + |
| 367 | + if (params.mb <= 0 || params.seq_len <= 0 || params.head_num <= 0 |
| 368 | + || params.head_size <= 0) { |
| 369 | + bad_args(); |
| 370 | + } |
| 371 | + } |
| 372 | + |
| 373 | + bench(ekind, params, 2000.0 /*ms*/); |
| 374 | +} |
| 375 | + |
| 376 | +int main(int argc, char **argv) { |
| 377 | + return handle_example_errors( |
| 378 | + sdpa_perf, parse_engine_kind(argc, argv, 4), argc, argv); |
| 379 | +} |
0 commit comments