Skip to content

Commit 0073ea6

Browse files
authored
Update README.rst (#181)
1 parent 9b73fa4 commit 0073ea6

File tree

1 file changed

+93
-1
lines changed

1 file changed

+93
-1
lines changed

README.rst

+93-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,96 @@
11
Stream Likelihoods with ML
22
##########################
33

4-
Stuff
4+
This is the PyTorch implementation of the StreamMapper code, which can be used to model stellar streams.
5+
StreamMapper-PyTorch is a PyTorch framework for building Bayesian Mixture Density Networks, which can
6+
then be trained using the standard PyTorch tooling.
7+
Detailed explanations can be found in our paper (https://ui.adsabs.harvard.edu/abs/2023arXiv231116960S/abstract)
8+
and especially in the code repository for the paper (https://github.com/nstarman/stellar_stream_density_ml_paper).
9+
10+
As an illustruative example:
11+
12+
.. code-block:: python
13+
14+
bkg_phi2_model = sml.builtin.Uniform(
15+
data_scaler=scaler,
16+
indep_coord_names=("phi1",),
17+
coord_names=("phi2",),
18+
coord_bounds={"phi2": (lower, upper)},
19+
params=ModelParameters(),
20+
)
21+
22+
bkg_plx_model = sml.builtin.Exponential(
23+
net=sml.nn.sequential(
24+
data=1, hidden_features=32, layers=3, features=1, dropout=0.15
25+
),
26+
data_scaler=scaler,
27+
indep_coord_names=("phi1",),
28+
coord_names=("parallax",),
29+
coord_bounds={"parallax": (lower, upper)},
30+
params=ModelParameters(
31+
{"parallax": {"slope": ModelParameter(bounds=SigmoidBounds(15.0, 25.0))}}
32+
),
33+
)
34+
35+
36+
bkg_flow = sml.builtin.compat.ZukoFlowModel(
37+
net=zuko.flows.MAF(features=2, context=1, transforms=4, hidden_features=[4] * 4),
38+
jacobian_logdet=-xp.log(xp.prod(...)),
39+
data_scaler=scaler[("phi1", "g", "r")],
40+
coord_names=phot_names,
41+
coord_bounds=phot_bounds,
42+
params=ModelParameters(),
43+
)
44+
45+
background_model = sml.IndependentModels(
46+
{
47+
"astrometric": sml.IndependentModels(
48+
{"phi2": bkg_phi2_model, "parallax": bkg_plx_model}
49+
),
50+
"photometric": bkg_flow,
51+
}
52+
)
53+
54+
55+
stream_astrometric_model = sml.builtin.Normal(
56+
net=..., # PyTorch NN
57+
data_scaler=scaler,
58+
coord_names=coord_astrometric_names,
59+
coord_bounds=coord_astrometric_bounds,
60+
params=ModelParameters(
61+
{
62+
"phi2": {
63+
"mu": ModelParameter(bounds=..., scaler=...),
64+
"ln-sigma": ModelParameter(bounds=..., scaler=...),
65+
},
66+
"parallax": {
67+
"mu": ModelParameter(bounds=..., scaler=...),
68+
"ln-sigma": ModelParameter(bounds=..., scaler=...),
69+
},
70+
}
71+
),
72+
)
73+
74+
stream_isochrone_model = sml.builtin.IsochroneMVNorm(...)
75+
76+
stream_model = sml.IndependentModels(
77+
{"astrometric": stream_astrometric_model, "photometric": stream_isochrone_model},
78+
unpack_params_hooks=(
79+
Parallax2DistMod(
80+
astrometric_coord="astrometric.parallax",
81+
photometric_coord="photometric.distmod",
82+
),
83+
),
84+
)
85+
86+
model = sml.MixtureModel(
87+
{"stream": stream_model, "background": background_model},
88+
net=...,
89+
data_scaler=scaler,
90+
params=ModelParameters(
91+
{
92+
f"stream.ln-weight": ModelParameter(...),
93+
f"background.ln-weight": ModelParameter(...),
94+
}
95+
),
96+
)

0 commit comments

Comments
 (0)