Skip to content

Commit 642e774

Browse files
authored
refactor: merge WeightedDataGenerator into DataGenerator (#458)
* chore: upgrade Jupyter notebook kernels * docs: add link to TR-018 * feat: embed weights as key to `DataSample` * feat: implement phase space weights in `UnbinnedNLL`
1 parent 811c17b commit 642e774

15 files changed

+128
-52
lines changed

docs/amplitude-analysis.ipynb

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,12 @@
299299
"cell_type": "markdown",
300300
"metadata": {},
301301
"source": [
302+
"::::{margin}\n",
303+
":::{tip}\n",
304+
"{doc}`TR-018<compwa-org:report/018>` explains some of the mechanisms behind the phase space generator as well as how to do {ref}`importance sampling<compwa-org:report/018:Intensity distribution>`.\n",
305+
":::\n",
306+
"::::\n",
307+
"\n",
302308
"In this section, we use the {class}`~ampform.helicity.HelicityModel` that we created with {mod}`ampform` in {ref}`the previous step <compwa-step-1>` to generate a data sample via hit & miss Monte Carlo. We do this with the {mod}`.data` module.\n",
303309
"\n",
304310
"First, we {func}`~pickle.load` the {class}`~ampform.helicity.HelicityModel` that was created in the previous step. This does not have to be done if the model has been generated in the same script or notebook, but can be useful if the model was generated elsewhere."
@@ -353,7 +359,7 @@
353359
"cell_type": "markdown",
354360
"metadata": {},
355361
"source": [
356-
"The {class}`~qrules.transition.ReactionInfo` class defines the constraints of the phase space. As such, we have enough information to generate a **phase-space sample** for this particle reaction. We do this with a {class}`.TFPhaseSpaceGenerator` class, which is an implementation of the {class}`.DataGenerator` for a {obj}`.DataSample` of **four-momenta** arrays (using {obj}`tensorflow <tf.Tensor>` and the [`phasespace`](https://phasespace.readthedocs.io) package as a back-end). We also need to construct a {class}`.RealNumberGenerator` that can generate random numbers. {class}`.TFUniformRealNumberGenerator` is the natural choice here.\n",
362+
"The {class}`~qrules.transition.ReactionInfo` class defines the constraints of the phase space. As such, we have enough information to generate a **phase-space sample** for this particle reaction. We do this with a {class}`.TFPhaseSpaceGenerator` class, which is a {class}`.DataGenerator` for a {obj}`.DataSample` of **four-momenta** arrays (using {obj}`tensorflow <tf.Tensor>` and the [`phasespace`](https://phasespace.readthedocs.io) package as a back-end). We also need to construct a {class}`.RealNumberGenerator` that can generate random numbers. {class}`.TFUniformRealNumberGenerator` is the natural choice here.\n",
357363
"\n",
358364
"As opposed to the main {ref}`amplitude-analysis:Step 2: Generate data` of the main usage example page, we will generate a **deterministic** data sample. This can be done by feeding a {class}`.RealNumberGenerator` with a specific {attr}`~.RealNumberGenerator.seed` and giving that generator to the {meth}`.TFPhaseSpaceGenerator.generate` method:"
359365
]
@@ -1935,8 +1941,16 @@
19351941
"name": "python3"
19361942
},
19371943
"language_info": {
1944+
"codemirror_mode": {
1945+
"name": "ipython",
1946+
"version": 3
1947+
},
1948+
"file_extension": ".py",
1949+
"mimetype": "text/x-python",
19381950
"name": "python",
1939-
"version": "3.8.12"
1951+
"nbconvert_exporter": "python",
1952+
"pygments_lexer": "ipython3",
1953+
"version": "3.8.13"
19401954
}
19411955
},
19421956
"nbformat": 4,

docs/amplitude-analysis/analytic-continuation.ipynb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,15 @@
290290
"name": "python3"
291291
},
292292
"language_info": {
293+
"codemirror_mode": {
294+
"name": "ipython",
295+
"version": 3
296+
},
297+
"file_extension": ".py",
298+
"mimetype": "text/x-python",
293299
"name": "python",
300+
"nbconvert_exporter": "python",
301+
"pygments_lexer": "ipython3",
294302
"version": "3.8.13"
295303
}
296304
},

