Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,59 @@ static void optimize_conv_permute(program_node& node) {
}
}

static void optimize_permute_conv(program_node& node) {
// Goal: Eliminate the Reorder by aligning connection to byxf
if (node.get_dependencies().empty())
return;

auto& dep = node.get_dependency(0);

// Dependency must be a Permute node (not network output)
if (!dep.is_type<permute>() || dep.is_output())
return;
//if (dep.get_users().size() != 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check that node has only one user to avoid breaking data for other nodes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @p-durandin,
in this case, dependency node, pnode (permute), has two users: permute -> conv, permute -> permute
while node var which is conv node has one user: conv -> permute

For which node do you want to add the check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add check for conv

// return;
if (node.get_output_layout().get_rank() != 4)
return;

auto& pnode = dep.as<permute>();

if (pnode.get_output_layout().data_type != node.get_output_layout().data_type)
return;

// NHWC <-> NCHW (ensures reverse rotation pattern)
if (!pnode.is_reverse_rotating_except_batch())
return;

auto pnode_upstream_fmt = pnode.get_dependency(0).get_preferred_output_fmt();

bool is_compatible_format = (pnode_upstream_fmt == format::bfyx
|| pnode_upstream_fmt == format::any
);

if (!is_compatible_format)
return;

// Set the layouts so that the memory buffer is re-interpreted rather than physically shuffled.
node.set_preferred_input_fmt(0, format::byxf);

// Set Permute Input to match upstream format
pnode.set_preferred_input_fmt(0, format::bfyx);

// This aligns with the Pre-Transpose memory, allowing it to be a zero copy optimization.
// alternative approach is to force the Permute node to set output fmt to match convolution input fmt
// which will eliminate reorder before convolution: pnode.set_preferred_output_fmt(0, node.get_preferred_input_fmt())
// however, for non planar blocked format (b_fs_yx_fsv16) The kernel has to calculate complex offsets to pack 16 channels together into a block.
// and may degrade performance which may possibly be not visible for small input like 56x56
// as small working set the working set fits largely in the GPU L2 cache. And penalty for complex addressing calculation
// may get masked by high L2 bandwidth.
pnode.set_preferred_output_fmt(0, format::byxf);

if (!pnode.has_fused_primitives()) {
pnode.can_be_optimized(true);
}
}

} // namespace

void select_preferred_formats::run(program& p) {
Expand Down Expand Up @@ -119,6 +172,7 @@ void select_preferred_formats::run(program& p) {
}
if (factory->get_impl_type() == impl_types::onednn && (n->is_type<convolution>() || n->is_type<deconvolution>())) {
optimize_conv_permute(*n);
optimize_permute_conv(*n);
}
} catch (std::exception& exception) {
GPU_DEBUG_LOG << "WARNING(select_preferred_formats): " << exception.what() << std::endl;
Expand Down
Loading