Skip to content

[GPU] Fix accuracy issues for mvn and sdpa_micro #30698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,21 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
#if WITH_ATTN_MASK
/* Load mask. No remainder handling needed assuming k block size is a power of 2. */
mask_tile_type mask_tile;
tile_load_t(&mask_tile, msk, q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq);

// Check if attention mask has a single Query dimension (e.g., [batch, num_heads, 1, sequence_length])
if (MSK_D2 == 1) {
// Define mask dimensions for single Query dimension
uint mask_m = MSK_D1; // num_heads
uint mask_n = MSK_D3; // sequence_length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why num_heads? masks are usually [q, seqlen]..


tile_load_t(&mask_tile, msk, mask_m, mask_n, 0, k0 + sg_i0_kq);
} else {
// General case: attention mask matches Q*K^T shape
uint mask_m = q; // Q sequence length
uint mask_n = k; // K sequence length

tile_load_t(&mask_tile, msk, mask_m, mask_n, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
}
#endif

#if REMAINDER_K
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,22 @@ JitConstants MVNKernelBfyxOpt::GetJitConstants(const mvn_params& params, MVNKern
"((in_data_set_idx + iteration_in_data_set_offset) % OUTPUT_SIZE_X)" };
}
}
auto conf = FusedOpsConfiguration("", idx_order, "result", activation_dt, 1, LoadType::LT_UNALIGNED, BoundaryCheck::DISABLED);
// Calculate total work items and maximum addressable range
size_t total_work_items = dispatchData.gws[0] * dispatchData.gws[1] * dispatchData.gws[2];
size_t max_addressable_range = dispatchData.dataSetSize * dispatchData.dataSetsCount;

// Determine if Boundary Check is needed
bool exceeds_boundary = total_work_items > max_addressable_range;

// Dynamic Shape: Always enable Boundary Check
BoundaryCheck boundary_check_mode = params.has_dynamic_tensors() || exceeds_boundary
? BoundaryCheck::ENABLED
: BoundaryCheck::DISABLED;

// Configure FusedOps with the determined BoundaryCheck mode
auto conf = FusedOpsConfiguration(
"", idx_order, "result", activation_dt, 1, LoadType::LT_UNALIGNED, boundary_check_mode);

jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
}

Expand Down
69 changes: 69 additions & 0 deletions src/plugins/intel_gpu/tests/unit/test_cases/mvn_gpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,3 +974,72 @@ TEST_P(mvn_random_test_bsv32, random_cached) {
this->execute(GetParam(), true);
}
#endif

TEST(mvn_gpu_test, fusion_mul_add) {

auto run_network = [&](bool use_opt_kernel = false) -> cldnn::memory::ptr {
auto& engine = get_test_engine();
const size_t b_length = 2;
const size_t f_length = 10;
const size_t y_length = 32;
cldnn::layout input0_dyn_layout({-1,-1,y_length}, data_types::f16, format::bfyx);
cldnn::layout input1_dyn_layout({-1,-1,y_length}, data_types::f16, format::bfyx);
cldnn::layout input2_dyn_layout({-1,-1,y_length}, data_types::f16, format::bfyx);

cldnn::layout input0_static_layout({b_length, f_length, y_length}, data_types::f16, format::bfyx);
cldnn::layout input1_static_layout({b_length, 1, y_length}, data_types::f16, format::bfyx);
cldnn::layout input2_static_layout({b_length, 1, y_length}, data_types::f16, format::bfyx);

auto input0 = engine.allocate_memory(input0_static_layout);
auto input1 = engine.allocate_memory(input1_static_layout);
auto input2 = engine.allocate_memory(input2_static_layout);

std::vector<ov::float16> input0_values(input0->count(), 0.f);
std::vector<ov::float16> input1_values(input1->count(), 2.f);
std::vector<ov::float16> input2_values(input2->count(), 0.1f);
for (size_t i = 0; i < input0_values.size(); i++) {
input0_values[i] = ov::float16(i * 0.03f);
}
set_values(input0, input0_values);
set_values(input1, input1_values);
set_values(input0, input2_values);

topology topo;
topo.add(input_layout("input0", input0_dyn_layout));
topo.add(input_layout("input1", input1_dyn_layout));//Gather, ADD_1
topo.add(input_layout("input2", input2_dyn_layout));//Gather_1
topo.add(mvn("mvn", input_info("input0"), true, 1e-06f, true, {2}));
topo.add(eltwise("mul", {input_info("mvn"), input_info("input1")}, eltwise_mode::prod, {}, data_types::f16));
topo.add(eltwise("add", {input_info("mul"), input_info("input2")}, eltwise_mode::sum, {}, data_types::f16));
topo.add(reorder("result",input_info("add"), format::bfyx, data_types::f32));

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));

