Skip to content

Commit ed396c2

Browse files
authored
(Smaller PR) Adding support for Sin (#106)
* Adding support for Sin * Added support and relevant test for Sin * Sin testset * Sin testset * Sin test correction (v3) * I swear I removed that last commit (v4) * Update Project.toml
1 parent 9cd42b9 commit ed396c2

File tree

4 files changed

+13
-1
lines changed

4 files changed

+13
-1
lines changed

src/load.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct OpConfig{BE, Op} end
4747
const VarVec = Vector{Umlaut.Variable}
4848
const AttrDict = Dict{Symbol, Any}
4949

50+
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Sin}, args::VarVec, attrs::AttrDict)
51+
return push_call!(tape, _sin, args[1])
52+
end
5053

5154
function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
5255
args = [tape.c.name2var[name] for name in nd.input]
@@ -68,7 +71,6 @@ function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
6871
end
6972
end
7073

71-
7274
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Gemm}, args::VarVec, attrs::AttrDict)
7375
if (length(args) == 2 && get(attrs, :alpha, 1) == 1 &&
7476
get(attrs, :transA, 0) == 0 && get(attrs, :transB, 0) == 0)

src/ops.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ end
4747

4848
add(xs...) = .+(xs...)
4949
sub(xs...) = .-(xs...)
50+
_sin(x) = sin.(x)
5051
mul(xs...) = .*(xs...)
5152
relu(x) = NNlib.relu.(x)
5253
leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a)

src/save.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(getfield)}, op::Umla
111111
# Using getfield() for anything other then destructuring is thus a mistake.
112112
end
113113

114+
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_sin)}, op::Umlaut.Call)
115+
nd = NodeProto("Sin", op)
116+
push!(g.node, nd)
117+
end
114118

115119
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
116120
nd = NodeProto(

test/saveload.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
2020
ort_test(ONNX.mul, args...)
2121
end
2222

23+
@testset "Sin" begin
24+
A = rand(3, 4)
25+
ort_test(ONNX._sin, A)
26+
end
27+
2328
@testset "Gemm" begin
2429
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
2530
ort_test(ONNX.onnx_gemm, A, B')

0 commit comments

Comments
 (0)