Skip to content

Commit c3641b6

Browse files
[hannk] augment L2NormOp to allow specifying axis (#6335)
1 parent d80bb23 commit c3641b6

File tree

5 files changed

+31
-6
lines changed

5 files changed

+31
-6
lines changed

apps/hannk/delegate/hannk_delegate.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,8 @@ class HannkDelegateKernel final {
755755
OpPtr BuildL2Normalization(TfLiteContext *context, TfLiteNode *node) {
756756
auto input = GetTensorById(context, node->inputs->data[0]);
757757
auto output = GetTensorById(context, node->outputs->data[0]);
758-
return make_op<L2NormalizationOp>(input, output);
758+
const int axis = 0; // In TFLite, normalization is always against the first axis.
759+
return make_op<L2NormalizationOp>(input, output, axis);
759760
}
760761

761762
OpPtr BuildUnary(TfLiteContext *context, TfLiteNode *node, UnaryOp::Operator type) {

apps/hannk/halide/normalizations_generator.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ class L2Normalization : public Generator<L2Normalization> {
4848
.update()
4949
.atomic()
5050
.vectorize(rx, vector_size);
51+
52+
// Normally we'd expect both buffers to be planar, but in unusual
53+
// cases, Hannk can transpose the buffers (to normalize along another
54+
// dimension), so for those cases, we'll just fall back to less-efficient
55+
// code.
56+
input_.dim(0).set_stride(Expr());
57+
output_.dim(0).set_stride(Expr());
58+
output_.specialize(input_.dim(0).stride() == 1 && output_.dim(0).stride() == 1);
5159
}
5260
};
5361

apps/hannk/interpreter/ops.cpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -1137,10 +1137,23 @@ void L2NormalizationOp::execute() {
11371137
const TensorPtr &in = input();
11381138
const TensorPtr &out = output();
11391139

1140+
// Negative values for axis_ must be normalized by the parser
1141+
assert(axis_ >= 0 && axis_ < in->rank());
1142+
11401143
if (in->type() == halide_type_of<uint8_t>() &&
11411144
out->type() == halide_type_of<uint8_t>()) {
1142-
const auto &in_buf = in->buffer();
1143-
const auto &out_buf = out->buffer();
1145+
// Make local copies in case we need to transpose them
1146+
HalideBuffer<void> in_buf = in->buffer();
1147+
HalideBuffer<void> out_buf = out->buffer();
1148+
1149+
// TODO: we currently assume that the axis-is-0 case is the most common
1150+
// and most important, and optimize for it; the other cases, we just transpose,
1151+
// which currently requires less-efficient specializations in the Halide code.
1152+
// Revisit if this proves too slow in practice.
1153+
if (axis_ != 0) {
1154+
in_buf.transpose(0, axis_);
1155+
out_buf.transpose(0, axis_);
1156+
}
11441157

11451158
const int input_zero = in->quantization().uniform_zero();
11461159
assert(input_zero >= 0 && input_zero <= 255);

apps/hannk/interpreter/ops.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,11 @@ class GatherOp : public Op {
228228
};
229229

230230
class L2NormalizationOp : public Op {
231+
const int axis_;
232+
231233
public:
232-
L2NormalizationOp(const TensorPtr &input, const TensorPtr &output)
233-
: Op({input}, {output}) {
234+
L2NormalizationOp(const TensorPtr &input, const TensorPtr &output, int axis)
235+
: Op({input}, {output}), axis_(axis) {
234236
}
235237

236238
void accept(OpVisitor *v) override;

apps/hannk/tflite/tflite_parser.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ class Parser {
365365
OpPtr parse_l2_normalization(const tflite::Operator *op) {
366366
TensorPtr input = tensors_[op->inputs()->Get(0)];
367367
TensorPtr output = tensors_[op->outputs()->Get(0)];
368-
return make_op<L2NormalizationOp>(input, output);
368+
const int axis = 0; // In TFLite, normalization is always against the first axis.
369+
return make_op<L2NormalizationOp>(input, output, axis);
369370
}
370371

371372
OpPtr parse_reduction(const tflite::Operator *op, ReductionOp::Operator reduction_op) {

0 commit comments

Comments
 (0)