Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 58 additions & 11 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,23 +322,24 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte
updateOutputShape(ctx, 2, present_shape);
}
} else if (use_max_past_present_buffer == -1) {
// shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size)
ONNX_NAMESPACE::TensorShapeProto present_shape;
*present_shape.add_dim() = past_dims[0]; // batch_size
*present_shape.add_dim() = past_dims[1]; // kv_num_heads
if (total_sequence_length_value > 0 && past_dims[2].has_dim_value()) {
// present_sequence_length = max(past_sequence_length, total_sequence_length)
const int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value()
? total_sequence_length_value
: past_dims[2].dim_value();

ONNX_NAMESPACE::TensorShapeProto present_shape;
for (auto& dim : past_dims) {
*present_shape.add_dim() = dim;
}

// shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size)
present_shape.mutable_dim(2)->set_dim_value(present_sequence_length);

updateOutputShape(ctx, 1, present_shape);
updateOutputShape(ctx, 2, present_shape);
present_shape.add_dim()->set_dim_value(present_sequence_length);
} else {
// Cannot compute exact present_sequence_length, copy from past_key (may be dynamic)
*present_shape.add_dim() = past_dims[2];
}
*present_shape.add_dim() = past_dims[3]; // head_size

updateOutputShape(ctx, 1, present_shape);
updateOutputShape(ctx, 2, present_shape);
}

if (output_qk_index >= 0) {
Expand Down Expand Up @@ -370,6 +371,52 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte
}
}
}
} else if (hasInputShape(ctx, 0)) {
// Handle the case when past_key/past_value is not provided (first token/prefill mode).
// We still need to infer present_key/present_value output shapes from query and attributes.
auto& query_shape = getInputShape(ctx, 0);
auto& query_dims = query_shape.dim();

int64_t num_heads = getAttribute(ctx, "num_heads", 0);
int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0);

if (num_heads > 0 && kv_num_heads > 0 && query_dims.size() == 3 && query_dims[2].has_dim_value()) {
int64_t hidden_size = query_dims[2].dim_value();
int64_t head_size = 0;

if (hasInputShape(ctx, 2)) {
// query shape is (batch_size, sequence_length, num_heads * head_size)
head_size = hidden_size / num_heads;
} else {
// Packed QKV: query shape is (batch_size, sequence_length, (num_heads + 2 * kv_num_heads) * head_size)
head_size = hidden_size / (num_heads + 2 * kv_num_heads);
}

if (head_size > 0) {
// Determine present_sequence_length from total_sequence_length or kv_sequence_length
int64_t present_sequence_length = 0;
if (total_sequence_length_value > 0) {
present_sequence_length = total_sequence_length_value;
} else if (kv_sequence_length > 0) {
present_sequence_length = kv_sequence_length;
}

// present key/value shape is (batch_size, kv_num_heads, present_sequence_length, head_size)
ONNX_NAMESPACE::TensorShapeProto present_shape;
*present_shape.add_dim() = query_dims[0]; // batch_size
present_shape.add_dim()->set_dim_value(kv_num_heads);
if (present_sequence_length > 0) {
present_shape.add_dim()->set_dim_value(present_sequence_length);
} else {
// Fallback: use query sequence_length (dim 1) as present_sequence_length for prefill
*present_shape.add_dim() = query_dims[1];
}
present_shape.add_dim()->set_dim_value(head_size);

updateOutputShape(ctx, 1, present_shape);
updateOutputShape(ctx, 2, present_shape);
}
}
}
}
}
Expand Down
Loading