Skip to content

Commit 9e3419b

Browse files
committed
2nd round
1 parent 78f6595 commit 9e3419b

File tree

3 files changed

+23
-23
lines changed

3 files changed

+23
-23
lines changed

include/onnxruntime/ep/adapter/op_kernel.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ struct OpKernelContext {
7777
if (index < 0 || static_cast<size_t>(index) >= input_tensors_.size()) {
7878
return nullptr;
7979
}
80-
if (input_tensors_[index] != nullptr) {
81-
return static_cast<const T*>(input_tensors_[index].get());
80+
if (input_tensors_[index].DataRaw() != nullptr) {
81+
return &input_tensors_[index];
8282
}
8383

84-
if (index < constant_input_tensors_.size() && constant_input_tensors_[index] != nullptr) {
85-
return static_cast<const T*>(constant_input_tensors_[index].get());
84+
if (index < constant_input_tensors_.size() && constant_input_tensors_[index].DataRaw() != nullptr) {
85+
return &constant_input_tensors_[index];
8686
}
8787

8888
auto input = context_.GetInput(index);
@@ -91,14 +91,14 @@ struct OpKernelContext {
9191
}
9292

9393
input_tensors_[index] = CreateTensorFromApiValue(input);
94-
return static_cast<const T*>(input_tensors_[index].get());
94+
return &input_tensors_[index];
9595
}
9696
Tensor* Output(int index, const TensorShape& shape) {
9797
if (index < 0 || static_cast<size_t>(index) >= output_tensors_.size()) {
9898
return nullptr;
9999
}
100-
if (output_tensors_[index] != nullptr) {
101-
return output_tensors_[index].get();
100+
if (output_tensors_[index].DataRaw() != nullptr) {
101+
return &output_tensors_[index];
102102
}
103103

104104
auto output = context_.GetOutput(index, shape.GetDims().data(), shape.GetDims().size());
@@ -107,7 +107,7 @@ struct OpKernelContext {
107107
}
108108

109109
output_tensors_[index] = CreateTensorFromApiValue(output);
110-
return output_tensors_[index].get();
110+
return &output_tensors_[index];
111111
}
112112
Tensor* Output(int index, const std::vector<int64_t>& shape) {
113113
return Output(index, TensorShape{shape});
@@ -132,9 +132,9 @@ struct OpKernelContext {
132132
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernelContext);
133133
Ort::KernelContext context_;
134134
const OpKernel& op_kernel_;
135-
const std::vector<std::unique_ptr<Tensor>>& constant_input_tensors_;
136-
mutable std::vector<std::unique_ptr<Tensor>> input_tensors_;
137-
std::vector<std::unique_ptr<Tensor>> output_tensors_;
135+
const std::vector<Tensor>& constant_input_tensors_;
136+
mutable std::vector<Tensor> input_tensors_;
137+
std::vector<Tensor> output_tensors_;
138138
};
139139

140140
/// <summary>
@@ -181,10 +181,10 @@ struct KernelImpl : OrtKernelImpl {
181181
_In_opt_ OrtSharedPrePackedWeightCache* /* prepacked_weight_cache */,
182182
_Out_ bool* is_packed) noexcept {
183183
auto* kernel_impl = static_cast<KernelImpl*>(this_ptr)->impl_.get();
184-
const auto tensor = CreateTensorFromApiValue(Ort::ConstValue{weight});
184+
const auto tensor = CreateTensorFromApiValue(weight);
185185
Status status;
186186
ORT_TRY {
187-
status = kernel_impl->PrePack(*tensor.get(), input_index, AllocatorPtr{}, *is_packed, nullptr);
187+
status = kernel_impl->PrePack(tensor, input_index, AllocatorPtr{}, *is_packed, nullptr);
188188
}
189189
ORT_CATCH(const std::exception& ex) {
190190
ORT_HANDLE_EXCEPTION([&]() {

include/onnxruntime/ep/adapter/op_kernel_info.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct OpKernelInfo {
5252
}
5353
}
5454
const OrtKernelInfo* kernel_info_;
55-
std::vector<std::unique_ptr<Tensor>> constant_input_tensors;
55+
std::vector<Tensor> constant_input_tensors;
5656
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(KernelInfoCache);
5757
};
5858

@@ -81,17 +81,17 @@ struct OpKernelInfo {
8181
return info_.GetInputCount();
8282
}
8383

84-
const std::vector<std::unique_ptr<Tensor>>& GetConstantInputTensors() const noexcept {
84+
const std::vector<Tensor>& GetConstantInputTensors() const noexcept {
8585
return cache_->constant_input_tensors;
8686
}
8787

8888
bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const {
8989
if (input_index < 0 || static_cast<size_t>(input_index) >= cache_->constant_input_tensors.size()) {
9090
return false;
9191
}
92-
const Tensor* tensor = cache_->constant_input_tensors[input_index].get();
93-
if (tensor != nullptr) {
94-
*constant_input_value = tensor;
92+
const Tensor& tensor = cache_->constant_input_tensors[input_index];
93+
if (tensor.DataRaw() != nullptr) {
94+
*constant_input_value = &tensor;
9595
return true;
9696
}
9797
return false;

include/onnxruntime/ep/adapter/tensor_helper.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace adapter {
1919
/// <summary>
2020
/// Create an unowned onnxruntime::Tensor from a tensor OrtValue from C API.
2121
/// </summary>
22-
inline std::unique_ptr<onnxruntime::Tensor> CreateTensorFromApiValue(const OrtValue* ort_value) {
22+
inline onnxruntime::Tensor CreateTensorFromApiValue(const OrtValue* ort_value) {
2323
Ort::ConstValue value{ort_value};
2424
EP_ENFORCE(value.IsTensor(), "Only tensor OrtValue is supported.");
2525

@@ -41,10 +41,10 @@ inline std::unique_ptr<onnxruntime::Tensor> CreateTensorFromApiValue(const OrtVa
4141
},
4242
memory_info.GetMemoryType()};
4343

44-
return std::make_unique<Tensor>(data_type,
45-
TensorShape{shape.shape, shape.shape_len},
46-
const_cast<void*>(value.GetTensorRawData()),
47-
tensor_memory_info);
44+
return Tensor(data_type,
45+
TensorShape{shape.shape, shape.shape_len},
46+
const_cast<void*>(value.GetTensorRawData()),
47+
tensor_memory_info);
4848
}
4949

5050
} // namespace adapter

0 commit comments

Comments
 (0)