Skip to content

Commit 2f4be9a

Browse files
authored
Add check that tensor sizes match in DataTransferManager::CopyTensors (#27008)
### Description <!-- Describe your changes. --> Add check that tensor sizes match in DataTransferManager::CopyTensors before calling the IDataTransfer implementation so that the check is done in one place. We check the sizes match in DataTransferManager::CopyTensor[Async] so this makes things consistent when a batched copy is done. It is not required for DataTransferManager::CopySparseTensors. The default implementation of IDataTransfer::CopySparseTensors is not overridden by any EP so all sparse tensor copies (single or batched) end up going via SparseTensor::Copy which has size checks. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> TRT RTX had a bug and was returning an output value that was an incorrect size. When pre-allocated outputs on a different device were provided we hit DataTransferManager::CopyTensors which had no check the sizes matched, leading to a heap checker violation.
1 parent a3e477e commit 2f4be9a

File tree

2 files changed

+66
-27
lines changed

2 files changed

+66
-27
lines changed

onnxruntime/core/framework/data_transfer_manager.cc

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,9 @@ Status DataTransferManager::CopyTensor(const Tensor& src, Tensor& dst) const {
5454
return data_transfer->CopyTensor(src, dst);
5555
}
5656

57-
return ORT_MAKE_STATUS(ONNXRUNTIME,
58-
FAIL,
57+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
5958
"There's no data transfer registered for copying tensors from ",
60-
src.Location().device.ToString(),
61-
" to ",
62-
dst.Location().device.ToString());
59+
src.Location().device.ToString(), " to ", dst.Location().device.ToString());
6360
}
6461

6562
Status DataTransferManager::CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const {
@@ -75,12 +72,9 @@ Status DataTransferManager::CopyTensorAsync(const Tensor& src, Tensor& dst, Stre
7572
return data_transfer->CopyTensorAsync(src, dst, stream);
7673
}
7774

78-
return ORT_MAKE_STATUS(ONNXRUNTIME,
79-
FAIL,
75+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
8076
"There's no data transfer registered for copying tensors from ",
81-
src.Location().device.ToString(),
82-
" to ",
83-
dst.Location().device.ToString());
77+
src.Location().device.ToString(), " to ", dst.Location().device.ToString());
8478
}
8579

8680
#if !defined(DISABLE_SPARSE_TENSORS)
@@ -97,12 +91,9 @@ Status DataTransferManager::CopySparseTensor(const SparseTensor& src, SparseTens
9791
return src.Copy(*data_transfer, dst);
9892
}
9993

100-
return ORT_MAKE_STATUS(ONNXRUNTIME,
101-
FAIL,
94+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
10295
"There's no data transfer registered for copying tensors from ",
103-
src.Location().device.ToString(),
104-
" to ",
105-
dst.Location().device.ToString());
96+
src.Location().device.ToString(), " to ", dst.Location().device.ToString());
10697
}
10798
#endif
10899

@@ -130,12 +121,17 @@ common::Status DataTransferManager::CopyTensors(const std::vector<IDataTransfer:
130121
}
131122

132123
if (first_dt == nullptr) {
133-
return ORT_MAKE_STATUS(ONNXRUNTIME,
134-
FAIL,
124+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
135125
"There's no data transfer registered for copying tensors from ",
136-
src_device.ToString(),
137-
" to ",
138-
dst_device.ToString());
126+
src_device.ToString(), " to ", dst_device.ToString());
127+
}
128+
129+
for (const auto& pair : src_dst_pairs) {
130+
const auto& src_shape = pair.src.get().Shape();
131+
const auto& dst_shape = pair.dst.get().Shape();
132+
if (src_shape.Size() != dst_shape.Size()) {
133+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor size mismatch. src:", src_shape, " dst:", dst_shape);
134+
}
139135
}
140136

