conda create -n latte python=3.10 -y
conda activate latte
pip install -r requirements.txt
pip install torchvision
pip install -e .MNIST, FashionMNIST, SVHN, CIFAR-10 download automatically to ./data on first use.
ImageNet must be provided manually in ./imagenet/:
imagenet/
train/<class>/*.JPEG
val/<class>/*.JPEG
Download from https://image-net.org/ after registering and arrange val/ into class folders with the official devkit.
Each experiment: train the VQ-VAE once, train the classifier(s), run LATTE, evaluate.
python train_vqvae.py --config configs/mnist_vqvae.yaml
python train_classifier.py --config configs/mnist_lenet5_single.yaml --target a
python run_latte.py --config configs/mnist_lenet5_single.yaml
python evaluate_results.py --config configs/mnist_lenet5_single.yaml \
--failures results/mnist_lenet5_single/failures_single.ptReplace lenet5 with lenet4 via configs/mnist_lenet4_single.yaml.
python train_vqvae.py --config configs/cifar10_vqvae.yaml
python train_classifier.py --config configs/cifar10_vgg16_single.yaml --target a
python run_latte.py --config configs/cifar10_vgg16_single.yaml
python evaluate_results.py --config configs/cifar10_vgg16_single.yaml \
--failures results/cifar10_vgg16_single/failures_single.ptReplace vgg16 with resnet18 via configs/cifar10_resnet18_single.yaml.
python train_vqvae.py --config configs/imagenet_vqvae.yaml
python train_classifier.py --config configs/imagenet_vgg19_single.yaml --target a
python run_latte.py --config configs/imagenet_vgg19_single.yaml
python evaluate_results.py --config configs/imagenet_vgg19_single.yaml \
--failures results/imagenet_vgg19_single/failures_single.ptReplace vgg19 with resnet50 via configs/imagenet_resnet50_single.yaml.
python train_vqvae.py --config configs/mnist_vqvae.yaml
python train_classifier.py --config configs/mnist_multi.yaml --target a
python train_classifier.py --config configs/mnist_multi.yaml --target b
python run_latte.py --config configs/mnist_multi.yaml
python evaluate_results.py --config configs/mnist_multi.yaml \
--failures results/mnist_multi/failures_multi.ptReplace mnist with fashionmnist (use configs/fashionmnist_multi.yaml and configs/fashionmnist_vqvae.yaml) or svhn (use configs/svhn_multi.yaml and configs/svhn_vqvae.yaml).