diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index e49a2b909225e5..51ac026c7c5060 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -699,8 +699,10 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format auto reshape_axis = crop_axis; if (reshape_mode == reshape::reshape_mode::base) { if (crop_axis == 0 && !crop_layout.get_partial_shape()[0].is_dynamic() && - crop_layout.get_partial_shape()[0].get_length() == 1) { - // The crop produces exactly batch=1 per slice. + crop_layout.get_partial_shape()[0].get_length() == 1 && + !(user_info.second.get_partial_shape()[0].is_static() && + user_info.second.get_partial_shape()[0].get_length() == 1)) { + // The crop produces exactly batch=1 per slice and the reshape squeezes that dim. // The reshape absorbs that dim, so the padding axis in the output remains 0. reshape_axis = 0; } else { @@ -764,9 +766,10 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format std::vector reshape_upper_sizes(output_rank, 0); padding::DynamicDimsMask reshape_dyn_pad_mask; - if (crop_axis == 0 && crop_dim_val == 1) { - // The crop splits on the batch axis with exactly batch=1 per slice. - // The reshape squeezes that batch=1 dim: [1, f, y, x] -> [f, y, x]. + if (crop_axis == 0 && crop_dim_val == 1 && + !(reshape_ps[0].is_static() && reshape_ps[0].get_length() == 1)) { + // The crop splits on the batch axis with exactly batch=1 per slice + // and the reshape squeezes that batch=1 dim: [1, f, y, x] -> [f, y, x]. // Padding offsets are in units of one 4D batch slice (= f*y*x elements), // but the 3D output counts elements at axis 0 directly, so multiply by f. const auto batch_stride_factor = reshape_ps[0].get_length(); diff --git a/src/plugins/intel_gpu/src/graph/include/reshape_inst.h b/src/plugins/intel_gpu/src/graph/include/reshape_inst.h index dbeb07ccedfd71..36a5ffcf72960a 100644 --- a/src/plugins/intel_gpu/src/graph/include/reshape_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/reshape_inst.h @@ -72,7 +72,14 @@ struct typed_program_node : public typed_program_node_base { if (axis == 0 && !input_pshape[0].is_dynamic()) { if (prim->output_pattern.empty()) return false; - return input_pshape[0].get_length() == 1; + if (input_pshape[0].get_length() != 1) + return false; + // Reject if the reshape just flattens spatial dims while keeping batch=1 + // (e.g. [1,C,H,W] -> [1,C,H*W]). Only allow when the batch dim is truly squeezed. + auto& out_ps = prim->output_partial_shape; + if (!out_ps[0].is_dynamic() && out_ps[0].get_length() == 1) + return false; + return true; } auto input_rank = input_pshape.size(); diff --git a/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp b/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp index bfdc024bb778ae..c48040eb8bc329 100644 --- a/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp @@ -1842,3 +1842,72 @@ TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_reshape) for (size_t i = 0; i < slice_elems; i++) ASSERT_FLOAT_EQ(v_out[i], input_data[2 * slice_elems + i]) << "V mismatch at " << i; } + +// RAFT-like pattern: VariadicSplit on batch axis → crop [1,C,H,W] → Reshape [1,C,H*W] +// The reshape flattens spatial dims while keeping batch=1. The in-place crop +// optimisation must NOT treat this as a batch-squeeze and must produce correct data. +TEST(prepare_buffer_fusing, in_place_crop_dynamic_batch_axis_split_with_spatial_flatten_reshape) { + auto& engine = get_test_engine(); + tests::random_generator rg(GET_SUITE_NAME); + + const size_t dim_b = 2, dim_c = 4, dim_h = 3, dim_w = 5; + const size_t slice_elems = dim_c * dim_h * dim_w; + + auto in_layout = layout{ov::PartialShape{static_cast(dim_b), -1, static_cast(dim_h), static_cast(dim_w)}, + data_types::f32, format::bfyx}; + auto input_mem = engine.allocate_memory({{static_cast(dim_b), static_cast(dim_c), + static_cast(dim_h), static_cast(dim_w)}, + data_types::f32, format::bfyx}); + auto axis_mem = engine.allocate_memory({{}, data_types::i64, format::bfyx}); + auto splits_length_mem = engine.allocate_memory({{2}, data_types::i64, format::bfyx}); + + const int64_t axis = 0; + + auto input_data = rg.generate_random_1d(input_mem->count(), -1.f, 1.f); + set_values(input_mem, input_data); + set_values(axis_mem, {axis}); + set_values(splits_length_mem, {1, 1}); + + // reshape [1, C, H, W] → [1, C, H*W] (spatial flatten, batch preserved) + const int64_t hw = static_cast(dim_h * dim_w); + const std::vector flatten_pattern = {1, -1, hw}; + const ov::PartialShape flatten_out_shape = {1, -1, hw}; + + cldnn::crop_ngraph_op_mode op_mode = cldnn::crop_ngraph_op_mode::variadic_split; + topology topology( + input_layout("input", in_layout), + data("axis", axis_mem), + data("splits_length", splits_length_mem), + // Branch 0: crop → [1, C, H, W] → reshape → [1, C, H*W] + crop("crop_0", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 0, axis), + reshape("reshape_0", input_info("crop_0"), false, flatten_pattern, flatten_out_shape, cldnn::reshape::reshape_mode::base), + reorder("output_0", input_info("reshape_0"), format::bfyx, data_types::f32, std::vector(), reorder_mean_mode::subtract, padding(), true), + // Branch 1: crop → [1, C, H, W] → reshape → [1, C, H*W] + crop("crop_1", {input_info("input"), input_info("axis"), input_info("splits_length")}, cldnn::tensor(1), cldnn::tensor(0), op_mode, 1, axis), + reshape("reshape_1", input_info("crop_1"), false, flatten_pattern, flatten_out_shape, cldnn::reshape::reshape_mode::base), + reorder("output_1", input_info("reshape_1"), format::bfyx, data_types::f32, std::vector(), reorder_mean_mode::subtract, padding(), true) + ); + + auto config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + config.set_property(ov::intel_gpu::optimize_data(true)); + network network(engine, topology, config); + network.set_input_data("input", input_mem); + + auto outputs = network.execute(); + + auto out0_mem = outputs.at("output_0").get_memory(); + cldnn::mem_lock out0(out0_mem, get_test_stream()); + auto out1_mem = outputs.at("output_1").get_memory(); + cldnn::mem_lock out1(out1_mem, get_test_stream()); + + ASSERT_EQ(out0.size(), slice_elems); + ASSERT_EQ(out1.size(), slice_elems); + + // Branch 0 must read the first batch slice + for (size_t i = 0; i < slice_elems; i++) + ASSERT_FLOAT_EQ(out0[i], input_data[0 * slice_elems + i]) << "Branch 0 mismatch at " << i; + // Branch 1 must read the second batch slice + for (size_t i = 0; i < slice_elems; i++) + ASSERT_FLOAT_EQ(out1[i], input_data[1 * slice_elems + i]) << "Branch 1 mismatch at " << i; +}