Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions onnxruntime/core/platform/posix/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ limitations under the License.
#include <utility> // for std::forward
#include <vector>

// We can not use CPUINFO if it is not supported and we do not want to used
// We can not use CPUINFO if it is not supported and we do not want to use
// it on certain platforms because of the binary size increase.
// We could use it to find out the number of physical cores for certain supported platforms
#if defined(CPUINFO_SUPPORTED) && !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__) && !defined(_AIX)
#if defined(CPUINFO_SUPPORTED) && defined(__linux__)
#include <cpuinfo.h>
#define ORT_USE_CPUINFO
#endif
Comment on lines +41 to 46

#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)
Expand Down Expand Up @@ -266,17 +265,17 @@ class PosixEnv : public Env {

// Return the number of physical cores
int GetNumPhysicalCpuCores() const override {
#ifdef ORT_USE_CPUINFO
#if defined(CPUINFO_SUPPORTED) && defined(__linux__)
if (cpuinfo_available_) {
return narrow<int>(cpuinfo_get_cores_count());
}
#endif // ORT_USE_CPUINFO
#endif // defined(CPUINFO_SUPPORTED) && defined(__linux__)
return DefaultNumCores();
}

std::vector<LogicalProcessors> GetDefaultThreadAffinities() const override {
std::vector<LogicalProcessors> ret;
#ifdef ORT_USE_CPUINFO
#if defined(CPUINFO_SUPPORTED) && defined(__linux__)
if (cpuinfo_available_) {
auto num_phys_cores = cpuinfo_get_cores_count();
ret.reserve(num_phys_cores);
Expand All @@ -292,7 +291,7 @@ class PosixEnv : public Env {
ret.push_back(std::move(th_aff));
}
}
#endif
#endif // defined(CPUINFO_SUPPORTED) && defined(__linux__)
// Just the size of the thread-pool
if (ret.empty()) {
ret.resize(GetNumPhysicalCpuCores());
Expand Down Expand Up @@ -614,15 +613,15 @@ class PosixEnv : public Env {

private:
Telemetry telemetry_provider_;
#ifdef ORT_USE_CPUINFO
#if defined(CPUINFO_SUPPORTED) && defined(__linux__)
PosixEnv() {
cpuinfo_available_ = cpuinfo_initialize();
if (!cpuinfo_available_) {
LOGS_DEFAULT(INFO) << "cpuinfo_initialize failed";
}
}
bool cpuinfo_available_{false};
#endif // ORT_USE_CPUINFO
#endif // defined(CPUINFO_SUPPORTED) && defined(__linux__)
};

} // namespace
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, ST
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, double, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, SequenceMap);
// Opset 18
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18, float, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18, int32_t, Resize);
Expand Down Expand Up @@ -3068,6 +3069,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, MLFloat16,
LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, SequenceMap)>,

// Opset 18
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18,
Expand Down
172 changes: 172 additions & 0 deletions onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "core/providers/op_kernel_type_control.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/session_state.h"
#include "core/framework/utils.h"

using namespace onnxruntime::common;

Expand Down Expand Up @@ -579,4 +582,173 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu

return Status::OK();
}

// SequenceMap (opset 17)
// Native CPU kernel: avoids the O(n^2) fallback that the context-dependent
// function body would otherwise generate via Loop + SequenceInsert.
ONNX_CPU_OPERATOR_KERNEL(
SequenceMap,
17,
KernelDefBuilder()
.TypeConstraint("S", DataTypeImpl::AllSequenceTensorTypes())
.TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()),
SequenceMap);

common::Status SequenceMap::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
ORT_ENFORCE(feeds_fetches_manager_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
ORT_UNUSED_PARAMETER(attribute_name);

const auto& node = Node();
const auto& subgraph = subgraph_session_state.GetGraphViewer();

// The subgraph inputs match the outer node's explicit inputs (one element
// per sequence input, pass-through for tensor inputs).
const auto& subgraph_inputs = subgraph.GetInputs();
const auto& subgraph_outputs = subgraph.GetOutputs();

const auto& implicit_defs = node.ImplicitInputDefs();

std::vector<std::string> feed_names;
feed_names.reserve(subgraph_inputs.size() + implicit_defs.size());
for (const auto* input : subgraph_inputs) {
feed_names.push_back(input->Name());
}
for (const auto* implicit_def : implicit_defs) {
feed_names.push_back(implicit_def->Name());
}

std::vector<std::string> fetch_names;
fetch_names.reserve(subgraph_outputs.size());
for (const auto* output : subgraph_outputs) {
fetch_names.push_back(output->Name());
}

const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
std::unique_ptr<FeedsFetchesManager> ffm;
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, fetch_names, subgraph_map, ffm));
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
Comment on lines +629 to +632

// Feeds come from the outer node's explicit inputs and implicit inputs (outer-scope captures).
const auto& node_inputs = node.InputDefs();
std::vector<std::string> outer_feed_names;
outer_feed_names.reserve(node_inputs.size() + implicit_defs.size());
for (const auto* input_def : node_inputs) {
outer_feed_names.push_back(input_def->Name());
}
Comment on lines +636 to +640
for (const auto* implicit_def : implicit_defs) {
outer_feed_names.push_back(implicit_def->Name());
}

std::vector<OrtDevice> feed_locations;
ORT_RETURN_IF_ERROR(
controlflow::detail::FindDevicesForValues(session_state, outer_feed_names, feed_locations));

