Skip to content

Commit a2167f2

Browse files
authored
docs: stop manual specification of precision config (#1536)
1 parent 6958227 commit a2167f2

File tree

12 files changed

+58
-139
lines changed

12 files changed

+58
-139
lines changed

examples/CIFAR10/common.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,7 @@ function train_model(
9898

9999
x_ra = rand(rng, prec_jl, size(first(trainloader)[1])) |> dev
100100
@printf "[Info] Compiling model with Reactant.jl\n"
101-
model_compiled = Reactant.with_config(;
102-
dot_general_precision=PrecisionConfig.HIGH,
103-
convolution_precision=PrecisionConfig.HIGH,
104-
) do
105-
@compile model(x_ra, ps, Lux.testmode(st))
106-
end
101+
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
107102
@printf "[Info] Model compiled!\n"
108103

109104
loss_fn = CrossEntropyLoss(; logits=Val(true))

examples/ConvolutionalVAE/main.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,19 +259,9 @@ function main(;
259259
ps, st = xdev(Lux.setup(rng, cvae))
260260
261261
z = xdev(randn(Float32, num_latent_dims, num_samples))
262-
decode_compiled = Reactant.with_config(;
263-
dot_general_precision=PrecisionConfig.HIGH,
264-
convolution_precision=PrecisionConfig.HIGH,
265-
) do
266-
@compile decode(cvae, z, ps, Lux.testmode(st))
267-
end
262+
decode_compiled = @compile decode(cvae, z, ps, Lux.testmode(st))
268263
x = xdev(randn(Float32, image_size..., 1, batchsize))
269-
cvae_compiled = Reactant.with_config(;
270-
dot_general_precision=PrecisionConfig.HIGH,
271-
convolution_precision=PrecisionConfig.HIGH,
272-
) do
273-
@compile cvae(x, ps, Lux.testmode(st))
274-
end
264+
cvae_compiled = @compile cvae(x, ps, Lux.testmode(st))
275265
276266
train_dataloader = xdev(loadmnist(batchsize, image_size))
277267

examples/GCN_Cora/main.jl

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,12 @@ function main(;
101101

102102
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
103103

104-
val_loss_compiled = Reactant.with_config(;
105-
dot_general_precision=PrecisionConfig.HIGH,
106-
convolution_precision=PrecisionConfig.HIGH,
107-
) do
108-
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
109-
end
104+
val_loss_compiled = @compile loss_function(
105+
gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx)
106+
)
110107

111-
train_model_compiled = Reactant.with_config(;
112-
dot_general_precision=PrecisionConfig.HIGH,
113-
convolution_precision=PrecisionConfig.HIGH,
114-
) do
115-
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
116-
end
117-
val_model_compiled = Reactant.with_config(;
118-
dot_general_precision=PrecisionConfig.HIGH,
119-
convolution_precision=PrecisionConfig.HIGH,
120-
) do
121-
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
122-
end
108+
train_model_compiled = @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
109+
val_model_compiled = @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
123110

124111
best_loss_val = Inf
125112
cnt = 0
@@ -177,33 +164,28 @@ function main(;
177164
end
178165
end
179166

180-
Reactant.with_config(;
181-
dot_general_precision=PrecisionConfig.HIGH,
182-
convolution_precision=PrecisionConfig.HIGH,
183-
) do
184-
test_loss = @jit(
185-
loss_function(
186-
gcn,
187-
train_state.parameters,
188-
Lux.testmode(train_state.states),
189-
(features, targets, adj, test_idx),
190-
)
191-
)[1]
192-
test_acc = accuracy(
193-
Array(
194-
@jit(
195-
gcn(
196-
(features, adj, test_idx),
197-
train_state.parameters,
198-
Lux.testmode(train_state.states),
199-
)
200-
)[1],
201-
),
202-
Array(targets)[:, test_idx],
167+
test_loss = @jit(
168+
loss_function(
169+
gcn,
170+
train_state.parameters,
171+
Lux.testmode(train_state.states),
172+
(features, targets, adj, test_idx),
203173
)
174+
)[1]
175+
test_acc = accuracy(
176+
Array(
177+
@jit(
178+
gcn(
179+
(features, adj, test_idx),
180+
train_state.parameters,
181+
Lux.testmode(train_state.states),
182+
)
183+
)[1],
184+
),
185+
Array(targets)[:, test_idx],
186+
)
204187

205-
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
206-
end
188+
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
207189
return nothing
208190
end
209191

examples/HyperNet/main.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,7 @@ function train()
120120

121121
x = first(first(dataloaders[1][1]))
122122
data_idx = ConcreteRNumber(1)
123-
model_compiled = Reactant.with_config(;
124-
dot_general_precision=PrecisionConfig.HIGH,
125-
convolution_precision=PrecisionConfig.HIGH,
126-
) do
127-
@compile model((data_idx, x), ps, Lux.testmode(st))
128-
end
123+
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
129124

130125
### Let's train the model
131126
nepochs = 50

examples/LSTMEncoderDecoder/main.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,7 @@ function train(
270270
train_state = Training.TrainState(model, ps, st, Optimisers.Adam(learning_rate))
271271

272272
stime = time()
273-
model_compiled = Reactant.with_config(;
274-
dot_general_precision=PrecisionConfig.HIGH,
275-
convolution_precision=PrecisionConfig.HIGH,
276-
) do
277-
@compile model((X_test, target_len, nothing), ps, Lux.testmode(st))
278-
end
273+
model_compiled = @compile model((X_test, target_len, nothing), ps, Lux.testmode(st))
279274
ttime = time() - stime
280275
@printf "Compilation time: %.4f seconds\n\n" ttime
281276

examples/PolynomialFitting/main.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,9 @@ tstate = main(tstate, vjp_rule, (x, y), 250)
100100

101101
# Since we are using Reactant, we need to compile the model before we can use it.
102102

103-
forward_pass = Reactant.with_config(;
104-
dot_general_precision=PrecisionConfig.HIGH,
105-
convolution_precision=PrecisionConfig.HIGH,
106-
) do
107-
@compile Lux.apply(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
108-
end
103+
forward_pass = @compile Lux.apply(
104+
tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
105+
)
109106

110107
y_pred = cdev(
111108
first(

examples/SimpleChains/main.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,7 @@ function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
8383

8484
if dev isa ReactantDevice
8585
x_ra = first(test_dataloader)[1]
86-
model_compiled = Reactant.with_config(;
87-
dot_general_precision=PrecisionConfig.HIGH,
88-
convolution_precision=PrecisionConfig.HIGH,
89-
) do
90-
@compile model(x_ra, ps, Lux.testmode(st))
91-
end
86+
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
9287
else
9388
model_compiled = model
9489
end

examples/SimpleRNN/main.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,7 @@ function main(model_type)
160160

161161
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
162162
model_compiled = if dev isa ReactantDevice
163-
Reactant.with_config(;
164-
dot_general_precision=PrecisionConfig.HIGH,
165-
convolution_precision=PrecisionConfig.HIGH,
166-
) do
167-
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
168-
end
163+
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
169164
else
170165
model
171166
end

ext/LuxReactantExt/saved_model.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
function Lux.Serialization.export_as_tf_saved_model_internal(
22
saved_model_path::String, model::AbstractLuxLayer, x, ps, st
33
)
4-
compiled_model = Reactant.with_config(;
5-
dot_general_precision=Reactant.PrecisionConfig.HIGH,
6-
convolution_precision=Reactant.PrecisionConfig.HIGH,
7-
) do
8-
@compile serializable = true model(x, ps, st)
9-
end
4+
compiled_model = @compile serializable = true model(x, ps, st)
105

116
# get the locations of the model inputs, parameters and states
127
input_locations = Union{String,Int}[]

perf/resnet/reactant.jl

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,9 @@ Comonicon.@main function main(;
3333
x = rand(Float32, 224, 224, 3, b) |> dev
3434
y = rand(Float32, 1000, b) |> dev
3535

36-
model_compiled = Reactant.with_config(;
37-
dot_general_precision=PrecisionConfig.DEFAULT,
38-
convolution_precision=PrecisionConfig.DEFAULT,
39-
) do
40-
Reactant.compile(
41-
model,
42-
(x, ps, Lux.testmode(st));
43-
sync=true,
44-
optimize=Symbol(optimize),
45-
)
46-
end
36+
model_compiled = Reactant.compile(
37+
model, (x, ps, Lux.testmode(st)); sync=true, optimize=Symbol(optimize)
38+
)
4739

4840
fwd_time = @belapsed begin
4941
$(model_compiled)($(x), $(ps), $(Lux.testmode(st)))
@@ -54,25 +46,21 @@ Comonicon.@main function main(;
5446
if b == 1
5547
bwd_time = -1.0 # batchnorm cannot support batch size 1
5648
else
57-
grad_compiled = Reactant.with_config(;
58-
dot_general_precision=PrecisionConfig.DEFAULT,
59-
convolution_precision=PrecisionConfig.DEFAULT,
60-
) do
61-
Reactant.compile(
62-
Enzyme.gradient,
63-
(
64-
Reverse,
65-
toy_loss_function,
66-
Const(model),
67-
ps,
68-
Const(st),
69-
Const(x),
70-
Const(y),
71-
);
72-
sync=true,
73-
optimize=Symbol(optimize),
74-
)
75-
end
49+
grad_compiled = Reactant.compile(
50+
Enzyme.gradient,
51+
(
52+
Reverse,
53+
toy_loss_function,
54+
Const(model),
55+
ps,
56+
Const(st),
57+
Const(x),
58+
Const(y),
59+
);
60+
sync=true,
61+
optimize=Symbol(optimize),
62+
)
63+
7664
bwd_time = @belapsed $(grad_compiled)(
7765
$Reverse,
7866
$(toy_loss_function),

0 commit comments

Comments
 (0)