Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions paddle2onnx/mapper/tensor/index_put.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/tensor/index_put.h"

namespace paddle2onnx {
REGISTER_MAPPER(index_put, IndexPutMapper)
REGISTER_PIR_MAPPER(index_put, IndexPutMapper)

int32_t IndexPutMapper::GetMinOpsetVersion(bool verbose) {
// ScatterND requires opset 11, and reduction attribute requires opset 16
if (accumulate_) {
Logger(verbose, 16) << RequireOpset(16) << std::endl;
return 16;
}
Logger(verbose, 11) << RequireOpset(11) << std::endl;
return 11;
}

void IndexPutMapper::Opset11() {
// Get inputs:
// - x: the tensor to be updated
// - indices: a vector of index tensors (one per dimension, can be boolean
// masks)
// - value: the values to put
auto x_info = GetInput("x");
auto indices_info = GetInput("indices"); // This is a vector of tensors
auto value_info = GetInput("value");
auto output_info = GetOutput("out");

// Check if indices are boolean masks
bool is_boolean_mask = false;
if (indices_info.size() == 1 && indices_info[0].dtype == P2ODataType::BOOL) {
is_boolean_mask = true;
}

if (is_boolean_mask) {
// Boolean indexing: x[mask] = value
// Use ONNX Where operator: out = Where(mask, value_broadcast, x)
std::string mask = indices_info[0].name;

// Cast value to match x's dtype if needed
std::string value_name = value_info[0].name;
if (value_info[0].dtype != x_info[0].dtype) {
value_name = helper_->AutoCast(
value_info[0].name, value_info[0].dtype, x_info[0].dtype);
}

// Expand value to match x's shape for broadcasting
auto x_shape_node = helper_->MakeNode("Shape", {x_info[0].name});
std::string value_broadcast =
helper_->MakeNode("Expand", {value_name, x_shape_node->output(0)})
->output(0);

if (accumulate_) {
// When accumulate is true: x[mask] += value
// out = Where(mask, x + value_broadcast, x)
std::string add_result =
helper_->MakeNode("Add", {x_info[0].name, value_broadcast})
->output(0);
helper_->MakeNode(
"Where", {mask, add_result, x_info[0].name}, {output_info[0].name});
} else {
// out = Where(mask, value_broadcast, x)
helper_->MakeNode("Where",
{mask, value_broadcast, x_info[0].name},
{output_info[0].name});
}
} else {
// Integer indexing: use ScatterND
std::vector<std::string> indices_names;
for (size_t i = 0; i < indices_info.size(); ++i) {
// Cast indices to INT64 if needed
std::string idx_name = helper_->AutoCast(
indices_info[i].name, indices_info[i].dtype, P2ODataType::INT64);
// Unsqueeze each index tensor to add a dimension at the end
std::string axes_node = helper_->Constant(
ONNX_NAMESPACE::TensorProto::INT64, std::vector<int64_t>{-1});
auto unsqueeze_node =
helper_->MakeNode("Unsqueeze", {idx_name, axes_node});
indices_names.push_back(unsqueeze_node->output(0));
}

// Concat all indices along the last dimension
std::string indices_concat;
if (indices_names.size() == 1) {
indices_concat = indices_names[0];
} else {
auto concat_node = helper_->MakeNode("Concat", indices_names);
AddAttribute(concat_node, "axis", static_cast<int64_t>(-1));
indices_concat = concat_node->output(0);
}

// Cast value to match x's dtype if needed
std::string value_name = value_info[0].name;
if (value_info[0].dtype != x_info[0].dtype) {
value_name = helper_->AutoCast(
value_info[0].name, value_info[0].dtype, x_info[0].dtype);
}

// For ScatterND, updates shape should be:
// indices.shape[:-1] + data.shape[num_dims:]
// where num_dims = indices.shape[-1] = number of index tensors

// Get indices shape (without the last dim we added via Unsqueeze)
auto indices_shape_node =
helper_->MakeNode("Shape", {indices_info[0].name});

// Get data shape and slice from num_dims onwards
auto data_shape_node = helper_->MakeNode("Shape", {x_info[0].name});
int64_t num_dims = static_cast<int64_t>(indices_info.size());
auto start_const = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64,
std::vector<int64_t>{num_dims});
auto end_const = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64,
std::vector<int64_t>{INT64_MAX});
auto axes_const = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64,
std::vector<int64_t>{0});
auto data_shape_suffix = helper_
->MakeNode("Slice",
{data_shape_node->output(0),
start_const,
end_const,
axes_const})
->output(0);

