Skip to content

Commit 6132a74

Browse files
authored
Adding support for Where (#117)
* Adding support for Where * Changes
1 parent dc655fb commit 6132a74

File tree

4 files changed

+21
-0
lines changed

4 files changed

+21
-0
lines changed

src/load.jl

+4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Expand}, args::VarVec, attrs:
8383
return push_call!(tape, expand, args...)
8484
end
8585

86+
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Where}, args::VarVec, attrs::AttrDict)
87+
return push_call!(tape, _where, args...)
88+
end
89+
8690
function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
8791
args = [tape.c.name2var[name] for name in nd.input]
8892
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))

src/ops.jl

+5
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ function expand(x, y)
6767
# note: order of arguments reversed due to row-major layout
6868
return shape .* x
6969
end
70+
71+
function _where(condition, x, y)
72+
return ifelse.(condition, x, y)
73+
end
74+
7075
add(xs...) = .+(xs...)
7176
sub(xs...) = .-(xs...)
7277
_sin(x) = sin.(x)

src/save.jl

+5
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(expand)}, op::Umlaut
156156
push!(g.node, nd)
157157
end
158158

159+
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_where)}, op::Umlaut.Call)
160+
nd = NodeProto("Where", op)
161+
push!(g.node, nd)
162+
end
163+
159164
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
160165
nd = NodeProto(
161166
input=[onnx_name(v) for v in reverse(op.args)],

test/saveload.jl

+7
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
113113
@test size(play!(tape, args...)) == (5, 20)
114114
end
115115

116+
@testset "Where" begin
117+
condition = rand(Bool, (1,20))
118+
A = rand(1,20)
119+
B = rand(1,20)
120+
ort_test(ONNX._where, condition, A, B)
121+
end
122+
116123
@testset "Gemm" begin
117124
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
118125
ort_test(ONNX.onnx_gemm, A, B')

0 commit comments

Comments
 (0)