Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions paddle/phi/kernels/set_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/set_kernel.h"
#include <cstring>
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
namespace phi {

// Compute the minimum number of elements required in storage to hold
// a strided view described by dims, stride and offset.
static int64_t ComputeRequiredStorageSize(const std::vector<int64_t>& dims,
const std::vector<int64_t>& stride,
int64_t offset) {
int64_t required = offset;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] > 0) {
required += (dims[i] - 1) * stride[i];
}
}
return required + 1; // +1 for the last element itself
}

template <typename T, typename Context>
void SetKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -29,11 +45,58 @@ void SetKernel(const Context& dev_ctx,
meta.strides = DDim(stride.data(), static_cast<int>(stride.size()));
meta.offset = offset;
if (x.numel() == 0 || source.numel() == 0) {
if (source.numel() != 0) {
int64_t out_numel = 1;
for (auto d : dims) {
out_numel *= d;
}
if (source.numel() == 0 && x.numel() != 0) {
// Source is empty but x has storage. Reuse x's storage and apply
// the user-specified meta, matching PyTorch behavior.
if (out_numel == 0) {
// Output has 0 elements — no storage needed, just set meta.
out->set_meta(meta);
out->ShareInplaceVersionCounterWith(x);
return;
}
// If the strided view requires more storage than x provides,
// allocate a larger zero-filled buffer and copy x's data into it
// to avoid out-of-bounds reads on elements beyond x's allocation.
int64_t required_size = ComputeRequiredStorageSize(dims, stride, offset);
if (required_size > x.numel()) {
DenseTensor tmp;
std::vector<int64_t> alloc_shape = {required_size};
Full<T, Context>(dev_ctx, alloc_shape, 0, &tmp);
if (dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU) {
std::memcpy(tmp.data<T>(), x.data<T>(), x.numel() * sizeof(T));
} else {
memory_utils::Copy(dev_ctx.GetPlace(),
tmp.data<T>(),
dev_ctx.GetPlace(),
x.data<T>(),
x.numel() * sizeof(T),
nullptr);
}
out->clear();
*out = DenseTensor{tmp.Holder(), meta};
} else {
out->set_meta(meta);
}
} else if (source.numel() == 0 && x.numel() == 0 && out_numel != 0) {
// Both x and source are 0-size but user wants non-zero shape.
// Allocate zero-filled storage to avoid null pointer access.
int64_t required_size = ComputeRequiredStorageSize(dims, stride, offset);
DenseTensor tmp;
std::vector<int64_t> alloc_shape = {required_size};
Full<T, Context>(dev_ctx, alloc_shape, 0, &tmp);
out->clear();
*out = DenseTensor{tmp.Holder(), meta};
} else if (source.numel() != 0) {
out->clear();
*out = DenseTensor{source.Holder(), meta};
} else {
// Both 0-size, output also 0-size
out->clear();
*out = DenseTensor{source.Holder(), meta};
} else if (x.numel() == 0) {
Full<T, Context>(dev_ctx, out->dims(), 0, out);
}
out->ShareInplaceVersionCounterWith(x);
return;
Expand Down
110 changes: 107 additions & 3 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2603,11 +2603,115 @@ class TestSet_API_ZeroSize(unittest.TestCase):
def setUp(self):
self.places = get_places()

def test_set_api(self):
def test_zero_size_source_with_nonzero_shape(self):
"""When source is 0-size but user specifies non-zero dims/stride,
output should respect user-specified shape (matching PyTorch behavior).
Storage is expanded if needed to avoid out-of-bounds access."""
for place in self.places:
with paddle.base.dygraph.guard(place):
out = paddle.randn([20]).set_(paddle.randn([0, 3]), [20], [2])
np.testing.assert_allclose(out.shape, [20])
source = paddle.randn([0, 3])
x = paddle.randn([20])
out = x.set_(source, [20], [2])
self.assertEqual(list(out.shape), [20])
# contiguous should work without OOB
c = out.contiguous()
self.assertEqual(list(c.shape), [20])

def test_zero_size_source_default_args(self):
"""set_ with 0-size source and no explicit shape/stride."""
for place in self.places:
with paddle.base.dygraph.guard(place):
source = paddle.randn([0, 5])
x = paddle.randn([10])
out = x.set_(source)
self.assertEqual(out.numel().item(), 0)
self.assertEqual(list(out.shape), [0, 5])
self.assertTrue(id(x) == id(out))

def test_zero_size_x_nonzero_source(self):
"""set_ with 0-size x but non-zero source should work normally."""
for place in self.places:
with paddle.base.dygraph.guard(place):
source = paddle.to_tensor([1.0, 2.0, 3.0])
x = paddle.randn([0])
out = x.set_(source)
self.assertEqual(list(out.shape), [3])
self.assertTrue(x._is_shared_buffer_with(source))

def test_both_zero_size(self):
"""set_ with both x and source being 0-size."""
for place in self.places:
with paddle.base.dygraph.guard(place):
source = paddle.randn([0])
x = paddle.randn([0])
out = x.set_(source)
self.assertEqual(out.numel().item(), 0)
self.assertTrue(id(x) == id(out))

def test_both_zero_size_with_nonzero_shape(self):
"""Both x and source are 0-size but user specifies non-zero dims/stride.
This covers the branch that allocates zero-filled storage when both
tensors are empty but a non-zero output shape is requested."""
for place in self.places:
with paddle.base.dygraph.guard(place):
source = paddle.randn([0])
x = paddle.randn([0])
out = x.set_(source, [4], [1])
self.assertEqual(list(out.shape), [4])
self.assertTrue(id(x) == id(out))
# The allocated storage should be zero-filled and accessible
c = out.contiguous()
self.assertEqual(list(c.shape), [4])
np.testing.assert_array_equal(
c.numpy(), np.zeros([4], dtype='float32')
)

def test_both_zero_size_with_nonzero_shape_and_offset(self):
"""Both x and source are 0-size, user specifies non-zero shape with
a non-zero offset. Verifies storage is large enough to accommodate
the offset without out-of-bounds access."""
for place in self.places:
with paddle.base.dygraph.guard(place):
source = paddle.randn([0])
x = paddle.randn([0])
# offset must be a multiple of element size (4 bytes for
# float32) to avoid misaligned GPU memory access.
out = x.set_(source, [3], [2], 4)
self.assertEqual(list(out.shape), [3])
self.assertTrue(id(x) == id(out))
c = out.contiguous()
self.assertEqual(list(c.shape), [3])
np.testing.assert_array_equal(
c.numpy(), np.zeros([3], dtype='float32')
)

def test_both_zero_size_with_nonzero_2d_shape(self):
"""Both x and source are 0-size, user specifies a 2D non-zero shape.
Verifies multi-dimensional strided view is allocated correctly."""
for place in self.places:
with paddle.base.dygraph.guard(place):
source = paddle.randn([0, 0])
x = paddle.randn([0])
out = x.set_(source, [2, 3], [3, 1])
self.assertEqual(list(out.shape), [2, 3])
self.assertTrue(id(x) == id(out))
c = out.contiguous()
self.assertEqual(list(c.shape), [2, 3])
np.testing.assert_array_equal(
c.numpy(), np.zeros([2, 3], dtype='float32')
)

def test_zero_size_source_no_crash_on_contiguous(self):
"""Ensure contiguous() works correctly on a tensor
that was set_ with a 0-size source but user-specified shape."""
for place in self.places:
with paddle.base.dygraph.guard(place):
source = paddle.randn([0, 3])
x = paddle.randn([20])
out = x.set_(source, [20], [2])
# contiguous should produce a valid tensor with correct shape
c = out.contiguous()
self.assertEqual(list(c.shape), [20])


if __name__ == '__main__':
Expand Down
Loading