Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,10 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
auto reshape_axis = crop_axis;
if (reshape_mode == reshape::reshape_mode::base) {
if (crop_axis == 0 && !crop_layout.get_partial_shape()[0].is_dynamic() &&
crop_layout.get_partial_shape()[0].get_length() == 1) {
// The crop produces exactly batch=1 per slice.
crop_layout.get_partial_shape()[0].get_length() == 1 &&
!(user_info.second.get_partial_shape()[0].is_static() &&
user_info.second.get_partial_shape()[0].get_length() == 1)) {
// The crop produces exactly batch=1 per slice and the reshape squeezes that dim.
// The reshape absorbs that dim, so the padding axis in the output remains 0.
reshape_axis = 0;
} else {
Expand Down Expand Up @@ -764,9 +766,10 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
std::vector<ov::Dimension::value_type> reshape_upper_sizes(output_rank, 0);
padding::DynamicDimsMask reshape_dyn_pad_mask;

if (crop_axis == 0 && crop_dim_val == 1) {
// The crop splits on the batch axis with exactly batch=1 per slice.
// The reshape squeezes that batch=1 dim: [1, f, y, x] -> [f, y, x].
if (crop_axis == 0 && crop_dim_val == 1 &&
!(reshape_ps[0].is_static() && reshape_ps[0].get_length() == 1)) {
// The crop splits on the batch axis with exactly batch=1 per slice
// and the reshape squeezes that batch=1 dim: [1, f, y, x] -> [f, y, x].
// Padding offsets are in units of one 4D batch slice (= f*y*x elements),
// but the 3D output counts elements at axis 0 directly, so multiply by f.
const auto batch_stride_factor = reshape_ps[0].get_length();
Expand Down
9 changes: 8 additions & 1 deletion src/plugins/intel_gpu/src/graph/include/reshape_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,14 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
if (axis == 0 && !input_pshape[0].is_dynamic()) {
if (prim->output_pattern.empty())
return false;
return input_pshape[0].get_length() == 1;
if (input_pshape[0].get_length() != 1)
return false;
// Reject if the reshape just flattens spatial dims while keeping batch=1
// (e.g. [1,C,H,W] -> [1,C,H*W]). Only allow when the batch dim is truly squeezed.
auto& out_ps = prim->output_partial_shape;
if (!out_ps[0].is_dynamic() && out_ps[0].get_length() == 1)
return false;
return true;
}

auto input_rank = input_pshape.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1842,3 +1842,72 @@ TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_reshape)
for (size_t i = 0; i < slice_elems; i++)
ASSERT_FLOAT_EQ(v_out[i], input_data[2 * slice_elems + i]) << "V mismatch at " << i;
}

// RAFT-like pattern: VariadicSplit on batch axis → crop [1,C,H,W] → Reshape [1,C,H*W]
// The reshape flattens spatial dims while keeping batch=1. The in-place crop
// optimisation must NOT treat this as a batch-squeeze and must produce correct data.
TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_spatial_flatten_reshape) {
auto& engine = get_test_engine();
tests::random_generator rg(GET_SUITE_NAME);

const size_t dim_b = 2, dim_c = 4, dim_h = 3, dim_w = 5;
const size_t slice_elems = dim_c * dim_h * dim_w;

auto in_layout = layout{ov::PartialShape{static_cast<int64_t>(dim_b), -1, static_cast<int64_t>(dim_h), static_cast<int64_t>(dim_w)},
data_types::f32, format::bfyx};
auto input_mem = engine.allocate_memory({{static_cast<int64_t>(dim_b), static_cast<int64_t>(dim_c),
static_cast<int64_t>(dim_h), static_cast<int64_t>(dim_w)},
data_types::f32, format::bfyx});
auto axis_mem = engine.allocate_memory({{}, data_types::i64, format::bfyx});
auto splits_length_mem = engine.allocate_memory({{2}, data_types::i64, format::bfyx});

const int64_t axis = 0;

auto input_data = rg.generate_random_1d<float>(input_mem->count(), -1.f, 1.f);
set_values(input_mem, input_data);
set_values<int64_t>(axis_mem, {axis});
set_values<int64_t>(splits_length_mem, {1, 1});

// reshape [1, C, H, W] → [1, C, H*W] (spatial flatten, batch preserved)
const int64_t hw = static_cast<int64_t>(dim_h * dim_w);
const std::vector<int64_t> flatten_pattern = {1, -1, hw};
const ov::PartialShape flatten_out_shape = {1, -1, hw};

cldnn::crop_ngraph_op_mode op_mode = cldnn::crop_ngraph_op_mode::variadic_split;
topology topology(
input_layout("input", in_layout),
data("axis", axis_mem),
data("splits_length", splits_length_mem),
// Branch 0: crop → [1, C, H, W] → reshape → [1, C, H*W]
crop("crop_0", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 0, axis),
reshape("reshape_0", input_info("crop_0"), false, flatten_pattern, flatten_out_shape, cldnn::reshape::reshape_mode::base),
reorder("output_0", input_info("reshape_0"), format::bfyx, data_types::f32, std::vector<float>(), reorder_mean_mode::subtract, padding(), true),
// Branch 1: crop → [1, C, H, W] → reshape → [1, C, H*W]
crop("crop_1", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 1, axis),
reshape("reshape_1", input_info("crop_1"), false, flatten_pattern, flatten_out_shape, cldnn::reshape::reshape_mode::base),
reorder("output_1", input_info("reshape_1"), format::bfyx, data_types::f32, std::vector<float>(), reorder_mean_mode::subtract, padding(), true)
);

auto config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
config.set_property(ov::intel_gpu::optimize_data(true));
network network(engine, topology, config);
network.set_input_data("input", input_mem);

auto outputs = network.execute();

auto out0_mem = outputs.at("output_0").get_memory();
cldnn::mem_lock<float> out0(out0_mem, get_test_stream());
auto out1_mem = outputs.at("output_1").get_memory();
cldnn::mem_lock<float> out1(out1_mem, get_test_stream());

ASSERT_EQ(out0.size(), slice_elems);
ASSERT_EQ(out1.size(), slice_elems);

// Branch 0 must read the first batch slice
for (size_t i = 0; i < slice_elems; i++)
ASSERT_FLOAT_EQ(out0[i], input_data[0 * slice_elems + i]) << "Branch 0 mismatch at " << i;
// Branch 1 must read the second batch slice
for (size_t i = 0; i < slice_elems; i++)
ASSERT_FLOAT_EQ(out1[i], input_data[1 * slice_elems + i]) << "Branch 1 mismatch at " << i;
}
Loading