Skip to content

PyTorch implementation of Monotone Gradient Networks (MGN) for Optimal Transport and Generative Modeling (Chaudhari et al. 2023).

Notifications You must be signed in to change notification settings

Omar-Ar1/monotone-gradient-networks

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

📈 Monotone Gradient Networks (MGN)

Python 3.8+ PyTorch Ask DeepWiki

Color Adaptation Example

Color Adaptation using Monotone Gradient Networks: Transporting color palette from Source to Target.

This repository implements Monotone Gradient Networks (MGN) for solving optimal transport and generative modeling problems, based on the paper:

Chaudhari, S., Pranav, S., & Moura, J. M. (2023). Learning Gradients of Convex Functions with Monotone Gradient Networks. arXiv preprint arXiv:2301.10862.

🧮 Problem Description & Math

The core problem addressed is learning the gradient of a convex function $\psi: \mathbb{R}^d \to \mathbb{R}$, denoted as $g(x) = \nabla \psi(x)$. Since $\psi$ is convex, its gradient $g(x)$ is a monotone map, meaning: $$\langle g(x) - g(y), x - y \rangle \ge 0, \quad \forall x, y \in \mathbb{R}^d$$

This mathematical property is fundamental in:

  1. Optimal Transport (OT): By Brenier's Theorem, the optimal transport map between two continuous probability measures (with quadratic cost) is the gradient of a strictly convex function.
  2. Generative Modeling: We can model a generative mapping pushing a latent distribution $\mu$ (e.g., Gaussian) to a target data distribution $\nu$ as $g_* \mu = \nu$.

🧠 Leveraging MGN

Standard neural networks do not guarantee that the learned map is the gradient of a convex function. MGN architectures are specifically designed to parameterize such maps $g(x) = \nabla \psi(x)$ by construction.

We implement two variants of MGN in src/monotone_grad_nets/models/:

  • M_MGN (Modular MGN): Uses a modular sum-structure with log-cosh smoothing and a PSD term ($V^T V$) to ensure monotonicity.
  • C_MGN (Cascade MGN): Uses a cascading layer structure with shared weights to efficiently parameterize the map.

We also include Input Convex Neural Networks (ICNN) (I_CNN and I_CGN) for comparison.

📉 Losses

We train these networks using various loss functions depending on the task (Gradient Fitting vs. Optimal Transport). These are implemented in src/monotone_grad_nets/trainers/trainer.py.

1. 🎯 Gradient Fitting

When the true gradient $\nabla f(x)$ is known, we use regression:

  • L1 Loss: $\mathbb{E} [| g(x) - \nabla f(x) |]$
  • MSE Loss: $\mathbb{E} [| g(x) - \nabla f(x) |^2]$

2. 🚚 Optimal Transport & Generative Modeling

When we only have samples from the source $\mu$ and target $\nu$ distributions, we minimize:

  • Sinkhorn Divergence: A regularized approximation of the Wasserstein distance, computed using the geomloss library.

    $$\mathcal{L} = S_\epsilon(g_\* \mu, \nu) + \lambda\, \mathcal{C}_{\text{transport}}(g)$$
    where $$\mathcal{C}_{\text{transport}}(g) = \mathbb{E}[\|x - g(x)\|^2]$$ penalizes the displacement.
  • Negative Log Likelihood (NLL) / KL Divergence: If fitting a map to a known target density (e.g., Gaussian), we maximize the likelihood of the mapped samples. $$\mathcal{L} = - \mathbb{E}_{x \sim \mu} \left[\log \nu(g(x)) + \log \det (\nabla g(x)) \right]$$. Here, $\nabla g(x) = \nabla^2 \psi(x)$ is the Hessian of the potential. We compute $\log \det(\nabla g(x))$ efficiently using torch.linalg.slogdet.

📓 Notebooks & Results

The notebooks/ directory contains experiments demonstrating the capabilities of MGN:

  • 01_sanity_check_trivial.ipynb: Basic validation of the models and training loop.
  • 02_toy_2d_mapping.ipynb: Visualization of 2D Optimal Transport maps (e.g., mapping a Gaussian to a Mixture of Gaussians). Displays the warped grid and particle transport.
  • 03_color_adaptation_cmgn.ipynb & 03_color_adaptation_mmgn.ipynb: Color Transfer application. MGN learns to map the color palette of a source image to match a target image while preserving content structure.
  • 04_mnist_generation.ipynb: Generative Modeling. Training an MGN to map Gaussian noise to the MNIST digits manifold.

🖼️ Visual Results

Here are some examples of Color Adaptation using Monotone Gradient Networks. The model learns to transport the color palette of a source image to a target image.

M_MGN Results

M_MGN Output Figure 2: Color Transfer results using Modular MGN (M_MGN).

📚 Citation

@article{chaudhari2023learning,
  title={Learning Gradients of Convex Functions with Monotone Gradient Networks},
  author={Chaudhari, Shreyas and Pranav, Srinivasa and Moura, José MF},
  journal={arXiv preprint arXiv:2301.10862},
  year={2023}
}

About

PyTorch implementation of Monotone Gradient Networks (MGN) for Optimal Transport and Generative Modeling (Chaudhari et al. 2023).

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published