Skip to content

Commit 30ad350

Browse files
authored
Add LpNormalization-22 and update the implementation to respect ONNX spec (#27164)
I missed the operator since it didn't have the corresponding tests at the time. With onnx/onnx#7618, the disabled test should be able to pass. --- This pull request updates the ONNX Runtime CPU execution provider to add support for the `LpNormalization` operator for opset version 22, in addition to clarifying and correcting the registration for earlier versions. It also updates the backend test filters to reflect this new support. **ONNX Operator Kernel Registration:** * Added new kernel registrations for `LpNormalization` with opset version 22 for both `float` and `double` data types in `cpu_execution_provider.cc`. [[1]](diffhunk://#diff-054ffdd679ada14ebb4b1db27a60b2881e2db48f9dc3f0b948c784cdcdaf4908R1328-R1329) [[2]](diffhunk://#diff-054ffdd679ada14ebb4b1db27a60b2881e2db48f9dc3f0b948c784cdcdaf4908R3389-R3392) * Updated the registration for `LpNormalization` for opset versions 1 through 21 to use the correct versioned kernel macro, ensuring correct kernel selection and compatibility. [[1]](diffhunk://#diff-054ffdd679ada14ebb4b1db27a60b2881e2db48f9dc3f0b948c784cdcdaf4908L197-R198) [[2]](diffhunk://#diff-054ffdd679ada14ebb4b1db27a60b2881e2db48f9dc3f0b948c784cdcdaf4908L1731-R1735) **Test Filters Update:** * Updated `onnx_backend_test_series_filters.jsonc` to remove the exclusion of `test_l1normalization*`, `test_lpnormalization*`, and `test_l2normalization*` now that `LpNormalization` opset 22 is implemented, and added a TODO comment referencing ONNX 1.21 for a known zero-norm issue. [[1]](diffhunk://#diff-abc0f78c2314f9e7648c8081125d0ce9f33b12399520d92d811d73e3c795ed59R32-R33) [[2]](diffhunk://#diff-abc0f78c2314f9e7648c8081125d0ce9f33b12399520d92d811d73e3c795ed59L42) [[3]](diffhunk://#diff-abc0f78c2314f9e7648c8081125d0ce9f33b12399520d92d811d73e3c795ed59L70-L71)
1 parent f83d4d0 commit 30ad350

File tree

4 files changed

+26
-12
lines changed

4 files changed

+26
-12
lines changed

docs/OperatorKernels.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ Do not modify directly.*
240240
|||[13, 15]|**B** = tensor(bool)<br/> **I** = tensor(int64)<br/> **V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
241241
|||[11, 12]|**B** = tensor(bool)<br/> **I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
242242
|||[1, 10]|**B** = tensor(bool)<br/> **I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
243-
|LpNormalization|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float)|
243+
|LpNormalization|*in* input:**T**<br> *out* output:**T**|22+|**T** = tensor(double), tensor(float)|
244+
|||[1, 21]|**T** = tensor(double), tensor(float)|
244245
|LpPool|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(float)|
245246
|||[18, 21]|**T** = tensor(float)|
246247
|||[11, 17]|**T** = tensor(float)|

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
194194
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose);
195195
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, Flatten);
196196
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, InstanceNormalization);
197-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, LpNormalization);
198-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LpNormalization);
197+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 21, float, LpNormalization);
198+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 21, double, LpNormalization);
199199
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 12, LRN);
200200
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, AveragePool);
201201
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 7, MaxPool);
@@ -1325,6 +1325,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
13251325
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Softsign);
13261326
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, ThresholdedRelu);
13271327
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, AveragePool);
1328+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float, LpNormalization);
1329+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, double, LpNormalization);
13281330

13291331
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
13301332
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, MLFloat16, Conv);
@@ -1728,10 +1730,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
17281730
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
17291731
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21,
17301732
InstanceNormalization)>,
1731-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float,
1732-
LpNormalization)>,
1733-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double,
1734-
LpNormalization)>,
1733+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 21, float,
1734+
LpNormalization)>,
1735+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 21, double,
1736+
LpNormalization)>,
17351737
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 12, LRN)>,
17361738
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9,
17371739
AveragePool)>,
@@ -3384,6 +3386,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
33843386
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Elu)>,
33853387
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, HardSigmoid)>,
33863388
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, InstanceNormalization)>,
3389+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float,
3390+
LpNormalization)>,
3391+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, double,
3392+
LpNormalization)>,
33873393
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, LpPool)>,
33883394
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, MaxPool)>,
33893395
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, MaxUnpool)>,

onnxruntime/core/providers/cpu/nn/lp_norm.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,22 @@
77
#include "core/providers/common.h"
88

99
namespace onnxruntime {
10+
#define REGISTER_LPNORMALISATION_VERSIONED_KERNEL(type, sinceVersion, endVersion) \
11+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
12+
LpNormalization, sinceVersion, endVersion, type, \
13+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<type>()), \
14+
LpNorm<type>);
15+
1016
#define REGISTER_LPNORMALISATION_KERNEL(type, sinceVersion) \
1117
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
1218
LpNormalization, sinceVersion, type, \
1319
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<type>()), \
1420
LpNorm<type>);
1521

16-
REGISTER_LPNORMALISATION_KERNEL(float, 1)
17-
REGISTER_LPNORMALISATION_KERNEL(double, 1)
22+
REGISTER_LPNORMALISATION_VERSIONED_KERNEL(float, 1, 21)
23+
REGISTER_LPNORMALISATION_VERSIONED_KERNEL(double, 1, 21)
24+
REGISTER_LPNORMALISATION_KERNEL(float, 22)
25+
REGISTER_LPNORMALISATION_KERNEL(double, 22)
1826

1927
using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
2028

onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
// Tests that are failing temporarily and should be fixed
3131
"current_failing_tests": [
32+
// TODO(titaiwang): onnx 1.21 should fix lpnorm zero norm issue
33+
"^test_l2normalization*", // LpNormalization(22) not implemented
3234
"^test_adagrad",
3335
"^test_adagrad_multiple",
3436
"^test_attention_4d_fp16*", // precision issue: 1 / 192 mismatched elements
@@ -39,7 +41,6 @@
3941
"^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal*", // location of infinities
4042
"^test_attention_4d_attn_mask_3d_causal_expanded*", // webgpu
4143
"^test_attention_4d_diff_heads_mask4d_padded_kv*", // Need nonpad_kv_seqlen
42-
"^test_l2normalization*", // LpNormalization(22) not implemented
4344
// TODO: support the following tests in Attention-cuda
4445
"^test_attention_3d_gqa.*_cuda", // GQA not supported in Attention-cuda
4546
"^test_attention_4d_gqa.*_cuda", // GQA not supported in Attention-cuda
@@ -67,8 +68,6 @@
6768
"^test_attention_4d_attn_mask_4d_causal_cuda",
6869
"^test_attention_4d_causal_cuda",
6970
"^test_attention_4d_diff_heads_sizes_causal_cuda",
70-
"^test_l1normalization*", // LpNormalization(22) not implemented
71-
"^test_lpnormalization*", // LpNormalization(22) not implemented
7271
"^test_tensorscatter*", // TensorScatter(24) not implemented
7372
"^test_castlike_no_saturate_FLOAT_to_FLOAT8*", // ORT does not support ml_dtypes
7473
"^test_castlike_UINT4_to*", // ORT does not support ml_dtypes

0 commit comments

Comments
 (0)