Skip to content

Commit ae180e7

Browse files
committed
Adding support for Acosh
1 parent 4abb65e commit ae180e7

File tree

4 files changed

+15
-0
lines changed

4 files changed

+15
-0
lines changed

src/load.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Acos}, args::VarVec, attrs::A
6363
return push_call!(tape, _acos, args[1])
6464
end
6565

66+
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Acosh}, args::VarVec, attrs::AttrDict)
67+
return push_call!(tape, _acos, args[1])
68+
end
69+
6670
function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
6771
args = [tape.c.name2var[name] for name in nd.input]
6872
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))

src/ops.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ _sin(x) = sin.(x)
5151
_cos(x) = cos.(x)
5252
_abs(x) = abs.(x)
5353
_acos(x) = acos.(x)
54+
_acosh(x) = acosh.(x)
5455
mul(xs...) = .*(xs...)
5556
relu(x) = NNlib.relu.(x)
5657
leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a)

src/save.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_acos)}, op::Umlaut.
131131
push!(g.node, nd)
132132
end
133133

134+
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_acosh)}, op::Umlaut.Call)
135+
nd = NodeProto("Acosh", op)
136+
push!(g.node, nd)
137+
end
138+
134139
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
135140
nd = NodeProto(
136141
input=[onnx_name(v) for v in reverse(op.args)],

test/saveload.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
4242
ort_test(ONNX._acos, A)
4343
end
4444

45+
@testset "Acosh" begin
46+
A = rand(3, 4)
47+
ort_test(ONNX._acosh, A)
48+
end
49+
4550
@testset "Gemm" begin
4651
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
4752
ort_test(ONNX.onnx_gemm, A, B')

0 commit comments

Comments
 (0)