@@ -101,25 +101,12 @@ function main(;
101101
102102 @printf " Total Trainable Parameters: %0.4f M\n " (Lux. parameterlength(ps) / 1.0e6 )
103103
104- val_loss_compiled = Reactant. with_config(;
105- dot_general_precision= PrecisionConfig. HIGH,
106- convolution_precision= PrecisionConfig. HIGH,
107- ) do
108- @compile loss_function(gcn, ps, Lux. testmode(st), (features, targets, adj, val_idx))
109- end
104+ val_loss_compiled = @compile loss_function(
105+ gcn, ps, Lux. testmode(st), (features, targets, adj, val_idx)
106+ )
110107
111- train_model_compiled = Reactant. with_config(;
112- dot_general_precision= PrecisionConfig. HIGH,
113- convolution_precision= PrecisionConfig. HIGH,
114- ) do
115- @compile gcn((features, adj, train_idx), ps, Lux. testmode(st))
116- end
117- val_model_compiled = Reactant. with_config(;
118- dot_general_precision= PrecisionConfig. HIGH,
119- convolution_precision= PrecisionConfig. HIGH,
120- ) do
121- @compile gcn((features, adj, val_idx), ps, Lux. testmode(st))
122- end
108+ train_model_compiled = @compile gcn((features, adj, train_idx), ps, Lux. testmode(st))
109+ val_model_compiled = @compile gcn((features, adj, val_idx), ps, Lux. testmode(st))
123110
124111 best_loss_val = Inf
125112 cnt = 0
@@ -177,33 +164,28 @@ function main(;
177164 end
178165 end
179166
180- Reactant. with_config(;
181- dot_general_precision= PrecisionConfig. HIGH,
182- convolution_precision= PrecisionConfig. HIGH,
183- ) do
184- test_loss = @jit(
185- loss_function(
186- gcn,
187- train_state. parameters,
188- Lux. testmode(train_state. states),
189- (features, targets, adj, test_idx),
190- )
191- )[1 ]
192- test_acc = accuracy(
193- Array(
194- @jit(
195- gcn(
196- (features, adj, test_idx),
197- train_state. parameters,
198- Lux. testmode(train_state. states),
199- )
200- )[1 ],
201- ),
202- Array(targets)[:, test_idx],
167+ test_loss = @jit(
168+ loss_function(
169+ gcn,
170+ train_state. parameters,
171+ Lux. testmode(train_state. states),
172+ (features, targets, adj, test_idx),
203173 )
174+ )[1 ]
175+ test_acc = accuracy(
176+ Array(
177+ @jit(
178+ gcn(
179+ (features, adj, test_idx),
180+ train_state. parameters,
181+ Lux. testmode(train_state. states),
182+ )
183+ )[1 ],
184+ ),
185+ Array(targets)[:, test_idx],
186+ )
204187
205- @printf " Test Loss: %.6f\t Test Acc: %.4f%%\n " test_loss test_acc
206- end
188+ @printf " Test Loss: %.6f\t Test Acc: %.4f%%\n " test_loss test_acc
207189 return nothing
208190end
209191
0 commit comments