Skip to content

Commit 0602655

Browse files
authored
Merge pull request #215 from FluxML/dev
For a 0.2.9 release
2 parents ac253f1 + 452c09d commit 0602655

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJFlux"
22
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>", "Ayush Shridhar <ayush.shridhar1999@gmail.com>"]
4-
version = "0.2.8"
4+
version = "0.2.9"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/core.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function train!(loss, penalty, chain, optimiser, X, y)
3636
parameters = Flux.params(chain)
3737
gs = Flux.gradient(parameters) do
3838
yhat = chain(X[i])
39-
batch_loss = loss(yhat, y[i]) + penalty(parameters)
39+
batch_loss = loss(yhat, y[i]) + penalty(parameters)/n_batches
4040
training_loss += batch_loss
4141
return batch_loss
4242
end
@@ -96,7 +96,7 @@ function fit!(loss, penalty, chain, optimiser, epochs, verbosity, X, y)
9696

9797
parameters = Flux.params(chain)
9898
losses = (loss(chain(X[i]), y[i]) +
99-
penalty(parameters) for i in 1:n_batches)
99+
penalty(parameters)/n_batches for i in 1:n_batches)
100100
history = [mean(losses),]
101101

102102
for i in 1:epochs

0 commit comments

Comments
 (0)