This code acompanies the paper Stability Enhanced Gaussian Process Variational Autoencoders.
- Carl R Richardson (carl.richardson@eng.ox.ac.uk)
- Jichen Zhang (jichen.zhang@eng.ox.ac.uk)
- Ethan King (ethan.king@pnnl.gov)
- Ján Drgoňa (jdrgona1@jh.edu.)
All the code is written in Python and predominantly in PyTorch. This must be installed along with several other standard libraries such as Numpy, Matplotlib, etc. For easy use, please use the same file structure as below.
The repository is organised as follows:
CodeDataGenerationDataGeneration_Particle.py: Script for generating video and trajectory data of the particle spiralling in the plane.
SEGPSEGP.py: SEGP model class.Train_GP.py: Script for training the SEGP.Evaluate_GP.py: Script for plotting figures used to evaluate GP.
SEGP_VAEVAE.py: VAE encoder and decoder classes.Train_VAE.py: Script for training the VAE.
UtilsUtils.py: Script containing miscellaneous utility functions.
Data: Directory for storing datasets.Models: Directory for storing models and associated files. Includes a pre-trained SEGP (trained directly on the trajectory data) and SEGP-VAE.
- Download dependencies.
- Add root directory to relevant scripts.
- Generate data using
DataGeneration_Particle.pycopying data parameters specified in paper. This will be stored in theDatadirectory. - Run
Train_VAE.pysetting following parameter setting from paper. Returned modelled will be stored inModels.