You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[ONNX](https://onnx.ai) import and export for [Flux](https://github.com/FluxML/Flux.jl).
8
8
9
9
Models are imported as [NaiveNASflux](https://github.com/DrChainsaw/NaiveNASflux.jl) graphs, meaning that things like removing/inserting layers and pruning pre-trained models is a breeze.
10
10
11
-
Model export does not require the model to have any particular format. Almost any julia function can be exported as long as the primitives are recognized by ONNXmutable.
11
+
Model export does not require the model to have any particular format. Almost any julia function can be exported as long as the primitives are recognized by ONNXNaiveNASflux.
Exporting is done using the `onnx` function which accepts a filename `String` or an `IO` as first argument:
20
20
21
21
```julia
22
22
# Save model as model.onnx where inputshapes are tuples with sizes of input.
23
-
onnx("model.onnx", model, inputshapes...)
23
+
save("model.onnx", model, inputshapes...)
24
24
25
25
# Load model as a CompGraph
26
-
graph =CompGraph("model.onnx", inputshapes...)
26
+
graph =load("model.onnx", inputshapes...)
27
27
```
28
28
Input shapes can be omitted in which case an attempt to infer the shapes will be made. If supplied, one tuple with size as the dimensions of the corresponding input array (including batch dimension) is expected.
29
29
@@ -37,7 +37,7 @@ Names can be attached to inputs by providing a `Pair` where the first element is
37
37
More elaborate example with a model defined as a plain Julia function:
38
38
39
39
```julia
40
-
usingONNXmutable, Test, Statistics
40
+
usingONNXNaiveNASflux, Test, Statistics
41
41
42
42
l1 = Conv((3,3), 2=>3, relu)
43
43
l2 = Dense(3, 4, elu)
@@ -54,20 +54,20 @@ end
54
54
io = PipeBuffer()
55
55
x_shape = (:W, :H, 2, :Batch)
56
56
y_shape = (4, :Batch)
57
-
onnx(io, f, x_shape, y_shape)
57
+
save(io, f, x_shape, y_shape)
58
58
59
59
# Deserialize as a NaiveNASflux CompGraph
60
-
g =CompGraph(io)
60
+
g =load(io)
61
61
62
62
x = ones(Float32, 5,4,2,3)
63
63
y = ones(Float32, 4, 3)
64
64
@test g(x,y) ≈ f(x,y)
65
65
66
66
# Serialization of CompGraphs does not require input shapes to be provided as they can be inferred.
67
67
io = PipeBuffer()
68
-
onnx(io, g)
68
+
save(io, g)
69
69
70
-
g =CompGraph(io)
70
+
g =load(io)
71
71
@test g(x,y) ≈ f(x,y)
72
72
```
73
73
@@ -112,7 +112,7 @@ To map the function `myfun(args::SomeType....)` to an ONNX operation one just de
112
112
This function typically looks something like this:
113
113
114
114
```julia
115
-
importONNXmutable: AbstractProbe, recursename, nextname, newfrom, add!, name
115
+
importONNXNaiveNASflux: AbstractProbe, recursename, nextname, newfrom, add!, name
116
116
function myfun(probes::AbstractProbe...)
117
117
p = probes[1] # select any probe
118
118
optype ="MyOpType"
@@ -138,7 +138,7 @@ See [serialize.jl](src/serialize/serialize.jl) for existing operations.
138
138
Deserialization is done by simply mapping operation types to functions in a dictionary. This allows for both easy extension as well as overwriting of existing mappings with own implementations:
139
139
140
140
```julia
141
-
importONNXmutable: actfuns
141
+
importONNXNaiveNASflux: actfuns
142
142
143
143
# All inputs which are not output from another node in the graph are provided in the method call
0 commit comments