Skip to content

Commit 45c15ef

Browse files
authored
Merge pull request #6 from andrinr/diffable_render
feat: Differentiable Rendering / Scene fiting
2 parents 5113de7 + df12b20 commit 45c15ef

56 files changed

Lines changed: 5159 additions & 1567 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.vscode/settings.json

Lines changed: 0 additions & 7 deletions
This file was deleted.

LICENSE

Lines changed: 64 additions & 635 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 15 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,21 @@ Differentiable SDF primitives, transformations, and constraint system built with
55
> [!WARNING]
66
> The API is not stable. Expect breaking changes.
77
8-
![primitives](examples/ior.png)
8+
![primitives](examples/assets/primitives.png)
99

10+
---
11+
12+
## Features
13+
14+
- **SDF primitives** — sphere, box, capsule, cylinder, torus, and more
15+
- **Boolean ops** — union, intersection, subtraction with smooth blending
16+
- **Transforms** — translate, rotate, scale, mirror, repeat
17+
- **Raymarcher** — sphere-tracing renderer with materials, lighting, refraction, and anti-aliasing
18+
- **Differentiable rendering** — gradients flow through the full render pipeline via JAX
19+
- **Constraint system** — geometric constraints (distance, angle, coincident) with Riemannian gradient descent and Newton projection onto the constraint manifold
20+
- **JAX-native** — every scene is a pure function; `jit`, `grad`, and `vmap` work out of the box
1021

22+
![primitives](examples/assets/constrained_optim.png)
1123
---
1224

1325
## Development install
@@ -37,125 +49,12 @@ quarto preview # serve locally at localhost:4321
3749

3850
---
3951