141137
// all copies are between the same devices so we can do them all at once
@@ -148,11 +144,13 @@ common::Status DataTransferManager::CopyTensors(const std::vector<IDataTransfer:
148144
// batch as much as possible.
149145

150146
// copy the first one as we already did the IDataTransfer lookup
151-
ORT_RETURN_IF_ERROR(first_pair.src_stream ? first_dt->CopyTensorAsync(first_pair.src.get(), first_pair.dst.get(), *(first_pair.src_stream))
147+
ORT_RETURN_IF_ERROR(first_pair.src_stream ? first_dt->CopyTensorAsync(first_pair.src.get(), first_pair.dst.get(),
148+
*(first_pair.src_stream))
152149
: first_dt->CopyTensor(first_pair.src.get(), first_pair.dst.get()));
153150

154151
for (auto cur_pair = src_dst_pairs.cbegin() + 1, end_pair = src_dst_pairs.cend(); cur_pair != end_pair; ++cur_pair) {
155-
ORT_RETURN_IF_ERROR(!cur_pair->src_stream ? CopyTensor(cur_pair->src, cur_pair->dst) : CopyTensorAsync(cur_pair->src, cur_pair->dst, *(cur_pair->src_stream)));
152+
ORT_RETURN_IF_ERROR(!cur_pair->src_stream ? CopyTensor(cur_pair->src, cur_pair->dst)
153+
: CopyTensorAsync(cur_pair->src, cur_pair->dst, *(cur_pair->src_stream)));
156154
}
157155

158156
return Status::OK();
@@ -183,12 +181,9 @@ common::Status DataTransferManager::CopySparseTensors(const std::vector<IDataTra
183181
}
184182

185183
if (first_dt == nullptr) {
186-
return ORT_MAKE_STATUS(ONNXRUNTIME,
187-
FAIL,
184+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
188185
"There's no data transfer registered for copying tensors from ",
189-
src_device.ToString(),
190-
" to ",
191-
dst_device.ToString());
186+
src_device.ToString(), " to ", dst_device.ToString());
192187
}
193188

194189
// all copies are between the same devices so we can do them all at once
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "gtest/gtest.h"
5+
#include "gmock/gmock.h"
6+
7+
#include "core/common/inlined_containers.h"
8+
#include "core/framework/data_transfer_manager.h"
9+
#include "core/framework/ort_value.h"
10+
#include "test/unittest_util/framework_test_utils.h"
11+
#include "test/util/include/asserts.h"
12+
13+
namespace onnxruntime {
14+
namespace test {
15+
16+
// DataTransferManager::CopyTensors should validate sizes match before calling the IDataTransfer implementation
17+
TEST(DataTransferManagerTest, BatchedTensorCopyBadSize) {
18+
auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0];
19+
std::vector<OrtValue> src_tensors{2};
20+
InlinedVector<int64_t> shape_a{4}, shape_b{5}, shape_c{6};
21+
std::vector<OrtValue> dst_tensors{2};
22+
23+
// first pair is matched
24+
AllocateMLValue<float>(allocator, shape_a, &src_tensors[0]);
25+
AllocateMLValue<float>(allocator, shape_a, &dst_tensors[0]);
26+
27+
// second pair has size mismatch
28+
AllocateMLValue<float>(allocator, shape_c, &src_tensors[1]);
29+
AllocateMLValue<float>(allocator, shape_b, &dst_tensors[1]);
30+
31+
DataTransferManager dtm;
32+
ASSERT_STATUS_OK(dtm.RegisterDataTransfer(std::make_unique<CPUDataTransfer>()));
33+
34+
std::vector<IDataTransfer::SrcDstPair> src_dst_pairs;
35+
src_dst_pairs.push_back({src_tensors[0].Get<Tensor>(), *dst_tensors[0].GetMutable<Tensor>(), nullptr});
36+
src_dst_pairs.push_back({src_tensors[1].Get<Tensor>(), *dst_tensors[1].GetMutable<Tensor>(), nullptr});
37+
auto status = dtm.CopyTensors(src_dst_pairs);
38+
39+
ASSERT_STATUS_NOT_OK(status);
40+
ASSERT_THAT(status.ErrorMessage(), testing::HasSubstr("Tensor size mismatch"));
41+
}
42+
43+
} // namespace test
44+
} // namespace onnxruntime

0 commit comments

Comments
 (0)