docs/usage.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,8 +857,16 @@
857857
"name": "python3"
858858
},
859859
"language_info": {
860+
"codemirror_mode": {
861+
"name": "ipython",
862+
"version": 3
863+
},
864+
"file_extension": ".py",
865+
"mimetype": "text/x-python",
860866
"name": "python",
861-
"version": "3.8.12"
867+
"nbconvert_exporter": "python",
868+
"pygments_lexer": "ipython3",
869+
"version": "3.8.13"
862870
}
863871
},
864872
"nbformat": 4,

docs/usage/basics.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1291,8 +1291,16 @@
12911291
"name": "python3"
12921292
},
12931293
"language_info": {
1294+
"codemirror_mode": {
1295+
"name": "ipython",
1296+
"version": 3
1297+
},
1298+
"file_extension": ".py",
1299+
"mimetype": "text/x-python",
12941300
"name": "python",
1295-
"version": "3.8.12"
1301+
"nbconvert_exporter": "python",
1302+
"pygments_lexer": "ipython3",
1303+
"version": "3.8.13"
12961304
}
12971305
},
12981306
"nbformat": 4,

docs/usage/binned-fit.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,16 @@
328328
"name": "python3"
329329
},
330330
"language_info": {
331+
"codemirror_mode": {
332+
"name": "ipython",
333+
"version": 3
334+
},
335+
"file_extension": ".py",
336+
"mimetype": "text/x-python",
331337
"name": "python",
332-
"version": "3.8.12"
338+
"nbconvert_exporter": "python",
339+
"pygments_lexer": "ipython3",
340+
"version": "3.8.13"
333341
}
334342
},
335343
"nbformat": 4,

docs/usage/caching.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,16 @@
540540
"name": "python3"
541541
},
542542
"language_info": {
543+
"codemirror_mode": {
544+
"name": "ipython",
545+
"version": 3
546+
},
547+
"file_extension": ".py",
548+
"mimetype": "text/x-python",
543549
"name": "python",
544-
"version": "3.8.12"
550+
"nbconvert_exporter": "python",
551+
"pygments_lexer": "ipython3",
552+
"version": "3.8.13"
545553
}
546554
},
547555
"nbformat": 4,

docs/usage/chi-squared.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,16 @@
274274
"name": "python3"
275275
},
276276
"language_info": {
277+
"codemirror_mode": {
278+
"name": "ipython",
279+
"version": 3
280+
},
281+
"file_extension": ".py",
282+
"mimetype": "text/x-python",
277283
"name": "python",
278-
"version": "3.8.12"
284+
"nbconvert_exporter": "python",
285+
"pygments_lexer": "ipython3",
286+
"version": "3.8.13"
279287
}
280288
},
281289
"nbformat": 4,

docs/usage/faster-lambdify.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,16 @@
413413
"name": "python3"
414414
},
415415
"language_info": {
416+
"codemirror_mode": {
417+
"name": "ipython",
418+
"version": 3
419+
},
420+
"file_extension": ".py",
421+
"mimetype": "text/x-python",
416422
"name": "python",
417-
"version": "3.8.12"
423+
"nbconvert_exporter": "python",
424+
"pygments_lexer": "ipython3",
425+
"version": "3.8.13"
418426
}
419427
},
420428
"nbformat": 4,

docs/usage/unbinned-fit.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,16 @@
351351
"name": "python3"
352352
},
353353
"language_info": {
354+
"codemirror_mode": {
355+
"name": "ipython",
356+
"version": 3
357+
},
358+
"file_extension": ".py",
359+
"mimetype": "text/x-python",
354360
"name": "python",
355-
"version": "3.8.12"
361+
"nbconvert_exporter": "python",
362+
"pygments_lexer": "ipython3",
363+
"version": "3.8.13"
356364
}
357365
},
358366
"nbformat": 4,

src/tensorwaves/data/__init__.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
DataTransformer,
1414
Function,
1515
RealNumberGenerator,
16-
WeightedDataGenerator,
1716
)
1817

