Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@ class AttentionBase {
int& past_sequence_length) const;

protected:
AttentionBase(const OpKernelInfo& info, bool require_same_hidden_size) {
template <typename KernelInfoType>
Copy link
Contributor

@adrianlizarraga adrianlizarraga Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably have some comment or documentation somewhere that explains the need for this templated type. Someone reading this code may find this confusing because normally a OrtKernelInfo can be reinterpret_cast to a onnxruntime::OpKernelInfo. However, if I'm not mistaken, this templated type is meant to be either a OrtKernelInfo or a adapter::OpKernelInfo from #26919.

I'm not sure where this documentation should exist but I think it should probably exist somewhere.

Copy link
Contributor Author

@fs-eire fs-eire Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the KernelInfoType should be either a onnxruntime::OpKernelInfo or a onnxruntime::ep::adapter::OpKernelInfo. It's a little different from using OrtKernelInfo. onnxruntime::ep::adapter::OpKernelInfo is a wrapper class of OrtKernelInfo which implement identical facade of onnxruntime::OpKernelInfo

It can be documented in include/onnxruntime/ep/README.md.

AttentionBase(const KernelInfoType& info, bool require_same_hidden_size) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);

is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_embedding_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

if (!info.GetAttrs<int64_t>("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) {
is_unidirectional_ = info.template GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
do_rotary_ = info.template GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_embedding_ = static_cast<int>(info.template GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
mask_filter_value_ = info.template GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
scale_ = info.template GetAttrOrDefault<float>("scale", 0.0f);
if (!info.template GetAttrs<int64_t>("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) {
qkv_hidden_sizes_.clear();
}

past_present_share_buffer_ = info.GetAttrOrDefault<int64_t>("past_present_share_buffer", 0LL);
past_present_share_buffer_ = info.template GetAttrOrDefault<int64_t>("past_present_share_buffer", 0LL);

require_same_hidden_size_ = require_same_hidden_size;
}
Expand Down
17 changes: 9 additions & 8 deletions onnxruntime/core/providers/cpu/nn/conv_attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
struct ConvAttributes {
using ConvPadVector = InlinedVector<int64_t, kTensorShapeSmallBufferElementsSize * 2>;

explicit ConvAttributes(const OpKernelInfo& info) {
template <typename KernelInfoType>
explicit ConvAttributes(const KernelInfoType& info) {
std::string auto_pad_str;
auto status = info.GetAttr<std::string>("auto_pad", &auto_pad_str);
auto status = info.template GetAttr<std::string>("auto_pad", &auto_pad_str);
if (status.IsOK()) {
auto_pad = StringToAutoPadType(auto_pad_str);
}
Expand All @@ -32,8 +33,8 @@
strides.resize(kernel_shape_.size(), 1);
}

gsl::span<const int64_t> pads_span;
status = info.GetAttrsAsSpan("pads", pads_span);
std::vector<int64_t> pads_attr;

Check warning on line 36 in onnxruntime/core/providers/cpu/nn/conv_attributes.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cpu/nn/conv_attributes.h:36: Add #include <vector> for vector<> [build/include_what_you_use] [4]
status = info.GetAttrs("pads", pads_attr);
if (!status.IsOK()) {
if (kernel_shape_specified) {
// If pads are not explicitly provided, fill the container with all zeros
Expand All @@ -44,15 +45,15 @@
// Pads are explicitly provided, make sure that auto_pad is NOTSET
ORT_ENFORCE(auto_pad == AutoPadType::NOTSET,
"A Conv/ConvTranspose node has both 'auto_pad' and 'pads' attributes");
pads.assign(pads_span.begin(), pads_span.end());
pads.assign(pads_attr.begin(), pads_attr.end());
}

status = info.GetAttrs("dilations", dilations);
if (kernel_shape_specified && (!status.IsOK() || dilations.empty())) {
dilations.resize(kernel_shape_.size(), 1);
}

status = info.GetAttr<int64_t>("group", &group);
status = info.template GetAttr<int64_t>("group", &group);
if (!status.IsOK()) {
group = 1;
}
Expand All @@ -61,9 +62,9 @@
// TODO: Re-enable when attributes values are guaranteed to be filled.
// Make sure empty strides or dilations are defaulted to 1 if necessary
std::string auto_pad_str;
ORT_ENFORCE(info.GetAttr<std::string>("auto_pad", &auto_pad_str).IsOK());
ORT_ENFORCE(info.template GetAttr<std::string>("auto_pad", &auto_pad_str).IsOK());
auto_pad = StringToAutoPadType(auto_pad_str);
ORT_ENFORCE(info.GetAttr<int64_t>("group", &group).IsOK());
ORT_ENFORCE(info.template GetAttr<int64_t>("group", &group).IsOK());
ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape_).IsOK());
ORT_ENFORCE(info.GetAttrs("strides", strides).IsOK());
ORT_ENFORCE(info.GetAttrs("pads", pads).IsOK());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
namespace onnxruntime {

struct ConvTransposeAttributes : public ConvAttributes {
explicit ConvTransposeAttributes(const OpKernelInfo& info)
template <typename KernelInfoType>
explicit ConvTransposeAttributes(const KernelInfoType& info)
: ConvAttributes(info),
output_padding(info.GetAttrsOrDefault("output_padding")),
output_shape(info.GetAttrsOrDefault("output_shape")) {
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/core/providers/cpu/nn/pool_attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
// Shared providers don't know about OpNodeProtoHelper
PoolAttributes(const OpKernelInfo& info,
#else
PoolAttributes(const OpNodeProtoHelper<ProtoHelperNodeContext>& info,
template <typename KernelInfoType>
PoolAttributes(const KernelInfoType& info,
#endif
const std::string& op_name, int start_version)
: global_pooling(IsGlobalPooling(op_name)) {
Expand All @@ -37,7 +38,7 @@

std::string auto_padding;
if (op_name != "MaxUnpool") {
ORT_ENFORCE(info.GetAttr<std::string>("auto_pad", &auto_padding).IsOK());
ORT_ENFORCE(info.template GetAttr<std::string>("auto_pad", &auto_padding).IsOK());

Check warning on line 41 in onnxruntime/core/providers/cpu/nn/pool_attributes.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cpu/nn/pool_attributes.h:41: Add #include <string> for string [build/include_what_you_use] [4]
}
auto_pad = StringToAutoPadType(auto_padding);

Expand All @@ -49,7 +50,7 @@
strides.resize(kernel_shape.size(), 1);
}

if (!info.GetAttr<int64_t>("ceil_mode", &ceil_mode).IsOK()) {
if (!info.template GetAttr<int64_t>("ceil_mode", &ceil_mode).IsOK()) {
ceil_mode = 0;
}

Expand All @@ -63,7 +64,7 @@

if (op_name == "AveragePool") {
int64_t temp;
ORT_ENFORCE(info.GetAttr<int64_t>("count_include_pad", &temp).IsOK());
ORT_ENFORCE(info.template GetAttr<int64_t>("count_include_pad", &temp).IsOK());
count_include_pad = (temp != 0);
}

Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/core/providers/cpu/nn/pool_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,15 @@

class PoolBase {
private:
static int GetStartVersion(const OpKernelInfo& info) {
template <typename KernelInfoType>
static int GetStartVersion(const KernelInfoType& info) {
return info.node().SinceVersion();
}

protected:
PoolBase(const OpKernelInfo& info)
: op_name_(info.GetKernelDef().OpName().rfind("QLinear", 0) != 0 ? info.GetKernelDef().OpName() : info.GetKernelDef().OpName().substr(7)),
template <typename KernelInfoType>
PoolBase(const KernelInfoType& info)

Check warning on line 112 in onnxruntime/core/providers/cpu/nn/pool_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/cpu/nn/pool_base.h:112: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
: op_name_(info.node().OpType().rfind("QLinear", 0) != 0 ? info.node().OpType() : info.node().OpType().substr(7)),
pool_attrs_(info, op_name_, GetStartVersion(info)) {
}

Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
template <bool allow_multi_axes>
class ReduceKernelBase {
protected:
ReduceKernelBase(const OpKernelInfo& info, optional<int64_t> keepdims_override = {}) {
template <typename KernelInfoType>
ReduceKernelBase(const KernelInfoType& info, optional<int64_t> keepdims_override = {}) {

Check warning on line 15 in onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h:15: Constructors callable with one argument should be marked explicit. [runtime/explicit] [4]
if (allow_multi_axes) {
axes_ = ToShapeVector(info.GetAttrsOrDefault<int64_t>("axes"));
axes_ = ToShapeVector(info.template GetAttrsOrDefault<int64_t>("axes"));
} else {
auto v = info.GetAttrOrDefault<int64_t>("axis", 0);
auto v = info.template GetAttrOrDefault<int64_t>("axis", 0);
axes_.push_back(v);
}
int64_t keepdims = 1;
Expand All @@ -25,9 +26,9 @@
ORT_ENFORCE(info.GetAttr("keepdims", &keepdims).IsOK());
}
keepdims_ = (keepdims == 1);
int64_t noop_with_empty_axes = info.GetAttrOrDefault<int64_t>("noop_with_empty_axes", 0);
int64_t noop_with_empty_axes = info.template GetAttrOrDefault<int64_t>("noop_with_empty_axes", 0);
noop_with_empty_axes_ = (noop_with_empty_axes == 1);
int64_t select_last_index = info.GetAttrOrDefault<int64_t>("select_last_index", 0);
int64_t select_last_index = info.template GetAttrOrDefault<int64_t>("select_last_index", 0);
select_last_index_ = (select_last_index != 0);
}

Expand Down
175 changes: 1 addition & 174 deletions onnxruntime/core/providers/cpu/tensor/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,180 +54,7 @@ using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExec
Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
const InlinedTensorsVector& input_tensors,
Prepare& p) const {
size_t input_count = input_tensors.size();

// Must have atleast one input to concat
ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs");

TensorShapeVector reference_dims;
size_t reference_rank = 0;

int reference_tensor_index = 0;

InlinedVector<int64_t, Prepare::kExpectedNumberOfInputs> input_tensor_sizes;
input_tensor_sizes.reserve(input_count);

bool all_inputs_are_empty = true;

for (size_t index = 0; index < input_count; ++index) {
const auto* input = input_tensors[index];
ORT_ENFORCE(input != nullptr, "input count mismatch");

// find the first tensor that isn't empty
// to be used as a reference for all
// downstream shape/rank validations of other inputs
const auto& shape = input->Shape();
const auto num_elements = shape.Size();
if (num_elements > 0) {
reference_dims = shape.AsShapeVector();
reference_rank = reference_dims.size();
reference_tensor_index = onnxruntime::narrow<int>(index);
input_tensor_sizes.push_back(num_elements);
all_inputs_are_empty = false;
break;
} else {
input_tensor_sizes.push_back(0);
}
}

if (all_inputs_are_empty) {
// Reference dim and reference rank can just come from the first input
// No shape/rank validations will be done (as all inputs are empty).
// But the rest of the execution flow (filling in the Prepare instance - p)
// can use this info.
reference_dims = input_tensors[0]->Shape().AsShapeVector();
reference_rank = reference_dims.size();
}

// Cannot concatenate scalars (but they can be stacked)
if (!is_stack_)
ORT_RETURN_IF_NOT(reference_rank > 0, "Cannot concatenate scalars");

// Handle and fix negative axis
// In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank)
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, onnxruntime::narrow<int64_t>(!is_stack_
? reference_rank
: reference_rank + 1)));

// Ensure all of the non concatenated axes match each other
for (size_t index = static_cast<size_t>(reference_tensor_index) + 1; index < input_count; index++) {
const auto* input = input_tensors[index];
ORT_ENFORCE(input != nullptr, "input count mismatch");
const auto& input_shape = input->Shape();
const auto input_dims = input_shape.GetDims();

// Skip shape/rank validation for inputs that are empty.
// The ONNX spec states that all dim values along axes not concatentated on
// need to be the same for all inputs (empty inputs are not explicitly exempted).
// The model in GH issue 8020 has a bunch of Loop nodes all feeding into
// the 'Concat' node and one of these Loops tend to have an iteration
// count of 0 for some inputs. If the iteration count for a Loop is zero,
// we don't execute its subgraph (since the outputs are going to be empty anyway)
// and we send an "empty" tensor(s) downstream and use ONNX shape inferred shape
// to "compose" the shape for these empty tensor(s).
// If we encounter symbolic dims in the ONNX shape inferred shape, we place a '0'
// in that position and due to the "lossy" nature of this process, the inputs' shape
// validation for such empty inputs fail and hence we skip these validations for all
// empty inputs.
// This isn't too bad as we will never use empty inputs while concatenating anyway.
// We just loosen this check to unblock model in GH issue 8020 to complete processing.
if (input_shape.Size() == 0) {
input_tensor_sizes.push_back(0);
} else {
const size_t input_rank = input_dims.size();

ORT_ENFORCE(input_rank == reference_rank,
"Ranks of input data are different, cannot concatenate them. expected rank: ",
reference_rank, " got: ", input_rank);

// Ensure all the other (non-concat) axes match
int64_t tensor_size = 1;
for (size_t axis_index = 0; axis_index < reference_rank; ++axis_index) {
auto dim_value = input_dims[axis_index];
tensor_size *= dim_value;

// In 'concat' mode, the axis to be concatenated may be different
// But in 'stack' mode, all input shapes must be the same and must be validated
if (!is_stack_ && axis_index == p.axis)
continue;

ORT_RETURN_IF_NOT(dim_value == reference_dims[axis_index],
"Non concat axis dimensions must match: Axis ",
axis_index, " has mismatched dimensions of ", dim_value,
" and ", reference_dims[axis_index]);
}

input_tensor_sizes.push_back(tensor_size); // assign the computed size of the input tensor
}
}

// Calculate the shape of the output tensor
auto output_dims = reference_dims;

if (!is_stack_) { // 'Concat' mode
// While concatenating, the rank of the output is the same as the input rank(s)

// Calculate the size of the concatenated axis
size_t concat_axis_size = 0;
for (size_t index = 0; index < input_count; index++) {
concat_axis_size += onnxruntime::narrow<size_t>(input_tensors[index]->Shape()[onnxruntime::narrow<size_t>(p.axis)]);
}

output_dims[onnxruntime::narrow<size_t>(p.axis)] = onnxruntime::narrow<int64_t>(concat_axis_size);
} else { // 'Stack' mode
// While stacking, the rank of the output is one more than the input rank(s).
// Stacking may be thought of as adding an unit dimension (of value 1) in the input tensors,
// and concatenating them on thie new axis.
// The value in the corresponding axis of the output will be the number of inputs that are being stacked.
output_dims.insert(output_dims.begin() + p.axis, static_cast<int64_t>(input_count));
}

TensorShape output_shape(output_dims);

// Create output tensor
p.output_tensor = &(*ctx->Output(0, output_shape));

// Make note if output tensor is going to be empty
p.output_num_elements = output_shape.Size();

// No need to proceed further if output is going to be empty
if (p.output_num_elements == 0)
return Status::OK();

// The output_axis_pitch is the number of elements to add to move to the next split axis in the output.
// Can handle stacking as well.
p.output_axis_pitch = 1;
auto output_rank = !is_stack_ ? reference_rank : reference_rank + 1;
for (size_t i = output_rank; i-- > p.axis;) {
p.output_axis_pitch *= output_dims[i];
}

// Fill the 'Prepare' struct with available information
p.inputs.reserve(input_count);
for (size_t input_index = 0; input_index < input_count; input_index++) {
const Tensor* data_n_ptr = input_tensors[input_index];
auto& data_n = *data_n_ptr;

// Type sanity check (Make sure we are working on homogeneous types)
ORT_RETURN_IF_NOT(data_n.DataType() == p.output_tensor->DataType(), "Data type mismatch");

// The input_axis_pitch is the number of elements to add to move to the next split axis in the input
// Can handle stacking as well (as the "new dummy dimension" in the input is of unit value).
// TODO: Minor Optimization possibility: This input_axis_patch will be common across all inputs
// in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating.
int64_t input_axis_pitch = 1;
const auto& data_dims = data_n.Shape().GetDims();
for (size_t i = reference_rank; i-- > p.axis;) {
input_axis_pitch *= data_dims[i];
}

p.inputs.push_back({&data_n, input_axis_pitch, input_tensor_sizes[input_index]});
}

// Make note if the input Tensors of type 'string'
p.is_string_type = p.inputs[0].tensor->IsDataTypeString();

return Status::OK();
return PrepareForComputeImpl(ctx, input_tensors, p);
}

namespace {
Expand Down
Loading
Loading