Skip to content

Commit ca4930e

Browse files
[hip-kernel-provider] Implement RMSNorm backward kernels and RMSNorm channel last support (#7702)
**Caution**: This PR should be merged only after [this PR](#7494) is merged. ## Motivation This PR implements the RMSNorm backward kernels and RMSNorm channel-last support for both forward and backward operations in the hip kernel provider. ## Technical Details - Adds the RMSNorm backward kernels and makes relevant changes in `RMSnormBwdPlan` to compile and launch the kernels. - Adds channel last support for both `RMSnormFwd` and `RMSnormBwd` operations. - Adds/updates unit tests and integration tests to test the changes introduced in this PR. ## Test Plan Build the plugin and run the unit and integration tests with `ninja check`. ## Test Result All unit and integration tests pass successfully on an MI210. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent c7d49aa commit ca4930e

18 files changed

Lines changed: 1231 additions & 173 deletions

dnn-providers/hip-kernel-provider/kernels/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ set(KERNEL_FILES
8282
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm/Configuration.hpp
8383
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm/ReductionFunctions.hpp
8484
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm/StaticUnroll.hpp
85+
${CMAKE_CURRENT_SOURCE_DIR}/rmsnorm/RMSNormCommon.hpp
8586
${CMAKE_CURRENT_SOURCE_DIR}/rmsnorm/RMSNormFwd.cpp
87+
${CMAKE_CURRENT_SOURCE_DIR}/rmsnorm/RMSNormBwd.cpp
8688
${CMAKE_CURRENT_SOURCE_DIR}/hip/vector_add.cpp
8789
${CMAKE_CURRENT_SOURCE_DIR}/layernorm/LayernormFwd.cpp
8890
${CMAKE_CURRENT_SOURCE_DIR}/types/FloatTypes.h
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include <type_traits>
5+
6+
#include "RMSNormCommon.hpp"
7+
8+
constexpr unsigned int LOCAL_SIZE = HIP_PLUGIN_RMSNORM_LOCAL_SIZE;
9+
constexpr unsigned int INNER_SIZE = HIP_PLUGIN_RMSNORM_INNER_SIZE;
10+
constexpr unsigned int OUTER_SIZE = HIP_PLUGIN_RMSNORM_OUTER_SIZE;
11+
constexpr unsigned int STRIDE = HIP_PLUGIN_RMSNORM_STRIDE;
12+
13+
using XType = HIP_PLUGIN_RMSNORM_X_TYPE;
14+
using DyType = HIP_PLUGIN_RMSNORM_DY_TYPE;
15+
using DxType = HIP_PLUGIN_RMSNORM_DX_TYPE;
16+
using ScaleType = HIP_PLUGIN_RMSNORM_SCALE_TYPE;
17+
using ComputeType = HIP_PLUGIN_RMSNORM_COMPUTE_TYPE;
18+
19+
extern "C" __global__ void RMSnormBwdWeightBias(const DyType* __restrict__ dy,
20+
const XType* __restrict__ x,
21+
const ComputeType* __restrict__ rstd,
22+
ScaleType* __restrict__ dweight,
23+
ScaleType* __restrict__ dbias)
24+
{
25+
static_assert(std::is_same<ComputeType, float>::value,
26+
"ComputeType must be float for the RMSnormBwdWeightBias kernel");
27+
28+
const unsigned int tidx = threadIdx.x + blockIdx.x * LOCAL_SIZE;
29+
30+
if(tidx >= INNER_SIZE)
31+
{
32+
return;
33+
}
34+
35+
float sum_dw = 0.0f;
36+
float sum_db = 0.0f;
37+
38+
// backward weight calculation
39+
for(unsigned int o = 0; o < OUTER_SIZE; ++o)
40+
{
41+
for(unsigned int s = 0; s < STRIDE; ++s)
42+
{
43+
size_t idx = o * INNER_SIZE * STRIDE + tidx * STRIDE + s;
44+
45+
float prstd = rstd[o * STRIDE + s];
46+
float pdy = hip_kernel_provider::rmsnorm::to_float32<DyType>(dy[idx]);
47+
float px = hip_kernel_provider::rmsnorm::to_float32<XType>(x[idx]);
48+
49+
sum_dw += pdy * px * prstd;
50+
sum_db += pdy;
51+
}
52+
}
53+
54+
dweight[tidx] = hip_kernel_provider::rmsnorm::from_float32<ScaleType>(sum_dw);
55+
if(dbias)
56+
{
57+
dbias[tidx] = hip_kernel_provider::rmsnorm::from_float32<ScaleType>(sum_db);
58+
}
59+
}
60+
61+
extern "C" __global__ void RMSnormBwdData(const DyType* __restrict__ dy,
62+
const XType* __restrict__ x,
63+
const ScaleType* __restrict__ weight,
64+
const ComputeType* __restrict__ rstd,
65+
DxType* __restrict__ dx)
66+
{
67+
static_assert(std::is_same<ComputeType, float>::value,
68+
"ComputeType must be float for the RMSnormBwdData kernel");
69+
70+
const unsigned int gid = blockIdx.x;
71+
const unsigned int lid = threadIdx.x;
72+
const unsigned int o = gid / STRIDE;
73+
const unsigned int s = gid % STRIDE;
74+
75+
__shared__ float ltmp[LOCAL_SIZE];
76+
float mean = 0.0f;
77+
78+
// reduce sum
79+
for(unsigned int i = lid; i < INNER_SIZE; i += LOCAL_SIZE)
80+
{
81+
size_t idx = o * INNER_SIZE * STRIDE + i * STRIDE + s;
82+
83+
float pdy = hip_kernel_provider::rmsnorm::to_float32<DyType>(dy[idx]);
84+
float px = hip_kernel_provider::rmsnorm::to_float32<XType>(x[idx]);
85+
float pw = hip_kernel_provider::rmsnorm::to_float32<ScaleType>(weight[i]);
86+
87+
mean += pdy * pw * px;
88+
}
89+
90+
ltmp[lid] = mean;
91+
__syncthreads();
92+
93+
for(unsigned int i = LOCAL_SIZE >> 1; i > 0; i >>= 1)
94+
{
95+
if(lid < i)
96+
{
97+
ltmp[lid] += ltmp[lid + i];
98+
}
99+
__syncthreads();
100+
}
101+
102+
mean = ltmp[0] / INNER_SIZE;
103+
float prstd = rstd[gid];
104+
105+
// backward data calculation
106+
for(unsigned int i = lid; i < INNER_SIZE; i += LOCAL_SIZE)
107+
{
108+
size_t idx = o * INNER_SIZE * STRIDE + i * STRIDE + s;
109+
110+
float pdy = hip_kernel_provider::rmsnorm::to_float32<DyType>(dy[idx]);
111+
float px = hip_kernel_provider::rmsnorm::to_float32<XType>(x[idx]);
112+
float pw = hip_kernel_provider::rmsnorm::to_float32<ScaleType>(weight[i]);
113+
114+
float dx_val = (pdy * pw * prstd) - (mean * px * prstd * prstd * prstd);
115+
dx[idx] = hip_kernel_provider::rmsnorm::from_float32<DxType>(dx_val);
116+
}
117+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "Bfloat16Dev.hpp"
5+
6+
namespace hip_kernel_provider::rmsnorm
7+
{
8+
9+
template <typename T>
10+
struct Cast;
11+
12+
template <>
13+
struct Cast<float>
14+
{
15+
static __device__ __forceinline__ float to(float value)
16+
{
17+
return value;
18+
}
19+
static __device__ __forceinline__ float from(float value)
20+
{
21+
return value;
22+
}
23+
};
24+
25+
template <>
26+
struct Cast<half>
27+
{
28+
static __device__ __forceinline__ float to(half value)
29+
{
30+
return __half2float(value);
31+
}
32+
static __device__ __forceinline__ half from(float value)
33+
{
34+
return __float2half(value);
35+
}
36+
};
37+
38+
template <>
39+
struct Cast<ushort>
40+
{
41+
static __device__ __forceinline__ float to(ushort value)
42+
{
43+
return bfloat16_to_float(value);
44+
}
45+
static __device__ __forceinline__ ushort from(float value)
46+
{
47+
return float_to_bfloat16(value);
48+
}
49+
};
50+
51+
template <typename T>
52+
__device__ __forceinline__ float to_float32(T value)
53+
{
54+
return Cast<T>::to(value);
55+
}
56+
57+
template <typename T>
58+
__device__ __forceinline__ T from_float32(float value)
59+
{
60+
return Cast<T>::from(value);
61+
}
62+
63+
} // namespace hip_kernel_provider::rmsnorm

dnn-providers/hip-kernel-provider/kernels/rmsnorm/RMSNormFwd.cpp

Lines changed: 11 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,70 +3,17 @@
33

44
#include <type_traits>
55

6-
#include "Bfloat16Dev.hpp"
6+
#include "RMSNormCommon.hpp"
77

88
constexpr unsigned int LOCAL_SIZE = HIP_PLUGIN_RMSNORM_LOCAL_SIZE;
99
constexpr unsigned int INNER_SIZE = HIP_PLUGIN_RMSNORM_INNER_SIZE;
10+
constexpr unsigned int STRIDE = HIP_PLUGIN_RMSNORM_STRIDE;
1011

1112
using InputType = HIP_PLUGIN_RMSNORM_INPUT_TYPE;
1213
using OutputType = HIP_PLUGIN_RMSNORM_OUTPUT_TYPE;
1314
using ScaleType = HIP_PLUGIN_RMSNORM_SCALE_TYPE;
1415
using ComputeType = HIP_PLUGIN_RMSNORM_COMPUTE_TYPE;
1516

16-
template <typename T>
17-
struct Cast;
18-
19-
template <>
20-
struct Cast<float>
21-
{
22-
static __device__ __forceinline__ float to(float value)
23-
{
24-
return value;
25-
}
26-
static __device__ __forceinline__ float from(float value)
27-
{
28-
return value;
29-
}
30-
};
31-
32-
template <>
33-
struct Cast<half>
34-
{
35-
static __device__ __forceinline__ float to(half value)
36-
{
37-
return __half2float(value);
38-
}
39-
static __device__ __forceinline__ half from(float value)
40-
{
41-
return __float2half(value);
42-
}
43-
};
44-
45-
template <>
46-
struct Cast<ushort>
47-
{
48-
static __device__ __forceinline__ float to(ushort value)
49-
{
50-
return bfloat16_to_float(value);
51-
}
52-
static __device__ __forceinline__ ushort from(float value)
53-
{
54-
return float_to_bfloat16(value);
55-
}
56-
};
57-
58-
template <typename T>
59-
__device__ __forceinline__ float to_float32(T value)
60-
{
61-
return Cast<T>::to(value);
62-
}
63-
64-
template <typename T>
65-
__device__ __forceinline__ T from_float32(float value)
66-
{
67-
return Cast<T>::from(value);
68-
}
69-
7017
extern "C" __global__ void RMSnormFwd(const InputType* __restrict__ x,
7118
const ScaleType* __restrict__ weight,
7219
const ScaleType* __restrict__ bias,
@@ -80,15 +27,17 @@ extern "C" __global__ void RMSnormFwd(const InputType* __restrict__ x,
8027

8128
const unsigned int gid = blockIdx.x;
8229
const unsigned int lid = threadIdx.x;
30+
const unsigned int o = gid / STRIDE;
31+
const unsigned int s = gid % STRIDE;
8332

8433
float pvar = 0.0f;
8534
__shared__ float ltmp[LOCAL_SIZE];
8635

8736
// reduce sum
8837
for(unsigned int i = lid; i < INNER_SIZE; i += LOCAL_SIZE)
8938
{
90-
size_t idx = gid * INNER_SIZE + i;
91-
float tmp = to_float32<InputType>(x[idx]);
39+
size_t idx = o * INNER_SIZE * STRIDE + i * STRIDE + s;
40+
float tmp = hip_kernel_provider::rmsnorm::to_float32<InputType>(x[idx]);
9241
pvar += tmp * tmp;
9342
}
9443

@@ -114,12 +63,13 @@ extern "C" __global__ void RMSnormFwd(const InputType* __restrict__ x,
11463
// forward calculation
11564
for(unsigned int i = lid; i < INNER_SIZE; i += LOCAL_SIZE)
11665
{
117-
size_t idx = gid * INNER_SIZE + i;
118-
float y_val = to_float32<InputType>(x[idx]) * prstd * to_float32<ScaleType>(weight[i]);
66+
size_t idx = o * INNER_SIZE * STRIDE + i * STRIDE + s;
67+
float y_val = hip_kernel_provider::rmsnorm::to_float32<InputType>(x[idx]) * prstd
68+
* hip_kernel_provider::rmsnorm::to_float32<ScaleType>(weight[i]);
11969
if(bias != nullptr)
12070
{
121-
y_val += to_float32<ScaleType>(bias[i]);
71+
y_val += hip_kernel_provider::rmsnorm::to_float32<ScaleType>(bias[i]);
12272
}
123-
y[idx] = from_float32<OutputType>(y_val);
73+
y[idx] = hip_kernel_provider::rmsnorm::from_float32<OutputType>(y_val);
12474
}
12575
}

dnn-providers/hip-kernel-provider/src/engines/plans/RMSnorm/RMSnormApplicabilityChecks.cpp

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,6 @@
1414

1515
namespace hip_kernel_provider::rmsnorm
1616
{
17-
// --- Validation Utilities ---
18-
19-
void RMSnormValidator::validateSupportedLayout(const std::vector<int64_t>& strideOrder,
20-
size_t numDims)
21-
{
22-
if(numDims == 4)
23-
{
24-
const auto layoutNchw = hipdnn_data_sdk::utilities::TensorLayout::NCHW;
25-
26-
if(strideOrder != layoutNchw.strideOrder)
27-
{
28-
throw hipdnn_plugin_sdk::HipdnnPluginException(
29-
HIPDNN_PLUGIN_STATUS_BAD_PARAM,
30-
"RMSnorm implementation supports only NCHW layouts for 4D tensors.");
31-
}
32-
}
33-
else
34-
{
35-
const auto layoutNcdhw = hipdnn_data_sdk::utilities::TensorLayout::NCDHW;
36-
37-
if(strideOrder != layoutNcdhw.strideOrder)
38-
{
39-
throw hipdnn_plugin_sdk::HipdnnPluginException(
40-
HIPDNN_PLUGIN_STATUS_BAD_PARAM,
41-
"RMSnorm implementation supports only NCDHW layouts for 5D tensors.");
42-
}
43-
}
44-
}
45-
4617
// --- Component Validators ---
4718

4819
void RMSnormValidator::checkTensorLayoutsAndDimsSupported()
@@ -83,23 +54,24 @@ void RMSnormValidator::checkTensorDataTypesSupported(const std::vector<int64_t>&
8354
"BFLOAT16 data types for x and y tensors.");
8455
}
8556

86-
// Only fp32 compute type is supported for now
87-
const std::unordered_set<hipdnn_flatbuffers_sdk::data_objects::DataType> allowedComputeTypes{
88-
hipdnn_flatbuffers_sdk::data_objects::DataType::FLOAT
57+
const std::unordered_set<hipdnn_flatbuffers_sdk::data_objects::DataType> allowedAffineTypes{
58+
hipdnn_flatbuffers_sdk::data_objects::DataType::FLOAT,
59+
hipdnn_flatbuffers_sdk::data_objects::DataType::BFLOAT16,
60+
hipdnn_flatbuffers_sdk::data_objects::DataType::HALF};
8961

90-
};
9162
validateConsistentDataTypes(affineTensorIds,
92-
allowedComputeTypes,
63+
allowedAffineTypes,
9364
"RMSnorm affine tensors use unsupported data type.",
9465
"All affine tensors for RMSnorm must have the same data type.");
9566

96-
const std::unordered_set<hipdnn_flatbuffers_sdk::data_objects::DataType> allowedStatTypes{
97-
hipdnn_flatbuffers_sdk::data_objects::DataType::FLOAT,
98-
hipdnn_flatbuffers_sdk::data_objects::DataType::BFLOAT16,
99-
hipdnn_flatbuffers_sdk::data_objects::DataType::HALF};
67+
// Only fp32 compute type is supported for now
68+
const std::unordered_set<hipdnn_flatbuffers_sdk::data_objects::DataType> allowedComputeTypes{
69+
hipdnn_flatbuffers_sdk::data_objects::DataType::FLOAT
70+
71+
};
10072

10173
validateConsistentDataTypes(statTensorIds,
102-
allowedStatTypes,
74+
allowedComputeTypes,
10375
"RMSnorm stat tensors use unsupported data type.",
10476
"All stat tensors for RMSnorm must have the same data type.");
10577
}

dnn-providers/hip-kernel-provider/src/engines/plans/RMSnorm/RMSnormApplicabilityChecks.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ namespace hip_kernel_provider::rmsnorm
1313
class RMSnormValidator : public IValidator
1414
{
1515
private:
16-
void validateSupportedLayout(const std::vector<int64_t>& strideOrder, size_t numDims) override;
17-
1816
void checkTensorLayoutsAndDimsSupported() override;
1917

2018
void checkTensorDataTypesSupported(const std::vector<int64_t>& ioTensorIds,

0 commit comments

Comments
 (0)