Color Adaptation using Monotone Gradient Networks: Transporting color palette from Source to Target.
This repository implements Monotone Gradient Networks (MGN) for solving optimal transport and generative modeling problems, based on the paper:
Chaudhari, S., Pranav, S., & Moura, J. M. (2023). Learning Gradients of Convex Functions with Monotone Gradient Networks. arXiv preprint arXiv:2301.10862.
The core problem addressed is learning the gradient of a convex function
This mathematical property is fundamental in:
- Optimal Transport (OT): By Brenier's Theorem, the optimal transport map between two continuous probability measures (with quadratic cost) is the gradient of a strictly convex function.
-
Generative Modeling: We can model a generative mapping pushing a latent distribution
$\mu$ (e.g., Gaussian) to a target data distribution$\nu$ as$g_* \mu = \nu$ .
Standard neural networks do not guarantee that the learned map is the gradient of a convex function. MGN architectures are specifically designed to parameterize such maps
We implement two variants of MGN in src/monotone_grad_nets/models/:
-
M_MGN (Modular MGN): Uses a modular sum-structure with
log-coshsmoothing and a PSD term ($V^T V$ ) to ensure monotonicity. - C_MGN (Cascade MGN): Uses a cascading layer structure with shared weights to efficiently parameterize the map.
We also include Input Convex Neural Networks (ICNN) (I_CNN and I_CGN) for comparison.
We train these networks using various loss functions depending on the task (Gradient Fitting vs. Optimal Transport). These are implemented in src/monotone_grad_nets/trainers/trainer.py.
When the true gradient
-
L1 Loss:
$\mathbb{E} [| g(x) - \nabla f(x) |]$ -
MSE Loss:
$\mathbb{E} [| g(x) - \nabla f(x) |^2]$
When we only have samples from the source
-
Sinkhorn Divergence: A regularized approximation of the Wasserstein distance, computed using the
geomlosslibrary.where$$\mathcal{L} = S_\epsilon(g_\* \mu, \nu) + \lambda\, \mathcal{C}_{\text{transport}}(g)$$ $$\mathcal{C}_{\text{transport}}(g) = \mathbb{E}[\|x - g(x)\|^2]$$ penalizes the displacement. -
Negative Log Likelihood (NLL) / KL Divergence: If fitting a map to a known target density (e.g., Gaussian), we maximize the likelihood of the mapped samples.
$$\mathcal{L} = - \mathbb{E}_{x \sim \mu} \left[\log \nu(g(x)) + \log \det (\nabla g(x)) \right]$$ . Here,$\nabla g(x) = \nabla^2 \psi(x)$ is the Hessian of the potential. We compute$\log \det(\nabla g(x))$ efficiently usingtorch.linalg.slogdet.
The notebooks/ directory contains experiments demonstrating the capabilities of MGN:
01_sanity_check_trivial.ipynb: Basic validation of the models and training loop.02_toy_2d_mapping.ipynb: Visualization of 2D Optimal Transport maps (e.g., mapping a Gaussian to a Mixture of Gaussians). Displays the warped grid and particle transport.03_color_adaptation_cmgn.ipynb&03_color_adaptation_mmgn.ipynb: Color Transfer application. MGN learns to map the color palette of a source image to match a target image while preserving content structure.04_mnist_generation.ipynb: Generative Modeling. Training an MGN to map Gaussian noise to the MNIST digits manifold.
Here are some examples of Color Adaptation using Monotone Gradient Networks. The model learns to transport the color palette of a source image to a target image.
Figure 2: Color Transfer results using Modular MGN (M_MGN).
@article{chaudhari2023learning,
title={Learning Gradients of Convex Functions with Monotone Gradient Networks},
author={Chaudhari, Shreyas and Pranav, Srinivasa and Moura, José MF},
journal={arXiv preprint arXiv:2301.10862},
year={2023}
}