Skip to content

Commit d5772b8

Browse files
authored
Adding support for ConstantOfShape (#121)
1 parent 3f30c4e commit d5772b8

File tree

4 files changed

+36
-0
lines changed

4 files changed

+36
-0
lines changed

src/load.jl

+5
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Neg}, args::VarVec, attrs::At
9595
return push_call!(tape, neg, args[1])
9696
end
9797

98+
function load_node!(tape::Tape, ::OpConfig{:ONNX, :ConstantOfShape}, args::VarVec, attrs::AttrDict)
99+
return push_call!(tape, makeshape, args...; attrs...)
100+
end
101+
102+
98103
function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
99104
args = [tape.c.name2var[name] for name in nd.input]
100105
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))

src/ops.jl

+4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ function _equal(x, y)
7676
return x .== y
7777
end
7878

79+
function makeshape(x; value = zeros(Float32, 1))
80+
return fill(value..., x...)
81+
end
82+
7983
add(xs...) = .+(xs...)
8084
sub(xs...) = .-(xs...)
8185
neg(x) = .-(x)

src/save.jl

+5
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(neg)}, op::Umlaut.Ca
171171
push!(g.node, nd)
172172
end
173173

174+
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(makeshape)}, op::Umlaut.Call)
175+
nd = NodeProto("ConstantOfShape", op)
176+
push!(g.node, nd)
177+
end
178+
174179
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
175180
nd = NodeProto(
176181
input=[onnx_name(v) for v in reverse(op.args)],

test/saveload.jl

+22
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,28 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
131131
ort_test(ONNX.neg, A)
132132
end
133133

134+
@testset "ConstantOfShape" begin
135+
# ort_test() checks for expected output of functions, errors on ConstantOfShape
136+
# because of array shape; manually testing!
137+
138+
# Testing expansion of shape in ConstantOfShape
139+
args = [2, 3]
140+
attrs = randn(Float32, 1)
141+
tape = Tape(ONNXCtx())
142+
inp = push!(tape, Input(args))
143+
res = push_call!(tape, ONNX.makeshape, inp; value = attrs)
144+
tape.result = res
145+
146+
# Make sure size is desired shape
147+
@test size(play!(tape, args)) == (2, 3)
148+
149+
# Make sure elements are of the desired datatype
150+
@test eltype(attrs) == eltype(play!(tape, args))
151+
152+
# Make sure the output is filled with the correct value
153+
@test attrs[1] == play!(tape, args)[1]
154+
end
155+
134156
@testset "Gemm" begin
135157
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
136158
ort_test(ONNX.onnx_gemm, A, B')

0 commit comments

Comments
 (0)