// Outputs go to the SequenceMap node's output sequence slots.
std::vector<const OrtDevice*> fetch_locations;
const auto& node_outputs = node.OutputDefs();
fetch_locations.reserve(node_outputs.size());
for (const auto* output_def : node_outputs) {
fetch_locations.push_back(&utils::FindDeviceForValue(session_state, output_def->Name()));
}

utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
feeds_fetches_manager_ = std::move(ffm);

return Status::OK();
}

Status SequenceMap::Compute(OpKernelContext* ctx) const {
ORT_ENFORCE(feeds_fetches_manager_,
"SetupSubgraphExecutionInfo must be called prior to SequenceMap Compute.");

auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
const auto* session_state = ctx_internal->SubgraphSessionState("body");
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'body' attribute.");

const auto* input_seq = ctx->Input<TensorSeq>(0);
ORT_ENFORCE(input_seq != nullptr, "SequenceMap: first input (input_sequence) must be a sequence.");
const auto seq_len = input_seq->Size();

const int num_outer_inputs = ctx->InputCount();
const int num_outputs = ctx->OutputCount();

// Initialise each output TensorSeq with its element type before the iteration loop so that
// an empty input sequence still produces correctly-typed outputs (change D).
std::vector<TensorSeq*> output_seqs(num_outputs, nullptr);
for (int j = 0; j < num_outputs; ++j) {
output_seqs[j] = ctx->Output<TensorSeq>(j);
ORT_ENFORCE(output_seqs[j] != nullptr, "SequenceMap: failed to get output TensorSeq slot ", j);
const auto* out_type = ctx->OutputType(j);
ORT_ENFORCE(out_type != nullptr, "SequenceMap: could not determine type for output ", j);
output_seqs[j]->SetType(out_type->AsSequenceTensorType()->GetElementType());
output_seqs[j]->Reserve(seq_len);
}
Comment on lines +681 to +688

// Validate that all additional sequence inputs have the same length as input_sequence.
for (int k = 1; k < num_outer_inputs; ++k) {
const auto* seq = ctx->Input<TensorSeq>(k);
if (seq != nullptr && seq->Size() != seq_len) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"SequenceMap: additional sequence input ", k,
" has length ", seq->Size(),
" but input_sequence has length ", seq_len);
}
}

// Hoist feeds/fetches outside the iteration loop to avoid repeated allocation (change C).
const auto& implicit_inputs = ctx_internal->GetImplicitInputs();
const auto num_implicit = implicit_inputs.size();

std::vector<OrtValue> feeds;
std::vector<OrtValue> fetches;
feeds.reserve(static_cast<size_t>(num_outer_inputs) + num_implicit);

for (size_t i = 0; i < seq_len; ++i) {
feeds.clear();
fetches.clear();

// Build feeds: sequence inputs -> element i; tensor inputs -> pass-through OrtValue.
for (int k = 0; k < num_outer_inputs; ++k) {
const auto* seq_k = (k == 0) ? input_seq : ctx->Input<TensorSeq>(k);
if (seq_k != nullptr) {
feeds.push_back(seq_k->GetAt(i));
} else {
// Tensor input: shallow-copy the OrtValue (shared_ptr, safe) from the kernel context.
const auto* input_val = ctx_internal->GetInputMLValue(k);
ORT_ENFORCE(input_val != nullptr, "SequenceMap: input ", k, " is neither a sequence nor a tensor.");
feeds.push_back(*input_val);
}
}

// Append implicit inputs (outer-scope captures) in ImplicitInputDefs() order (change B).
for (size_t m = 0; m < num_implicit; ++m) {
feeds.push_back(*implicit_inputs[m]);
}

ORT_RETURN_IF_ERROR(utils::ExecuteSubgraph(*session_state, *feeds_fetches_manager_,
feeds, fetches, {},
ExecutionMode::ORT_SEQUENTIAL,
ctx->GetTerminateFlag(),
ctx->Logger(),
ctx->GetComputeStream(),
/*sync_subgraph_fetches=*/false,
ctx_internal->GetRunProfiler()));

ORT_ENFORCE(static_cast<int>(fetches.size()) == num_outputs,
"SequenceMap: body returned ", fetches.size(), " outputs but ", num_outputs, " were expected.");

for (int j = 0; j < num_outputs; ++j) {
ORT_ENFORCE(fetches[j].IsTensor(),
"SequenceMap: body output ", j, " must be a tensor.");
// SetType was called before the loop (change D); per-iteration call removed (change E).
output_seqs[j]->Add(std::move(fetches[j]));
}
}

return Status::OK();
}

} // namespace onnxruntime
16 changes: 16 additions & 0 deletions onnxruntime/core/providers/cpu/sequence/sequence_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
#pragma once

#include "core/common/common.h"
#include "core/framework/feeds_fetches_manager.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/controlflow/utils.h"

namespace onnxruntime {

Expand Down Expand Up @@ -70,4 +72,18 @@ class SplitToSequence final : public OpKernel {
int64_t keepdims_{1};
const int64_t DEFAULT_LENGTH_EACH_OUTPUT_ = 1;
};

class SequenceMap final : public controlflow::IControlFlowKernel {
public:
SequenceMap(const OpKernelInfo& info) : IControlFlowKernel(info) {}

Status Compute(OpKernelContext* ctx) const override;

Status SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) override;

private:
std::unique_ptr<FeedsFetchesManager> feeds_fetches_manager_;
};
} // namespace onnxruntime
Loading