if (use_opt_kernel) {
config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"mvn", {format::type::bfyx, "mvn_gpu_bfyx_opt"}} }));
} else {
config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"mvn", {format::type::bfyx, "mvn_gpu_ref"}} }));
}

cldnn::network::ptr net = get_network(engine, topo, config, get_test_stream_ptr(), false);

net->set_input_data("input0", input0);
net->set_input_data("input1", input1);
net->set_input_data("input2", input2);

auto outputs = net->execute();
auto output = outputs.at("result").get_memory();
return output;
};

auto mem_ref_ptr = run_network(false);
auto mem_opt_ptr = run_network(true);
cldnn::mem_lock<float> ref_data(mem_ref_ptr, get_test_stream());
cldnn::mem_lock<float> opt_data(mem_opt_ptr, get_test_stream());

ASSERT_EQ(ref_data.size(), opt_data.size());
for (size_t i = 0; i < ref_data.size(); i++) {
ASSERT_NEAR(static_cast<float>(ref_data[i]), static_cast<float>(opt_data[i]), 1.e-1f);
}
}
185 changes: 185 additions & 0 deletions src/plugins/intel_gpu/tests/unit/test_cases/sdpa_gpu_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "test_utils.h"
#include "random_generator.hpp"

#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/reorder.hpp>
#include <intel_gpu/primitives/eltwise.hpp>
#include <intel_gpu/runtime/debug_configuration.hpp>

#include "openvino/util/file_util.hpp"
#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <iostream>

#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/scaled_dot_product_attention.hpp>

#include <cstddef>
#include <vector>

using namespace cldnn;
using namespace ::tests;

