Skip to content

Commit 78f6595

Browse files
committed
first round
1 parent 16d2ed6 commit 78f6595

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

include/onnxruntime/ep/adapter/op_kernel.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ struct OpKernel {
6464
/// An adapter class partially implementing the facade of `onnxruntime::OpKernelContext`.
6565
/// </summary>
6666
struct OpKernelContext {
67-
explicit OpKernelContext(OrtKernelContext* context, const OpKernel& op_kernel) : context_{context}, op_kernel_{op_kernel} {
67+
explicit OpKernelContext(OrtKernelContext* context, const OpKernel& op_kernel) : context_{context},
68+
op_kernel_{op_kernel},
69+
constant_input_tensors_{op_kernel.Info().GetConstantInputTensors()} {
6870
input_tensors_.resize(InputCount());
6971
output_tensors_.resize(OutputCount());
7072
}
@@ -79,6 +81,10 @@ struct OpKernelContext {
7981
return static_cast<const T*>(input_tensors_[index].get());
8082
}
8183

84+
if (index < constant_input_tensors_.size() && constant_input_tensors_[index] != nullptr) {
85+
return static_cast<const T*>(constant_input_tensors_[index].get());
86+
}
87+
8288
auto input = context_.GetInput(index);
8389
if (input == nullptr || !input.IsTensor()) {
8490
return nullptr;
@@ -126,6 +132,7 @@ struct OpKernelContext {
126132
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernelContext);
127133
Ort::KernelContext context_;
128134
const OpKernel& op_kernel_;
135+
const std::vector<std::unique_ptr<Tensor>>& constant_input_tensors_;
129136
mutable std::vector<std::unique_ptr<Tensor>> input_tensors_;
130137
std::vector<std::unique_ptr<Tensor>> output_tensors_;
131138
};

include/onnxruntime/ep/adapter/op_kernel_info.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ struct OpKernelInfo {
8181
return info_.GetInputCount();
8282
}
8383

84+
const std::vector<std::unique_ptr<Tensor>>& GetConstantInputTensors() const noexcept {
85+
return cache_->constant_input_tensors;
86+
}
87+
8488
bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const {
8589
if (input_index < 0 || static_cast<size_t>(input_index) >= cache_->constant_input_tensors.size()) {
8690
return false;

include/onnxruntime/ep/adapter/tensor_helper.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ inline std::unique_ptr<onnxruntime::Tensor> CreateTensorFromApiValue(const OrtVa
2323
Ort::ConstValue value{ort_value};
2424
EP_ENFORCE(value.IsTensor(), "Only tensor OrtValue is supported.");
2525

26-
auto type_and_shape_info = value.GetTypeInfo().GetTensorTypeAndShapeInfo();
27-
auto type = type_and_shape_info.GetElementType();
28-
auto shape_vec = type_and_shape_info.GetShape();
26+
ONNXTensorElementDataType element_type;
27+
Ort::Value::Shape shape{};
28+
value.GetTensorElementTypeAndShapeDataReference(element_type, shape);
2929

3030
auto memory_info = value.GetTensorMemoryInfo();
31-
MLDataType data_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType();
31+
MLDataType data_type = DataTypeImpl::TensorTypeFromONNXEnum(element_type)->GetElementType();
3232

3333
OrtMemoryInfo tensor_memory_info{memory_info.GetAllocatorName(),
3434
memory_info.GetAllocatorType(),
@@ -42,7 +42,7 @@ inline std::unique_ptr<onnxruntime::Tensor> CreateTensorFromApiValue(const OrtVa
4242
memory_info.GetMemoryType()};
4343

4444
return std::make_unique<Tensor>(data_type,
45-
TensorShape{shape_vec},
45+
TensorShape{shape.shape, shape.shape_len},
4646
const_cast<void*>(value.GetTensorRawData()),
4747
tensor_memory_info);
4848
}

0 commit comments

Comments
 (0)