Skip to content

Commit 7b3c4f9

Browse files
authored
Adding support for Equal (#118)
1 parent 6132a74 commit 7b3c4f9

File tree

4 files changed

+19
-0
lines changed

4 files changed

+19
-0
lines changed

src/load.jl

+4
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Where}, args::VarVec, attrs::
8787
return push_call!(tape, _where, args...)
8888
end
8989

90+
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Equal}, args::VarVec, attrs::AttrDict)
91+
return push_call!(tape, _equal, args...)
92+
end
93+
9094
function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
9195
args = [tape.c.name2var[name] for name in nd.input]
9296
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))

src/ops.jl

+4
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ function _where(condition, x, y)
7272
return ifelse.(condition, x, y)
7373
end
7474

75+
function _equal(x, y)
76+
return x .== y
77+
end
78+
7579
add(xs...) = .+(xs...)
7680
sub(xs...) = .-(xs...)
7781
_sin(x) = sin.(x)

src/save.jl

+5
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_where)}, op::Umlaut
161161
push!(g.node, nd)
162162
end
163163

164+
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_equal)}, op::Umlaut.Call)
165+
nd = NodeProto("Equal", op)
166+
push!(g.node, nd)
167+
end
168+
164169
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
165170
nd = NodeProto(
166171
input=[onnx_name(v) for v in reverse(op.args)],

test/saveload.jl

+6
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
120120
ort_test(ONNX._where, condition, A, B)
121121
end
122122

123+
@testset "Equal" begin
124+
A = rand(Bool, (1, 20))
125+
B = rand(Bool, (1, 20))
126+
ort_test(ONNX._equal, A, B)
127+
end
128+
123129
@testset "Gemm" begin
124130
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
125131
ort_test(ONNX.onnx_gemm, A, B')

0 commit comments

Comments
 (0)