namespace {
// #ifdef ENABLE_ONEDNN_FOR_GPU
// Disable onednn test because onednn does not support format_tag::cbda, format_tag::badc.


struct spda_gpu_test {

// const std::string opt_data_path = "/home/ahnyoung/cldnn/cvs_164660/dumps/outs/gpu.fp16.sdpa.ref.raw/";
const std::string opt_data_path = "/home/ahnyoung/cldnn/cvs_164660/dumps/outs/gpu.fp16.sdpa.micro.raw/";

void load_input(cldnn::memory::ptr mem, size_t idx) {
std::vector<std::string> bin_names = {
"program1_network1_0_sdpa___module.transformer_blocks.0.attn2_aten__scaled_dot_product_attention_ScaledDotProductAttention_dst0__f16__2_32_990_64__bfyx.bin",
"program1_network1_0_sdpa___module.transformer_blocks.0.attn2_aten__scaled_dot_product_attention_ScaledDotProductAttention_src0__f16__2_990_32_64__bfyx.bin",
"program1_network1_0_sdpa___module.transformer_blocks.0.attn2_aten__scaled_dot_product_attention_ScaledDotProductAttention_src1__f16__2_128_32_64__bfyx.bin",
"program1_network1_0_sdpa___module.transformer_blocks.0.attn2_aten__scaled_dot_product_attention_ScaledDotProductAttention_src2__f16__2_128_32_64__bfyx.bin",
"program1_network1_0_sdpa___module.transformer_blocks.0.attn2_aten__scaled_dot_product_attention_ScaledDotProductAttention_src3__f16__2_32_1_128__bfyx.bin"
};
std::string input_file_name = opt_data_path + bin_names[idx];
load_data_from_bin(mem, input_file_name);
}

void load_data_from_bin(cldnn::memory::ptr mem, const std::string filepath) {
GPU_DEBUG_COUT << "Load data from " << filepath << std::endl;
std::vector<uint8_t> bin = ov::util::load_binary(filepath);
mem->copy_from(get_test_stream(), static_cast<void *>(&bin[0]), true);
}

cldnn::memory::ptr run_network(bool is_caching_test, bool use_micro_sdpa = false) {
auto& engine = get_test_engine();
cldnn::layout input0_dyn_layout({-1, -1, 32, 64}, data_types::f16, format::bfyx);
cldnn::layout input1_dyn_layout({-1, -1, 32, 64}, data_types::f16, format::bfyx);
cldnn::layout input2_dyn_layout({-1, -1, 32, 64}, data_types::f16, format::bfyx);
cldnn::layout input3_dyn_layout({-1, 32, -1, -1}, data_types::f16, format::bfyx);

cldnn::layout input0_static_layout({2, 990, 32, 64}, data_types::f16, format::bfyx);
cldnn::layout input1_static_layout({2, 128, 32, 64}, data_types::f16, format::bfyx);
cldnn::layout input2_static_layout({2, 128, 32, 64}, data_types::f16, format::bfyx);
cldnn::layout input3_static_layout({2, 32, 1, 128}, data_types::f16, format::bfyx);

auto input0 = engine.allocate_memory(input0_static_layout);
auto input1 = engine.allocate_memory(input1_static_layout);
auto input2 = engine.allocate_memory(input2_static_layout);
auto input3 = engine.allocate_memory(input3_static_layout);

load_input(input0, 0);
load_input(input1, 1);
load_input(input2, 2);
load_input(input3, 3);

GPU_DEBUG_COUT << "Topology: SDPA kernel test " << std::endl;
GPU_DEBUG_COUT << "* use micro_sdpa : " << (use_micro_sdpa ? "Yes" : "No") << std::endl;
GPU_DEBUG_COUT << "* input0 : " << input0_static_layout.to_short_string() << ", " << input0_static_layout.count() << std::endl;
GPU_DEBUG_COUT << "* input1 : " << input1_static_layout.to_short_string() << ", " << input1_static_layout.count() << std::endl;
GPU_DEBUG_COUT << "* input2 : " << input2_static_layout.to_short_string() << ", " << input2_static_layout.count() << std::endl;
GPU_DEBUG_COUT << "* input3 : " << input3_static_layout.to_short_string() << ", " << input3_static_layout.count() << std::endl;

topology topo;
topo.add(input_layout("input0", input0_dyn_layout));
topo.add(input_layout("input1", input1_dyn_layout));
topo.add(input_layout("input2", input2_dyn_layout));
topo.add(input_layout("input3", input3_dyn_layout));
topo.add(scaled_dot_product_attention("sdpa", {input_info("input0"), input_info("input1"), input_info("input2"), input_info("input3")},
false, -1, {0,2,1,3}, {0,2,1,3}, {0,2,1,3}, {0,1,2,3}, {}, false));
topo.add(reorder("result",input_info("sdpa"), format::bfyx, data_types::f16));

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));

if (use_micro_sdpa) {
config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"sdpa", {format::type::bfyx, "sdpa_micro"}} }));
config.set_property(ov::intel_gpu::dump_iterations(std::set<int64_t>{0, 1}));
config.set_property(ov::intel_gpu::dump_tensors("all"));
config.set_property(ov::intel_gpu::dump_tensors_path("/home/ahnyoung/cldnn/cvs_164660/dumps/outs/units/"));
} else {
config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"sdpa", {format::type::bfyx, "sdpa_ref"}} }));
}

