@@ -122,27 +122,23 @@ end
122122@concrete struct TensorDataset
123123 dataset
124124 transform
125+ total_samples:: Int
125126end
126127
127- Base. length(ds:: TensorDataset ) = length( ds. dataset)
128+ Base. length(ds:: TensorDataset ) = ds. total_samples
128129
129130function Base. getindex(ds:: TensorDataset , idxs:: Union{Vector{<:Integer}, AbstractRange} )
130131 img = Image.(eachslice(convert2image(ds. dataset, idxs); dims= 3 ))
131132 return stack(parent ∘ itemdata ∘ Base. Fix1(apply, ds. transform), img)
132133end
133134
134135function loadmnist(batchsize, image_size:: Dims{2} )
135- # # Load MNIST: Only 1500 for demonstration purposes
136- N = parse(Bool, get(ENV , " CI" , " false" )) ? 1500 : nothing
136+ # # Load MNIST: Only 1500 for demonstration purposes on CI
137137 train_dataset = MNIST(; split= :train)
138- test_dataset = MNIST(; split= :test)
139- if N != = nothing
140- train_dataset = train_dataset[1 : N]
141- test_dataset = test_dataset[1 : N]
142- end
138+ N = parse(Bool, get(ENV , " CI" , " false" )) ? 1500 : length(train_dataset)
143139
144140 train_transform = ScaleKeepAspect(image_size) |> ImageToTensor()
145- trainset = TensorDataset(train_dataset, train_transform)
141+ trainset = TensorDataset(train_dataset, train_transform, N )
146142 trainloader = DataLoader(trainset; batchsize, shuffle= true , partial= false )
147143
148144 return trainloader
@@ -247,7 +243,7 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f
247243
248244 for (i, X) in enumerate(train_dataloader)
249245 throughput_tic = time()
250- (_, loss, stats , train_state) = Training.single_train_step!(
246+ (_, loss, _ , train_state) = Training.single_train_step!(
251247 AutoEnzyme(), loss_function, X, train_state)
252248 throughput_toc = time()
253249
0 commit comments