Skip to content
8 changes: 5 additions & 3 deletions paddle/phi/api/include/compat/ATen/core/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,11 @@ class Tensor : public TensorBase {
::std::optional<int64_t> end = ::std::nullopt,
int64_t step = 1);

// TODO(wangyanpeng04): modify the api to
// Tensor index(ArrayRef<at::indexing::TensorIndex> indices) const;
at::Tensor index(const std::vector<at::indexing::Slice>& indices) const;
at::Tensor index(ArrayRef<at::indexing::TensorIndex> indices) const;
inline at::Tensor index(
std::initializer_list<at::indexing::TensorIndex> indices) const {
return index(ArrayRef<at::indexing::TensorIndex>(indices));
}

at::Tensor& floor_divide_(const at::Scalar& other) const {
paddle::experimental::floor_divide_(
Expand Down
55 changes: 53 additions & 2 deletions paddle/phi/api/include/compat/ATen/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@
#include <c10/core/SymInt.h>

#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <optional>
#include <stdexcept>
#include <type_traits>
#include <utility>

namespace at {
class Tensor;
}

namespace at::indexing {

Expand Down Expand Up @@ -58,9 +68,7 @@ struct Slice final {
}

inline c10::SymInt start() const { return start_; }

inline c10::SymInt stop() const { return stop_; }

inline c10::SymInt step() const { return step_; }

private:
Expand All @@ -69,4 +77,47 @@ struct Slice final {
c10::SymInt step_;
};

struct TensorIndex final {
TensorIndex(std::nullopt_t /*unused*/) : type_(TensorIndexType::None) {}

TensorIndex(at::indexing::EllipsisIndexType /*unused*/)
: type_(TensorIndexType::Ellipsis) {}
TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
if (std::strcmp(str, "...") != 0) {
throw std::invalid_argument(
"Expected \"...\" to represent an ellipsis index.");
}
}

TensorIndex(c10::SymInt integer)
: integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
TensorIndex(int integer) : TensorIndex(c10::SymInt(integer)) {}

template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}

TensorIndex(Slice slice)
: slice_(std::move(slice)), type_(TensorIndexType::Slice) {}

TensorIndex(const at::Tensor& tensor);

inline bool is_none() const { return type_ == TensorIndexType::None; }
inline bool is_ellipsis() const { return type_ == TensorIndexType::Ellipsis; }
inline bool is_integer() const { return type_ == TensorIndexType::SymInt; }
inline c10::SymInt integer() const { return integer_; }
inline bool is_boolean() const { return type_ == TensorIndexType::Boolean; }
inline bool boolean() const { return boolean_; }
inline bool is_slice() const { return type_ == TensorIndexType::Slice; }
inline const Slice& slice() const { return slice_; }
inline bool is_tensor() const { return type_ == TensorIndexType::Tensor; }
const at::Tensor& tensor() const;

private:
c10::SymInt integer_ = 0;
bool boolean_ = false;
Slice slice_;
std::shared_ptr<at::Tensor> tensor_;
TensorIndexType type_;
};

} // namespace at::indexing
77 changes: 65 additions & 12 deletions paddle/phi/api/include/compat/ATen/ops/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,82 @@
#include <ATen/core/Tensor.h>
#include <ATen/indexing.h>

namespace at::indexing {

inline TensorIndex::TensorIndex(const at::Tensor& tensor)
: tensor_(std::make_shared<at::Tensor>(tensor)),
type_(TensorIndexType::Tensor) {}

inline const at::Tensor& TensorIndex::tensor() const { return *tensor_; }

} // namespace at::indexing

