Backend-Agnostic Inverse 1D Burgers Solver via Tesseract (Viscosity Estimation with JAX and PyTorch PINNs)
Overview This project demonstrates pipeline-level automatic differentiation across frameworks using Tesseract. An inverse 1D viscous Burgers equation solver (Inferring viscosity coefficient via physics-informed neural networks) runs with either JAX or PyTorch PINN backends while maintaining identical optimization code. When the PyTorch backend is selected, JAX gradients are computed through the Tesseract VJP interface, enabling cross-framework automatic differentiation.
Key implementations:
- Implementation of
apply,vector_jacobian_product, andjacobian_vector_productendpoints for both JAX and PyTorch PINNs - Demonstration of JAX optimizer computing gradients through PyTorch models via Tesseract's VJP endpoint
- Backend-agnostic inverse problem pipeline with swappable implementations (Pytorch/Jax)
Given noisy observations of the 1D Burgers equation solution, infer the unknown viscosity parameter
where:
-
$u(x, t)$ is the velocity field on$[0, 1] \times [0, T]$ -
$\nu$ is the kinematic viscosity (inferred parameter) - Initial condition:
$u(x, 0) = \sin(2\pi x)$ - Boundary conditions: periodic on
$[0, 1]$
Synthetic observations are generated using a heat equation approximation
A physics-informed neural network (PINN) minimizes a combined loss function:
where:
-
$\mathcal{L}_{\text{data}}$ : mean squared error between predictions and observations -
$\mathcal{L}_{\text{physics}}$ : PDE residual at collocation points -
$\mathcal{L}_{\text{IC}}$ : initial condition violation -
$\mathcal{L}_{\text{BC}}$ : boundary condition violation
Note: In this demo, synthetic data are generated from a heat‑equation approximation of Burgers for small viscosity, which remains smooth and does not exhibit shock formation. The focus is on inverse viscosity estimation and pipeline‑level autodiff, not on resolving nonlinear shock dynamics.
The PINN uses Fourier feature encoding to mitigate spectral bias:
Input (x, t) ∈ ℝ²
↓
Fourier encoding: [x, t, sin(x·B_x), cos(x·B_x), sin(t·B_t), cos(t·B_t)]
↓
MLP: 130 → 64 → 64 → 64 → 1 (tanh activations)
↓
Output: u(x, t)
Derivatives (jax.grad for the JAX backend and torch.autograd.grad for the PyTorch backend.
Both pinn_jax and pinn_pytorch implement:
- apply(inputs): Forward pass returning u_pred, u_x, u_t, u_xx
- vector_jacobian_product(...): Reverse-mode AD for gradient computation
- jacobian_vector_product(...): Forward-mode AD for sensitivity analysis etc (not used in this project)
Input/output schemas use Tesseract's Differentiable[Array[...]] annotations to declare which fields participate in autodiff.
The optimization loop in inverse_problem.py demonstrates Tesseract-mediated gradient computation:
# Load backend (JAX or PyTorch)
pinn = Tesseract.from_image("pinn_jax") # or "pinn_pytorch"
# Define loss using Tesseract apply
def compute_loss(viscosity, params, ...):
result = apply_tesseract(pinn, {"x": x, "t": t, "params_flat": params})
# ... compute data + physics + IC + BC losses
return total_loss
# System-level gradients computed via Tesseract VJP (regardless of backend)
grad_visc = jax.grad(compute_loss, argnums=0) # ∂L/∂ν
grad_params = jax.grad(compute_loss, argnums=1) # ∂L/∂params
# When jax.grad is called, it triggers Tesseract's VJP endpoint
# For PyTorch backend: Tesseract VJP internally uses torch.autograd.grad
# For JAX backend: Tesseract VJP internally uses jax.grad
v_grad = grad_visc(viscosity, params, ...)
p_grad = grad_params(viscosity, params, ...)Key point: The system-level gradients (
$\partial \mathcal{L}/\partial \nu$ and$\partial \mathcal{L}/\partial \text{params}$ ) are computed through Tesseract'svector_jacobian_productendpoint for both backends. The backend selection only determines which autograd implementation Tesseract uses internally for its VJP computation.
tesseract-pinn-inverse-burgers/
├── inverse_problem.py # CLI demo comparing JAX/PyTorch backends
├── app.py # Streamlit interactive interface
├── buildall.sh # Builds Docker containers for both backends
├── pyproject.toml
└── tesseracts/
├── pinn_jax/
│ ├── tesseract_api.py # JAX/Equinox PINN with Tesseract endpoints
│ ├── tesseract_config.yaml
│ └── tesseract_requirements.txt
└── pinn_pytorch/
├── tesseract_api.py # PyTorch PINN with Tesseract endpoints
├── tesseract_config.yaml
└── tesseract_requirements.txt
Requirements: Python ≥3.10, Docker. (Optionally: uv for its venv/pip shims)
# Clone repository
git clone https://github.com/julian-8897/tesseract-pinn-inverse-burgers.git
cd tesseract-pinn-inverse-burgers
# Option A — using uv (recommended if you use uv workflow)
# Install uv if missing: pip install uv
uv venv
source .venv/bin/activate
uv pip install -e .
# Option B (python venv)
python -m venv .venv
source .venv/bin/activate
pip install -e .
# Build Tesseract containers (requires Docker running)
./buildall.sh
# Verify built images
docker images | grep pinn# Compare both backends
python inverse_problem.py --backend both --epochs 100
# Single backend
python inverse_problem.py --backend jax --epochs 50
python inverse_problem.py --backend pytorch --epochs 50streamlit run app.pyThe Streamlit app provides:
- Adjustable hyperparameters (viscosity, noise, learning rate)
- Real-time training visualization
- Gradient flow inspector (Tesseract API call statistics)
- Solution comparison (PINN vs analytical)
PINN inferred viscosity converges close to ground truth
PINN solution comparison (JAX vs PyTorch)
|
PINN
PINN vs Analytical Solution (JAX)
|
PINN vs Analytical Solution (PyTorch)
|
Tesseract Documentation:
- Tesseract Core — Main repository and CLI
- Tesseract-JAX — JAX integration layer
- Creating Tesseracts — Implementation guide
- Differentiable Programming — VJP/JVP concepts
Related Publications to PINNs:
- Raissi, M., Perdikaris, P., & Karniadakis, G. E., "Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations", Journal of Computational Physics 378 (2019): 686-707
- Tancik, M., Srinivasan, P. P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., & Ng, R., "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains", NeurIPS 2020
| Component | Version | Notes |
|---|---|---|
| tesseract-core | 1.2.0 | Runtime and CLI |
| tesseract-jax | 0.2.3 | JAX integration |
| Python | ≥3.10 | Type hints required |
| Docker | latest | Container execution |
| JAX | 0.8.2 | CPU backend |
| PyTorch | 2.9.1 | PyTorch backend |
| Equinox | 0.13.2 | JAX PINN modules |
| Optax | 0.2.6 | JAX optimizer |
Tested platforms: macOS (Apple Silicon)
Licensed under Apache License 2.0.


