This is the official implementation of "Improved Out-of-Distribution Detection with Additive Angular Margin Loss (CVPRW 2025)". The project leverages PyTorch along with several essential Python libraries (NumPy, PIL, torchvision, scikit-image, and scikit-learn) to implement advanced OOD detection methods and training strategies. The repository includes implementations for training with an additive angular margin loss, standard training routines, and evaluating OOD performance.
To set up the project, first create and activate a conda environment (or use your preferred environment manager) and install the required libraries:
conda create -n ood_env python=3.8
conda activate ood_env
pip install torch torchvision numpy pillow scikit-image scikit-learn
Each script can be run individually from the command line. For example, to train a model using the additive angular margin technique, you can run:
python train.py --epochs=250 --margin=0.8 --comment="--Arc Margin 1" --train_batch_size=768 --suffix=arc_512_m08_30_1
Similarly, use std_trn.py
for standard training and ood.py
for evaluating the OOD detection performance.
-
Train with Additive Angular Margin
python train.py --epochs=250 --margin=0.8 --comment="--Arc Margin 1" --train_batch_size=768 --suffix=arc_512_m08_30_1
-
Standard Training
python std_trn.py --suffix=erm_1 --layerWidth=2048 --dataset=tinyimagenet --train_batch_size=128 --test_batch_size=128 --epochs=200
-
Evaluate OOD Performance
python ood.py --outfile=./outputs/ood_msp.txt --method=msp --test_batch_size=1024
- The
methods
folder contains the core implementations of the various OOD detection approaches. - The
models
folder provides different network architectures used during training. - The
run.sh
script contains examples of how to reproduce experiment results. - All scripts can be run standalone. Ensure you provide the proper command-line arguments as shown in the examples.
├── methods/ # Contains implementations of various OOD detection
├── models/ # Contains various model architectures such as:
├── utils/ # Contains utilities for data loading, inference, and
├── ood.py # Calculates the OOD detection performance of various
├── std_trn.py # Implements standard training (ERM) of the model.
├── train.py # Trains a model using an additive angular margin
└── run.sh # (Unix) example run scripts for reproducing