namespace at {

// TODO(wangyanpeng04): modify the api to
// Tensor index(ArrayRef<at::indexing::TensorIndex> indices) const;
inline at::Tensor index(const at::Tensor& self,
const std::vector<at::indexing::Slice>& indices) {
std::vector<int64_t> starts(indices.size());
std::vector<int64_t> ends(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
starts[i] = indices[i].start();
ends[i] = indices[i].stop();
ArrayRef<at::indexing::TensorIndex> indices) {
if (indices.size() == 0) {
return self;
}

bool has_slice = false;
bool has_tensor_like = false;
for (const auto& index : indices) {
has_slice = has_slice || index.is_slice();
has_tensor_like = has_tensor_like || index.is_tensor() || index.is_none();
PD_CHECK(!index.is_ellipsis(), "Ellipsis index is not supported yet.");
PD_CHECK(!index.is_integer(), "Integer index is not supported yet.");
PD_CHECK(!index.is_boolean(), "Boolean index is not supported yet.");
}

if (has_slice && !has_tensor_like) {
std::vector<int64_t> axes;
std::vector<int64_t> starts;
std::vector<int64_t> ends;
std::vector<int64_t> strides;
axes.reserve(indices.size());
starts.reserve(indices.size());
ends.reserve(indices.size());
strides.reserve(indices.size());

int64_t dim = 0;
for (const auto& index : indices) {
const auto& slice = index.slice();
axes.push_back(dim++);
starts.push_back(static_cast<int64_t>(slice.start()));
ends.push_back(static_cast<int64_t>(slice.stop()));
strides.push_back(static_cast<int64_t>(slice.step()));
}

return paddle::experimental::slice(
self._PD_GetInner(), axes, starts, ends, strides, {});
}
return paddle::experimental::slice(
self._PD_GetInner(), {0, 1}, starts, ends, {1}, {})
.contiguous();

PD_CHECK(!has_slice,
"Mixed slice and tensor/None indexing is not supported yet.");
c10::List<::std::optional<at::Tensor>> tensor_indices;
for (const auto& index : indices) {
if (index.is_none()) {
tensor_indices.push_back(::std::nullopt);
} else if (index.is_tensor()) {
tensor_indices.push_back(index.tensor());
}
}
return self.index(tensor_indices);
}

inline at::Tensor index(
const at::Tensor& self,
std::initializer_list<at::indexing::TensorIndex> indices) {
return at::index(self, ArrayRef<at::indexing::TensorIndex>(indices));
}

} // namespace at

namespace at {

inline at::Tensor Tensor::index(
const std::vector<at::indexing::Slice>& indices) const {
ArrayRef<at::indexing::TensorIndex> indices) const {
return at::index(*this, indices);
}

Expand Down
16 changes: 16 additions & 0 deletions test/cpp/compat/ATen_index_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <ATen/Functions.h>
#include <ATen/core/TensorBody.h>
#include <ATen/indexing.h>
#include <ATen/ops/tensor.h>
#include <c10/core/List.h>
#include <c10/core/ScalarType.h>
Expand Down Expand Up @@ -51,6 +52,21 @@ TEST(TensorIndexTest, IndexWithSingleTensor) {
ASSERT_FLOAT_EQ(result_data[2], 40.0f);
}

TEST(TensorIndexTest, SliceKeepsStrideWithoutContiguousCopy) {
at::Tensor base = at::arange(24, at::kFloat).reshape({4, 6});
at::Tensor transposed = base.t(); // shape: [6, 4], strides: [1, 6]
ASSERT_FALSE(transposed.is_contiguous());

at::Tensor sliced =
transposed.index({at::indexing::Slice(1, 5), at::indexing::Slice(0, 3)});

ASSERT_EQ(sliced.sizes(), c10::IntArrayRef({4, 3}));
ASSERT_EQ(sliced.strides(), c10::IntArrayRef({1, 6}));
ASSERT_EQ(sliced.stride(0), transposed.stride(0));
ASSERT_EQ(sliced.stride(1), transposed.stride(1));
ASSERT_FALSE(sliced.is_contiguous());
}

// ======================== index_put_ tests ========================

TEST(TensorIndexPutTest, IndexPutInplaceWithTensor) {
Expand Down
Loading