cldnn::network::ptr net = get_network(engine, topo, config, get_test_stream_ptr(), is_caching_test);

net->set_input_data("input0", input0);
net->set_input_data("input1", input1);
net->set_input_data("input2", input2);
net->set_input_data("input3", input3);

auto outputs = net->execute();
auto output = outputs.at("result").get_memory();
return output;
}

void execute(bool is_caching_test = false) {
GPU_DEBUG_COUT << "********************************************************************************" << std::endl;
GPU_DEBUG_COUT << "********************************************************************************" << std::endl;
auto mem_ref_ptr = run_network(is_caching_test, false);
GPU_DEBUG_COUT << "********************************************************************************" << std::endl;
GPU_DEBUG_COUT << "********************************************************************************" << std::endl;
auto mem_opt_ptr = run_network(is_caching_test, true);
GPU_DEBUG_COUT << "********************************************************************************" << std::endl;
GPU_DEBUG_COUT << "********************************************************************************" << std::endl;
cldnn::mem_lock<ov::float16, mem_lock_type::read> ref_data(mem_ref_ptr, get_test_stream());
cldnn::mem_lock<ov::float16, mem_lock_type::read> opt_data(mem_opt_ptr, get_test_stream());
// if (ret < 0.9f) {
{
std::vector<std::pair<size_t, ov::float16>> differences;
for (size_t idx = 0; idx < ref_data.size(); idx++) {
if (std::isnan(opt_data[idx])) {
GPU_DEBUG_COUT << "opt_data has nan " << opt_data[idx] << std::endl;
}
if (std::isnan(ref_data[idx])) {
GPU_DEBUG_COUT << "ref_data has nan " << ref_data[idx] << std::endl;
}
ASSERT_FALSE(std::isnan(opt_data[idx]));
float diff = std::abs(ref_data[idx] - opt_data[idx]);
differences.push_back({idx, diff});
}
// auto ret = cosineSimilarity(ref_data, opt_data);
// GPU_DEBUG_COUT << "Cosine Similarity : " << ret << std::endl;
std::sort(differences.begin(), differences.end(), [](std::pair<size_t, ov::float16> a, std::pair<size_t, ov::float16> b){
return a.second > b.second;
});
GPU_DEBUG_COUT << "Compare data] ref_data : act_data" << std::endl;
for (size_t i = 0; i < 10 && i < differences.size(); i++) {
size_t idx = differences[i].first;
GPU_DEBUG_COUT << std::setw(8) << std::fixed << idx << "] " << std::setw(12) << ref_data[idx] << " : "
<< std::setw(12) << opt_data[idx]
<< " (Difference: " << differences[i].second << ")" << std::endl;
}
// ASSERT_GE(ret, 0.9f);
}
}

// float cosineSimilarity(cldnn::mem_lock<ov::float16, mem_lock_type::read>& vec1, cldnn::mem_lock<ov::float16, mem_lock_type::read>& memLockVec2) {
// if (vec1.size() != memLockVec2.size()) {
// std::cerr << "Vectors must be of the same size." << std::endl;
// return -1.0f;
// }

// float dotProduct = std::inner_product(vec1.begin(), vec1.end(), memLockVec2.begin(), 0.0f);

// float magnitude1 = std::sqrt(std::inner_product(vec1.begin(), vec1.end(), vec1.begin(), 0.0f));
// float magnitude2 = std::sqrt(std::inner_product(memLockVec2.begin(), memLockVec2.end(), memLockVec2.begin(), 0.0f));

// if (magnitude1 == 0.0f || magnitude2 == 0.0f) {
// std::cerr << "One of the vectors is zero vector." << std::endl;
// return -1.0f;
// }

// return dotProduct / (magnitude1 * magnitude2);
// }
};

TEST(sdpa_gpu_test, basic) {
spda_gpu_test test;
test.execute();
}
// #endif
} // namespace
Loading