Skip to content

Commit 99aa29b

Browse files
committed
Adding support for Acos
1 parent 296a627 commit 99aa29b

File tree

4 files changed

+14
-0
lines changed

4 files changed

+14
-0
lines changed

src/load.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Abs}, args::VarVec, attrs::At
5959
return push_call!(tape, _abs, args[1])
6060
end
6161

62+
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Acos}, args::VarVec, attrs::AttrDict)
63+
return push_call!(tape, _acos, args[1])
64+
end
6265

6366
function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
6467
args = [tape.c.name2var[name] for name in nd.input]

src/ops.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ sub(xs...) = .-(xs...)
5050
_sin(x) = sin.(x)
5151
_cos(x) = cos.(x)
5252
_abs(x) = abs.(x)
53+
_acos(x) = acos.(x)
5354
mul(xs...) = .*(xs...)
5455
relu(x) = NNlib.relu.(x)
5556
leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a)

src/save.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_abs)}, op::Umlaut.C
126126
push!(g.node, nd)
127127
end
128128

129+
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_acos)}, op::Umlaut.Call)
130+
nd = NodeProto("Acos", op)
131+
push!(g.node, nd)
132+
end
133+
129134
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
130135
nd = NodeProto(
131136
input=[onnx_name(v) for v in reverse(op.args)],

test/saveload.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
3636
ort_test(ONNX._abs, A)
3737
end
3838

39+
@testset "Acos" begin
40+
A = rand(3, 4)
41+
ort_test(ONNX._acos, A)
42+
end
43+
3944
@testset "Gemm" begin
4045
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
4146
ort_test(ONNX.onnx_gemm, A, B')

0 commit comments

Comments
 (0)