// Concat to get the target updates shape
auto target_shape_node = helper_->MakeNode(
"Concat", {indices_shape_node->output(0), data_shape_suffix});
AddAttribute(target_shape_node, "axis", static_cast<int64_t>(0));

// Expand value to match target shape
value_name =
helper_->MakeNode("Expand", {value_name, target_shape_node->output(0)})
->output(0);

if (accumulate_) {
auto shape_node = helper_->MakeNode("Shape", {x_info[0].name});
std::string zeros_node =
helper_->ConstOfShape(shape_node->output(0),
GetOnnxDtype(x_info[0].dtype),
static_cast<float>(0));

auto scatter_node = helper_->MakeNode(
"ScatterND", {zeros_node, indices_concat, value_name});
AddAttribute(scatter_node, "reduction", std::string("add"));

helper_->MakeNode("Add",
{x_info[0].name, scatter_node->output(0)},
{output_info[0].name});
} else {
helper_->MakeNode("ScatterND",
{x_info[0].name, indices_concat, value_name},
{output_info[0].name});
}
}
}

} // namespace paddle2onnx
46 changes: 46 additions & 0 deletions paddle2onnx/mapper/tensor/index_put.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {

class IndexPutMapper : public Mapper {
public:
IndexPutMapper(const PaddleParser& p,
OnnxHelper* helper,
int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {
GetAttr("accumulate", &accumulate_);
}
IndexPutMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t op_id,
bool in_cf_block)
: Mapper(p, helper, op_id, in_cf_block) {
GetAttr("accumulate", &accumulate_);
}
int32_t GetMinOpsetVersion(bool verbose) override;
void Opset11() override;

private:
bool accumulate_ = false;
};

} // namespace paddle2onnx
26 changes: 25 additions & 1 deletion paddle2onnx/mapper/tensor/set_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ int32_t SetValueMapper::GetMinOpsetVersion(bool verbose) {
void SetValueMapper::Opset17() {
auto input_info = GetInput("Input");
auto output_info = GetOutput("Out");

// Special case: if axes is empty, this is a full tensor assignment
// Just copy the value to output (for set_value_with_tensor_ with empty axes)
std::string op_type = OpType();
bool is_set_value_with_tensor =
(op_type.find("set_value_with_tensor") != std::string::npos);
if (in_pir_mode && is_set_value_with_tensor && axes_.empty()) {
auto value_info = GetInput(1);
helper_->MakeNode("Identity", {value_info[0].name}, {output_info[0].name});
return;
}

std::string starts = "";
if (HasInput("StartsTensorList")) {
// if negtive value exists, not supported
Expand Down Expand Up @@ -91,14 +103,26 @@ void SetValueMapper::Opset17() {
auto input_tensor = input_info[0].name;
std::string value = "";
int64_t value_rank = input_info[0].Rank();
if (HasInput("ValueTensor")) {

// Reuse op_type and is_set_value_with_tensor from earlier in function
if (in_pir_mode && is_set_value_with_tensor) {
// In PIR mode, set_value_with_tensor_ has value as second input (index 1)
auto value_info = GetInput(1);
value = value_info[0].name;
value_rank = value_info[0].Rank();
} else if (HasInput("ValueTensor")) {
auto value_info = GetInput("ValueTensor");
value = value_info[0].name;
value_rank = value_info[0].Rank();
} else if (HasInput("values")) {
auto value_info = GetInput("values");
value = value_info[0].name;
value_rank = value_info[0].Rank();
} else if (HasInput("value")) {
// PIR mode: set_value_with_tensor_ uses "value" as input name
auto value_info = GetInput("value");
value = value_info[0].name;
value_rank = value_info[0].Rank();
} else {
value_rank = shape_.size();
int in_dtype = input_info[0].dtype;
Expand Down
8 changes: 8 additions & 0 deletions paddle2onnx/mapper/tensor/slice.cc
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modifying this file from the version prior to this PR (3e77ec7) was unnecessary, reverted it.

Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ std::vector<int64_t> SliceMapper::DecreaseAxis() {
bool has_attr = HasAttr("decrease_axis");
if (has_attr) {
GetAttr("decrease_axis", &decrease_axis);

// In PIR mode, if decrease_axis is not empty, we should use it directly
// The shape comparison logic may fail in PIR mode due to input name
// differences
if (in_pir_mode && !decrease_axis.empty()) {
return decrease_axis;
}

auto input_info = GetInput("Input");
auto output_info = GetOutput("Out");
if (output_info[0].shape.size() == 1 && output_info[0].shape[0] == 0) {
Expand Down
31 changes: 31 additions & 0 deletions paddle2onnx/mapper/tensor/squeeze2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,42 @@ void Squeeze2Mapper::Opset7() {
if (i > 1) ret.push_back(i);
}
if (ret.size() == input_info[0].Rank()) {
// All dimensions are > 1, nothing to squeeze
helper_->MakeNode("Identity", {input_info[0].name}, {output_info[0].name});
} else {
bool with_axis = in_pir_mode ? HasInput("axis") : IsAttrVar("axes");
if (helper_->GetOpsetVersion() >= 13 && with_axis) {
auto axes_info = in_pir_mode ? GetInput("axis") : GetAttrVar("axes");

// Check if we can get the axes values statically
std::vector<int64_t> axes_values;
bool axes_known = false;
if (in_pir_mode) {
axes_known = TryGetInputValue("axis", &axes_values);
}

// If axes are known, check if the dimensions at those axes are 1
if (axes_known && !axes_values.empty()) {
bool all_dims_not_one = true;
for (auto axis : axes_values) {
int64_t actual_axis = axis >= 0 ? axis : axis + input_info[0].Rank();
if (actual_axis >= 0 && actual_axis < input_info[0].Rank()) {
int64_t dim_size = input_info[0].shape[actual_axis];
if (dim_size == 1 || dim_size == -1) {
// -1 means dynamic, might be 1 at runtime
all_dims_not_one = false;
break;
}
}
}
if (all_dims_not_one) {
// None of the dimensions to squeeze have size 1, use Identity
helper_->MakeNode(
"Identity", {input_info[0].name}, {output_info[0].name});
return;
}
}

std::string axes_name;
if (axes_info.size() == 1U) {
axes_name = helper_->AutoCast(
Expand Down
64 changes: 61 additions & 3 deletions paddle2onnx/mapper/tensor/stack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,74 @@ void StackMapper::Opset7() {
int32_t out_dtype = 0;
std::vector<std::string> aligned_inputs =
helper_->DtypeAlignment(x_info, &out_dtype);

// Find the maximum rank among all inputs based on TensorInfo
int32_t max_rank = 0;
for (size_t i = 0; i < x_info.size(); ++i) {
int32_t rank = x_info[i].Rank();
if (rank > max_rank) {
max_rank = rank;
}
}

// Special case: if all inputs are scalars [] or single-element tensors [1]
// Check if all inputs have at most 1 element total
bool all_single_element = true;
for (size_t i = 0; i < x_info.size(); ++i) {
if (x_info[i].Rank() == 0) {
// Scalar, has 1 element - OK
continue;
} else if (x_info[i].Rank() == 1) {
// Check if it's exactly [1] not [4] or other sizes
if (x_info[i].shape[0] == 1) {
// Single element [1] - OK
continue;
} else {
// It's like [4] or [N] where N != 1 - NOT single element
all_single_element = false;
break;
}
} else {
// Rank > 1, definitely not single element
all_single_element = false;
break;
}
}

if (all_single_element && max_rank <= 1) {
// All inputs are scalars or [1], normalize to scalars []
for (size_t i = 0; i < aligned_inputs.size(); ++i) {
aligned_inputs[i] =
helper_->Reshape(aligned_inputs[i], std::vector<int64_t>{});
}
max_rank = 0; // All are now scalars
} else {
// Normal case: make all inputs have the same rank by unsqueezing lower-rank
// tensors
for (size_t i = 0; i < aligned_inputs.size(); ++i) {
int32_t rank_diff = max_rank - x_info[i].Rank();
if (rank_diff > 0) {
// Unsqueeze to match max_rank
std::vector<int64_t> axes_to_add;
for (int32_t j = 0; j < rank_diff; ++j) {
axes_to_add.push_back(j);
}
aligned_inputs[i] = helper_->Unsqueeze(aligned_inputs[i], axes_to_add);
}
}
}

auto axis = axis_;
if (axis < 0) {
axis = axis + x_info[0].Rank() + 1;
axis = axis + max_rank + 1;
}

// Now unsqueeze all inputs at the target axis for stacking
for (size_t i = 0; i < aligned_inputs.size(); ++i) {
aligned_inputs[i] =
helper_->Unsqueeze(aligned_inputs[i], std::vector<int64_t>(1, axis));
}
auto out = helper_->Concat(aligned_inputs, axis_);
auto out = helper_->Concat(aligned_inputs, axis);
helper_->AutoCast(out, y_info[0].name, out_dtype, y_info[0].dtype);
}

} // namespace paddle2onnx
Loading
Loading