Skip to content

Commit 1c9d10e

Browse files
When xla::Array::Reshape is called with a dimension span that aliases the array's internal buffer, reallocating the buffer would invalidate the span before the copy was performed. This CL addresses the issue by using a temporary buffer for reallocation and using std::memmove for in-place updates. A regression test 'ReshapeWithLowerRank' is added to verify the fix.
PiperOrigin-RevId: 900935128
1 parent c598dd8 commit 1c9d10e

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

xla/array.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,10 +542,14 @@ class Array {
542542
std::multiplies<int64_t>());
543543
CHECK_EQ(new_num_elements, num_elements());
544544
if (sizes_.size != new_dimensions.size()) {
545-
sizes_ = OwnedBuffer<int64_t>(new_dimensions.size());
545+
OwnedBuffer<int64_t> new_sizes(new_dimensions.size());
546+
std::memcpy(new_sizes.data.get(), new_dimensions.data(),
547+
new_dimensions.size() * sizeof(int64_t));
548+
sizes_ = std::move(new_sizes);
549+
} else {
550+
std::memmove(sizes_.data.get(), new_dimensions.data(),
551+
new_dimensions.size() * sizeof(int64_t));
546552
}
547-
std::memcpy(sizes_.data.get(), new_dimensions.data(),
548-
new_dimensions.size() * sizeof(int64_t));
549553
}
550554

551555
// Performs a permutation of dimensions.

xla/array_test.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,5 +466,18 @@ TEST(ArrayTest, UpdateSlice) {
466466
EXPECT_EQ(expected, arr.ToString());
467467
}
468468

469+
TEST(ArrayTest, ReshapeWithLowerRank) {
470+
Array<int64_t> arr({1, 1, 24});
471+
EXPECT_EQ(arr.num_dimensions(), 3);
472+
// Reshape to 1D using a sub-span of its own dimensions.
473+
// The dropped dimensions are all 1, so the number of elements stays the same.
474+
// This triggers reallocation of sizes_ (3 -> 1) while new_dimensions aliases
475+
// it.
476+
arr.Reshape(arr.dimensions().subspan(2, 1));
477+
EXPECT_EQ(arr.num_dimensions(), 1);
478+
EXPECT_EQ(arr.dim(0), 24);
479+
EXPECT_EQ(arr.num_elements(), 24);
480+
}
481+
469482
} // namespace
470483
} // namespace xla

0 commit comments

Comments
 (0)