Skip to content

Commit 23efe1b

Browse files
committed
address spec in lpnorm
1 parent 78925bf commit 23efe1b

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

onnx/backend/test/case/node/lpnormalization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def export_l2normalization_axis_0() -> None:
2121
dtype=np.float32,
2222
)
2323
l2_norm_axis_0 = np.sqrt(np.sum(x**2, axis=0, keepdims=True))
24-
y = x / l2_norm_axis_0
24+
# When norm is 0, output is 0 (0/0 = 0)
25+
y = np.where(l2_norm_axis_0 == 0, 0, x / l2_norm_axis_0)
2526
expect(node, inputs=[x], outputs=[y], name="test_l2normalization_axis_0")
2627

2728
@staticmethod

onnx/defs/doc_strings.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,9 @@ is applied to the tensor elementwise.
609609

610610
const char kDoc_LpNormalization_ver1[] = R"DOC(
611611
Given a matrix, apply Lp-normalization along the provided axis.
612+
The output is computed as: `output = input / Lp_norm(input, axis)`.
613+
When the Lp norm is zero (i.e., all elements along the axis are zero),
614+
the output is defined to be zero to avoid division by zero.
612615
)DOC";
613616

614617
const char kDoc_Erf_ver9[] = R"DOC(

onnx/reference/ops/op_lp_normalization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ def _run(self, x, axis=None, p=None):
1414
p = p or self.p
1515
norm = np.power(np.power(x, p).sum(axis=axis), 1.0 / p)
1616
norm = np.expand_dims(norm, axis)
17-
return ((x / norm).astype(x.dtype),)
17+
# When norm is 0, return 0 instead of NaN (0/0 = 0)
18+
result = np.where(norm == 0, 0, x / norm)
19+
return (result.astype(x.dtype),)

0 commit comments

Comments
 (0)