Skip to content

Commit b5b9f52

Browse files
authored
Adding support for Cos (#108)
* Added support for Cos * Adding test case for Cos
1 parent ed396c2 commit b5b9f52

File tree

4 files changed

+16
-0
lines changed

4 files changed

+16
-0
lines changed

src/load.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Sin}, args::VarVec, attrs::At
5151
return push_call!(tape, _sin, args[1])
5252
end
5353

54+
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Cos}, args::VarVec, attrs::AttrDict)
55+
return push_call!(tape, _cos, args[1])
56+
end
57+
5458
function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
5559
args = [tape.c.name2var[name] for name in nd.input]
5660
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))

src/ops.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ end
4848
add(xs...) = .+(xs...)
4949
sub(xs...) = .-(xs...)
5050
_sin(x) = sin.(x)
51+
_cos(x) = cos.(x)
5152
mul(xs...) = .*(xs...)
5253
relu(x) = NNlib.relu.(x)
5354
leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a)

src/save.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_sin)}, op::Umlaut.C
116116
push!(g.node, nd)
117117
end
118118

119+
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_cos)}, op::Umlaut.Call)
120+
nd = NodeProto("Cos", op)
121+
push!(g.node, nd)
122+
end
123+
119124
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
120125
nd = NodeProto(
121126
input=[onnx_name(v) for v in reverse(op.args)],

test/saveload.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
2525
ort_test(ONNX._sin, A)
2626
end
2727

28+
@testset "Cos" begin
29+
# ONNXRunTime has no implementation for Cos(x::Float64), using Float32
30+
A = rand(Float32, 3, 4)
31+
ort_test(ONNX._cos, A)
32+
end
33+
2834
@testset "Gemm" begin
2935
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
3036
ort_test(ONNX.onnx_gemm, A, B')

0 commit comments

Comments
 (0)