Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
return context.RunProgram(program);
};

void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) {
void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) {
if (seqlen_k != nullptr) {
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n";
ss << "var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);\n";
} else {
ss << "let past_sequence_length = uniforms.past_sequence_length;\n";
}
Expand Down Expand Up @@ -106,7 +106,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_, is_first_prompt_);
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
if (has_present_key_) {
Expand All @@ -121,7 +121,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
" var idx = TILE_SIZE * local_id.y + local_id.x;\n";

if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
<< " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
Expand Down Expand Up @@ -213,7 +213,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)}})
{static_cast<uint32_t>(parameters.n_reps)},
{static_cast<uint32_t>(parameters.is_first_prompt_ ? 1 : 0)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});

return context.RunProgram(program);
Expand All @@ -231,7 +232,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.sequence_length;\n"
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_, is_first_prompt_);
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
Expand Down Expand Up @@ -285,20 +286,21 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
}
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;

InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k};
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k};
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.CacheHint(work_group_size, is_first_prompt)
.CacheHint(work_group_size)
.SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
{static_cast<uint32_t>(num_heads)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(sequence_length)},
{static_cast<uint32_t>(total_sequence_length_comp)},
{static_cast<uint32_t>(elementsPerThread)}});
{static_cast<uint32_t>(elementsPerThread)},
{static_cast<uint32_t>(is_first_prompt ? 1 : 0)}});

return context.RunProgram(program);
}
Expand Down Expand Up @@ -327,7 +329,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.K;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_, is_first_prompt_);
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
if (has_present_value_) {
Expand All @@ -342,12 +344,12 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n"
<< " var idx = TILE_SIZE * local_id.y + local_id.x;\n";

if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) {
if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) {
shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n"
<< " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n"
<< " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n"
<< " }\n";
} else {
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length) {\n"
Expand Down Expand Up @@ -425,7 +427,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)}})
{static_cast<uint32_t>(parameters.n_reps)},
{static_cast<uint32_t>(parameters.is_first_prompt_)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});

return context.RunProgram(program);
Expand Down
14 changes: 8 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32});
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

Expand All @@ -67,8 +68,8 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
public:
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr)
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) {
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr)
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -78,13 +79,13 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
{"elements_per_thread", ProgramUniformVariableDataType::Uint32});
{"elements_per_thread", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});

private:
int work_group_size_;
int components_;
const Tensor* seqlen_k_;
bool is_first_prompt_;
};

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
Expand All @@ -104,7 +105,8 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32});
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

Expand Down
Loading