Skip to content

Commit 51b4ce9

Browse files
authored
Merge pull request #52 from DrChainsaw/renamepackage
Rename package to ONNXNaiveNASflux
2 parents 81d84fa + 5f1c289 commit 51b4ce9

File tree

16 files changed

+147
-151
lines changed

16 files changed

+147
-151
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
name = "ONNXmutable"
2-
uuid = "cf2a63a0-f8ae-421c-82b7-306ecfceaf66"
1+
name = "ONNXNaiveNASflux"
2+
uuid = "2e935253-ba83-4645-9154-13ffeb13a688"
33
authors = ["DrChainsaw"]
44
version = "0.1.0"
55

README.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
1-
# ONNXmutable
1+
# ONNXNaiveNASflux
22

3-
[![Build status](https://github.com/DrChainsaw/ONNXmutable.jl/workflows/CI/badge.svg?branch=master)](https://github.com/DrChainsaw/ONNXmutable.jl/actions)
4-
[![Build Status](https://ci.appveyor.com/api/projects/status/github/DrChainsaw/ONNXmutable.jl?svg=true)](https://ci.appveyor.com/project/DrChainsaw/ONNXmutable-jl)
5-
[![Codecov](https://codecov.io/gh/DrChainsaw/ONNXmutable.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/DrChainsaw/ONNXmutable.jl)
3+
[![Build status](https://github.com/DrChainsaw/ONNXNaiveNASflux.jl/workflows/CI/badge.svg?branch=master)](https://github.com/DrChainsaw/ONNXNaiveNASflux.jl/actions)
4+
[![Build Status](https://ci.appveyor.com/api/projects/status/github/DrChainsaw/ONNXNaiveNASflux.jl?svg=true)](https://ci.appveyor.com/project/DrChainsaw/ONNXNaiveNASflux-jl)
5+
[![Codecov](https://codecov.io/gh/DrChainsaw/ONNXNaiveNASflux.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/DrChainsaw/ONNXNaiveNASflux.jl)
66

77
[ONNX](https://onnx.ai) import and export for [Flux](https://github.com/FluxML/Flux.jl).
88

99
Models are imported as [NaiveNASflux](https://github.com/DrChainsaw/NaiveNASflux.jl) graphs, meaning that things like removing/inserting layers and pruning pre-trained models is a breeze.
1010

11-
Model export does not require the model to have any particular format. Almost any julia function can be exported as long as the primitives are recognized by ONNXmutable.
11+
Model export does not require the model to have any particular format. Almost any julia function can be exported as long as the primitives are recognized by ONNXNaiveNASflux.
1212

1313
## Basic usage
1414

1515
```julia
16-
Pkg.add(url="https://github.com/DrChainsaw/ONNXmutable.jl")
16+
Pkg.add(url="https://github.com/DrChainsaw/ONNXNaiveNASflux.jl")
1717
```
1818

1919
Exporting is done using the `onnx` function which accepts a filename `String` or an `IO` as first argument:
2020

2121
```julia
2222
# Save model as model.onnx where inputshapes are tuples with sizes of input.
23-
onnx("model.onnx", model, inputshapes...)
23+
save("model.onnx", model, inputshapes...)
2424

2525
# Load model as a CompGraph
26-
graph = CompGraph("model.onnx", inputshapes...)
26+
graph = load("model.onnx", inputshapes...)
2727
```
2828
Input shapes can be omitted in which case an attempt to infer the shapes will be made. If supplied, one tuple with size as the dimensions of the corresponding input array (including batch dimension) is expected.
2929

@@ -37,7 +37,7 @@ Names can be attached to inputs by providing a `Pair` where the first element is
3737
More elaborate example with a model defined as a plain Julia function:
3838

3939
```julia
40-
using ONNXmutable, Test, Statistics
40+
using ONNXNaiveNASflux, Test, Statistics
4141

4242
l1 = Conv((3,3), 2=>3, relu)
4343
l2 = Dense(3, 4, elu)
@@ -54,20 +54,20 @@ end
5454
io = PipeBuffer()
5555
x_shape = (:W, :H, 2, :Batch)
5656
y_shape = (4, :Batch)
57-
onnx(io, f, x_shape, y_shape)
57+
save(io, f, x_shape, y_shape)
5858

5959
# Deserialize as a NaiveNASflux CompGraph
60-
g = CompGraph(io)
60+
g = load(io)
6161

6262
x = ones(Float32, 5,4,2,3)
6363
y = ones(Float32, 4, 3)
6464
@test g(x,y) f(x,y)
6565

6666
# Serialization of CompGraphs does not require input shapes to be provided as they can be inferred.
6767
io = PipeBuffer()
68-
onnx(io, g)
68+
save(io, g)
6969

70-
g = CompGraph(io)
70+
g = load(io)
7171
@test g(x,y) f(x,y)
7272
```
7373

@@ -112,7 +112,7 @@ To map the function `myfun(args::SomeType....)` to an ONNX operation one just de
112112
This function typically looks something like this:
113113

114114
```julia
115-
import ONNXmutable: AbstractProbe, recursename, nextname, newfrom, add!, name
115+
import ONNXNaiveNASflux: AbstractProbe, recursename, nextname, newfrom, add!, name
116116
function myfun(probes::AbstractProbe...)
117117
p = probes[1] # select any probe
118118
optype = "MyOpType"
@@ -138,7 +138,7 @@ See [serialize.jl](src/serialize/serialize.jl) for existing operations.
138138
Deserialization is done by simply mapping operation types to functions in a dictionary. This allows for both easy extension as well as overwriting of existing mappings with own implementations:
139139

140140
```julia
141-
import ONNXmutable: actfuns
141+
import ONNXNaiveNASflux: actfuns
142142

143143
# All inputs which are not output from another node in the graph are provided in the method call
144144
actfuns[:SomeOp] = (params, α, β) -> x -> x^α + β
@@ -147,9 +147,9 @@ actfuns[:AnotherOp] = function(params)
147147
α = get(params, :alpha, 1)
148148
return x -> α / x
149149
end
150-
ONNXmutable.refresh()
150+
ONNXNaiveNASflux.refresh()
151151
```
152-
Note: After adding/changing an operation mapping one needs to call `ONNXmutable.refresh()` for it to take effect.
152+
Note: After adding/changing an operation mapping one needs to call `ONNXNaiveNASflux.refresh()` for it to take effect.
153153
See [ops.jl](src/deserialize/ops.jl) for existing operations.
154154

155155

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module ONNXmutable
1+
module ONNXNaiveNASflux
22

33
include("baseonnx/BaseOnnx.jl")
44

@@ -13,7 +13,7 @@ import Pkg
1313
import JuMP: @variable, @constraint
1414
import NaiveNASflux.NaiveNASlib: compconstraint!, all_in_Δsize_graph
1515

16-
export onnx, CompGraph
16+
export load, save
1717

1818
include("shapes.jl")
1919
include("validate.jl")

src/deserialize/deserialize.jl

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,10 @@ NaiveNASlib.name(n::ONNX.NodeProto) = n.name
1515
NaiveNASlib.name(vip::ONNX.ValueInfoProto) = vip.name
1616
NaiveNASlib.name(tp::ONNX.TensorProto) = tp.name
1717

18-
"""
19-
CompGraph(filename::String, insizes...)
20-
21-
Return a [`CompGraph`](@ref) loaded from the given file.
22-
23-
Argument insizes and be either size tuples or name => tuple pairs indicating the size for each model input.
24-
"""
25-
NaiveNASlib.CompGraph(filename::String, insizes...; vfun = create_vertex_default) = open(io -> CompGraph(io, insizes...; vfun), filename)
26-
NaiveNASlib.CompGraph(io::IO, insizes...; vfun = create_vertex_default) = CompGraph(extract(io), insizes...; vfun)
27-
NaiveNASlib.CompGraph(m::ONNX.ModelProto, insizes...; vfun = create_vertex_default) = CompGraph(m.graph, insizes...; vfun)
18+
load(filename::String, insizes...; vfun = create_vertex_default) = open(io -> load(io, insizes...; vfun), filename)
19+
load(io::IO, insizes...; vfun = create_vertex_default) = load(extract(io), insizes...; vfun)
20+
load(m::ONNX.ModelProto, insizes...; vfun = create_vertex_default) = load(m.graph, insizes...; vfun)
21+
load(g::ONNX.GraphProto, insizes...; vfun = create_vertex_default) = CompGraph(g, insizes...; vfun)
2822
NaiveNASlib.CompGraph(g::ONNX.GraphProto, insizes...; vfun = create_vertex_default) = CompGraph(CompGraphBuilder(g, insizes...); vfun)
2923
function NaiveNASlib.CompGraph(gb::CompGraphBuilder; vfun = create_vertex_default)
3024
outputs::Vector{AbstractVertex} = vertex.(gb, node.(name.(gb.g.output), gb), vfun)

src/serialize/serialize.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11

22
"""
3-
onnx(filename::AbstractString, f, args...; kwargs...)
4-
onnx(io::IO, f, args...; kwargs...)
3+
save(filename::AbstractString, f, args...; kwargs...)
4+
save(io::IO, f, args...; kwargs...)
55
66
Serialize the result of `modelproto(f, args...; kwargs...)` to a file with path `filename` or to `io`.
77
88
See [`modelproto`](@ref) for description of arguments.
99
"""
10-
onnx(filename::AbstractString, f, args...; modelname=filename, kwargs...) = onnx(filename, modelproto(f, args...; modelname=modelname, kwargs...))
11-
onnx(io::IO, f, args...; kwargs...) = onnx(io, modelproto(f, args...; kwargs...))
10+
save(filename::AbstractString, f, args...; modelname=filename, kwargs...) = save(filename, modelproto(f, args...; modelname=modelname, kwargs...))
11+
save(io::IO, f, args...; kwargs...) = save(io, modelproto(f, args...; kwargs...))
1212

1313
"""
14-
onnx(filename::AbstractString, mp::ONNX.ModelProto)
15-
onnx(io::IO, mp::ONNX.ModelProto)
14+
save(filename::AbstractString, mp::ONNX.ModelProto)
15+
save(io::IO, mp::ONNX.ModelProto)
1616
1717
Serialize the given [`ONNX.ModelProto`](@ref) to a file with path `filename` or to `io`.
1818
"""
19-
onnx(filename::AbstractString, mp::ONNX.ModelProto) = open(io -> onnx(io, mp), filename, "w")
20-
onnx(io::IO, mp::ONNX.ModelProto) = ONNX.writeproto(io, mp)
19+
save(filename::AbstractString, mp::ONNX.ModelProto) = open(io -> save(io, mp), filename, "w")
20+
save(io::IO, mp::ONNX.ModelProto) = ONNX.writeproto(io, mp)
2121

2222

2323
"""
@@ -92,7 +92,7 @@ infer_shape(::Type{<:AbstractArray{T,N}}) where {T,N} = ntuple(i -> missing, N)
9292
modelproto(;kwargs...) = ONNX.ModelProto(;
9393
ir_version=6,
9494
opset_import=[ONNX.OperatorSetIdProto(version=11)],
95-
producer_name="ONNXmutable.jl",
95+
producer_name="ONNXNaiveNASflux.jl",
9696
producer_version=string(Pkg.Types.Context().env.project.version), # TODO: Ugh....
9797
kwargs...)
9898

@@ -429,7 +429,7 @@ argpermswith(t, n::Integer, args...) = (a for a in argpermutations(n, t, args...
429429

430430
function gen_broadcastable_elemwise(f, optype, n=2)
431431
fs = Symbol(f)
432-
fm = which(ONNXmutable, fs)
432+
fm = which(ONNXNaiveNASflux, fs)
433433
generate_elemwise(fm, fs, optype, argpermswith(AbstractProbe, n, nothing))
434434
override_broadcast(f, argpermswith(Base.RefValue{<:AbstractProbe}, n, AbstractArray))
435435
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
22
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
3+
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
4+
ONNXNaiveNASflux = "2e935253-ba83-4645-9154-13ffeb13a688"
35
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
46
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
57
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

test/baseonnx/readwrite.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "Read and write" begin
2-
import ONNXmutable.BaseOnnx
2+
import ONNXNaiveNASflux.BaseOnnx
33

44
function serdeser(p::T) where T
55
iob = PipeBuffer();

test/deserialize/constraints.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11

22
@testset "Constraints" begin
3-
using ONNXmutable.NaiveNASflux
4-
import ONNXmutable: SizePseudoTransparent
3+
using ONNXNaiveNASflux.NaiveNASflux
4+
import ONNXNaiveNASflux: SizePseudoTransparent
55

66
dv(name, invertex, outsize) = mutable(name, Dense(nout(invertex), outsize), invertex)
77
cv(name, invertex, outsize) = mutable(name, Conv((3,3), nout(invertex) => outsize, pad=(1,1)), invertex)
88

99
@testset "Reshape" begin
10-
import ONNXmutable: Reshape
10+
import ONNXNaiveNASflux: Reshape
1111

1212
rv(name, invertex, outsize, dims) = absorbvertex(Reshape(dims), outsize, invertex; traitdecoration=t -> NamedTrait(SizePseudoTransparent(t), name))
1313

@@ -203,7 +203,7 @@
203203

204204

205205
@testset "Flatten" begin
206-
import ONNXmutable: Flatten
206+
import ONNXNaiveNASflux: Flatten
207207
fv(name, invertex, outsize, dim) = absorbvertex(Flatten(dim), outsize, invertex; traitdecoration=t -> NamedTrait(SizePseudoTransparent(t), name))
208208

209209
function tg(outsize, dim)

test/deserialize/deserialize.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import ONNXmutable: fluxlayers, sources, actfuns, invariantops, pseudotransparentops, optype, nodes
2-
using ONNXmutable.NaiveNASflux
1+
import ONNXNaiveNASflux: fluxlayers, sources, actfuns, invariantops, pseudotransparentops, optype, nodes
2+
using ONNXNaiveNASflux.NaiveNASflux
33

44
# Logging to avoid travis timeouts
55
@info " Test padding and sources"
66

77
@testset "Read padding" begin
8-
import ONNXmutable: prev
8+
import ONNXNaiveNASflux: prev
99

1010
@test prev(2) == 2
1111
@test prev([1,2]) == [1,2]
@@ -29,15 +29,15 @@ end
2929
end
3030

3131
@testset "$(tc.name) graph" begin
32-
cg = CompGraph(model)
32+
cg = load(model)
3333
res = cg()
3434
@test size(res) == size(outputs[1])
3535
@test res outputs[1]
3636

3737
# Also test that it we get the same thing by serializing and then deserializing
3838
io = PipeBuffer()
39-
onnx(io, cg)
40-
cg = CompGraph(io)
39+
save(io, cg)
40+
cg = load(io)
4141
res = cg()
4242
@test size(res) == size(outputs[1])
4343
@test res outputs[1]
@@ -116,15 +116,15 @@ end
116116
end
117117

118118
@testset "$(tc.name) graph" begin
119-
cg = CompGraph(model)
119+
cg = load(model)
120120
res = cg(Float32.(inputs[1]))
121121
@test size(res) == size(outputs[1])
122122
@test res outputs[1]
123123

124124
# Also test that it we get the same thing by serializing and then deserializing
125125
io = PipeBuffer()
126-
onnx(io, cg)
127-
cg = CompGraph(io)
126+
save(io, cg)
127+
cg = load(io)
128128
res = cg(Float32.(inputs[1]))
129129
@test size(res) == size(outputs[1])
130130
@test res outputs[1]
@@ -213,15 +213,15 @@ end
213213
end
214214

215215
@testset "$(tc.name) graph" begin
216-
cg = CompGraph(model)
216+
cg = load(model)
217217
res = cg(inputs[1])
218218
@test size(res) == size(outputs[1])
219219
@test res outputs[1]
220220

221221
# Also test that it we get the same thing by serializing and then deserializing
222222
io = PipeBuffer()
223-
onnx(io, cg)
224-
cg = CompGraph(io)
223+
save(io, cg)
224+
cg = load(io)
225225
res = cg(inputs[1])
226226
@test size(res) == size(outputs[1])
227227
@test res outputs[1]
@@ -250,15 +250,15 @@ end
250250
model, gb, inputs, outputs = prepare_node_test(tc.name, tc.ninputs, tc.noutputs)
251251

252252
@testset "$(tc.name) graph" begin
253-
cg = CompGraph(model)
253+
cg = load(model)
254254
res = cg(inputs[1:length(cg.inputs)]...)
255255
@test size(res) == size(outputs[1])
256256
@test res outputs[1]
257257

258258
# Also test that it we get the same thing by serializing and then deserializing
259259
io = PipeBuffer()
260-
onnx(io, cg)
261-
cg = CompGraph(io)
260+
save(io, cg)
261+
cg = load(io)
262262
res = cg(inputs[1:length(cg.inputs)]...)
263263
@test size(res) == size(outputs[1])
264264
@test res outputs[1]
@@ -271,29 +271,29 @@ end
271271
ivs = inputvertex.(["in1", "in2"], 4, Ref(FluxDense()))
272272
g_org = CompGraph(ivs, "out" >> ivs[1] + ivs[2])
273273
pb = PipeBuffer()
274-
onnx(pb, g_org, "in1" => missing, "in2" => missing)
274+
save(pb, g_org, "in1" => missing, "in2" => missing)
275275
return pb
276276
end
277277

278-
insize(t::Tuple) = ONNXmutable.int_size(t[NaiveNASflux.actdim(length(t))])
278+
insize(t::Tuple) = ONNXNaiveNASflux.int_size(t[NaiveNASflux.actdim(length(t))])
279279
insize(p::Pair) = p |> last |> insize
280280
@testset "Input format $inshapes" for inshapes in (
281281
((4,1), (4,1)),
282282
("in1" => (4,1), "in2" => (4,1)),
283283
((4,missing), (4, :B)),
284284
((:I, 3), (:I, 4))
285285
)
286-
g_new = CompGraph(sumgraph(), inshapes...)
286+
g_new = load(sumgraph(), inshapes...)
287287
@test nout.(g_new.inputs) == insize.(inshapes) |> collect
288288
@test layertype.(g_new.inputs) == [FluxDense(), FluxDense()]
289289
end
290290

291-
inshape(t::Tuple) = t |> length |> ONNXmutable.guess_layertype
291+
inshape(t::Tuple) = t |> length |> ONNXNaiveNASflux.guess_layertype
292292
@testset "Mixshape format $inshapes" for inshapes in (
293293
((1,1,5,1), (5,1)),
294294
((5,1), (1,1,5,1)),
295295
)
296-
g_new = CompGraph(sumgraph(), inshapes...)
296+
g_new = load(sumgraph(), inshapes...)
297297
@test nout.(g_new.inputs) == [5, 5]
298298
@test layertype.(g_new.inputs) == inshape.(inshapes |> collect)
299299
end
@@ -303,15 +303,15 @@ end
303303
("in1" => (4,1), "in2" => (4,1), "in2" => (4,1)),
304304
("in1" => (4,1), "notin2" => (4,1))
305305
)
306-
@test_throws AssertionError CompGraph(sumgraph(), inshapes...)
306+
@test_throws AssertionError load(sumgraph(), inshapes...)
307307
end
308308
end
309309

310310
@testset "Deserialize with merging" begin
311311
function remodel(f, args...)
312312
pb = PipeBuffer()
313-
onnx(pb, f, args...)
314-
return CompGraph(pb, args...)
313+
save(pb, f, args...)
314+
return load(pb, args...)
315315
end
316316

317317

@@ -322,8 +322,8 @@ end
322322
end
323323

324324
@testset "Merge Reshape and $gp" for gp in (
325-
ONNXmutable.globalmeanpool,
326-
ONNXmutable.globalmaxpool
325+
ONNXNaiveNASflux.globalmeanpool,
326+
ONNXNaiveNASflux.globalmaxpool
327327
)
328328
m = remodel(Chain(
329329
Conv((3,3), 3 => 3),

0 commit comments

Comments
 (0)