Skip to content

Commit 8792097

Browse files
authored
fix: ConditionalVAE on CI (#1159)
1 parent 6341b3d commit 8792097

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

examples/ConditionalVAE/main.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,27 +122,23 @@ end
122122
@concrete struct TensorDataset
123123
dataset
124124
transform
125+
total_samples::Int
125126
end
126127

127-
Base.length(ds::TensorDataset) = length(ds.dataset)
128+
Base.length(ds::TensorDataset) = ds.total_samples
128129

129130
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange})
130131
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
131132
return stack(parent itemdata Base.Fix1(apply, ds.transform), img)
132133
end
133134

134135
function loadmnist(batchsize, image_size::Dims{2})
135-
## Load MNIST: Only 1500 for demonstration purposes
136-
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
136+
## Load MNIST: Only 1500 for demonstration purposes on CI
137137
train_dataset = MNIST(; split=:train)
138-
test_dataset = MNIST(; split=:test)
139-
if N !== nothing
140-
train_dataset = train_dataset[1:N]
141-
test_dataset = test_dataset[1:N]
142-
end
138+
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : length(train_dataset)
143139

144140
train_transform = ScaleKeepAspect(image_size) |> ImageToTensor()
145-
trainset = TensorDataset(train_dataset, train_transform)
141+
trainset = TensorDataset(train_dataset, train_transform, N)
146142
trainloader = DataLoader(trainset; batchsize, shuffle=true, partial=false)
147143

148144
return trainloader
@@ -247,7 +243,7 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f
247243
248244
for (i, X) in enumerate(train_dataloader)
249245
throughput_tic = time()
250-
(_, loss, stats, train_state) = Training.single_train_step!(
246+
(_, loss, _, train_state) = Training.single_train_step!(
251247
AutoEnzyme(), loss_function, X, train_state)
252248
throughput_toc = time()
253249

0 commit comments

Comments
 (0)