40-
## SDF example
41-
42-
Build a scene from primitives, boolean ops, and transforms:
43-
44-
```python
45-
import jax.numpy as jnp
46-
from jaxcad.sdf import Sphere, Box, Capsule, Cylinder, Torus, Union, Translate
47-
48-
sphere = Translate(Sphere(radius=0.6), offset=jnp.array([-1.0, 0.0, 0.0]))
49-
box = Translate(Box(size=[0.7, 0.7, 1.0]), offset=jnp.array([0.0, 0.0, 0.8]))
50-
capsule = Translate(Capsule(radius=0.3, height=1.3), offset=jnp.array([1.0, 0.0, 0.0]))
51-
52-
scene = Union((sphere, box, capsule), smoothness=0.1)
53-
54-
# Evaluate the SDF at any point
55-
p = jnp.array([0.5, 0.0, 0.0])
56-
print(scene(p)) # signed distance from p to the surface
57-
```
58-
59-
Every node is a pure JAX function — `jax.grad`, `jax.jit`, and `jax.vmap` work directly on the scene.
60-
61-
---
62-
63-
## Rendering example
64-
65-
Assign materials to primitives and render with the sphere-tracing raymarcher.
66-
`background_color`, `refract_steps`, and `ior` are all new in this release.
67-
68-
```python
69-
import jax.numpy as jnp
70-
from jaxcad.render import raymarch, Material
71-
from jaxcad.sdf.primitives import Sphere
72-
from jaxcad.sdf.boolean import Union
73-
from jaxcad.sdf.transforms import Translate
74-
75-
# Glass sphere (ior=1.5) in front of two coloured spheres
76-
glass = Sphere(
77-
radius=1.0,
78-
material=Material(color=[0.92, 0.97, 1.0], roughness=0.05, opacity=0.04, ior=1.5),
79-
)
80-
red = Translate(Sphere(radius=0.65, material=Material(color=[0.93, 0.26, 0.22])),
81-
offset=jnp.array([-1.1, 0.5, -3.0]))
82-
green = Translate(Sphere(radius=0.65, material=Material(color=[0.05, 0.72, 0.50])),
83-
offset=jnp.array([ 1.1, -0.5, -3.0]))
84-
scene = Union((glass, red, green), smoothness=0.0)
85-
86-
image = raymarch(
87-
scene,
88-
camera_pos=jnp.array([0.0, 0.5, 5.5]),
89-
resolution=(400, 400),
90-
background_color=jnp.array([0.07, 0.09, 0.16]), # dark night sky
91-
refract_steps=48, # two-bounce Snell's-law refraction
92-
aa_samples=2,
93-
)
94-
# image is a (400, 400, 3) float32 numpy array
95-
```
96-
97-
---
98-
99-
## Constraint example — Riemannian gradient descent
100-
101-
Move a point along a constraint manifold using Riemannian gradient descent: gradient steps stay on the tangent plane and a Newton projection snaps back to the manifold after each step.
102-
103-
```python
104-
import jax
105-
import jax.numpy as jnp
106-
import optax
107-
108-
from jaxcad.constraints import (
109-
DistanceConstraint, Vector,
110-
null_space, make_manifold_projection,
111-
)
112-
from jaxcad.extraction import extract_parameters
113-
114-
# Constrain p to lie on the sphere |p| = 2
115-
anchor = Vector(jnp.array([0.0, 0.0, 0.0]))
116-
p = Vector(jnp.array([2.0, 0.0, 0.0]), free=True, name="p")
117-
DistanceConstraint(anchor, p, distance=2.0)
118-
119-
free_params, _, metadata = extract_parameters(p)
120-
target = jnp.array([1.0, 1.5, 0.0])
121-
122-
def objective(params):
123-
return jnp.sum((params["p"] - target) ** 2)
124-
125-
value_and_grad = jax.value_and_grad(objective)
126-
127-
def riemannian_grad(params):
128-
"""Project gradient onto the tangent plane at the current point."""
129-
N = null_space(params, metadata) # tangent-space basis (relinearized)
130-
loss, g = value_and_grad(params)
131-
return loss, N @ (g @ N) # Riemannian gradient
132-
133-
# Riemannian GD: tangent-plane steps + manifold projection after each update
134-
optimizer = optax.chain(
135-
optax.sgd(0.15),
136-
make_manifold_projection(metadata), # Newton snap-back onto |p|=2
137-
)
138-
139-
params = free_params
140-
state = optimizer.init(params)
141-
for step in range(20):
142-
loss, g = riemannian_grad(params)
143-
updates, state = optimizer.update(g, state, params)
144-
params = optax.apply_updates(params, updates)
145-
146-
print(params["p"]) # [1.109, 1.664, 0.] — optimal point on |p|=2 closest to target
147-
```
148-
149-
`null_space` recomputes the constraint Jacobian at the current point each step, so the gradient is always projected onto the correct tangent plane. `make_manifold_projection` chains as a standard optax transform and works with any base optimizer.
150-
151-
---
152-
15352
Inspired by [Fidget](https://www.mattkeeter.com/projects/fidget/) and [Inigo Quilez's distance functions](https://iquilezles.org/articles/distfunctions/).
15453

15554
---
15655

157-
![primitives](examples/thingy.png)
56+
![primitives](examples/assets/thingy.png)
15857

15958
## License
16059

161-
[GNU Affero General Public License v3.0](LICENSE) — free for open source use; commercial use requires a separate license. Contact the authors if you want to use jaxcad in a proprietary product.
60+
[Elastic License 2.0](LICENSE) — free for personal, research, and internal business use. Offering jaxcad as a hosted or managed service requires a commercial license. Contact [andrin.rehmann@simulation.science](mailto:andrin.rehmann@simulation.science) for commercial enquiries.

_quarto.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@ quartodoc:
4747
- extraction.extract_parameters
4848
- title: Functionalize
4949
contents:
50-
- sdf.functionalize.functionalize
50+
- functionalize.functionalize
51+
- render.functionalize.functionalize_render
52+
- title: Parametrization
53+
contents:
54+
- parametrization.to_normalized
55+
- parametrization.from_normalized
56+
- parametrization.compute_param_scales
5157
- title: Constraints — types
5258
contents:
5359
- constraints.types.base.Constraint
@@ -61,18 +67,12 @@ quartodoc:
6167
- constraints.solve.project_to_manifold
6268
- constraints.solve.constraint_residuals
6369
- constraints.solve.make_manifold_projection
64-
- title: Construction
65-
contents:
66-
- construction.from_point.from_point
67-
- construction.from_line.from_line
68-
- construction.from_circle.from_circle
69-
- construction.extrude.extrude
7070
- title: Geometry
7171
contents:
7272
- geometry.parameters.Vector
7373
- geometry.parameters.Scalar
7474
- title: Render
7575
contents:
76-
- render.raymarch.raymarch
77-
- render.raymarch.render_raymarched
76+
- render.raymarch.render.raymarch
77+
- render.raymarch.render.render_raymarched
7878
- render.marching_cubes.render_marching_cubes
88.4 KB
Loading

examples/assets/lada.png

1.43 MB
Loading

0 commit comments

Comments
 (0)