1918
from ._data_sample import (
@@ -71,7 +70,7 @@ class IntensityDistributionGenerator(DataGenerator):
7170

7271
def __init__(
7372
self,
74-
domain_generator: DataGenerator | WeightedDataGenerator,
73+
domain_generator: DataGenerator,
7574
function: Function,
7675
domain_transformer: DataTransformer | None = None,
7776
bunch_size: int = 50_000,
@@ -115,18 +114,14 @@ def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
115114
return select_events(returned_data, selector=slice(None, size))
116115

117116
def _generate_bunch(self, rng: RealNumberGenerator) -> tuple[DataSample, float]:
118-
domain_generator = self.__domain_generator
119-
if isinstance(domain_generator, WeightedDataGenerator):
120-
domain, weights = domain_generator.generate(self.__bunch_size, rng)
121-
else:
122-
domain = _generate_without_progress_bar(
123-
domain_generator, self.__bunch_size, rng
124-
)
125-
weights = 1 # type: ignore[assignment]
117+
domain = _generate_without_progress_bar(
118+
self.__domain_generator, self.__bunch_size, rng
119+
)
126120
transformed_domain = self.__domain_transformer(domain)
127121
computed_intensities = self.__function(transformed_domain)
128122
max_intensity: float = np.max(computed_intensities)
129123
random_intensities = rng(size=self.__bunch_size, max_value=max_intensity)
124+
weights = domain.get("weights", 1)
130125
hit_and_miss_sample = select_events(
131126
domain,
132127
selector=weights * computed_intensities > random_intensities,
@@ -139,9 +134,9 @@ def _generate_without_progress_bar(
139134
) -> DataSample:
140135
# https://github.com/ComPWA/tensorwaves/issues/395
141136
show_progress = getattr(domain_generator, "show_progress", None)
142-
if show_progress:
137+
if show_progress is not None:
143138
domain_generator.show_progress = False # type: ignore[attr-defined]
144139
domain = domain_generator.generate(bunch_size, rng)
145-
if show_progress:
140+
if show_progress is not None:
146141
domain_generator.show_progress = show_progress # type: ignore[attr-defined]
147142
return domain

src/tensorwaves/data/phasespace.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
11
# pylint: disable=import-outside-toplevel
2-
"""Implementations of `.DataGenerator` and `.WeightedDataGenerator`."""
2+
"""Implementations of a `.DataGenerator` for four-momentum samples."""
33
from __future__ import annotations
44

55
import logging
66
from typing import Mapping
77

8-
import numpy as np
98
from tqdm.auto import tqdm
109

1110
from tensorwaves.function._backend import raise_missing_module_error
12-
from tensorwaves.interface import (
13-
DataGenerator,
14-
DataSample,
15-
RealNumberGenerator,
16-
WeightedDataGenerator,
17-
)
11+
from tensorwaves.interface import DataGenerator, DataSample, RealNumberGenerator
1812

1913
from ._data_sample import (
2014
finalize_progress_bar,
@@ -64,20 +58,30 @@ def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
6458
)
6559
momentum_pool: DataSample = {}
6660
while get_number_of_events(momentum_pool) < size:
67-
phsp_momenta, weights = self.__phsp_generator.generate(
68-
self.__bunch_size, rng
69-
)
61+
phsp_momenta = self.__phsp_generator.generate(self.__bunch_size, rng)
62+
weights = phsp_momenta.get("weights")
63+
if weights is None:
64+
raise ValueError(
65+
"DataSample returned by"
66+
f" {type(self.__phsp_generator).__name__} doesn't contain"
67+
' "weights"'
68+
)
7069
hit_and_miss_randoms = rng(self.__bunch_size)
7170
bunch = select_events(phsp_momenta, selector=weights > hit_and_miss_randoms)
7271
momentum_pool = merge_events(momentum_pool, bunch)
7372
progress_bar.update(n=get_number_of_events(bunch))
7473
finalize_progress_bar(progress_bar)
75-
return select_events(momentum_pool, selector=slice(None, size))
74+
phsp = select_events(momentum_pool, selector=slice(None, size))
75+
del phsp["weights"]
76+
return phsp
7677

7778

78-
class TFWeightedPhaseSpaceGenerator(WeightedDataGenerator):
79+
class TFWeightedPhaseSpaceGenerator(DataGenerator):
7980
"""Implements a phase space generator **with weights** using tensorflow.
8081
82+
The weights are provided in the returned `.DataSample` under the key
83+
:code:`"weights"`.
84+
8185
Args:
8286
initial_state_mass: Mass of the decaying state.
8387
final_state_masses: A mapping of final state IDs to the corresponding masses.
@@ -102,9 +106,7 @@ def __init__(
102106
names=list(map(str, sorted_ids)),
103107
)
104108

105-
def generate(
106-
self, size: int, rng: RealNumberGenerator
107-
) -> tuple[DataSample, np.ndarray]:
109+
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
108110
r"""Generate a `.DataSample` of phase space four-momenta with weights.
109111
110112
Returns:
@@ -122,4 +124,7 @@ def generate(
122124
f"p{label}": momenta.numpy()[:, [3, 0, 1, 2]]
123125
for label, momenta in particles.items()
124126
}
125-
return phsp_momenta, weights.numpy()
127+
return {
128+
"weights": weights.numpy(),
129+
**phsp_momenta,
130+
}

src/tensorwaves/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ def __init__( # pylint: disable=too-many-arguments
193193
backend: str = "numpy",
194194
) -> None:
195195
self.__data = {k: np.array(v) for k, v in data.items()}
196-
self.__phsp = {k: np.array(v) for k, v in phsp.items()}
196+
self.__phsp = {k: np.array(v) for k, v in phsp.items() if k != "weights"}
197+
self.__phsp_weights = phsp.get("weights")
197198
self.__function = function
198199
self.__gradient = gradient_creator(self.__call__, backend)
199200

@@ -207,6 +208,8 @@ def __call__(self, parameters: Mapping[str, ParameterValue]) -> float:
207208
self.__function.update_parameters(parameters)
208209
bare_intensities = self.__function(self.__data)
209210
phsp_intensities = self.__function(self.__phsp)
211+
if self.__phsp_weights is not None:
212+
phsp_intensities *= self.__phsp_weights
210213
normalization_factor = 1.0 / (
211214
self.__phsp_volume * self.__mean_function(phsp_intensities)
212215
)

src/tensorwaves/interface.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,17 +231,7 @@ class DataGenerator(ABC):
231231

232232
@abstractmethod
233233
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
234-
...
235-
236-
237-
class WeightedDataGenerator(ABC):
238-
"""Abstract class for generating a `.DataSample` with weights."""
239-
240-
@abstractmethod
241-
def generate(
242-
self, size: int, rng: RealNumberGenerator
243-
) -> tuple[DataSample, np.ndarray]:
244-
r"""Generate `.DataSample` with weights.
234+
r"""Generate a `.DataSample` with :code:`size` events.
245235
246236
Returns:
247237
A `tuple` of a `.DataSample` with an array of weights.

tests/data/test_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# pylint: disable=import-outside-toplevel
2+
from __future__ import annotations
3+
24
from typing import TYPE_CHECKING
35

46
import numpy as np
@@ -96,7 +98,7 @@ def test_generate_four_momenta_on_flat_distribution(self):
9698
assert pytest.approx(phsp[i]) == data[i]
9799

98100

99-
def test_generate_without_progress_bar(capsys: "CaptureFixture"):
101+
def test_generate_without_progress_bar(capsys: CaptureFixture):
100102
class SilentGenerator(DataGenerator):
101103
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
102104
return {"x": 1} # type: ignore[dict-item]

tests/data/test_phasespace.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ def test_generate_deterministic(self, pdg: "ParticleCollection"):
137137
i: pdg[name].mass for i, name in enumerate(final_state_names)
138138
},
139139
)
140-
phsp_momenta, weights = phsp_generator.generate(sample_size, rng)
140+
phsp_momenta = phsp_generator.generate(sample_size, rng)
141+
assert list(phsp_momenta) == ["weights", "p0", "p1", "p2"]
142+
weights = phsp_momenta.get("weights", [])
143+
del phsp_momenta["weights"]
141144
print("Expected values, get by running pytest with the -s flag")
142145
pprint(
143146
{

0 commit comments

Comments
 (0)