Skip to content

Commit 296a627

Browse files
authored
Adding support for Abs (#109)
1 parent b5b9f52 commit 296a627

File tree

4 files changed

+16
-0
lines changed

4 files changed

+16
-0
lines changed

src/load.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Cos}, args::VarVec, attrs::At
5555
return push_call!(tape, _cos, args[1])
5656
end
5757

58+
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Abs}, args::VarVec, attrs::AttrDict)
59+
return push_call!(tape, _abs, args[1])
60+
end
61+
62+
5863
function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
5964
args = [tape.c.name2var[name] for name in nd.input]
6065
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))

src/ops.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ add(xs...) = .+(xs...)
4949
sub(xs...) = .-(xs...)
5050
_sin(x) = sin.(x)
5151
_cos(x) = cos.(x)
52+
_abs(x) = abs.(x)
5253
mul(xs...) = .*(xs...)
5354
relu(x) = NNlib.relu.(x)
5455
leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a)

src/save.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_cos)}, op::Umlaut.C
121121
push!(g.node, nd)
122122
end
123123

124+
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_abs)}, op::Umlaut.Call)
125+
nd = NodeProto("Abs", op)
126+
push!(g.node, nd)
127+
end
128+
124129
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
125130
nd = NodeProto(
126131
input=[onnx_name(v) for v in reverse(op.args)],

test/saveload.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
3131
ort_test(ONNX._cos, A)
3232
end
3333

34+
@testset "Abs" begin
35+
A = rand(3, 4)
36+
ort_test(ONNX._abs, A)
37+
end
38+
3439
@testset "Gemm" begin
3540
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
3641
ort_test(ONNX.onnx_gemm, A, B')

0 commit comments

Comments
 (0)