This repository includes the flowtrain package for stochastic interpolation and managing machine learning models. The source code is located in src/flowtrain.
Make sure that Python 3.12 is installed first through an environment manager, e.g. conda create -n flowtrain-test python=3.12 -y.To install the package in editable mode, navigate to the project root (where setup.py is located) and run:
pip install -e .This package also installs a dependency for synthetic geological data generation:
StructuralGeo, which will be installed automatically via pip.
The project/ directory contains code and supporting files for training and evaluating flow-based models on 3D StructuralGeo data. These models are designed for stochastic interpolation using flow-matching techniques.
The codebase is built on:
- StructuralGeo for synthetic geological data generation GitHub
- PyTorch for deep learning
- PyTorch Lightning for cleaner training loops
- Weights & Biases (wandb) for experiment tracking
Pretrained models for both unconditional and conditional generation at 64³ resolution are available. These weights are downloaded automatically if not found on tge first use and stored in the project/*/demo_model/ directory.
If desired, the weights can also be downloaded manually from the v1.0.0 GitHub release:
conditional-weights.ckptunconditional-weights.ckptwith training run WandB
- Base channels: 48
- Channel multipliers: (1, 2, 2, 3, 4)
- Time embeddings: Learned Fourier (1024 dim, bandwidth 1000)
- Attention: Enabled at all scales, 4 heads, dim_head = 32
- Conditioning: ATb embedding with ATb mixing at every resolution
- Training: LR = 1e-3, EMA = 0.9995, t ∈ [1e-4, 0.9999], batch=8
- Training:
project/geodata-3d-unconditional/train_unconditional.pyTraining parameters can be edited via theget_config()function in the script, currently set to values used in training the saved demo model. To train on multiple GPUs, use the--train-devicesflag.
cd project/geodata-3d-unconditional
python model_train_inference.py --mode train --train-devices 0,1- Inference demo: Use the
main()function in the same script to run inference with pretrained weights. Optional flags include:--mode: Set to train, inference, or both (default: inference)--n-samples: Number of samples to generate (default: 8)--batch-size: Batch size for inference (default: 1)--seed: Random seed for reproducibility (default: 100)--no-save-images: Disable saving visualization images (default: save images)--infer-device: Device for inference, e.g.,cudaorcpu(default:cpu)--checkpoint_path: Path to custom checkpoint file to override pretrained weights (default: use pretrained weights) The pretrained model will be automatically downloaded if not found locally. Note that the pretrained weights are setup to load automatically, custom training checkpoint loading is available with the--checkpoint_pathflag.
cd project/geodata-3d-unconditional
# Saves tensors + PNGs to project/samples/<project_name>/
python model_train_inference.py --mode inference --n-samples 8 --batch-size 2 --seed 100 --infer-device cudaConditional training and inference requires an additional step to set up the surface and borehole data from a random generated StructuralGeo streaming data sample.
- Training:
model_train_sh_inference_cond.py
Training parameters can be adjusted via the get_config() function inside of the script. Script is set to use the same set of hyper parameters that were used for the pretrained conditional model provided.
cd project/geodata-3d-conditional
python model_train_sh_inference_cond.py- Inference:
A Jupyter notebook
project/geodata-3d-conditional/inference_demo.ipynbis provided to demonstrate generating conditional data, loading the saved weights, and running inference with the pretrained model. An additional probabilistic analysis using an ensemble of models is also included, making use of compressed data in thedikes_ptpack.tar.gzarchive.
An automated python script has also been provided to automatically generate synthetic geology, extract borehold data, and produce reconstructions:
cd project/geodata-3d-conditional
python model_inference_experiments.py --n-samples 4 --n-scenarios 1Available flags include:
--device: Device for inference, e.g.,cudaorcpu(default:cpu)--n-samples: Number of samples to generate per scenario (default: 1)--n-scenarios: Number of different geological scenarios to generate for sample reconstruction (default: 4)--use-ema: Use EMA weights for inference (default: True)--no-display: Disable displaying images during inference (default: display images)--checkpoint_path: Path to custom checkpoint file to override pretrained weights (default: use pretrained weights)