Skip to content

Extend support for JIT Backward Convolution Operators with ARM SVE 128bit  #2165

Open
@snadampal

Description

@snadampal

Summary

On aarch64 platform, Convolution backward operators are supported via jitted SVE kernels. Today the support exists only for SVE 512 and SVE 256bit width, but not for SVE 128bit processors like AWS Graviton4. The request is to extend the existing SVE jitted kernels to support 128bit width.

Problem statement

resnet50 model training requires backward convolution operators, and these currently executed with reference 'c' kernels. Extending the following oneDNN operators for SVE 128bit accelerates these operators with SIMD and improved the performance by several orders.
Here are the details on the existing oneDNN jitted kernels for backward pass operators:


https://github.com/oneapi-src/oneDNN/blob/main/src/cpu/cpu_convolution_list.cpp#L262
 
{{backward_data, f32, f32, f32}, REG_BWD_D_PK({
…..
            CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t<sve_512,data_type::f32>)
            CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_bwd_data_t<f32,f32,f32,sve_512>)
            CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_data_t<f32,f32,f32,sve_512>)
            CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t<sve_256,data_type::f32>)
            CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_bwd_data_t<f32,f32,f32,sve_256>)
            CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_data_t<f32,f32,f32,sve_256>)
….
}
 
{{backward_weights, f32, f32, f32}, REG_BWD_PK({
……
            CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_weights_t<sve_512,data_type::f32>)
            CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_bwd_weights_t<f32,f32,f32,sve_512>)
            CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_weights_t<f32,f32,f32,sve_512>)
            CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_weights_t<sve_256,data_type::f32>)
            CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_bwd_weights_t<f32,f32,f32,sve_256>)
            CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_weights_t<f32,f32,f32,sve_256>)
….
}

The kernel sources are here:

https://github.com/oneapi-src/oneDNN/tree/main/src/cpu/aarch64
 
jit_sve_conv_kernel.cpp/hpp
jit_sve_convolution.cpp/hpp
few other files in the same folder.

Preferred solution

Document your thoughts on what solution may look like.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions