Skip to content

Commit 25fd218

Browse files
Merge pull request #207 from Ciela-Institute/dev
Dev
2 parents 0d895ab + a2bee45 commit 25fd218

31 files changed

+3487
-826
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ ci:
99

1010
repos:
1111
- repo: https://github.com/psf/black
12-
rev: "24.2.0"
12+
rev: "24.4.0"
1313
hooks:
1414
- id: black-jupyter
1515

@@ -20,7 +20,7 @@ repos:
2020
additional_dependencies: [black==23.7.0]
2121

2222
- repo: https://github.com/pre-commit/pre-commit-hooks
23-
rev: "v4.5.0"
23+
rev: "v4.6.0"
2424
hooks:
2525
- id: check-added-large-files
2626
- id: check-case-conflict
@@ -50,7 +50,7 @@ repos:
5050
args: [--prose-wrap=always]
5151

5252
- repo: https://github.com/astral-sh/ruff-pre-commit
53-
rev: "v0.3.2"
53+
rev: "v0.4.1"
5454
hooks:
5555
- id: ruff
5656
args: ["--fix", "--show-fixes"]

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
ipywidgets
22
jupyter-book
33
matplotlib
4+
pydantic>=2.6.1,<3
45
pyro-ppl
56
sphinx
67
sphinx_rtd_theme

docs/source/tutorials/BasicIntroduction.ipynb

Lines changed: 271 additions & 13 deletions
Large diffs are not rendered by default.

docs/source/tutorials/example.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
cosmology: &cosmo
2+
name: cosmo
3+
kind: FlatLambdaCDM
4+
5+
lens: &lens
6+
name: lens
7+
kind: SIE
8+
init_kwargs:
9+
cosmology: *cosmo
10+
11+
src: &src
12+
name: source
13+
kind: Sersic
14+
15+
lnslt: &lnslt
16+
name: lenslight
17+
kind: Sersic
18+
19+
simulator:
20+
name: minisim
21+
kind: Lens_Source
22+
init_kwargs:
23+
# Single lense
24+
lens: *lens
25+
source: *src
26+
lens_light: *lnslt
27+
pixelscale: 0.05
28+
pixels_x: 100

src/caustics/func.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from .lenses.func import (
2+
forward_raytrace,
3+
physical_from_reduced_deflection_angle,
4+
reduced_from_physical_deflection_angle,
5+
reduced_deflection_angle_sie,
6+
potential_sie,
7+
convergence_sie,
8+
reduced_deflection_angle_point,
9+
potential_point,
10+
convergence_point,
11+
reduced_deflection_angle_mass_sheet,
12+
potential_mass_sheet,
13+
convergence_mass_sheet,
14+
reduced_deflection_angle_epl,
15+
potential_epl,
16+
convergence_epl,
17+
reduced_deflection_angle_external_shear,
18+
potential_external_shear,
19+
physical_deflection_angle_nfw,
20+
potential_nfw,
21+
convergence_nfw,
22+
_f_batchable_nfw,
23+
_f_differentiable_nfw,
24+
_g_batchable_nfw,
25+
_g_differentiable_nfw,
26+
_h_batchable_nfw,
27+
_h_differentiable_nfw,
28+
reduced_deflection_angle_pixelated_convergence,
29+
potential_pixelated_convergence,
30+
_fft2_padded,
31+
build_kernels_pixelated_convergence,
32+
convergence_0_pseudo_jaffe,
33+
potential_pseudo_jaffe,
34+
reduced_deflection_angle_pseudo_jaffe,
35+
mass_enclosed_2d_pseudo_jaffe,
36+
convergence_pseudo_jaffe,
37+
reduced_deflection_angle_sis,
38+
potential_sis,
39+
convergence_sis,
40+
mass_enclosed_2d_tnfw,
41+
physical_deflection_angle_tnfw,
42+
potential_tnfw,
43+
convergence_tnfw,
44+
scale_density_tnfw,
45+
M0_scalemass_tnfw,
46+
M0_totmass_tnfw,
47+
concentration_tnfw,
48+
)
49+
50+
from .light.func import brightness_sersic, k_lenstronomy, k_sersic
51+
52+
__all__ = (
53+
"forward_raytrace",
54+
"physical_from_reduced_deflection_angle",
55+
"reduced_from_physical_deflection_angle",
56+
"reduced_deflection_angle_sie",
57+
"potential_sie",
58+
"convergence_sie",
59+
"reduced_deflection_angle_point",
60+
"potential_point",
61+
"convergence_point",
62+
"reduced_deflection_angle_mass_sheet",
63+
"potential_mass_sheet",
64+
"convergence_mass_sheet",
65+
"reduced_deflection_angle_epl",
66+
"potential_epl",
67+
"convergence_epl",
68+
"reduced_deflection_angle_external_shear",
69+
"potential_external_shear",
70+
"physical_deflection_angle_nfw",
71+
"potential_nfw",
72+
"convergence_nfw",
73+
"_f_batchable_nfw",
74+
"_f_differentiable_nfw",
75+
"_g_batchable_nfw",
76+
"_g_differentiable_nfw",
77+
"_h_batchable_nfw",
78+
"_h_differentiable_nfw",
79+
"reduced_deflection_angle_pixelated_convergence",
80+
"potential_pixelated_convergence",
81+
"_fft2_padded",
82+
"build_kernels_pixelated_convergence",
83+
"convergence_0_pseudo_jaffe",
84+
"potential_pseudo_jaffe",
85+
"reduced_deflection_angle_pseudo_jaffe",
86+
"mass_enclosed_2d_pseudo_jaffe",
87+
"convergence_pseudo_jaffe",
88+
"reduced_deflection_angle_sis",
89+
"potential_sis",
90+
"convergence_sis",
91+
"mass_enclosed_2d_tnfw",
92+
"physical_deflection_angle_tnfw",
93+
"potential_tnfw",
94+
"convergence_tnfw",
95+
"scale_density_tnfw",
96+
"M0_scalemass_tnfw",
97+
"M0_totmass_tnfw",
98+
"concentration_tnfw",
99+
"brightness_sersic",
100+
"k_lenstronomy",
101+
"k_sersic",
102+
)

src/caustics/lenses/base.py

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from ..cosmology import Cosmology
1212
from ..parametrized import Parametrized, unpack
1313
from .utils import get_magnification
14-
from ..utils import batch_lm
1514
from ..packed import Packed
15+
from . import func
1616

1717
__all__ = ("ThinLens", "ThickLens")
1818

@@ -185,40 +185,19 @@ def forward_raytrace(
185185
*Unit: arcsec*
186186
"""
187187

188-
bxy = torch.stack((bx, by)).repeat(n_init, 1) # has shape (n_init, Dout:2)
189-
190188
# TODO make FOV more general so that it doesn't have to be centered on zero,zero
191189
if fov is None:
192190
raise ValueError("fov must be given to generate initial guesses")
193191

194-
# Random starting points in image plane
195-
guesses = (torch.as_tensor(fov) * (torch.rand(n_init, 2) - 0.5)).to(
196-
device=bxy.device
197-
) # Has shape (n_init, Din:2)
198-
199-
# Optimize guesses in image plane
200-
x, l, c = batch_lm( # noqa: E741 Unused `l` variable
201-
guesses,
202-
bxy,
203-
lambda *a, **k: torch.stack(
204-
self.raytrace(a[0][..., 0], a[0][..., 1], *a[1:], **k), dim=-1
205-
),
206-
f_args=(z_s, params),
192+
return func.forward_raytrace(
193+
bx,
194+
by,
195+
partial(self.raytrace, params=params, z_s=z_s),
196+
epsilon,
197+
n_init,
198+
fov,
207199
)
208200

209-
# Clip points that didn't converge
210-
x = x[c < 1e-2 * epsilon**2]
211-
212-
# Cluster results into n-images
213-
res = []
214-
while len(x) > 0:
215-
res.append(x[0])
216-
d = torch.linalg.norm(x - x[0], dim=-1)
217-
x = x[d > epsilon]
218-
219-
res = torch.stack(res, dim=0)
220-
return res[..., 0], res[..., 1]
221-
222201

223202
class ThickLens(Lens):
224203
"""
@@ -782,9 +761,8 @@ def reduced_deflection_angle(
782761
deflection_angle_x, deflection_angle_y = self.physical_deflection_angle(
783762
x, y, z_s, params
784763
)
785-
return (
786-
(d_ls / d_s) * deflection_angle_x,
787-
(d_ls / d_s) * deflection_angle_y,
764+
return func.reduced_from_physical_deflection_angle(
765+
deflection_angle_x, deflection_angle_y, d_s, d_ls
788766
)
789767

790768
@unpack
@@ -839,9 +817,8 @@ def physical_deflection_angle(
839817
deflection_angle_x, deflection_angle_y = self.reduced_deflection_angle(
840818
x, y, z_s, params
841819
)
842-
return (
843-
(d_s / d_ls) * deflection_angle_x,
844-
(d_s / d_ls) * deflection_angle_y,
820+
return func.physical_from_reduced_deflection_angle(
821+
deflection_angle_x, deflection_angle_y, d_s, d_ls
845822
)
846823

847824
@abstractmethod

src/caustics/lenses/epl.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import torch
55
from torch import Tensor
66

7-
from ..utils import derotate, translate_rotate
87
from .base import ThinLens, CosmologyType, NameType, ZLType
98
from ..parametrized import unpack
109
from ..packed import Packed
10+
from . import func
1111

1212
__all__ = ("EPL",)
1313

@@ -243,19 +243,9 @@ def reduced_deflection_angle(
243243
*Unit: arcsec*
244244
245245
"""
246-
x, y = translate_rotate(x, y, x0, y0, phi)
247-
248-
# follow Tessore et al 2015 (eq. 5)
249-
z = q * x + y * 1j
250-
r = torch.abs(z)
251-
252-
# Tessore et al 2015 (eq. 23)
253-
r_omega = self._r_omega(z, t, q)
254-
alpha_c = 2.0 / (1.0 + q) * (b / r) ** t * r_omega # fmt: skip
255-
256-
alpha_real = torch.nan_to_num(alpha_c.real, posinf=10**10, neginf=-(10**10))
257-
alpha_imag = torch.nan_to_num(alpha_c.imag, posinf=10**10, neginf=-(10**10))
258-
return derotate(alpha_real, alpha_imag, phi)
246+
return func.reduced_deflection_angle_epl(
247+
x0, y0, q, phi, b, t, x, y, self.n_iter
248+
)
259249

260250
def _r_omega(self, z, t, q):
261251
"""
@@ -349,10 +339,7 @@ def potential(
349339
*Unit: arcsec^2*
350340
351341
"""
352-
ax, ay = self.reduced_deflection_angle(x, y, z_s, params)
353-
ax, ay = derotate(ax, ay, -phi)
354-
x, y = translate_rotate(x, y, x0, y0, phi)
355-
return (x * ax + y * ay) / (2 - t) # fmt: skip
342+
return func.potential_epl(x0, y0, q, phi, b, t, x, y, self.n_iter)
356343

357344
@unpack
358345
def convergence(
@@ -402,6 +389,4 @@ def convergence(
402389
*Unit: unitless*
403390
404391
"""
405-
x, y = translate_rotate(x, y, x0, y0, phi)
406-
psi = (q**2 * (x**2 + self.s**2) + y**2).sqrt() # fmt: skip
407-
return (2 - t) / 2 * (b / psi) ** t # fmt: skip
392+
return func.convergence_epl(x0, y0, q, phi, b, t, x, y, self.s)

src/caustics/lenses/external_shear.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from typing import Optional, Union, Annotated
33

44
from torch import Tensor
5+
import torch
56

6-
from ..utils import translate_rotate
77
from .base import ThinLens, CosmologyType, NameType, ZLType
88
from ..parametrized import unpack
99
from ..packed import Packed
10+
from . import func
1011

1112
__all__ = ("ExternalShear",)
1213

@@ -135,16 +136,9 @@ def reduced_deflection_angle(
135136
*Unit: arcsec*
136137
137138
"""
138-
x, y = translate_rotate(x, y, x0, y0)
139-
# Meneghetti eq 3.83
140-
# TODO, why is it not:
141-
# th = (x**2 + y**2).sqrt() + self.s
142-
# a1 = x/th + x * gamma_1 + y * gamma_2
143-
# a2 = y/th + x * gamma_2 - y * gamma_1
144-
a1 = x * gamma_1 + y * gamma_2
145-
a2 = x * gamma_2 - y * gamma_1
146-
147-
return a1, a2 # I'm not sure but I think no derotation necessary
139+
return func.reduced_deflection_angle_external_shear(
140+
x0, y0, gamma_1, gamma_2, x, y
141+
)
148142

149143
@unpack
150144
def potential(
@@ -192,9 +186,7 @@ def potential(
192186
*Unit: arcsec^2*
193187
194188
"""
195-
ax, ay = self.reduced_deflection_angle(x, y, z_s, params)
196-
x, y = translate_rotate(x, y, x0, y0)
197-
return 0.5 * (x * ax + y * ay)
189+
return func.potential_external_shear(x0, y0, gamma_1, gamma_2, x, y)
198190

199191
@unpack
200192
def convergence(
@@ -247,4 +239,4 @@ def convergence(
247239
This method is not implemented as the convergence is not defined
248240
for an external shear.
249241
"""
250-
raise NotImplementedError("convergence undefined for external shear")
242+
return torch.zeros_like(x)

0 commit comments

Comments
 (0)