|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# pylint: disable=missing-docstring,unused-import,import-outside-toplevel,unused-variable |
| 3 | +import unittest |
| 4 | +from backend_test_base import Tf2OnnxBackendTestBase |
| 5 | + |
| 6 | + |
| 7 | +class TestIssue2025(Tf2OnnxBackendTestBase): |
| 8 | + def test_tanhgrad(self): |
| 9 | + |
| 10 | + import tensorflow as tf |
| 11 | + import tf2onnx |
| 12 | + from tf2onnx.handler import tf_op |
| 13 | + import numpy as np |
| 14 | + |
| 15 | + @tf_op("TanhGrad") |
| 16 | + class TanhGrad: |
| 17 | + @classmethod |
| 18 | + def version_1(cls, ctx, node, **_kwargs): |
| 19 | + tanh_output = node.input[0] |
| 20 | + grad = node.input[1] |
| 21 | + square = ctx.make_node("Mul", [tanh_output, tanh_output]) |
| 22 | + one = ctx.make_const( |
| 23 | + name=node.name + "_one", np_val=np.array(1, dtype=np.float32) |
| 24 | + ) |
| 25 | + derivative = ctx.make_node("Sub", [one.output[0], square.output[0]]) |
| 26 | + result = ctx.make_node("Mul", [derivative.output[0], grad]) |
| 27 | + ctx.replace_all_inputs(node.output[0], result.output[0]) |
| 28 | + return result.output |
| 29 | + |
| 30 | + class QFGrad(tf.keras.Model): |
| 31 | + def __init__(self): |
| 32 | + super().__init__() |
| 33 | + self.output_names = ["grad"] |
| 34 | + |
| 35 | + def calc_q_grad(self, x): |
| 36 | + with tf.GradientTape() as tape: |
| 37 | + tape.watch(x) |
| 38 | + y = tf.keras.activations.tanh(tf.abs(x)) |
| 39 | + # tf.raw_ops.TanhGrad |
| 40 | + x_grad = tape.gradient(y, x) |
| 41 | + return x_grad |
| 42 | + |
| 43 | + def call(self, x): |
| 44 | + q_grad = self.calc_q_grad(x) |
| 45 | + return q_grad |
| 46 | + |
| 47 | + model = QFGrad() |
| 48 | + x = tf.random.uniform((1, 1, 6)) |
| 49 | + model(x) |
| 50 | + |
| 51 | + save_path = "test_tanhgrad.onnx" |
| 52 | + model_proto, _ = tf2onnx.convert.from_keras( |
| 53 | + model=model, |
| 54 | + input_signature=(tf.TensorSpec((1, 1, 6), dtype=tf.float32, name="x"),), |
| 55 | + opset=13, |
| 56 | + output_path=save_path, |
| 57 | + ) |
| 58 | + node_types = [n.op_type for n in model_proto.graph.node] |
| 59 | + self.assertNotIn("TanhGrad", node_types) |
| 60 | + |
| 61 | + |
| 62 | + |
| 63 | +if __name__ == "__main__": |
| 64 | + unittest.main() |
0 commit comments