Skip to content

Commit

Permalink
Merge pull request #70 from scipp/add-seed
Browse files Browse the repository at this point in the history
Add seed argument to source
  • Loading branch information
nvaytet authored Jan 13, 2025
2 parents e8985a3 + 1038f96 commit c42a2f8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/tof/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _make_pulses(
p_time: sc.DataArray,
p_wav: sc.DataArray,
sampling: int,
seed: Optional[int],
wmin: Optional[sc.Variable] = None,
wmax: Optional[sc.Variable] = None,
):
Expand All @@ -66,6 +67,8 @@ def _make_pulses(
Wavelength probability distribution for a single pulse.
sampling:
Number of points used to sample the probability distributions.
seed:
Seed for the random number generator.
wmin:
Minimum neutron wavelength.
wmax:
Expand Down Expand Up @@ -127,14 +130,15 @@ def _make_pulses(
times = []
wavs = []
ntot = pulses * neutrons
rng = np.random.default_rng(seed)
while n < ntot:
size = ntot - n
t = np.random.choice(
t = rng.choice(
p_time.coords[t_dim].values, size=size, p=p_time.values
) + np.random.normal(scale=dt, size=size)
w = np.random.choice(
) + rng.normal(scale=dt, size=size)
w = rng.choice(
p_wav.coords[w_dim].values, size=size, p=p_wav.values
) + np.random.normal(scale=dw, size=size)
) + rng.normal(scale=dw, size=size)
mask = (
(t >= tmin.value)
& (t <= tmax.value)
Expand Down Expand Up @@ -191,6 +195,8 @@ class Source:
Minimum neutron wavelength.
wmax:
Maximum neutron wavelength.
seed:
Seed for the random number generator.
"""

def __init__(
Expand All @@ -201,6 +207,7 @@ def __init__(
sampling: int = 1000,
wmin: Optional[sc.Variable] = None,
wmax: Optional[sc.Variable] = None,
seed: Optional[int] = None,
):
self.facility = facility
self.neutrons = int(neutrons)
Expand All @@ -219,6 +226,7 @@ def __init__(
pulses=pulses,
wmin=wmin,
wmax=wmax,
seed=seed,
)
self.data = sc.DataArray(
data=sc.ones(sizes=pulse_params["time"].sizes, unit="counts"),
Expand Down Expand Up @@ -290,6 +298,7 @@ def from_distribution(
pulses: int = 1,
frequency: Optional[sc.Variable] = None,
sampling: Optional[int] = 1000,
seed: Optional[int] = None,
):
"""
Create source pulses from time a wavelength probability distributions.
Expand All @@ -314,6 +323,8 @@ def from_distribution(
Frequency of the pulse.
sampling:
Number of points used to interpolate the probability distributions.
seed:
Seed for the random number generator.
"""

source = cls(facility=None, neutrons=neutrons, pulses=pulses)
Expand All @@ -325,6 +336,7 @@ def from_distribution(
frequency=source.frequency,
pulses=pulses,
sampling=sampling,
seed=seed,
)
source.data = sc.DataArray(
data=sc.ones(sizes=pulse_params["time"].sizes, unit="counts"),
Expand Down
38 changes: 38 additions & 0 deletions tests/source_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,41 @@ def test_multiple_pulses_from_neutrons_no_frequency_raises():
def test_source_repr_does_not_raise():
assert repr(tof.Source(facility='ess', neutrons=100_000)) is not None
assert repr(tof.Source(facility='ess', neutrons=100_000, pulses=3)) is not None


def test_seed_ess_pulse():
a = tof.Source(facility='ess', neutrons=100_000, seed=1234)
b = tof.Source(facility='ess', neutrons=100_000, seed=1234)
assert sc.identical(a.data, b.data)
c = tof.Source(facility='ess', neutrons=100_000, seed=0)
assert not sc.identical(a.data, c.data)


def test_seed_from_distribution():
v = np.ones(9) * 0.1
v[3:6] = 1.0

p_time = sc.DataArray(
data=sc.array(dims=['time'], values=v),
coords={'time': sc.linspace('time', 0.0, 8000.0, len(v), unit='us')},
)
p_wav = sc.DataArray(
data=sc.array(dims=['wavelength'], values=[1.0, 2.0, 3.0, 4.0]),
coords={
'wavelength': sc.array(
dims=['wavelength'], values=[1.0, 2.0, 3.0, 4.0], unit='angstrom'
)
},
)

a = tof.Source.from_distribution(
neutrons=100_000, p_time=p_time, p_wav=p_wav, seed=12
)
b = tof.Source.from_distribution(
neutrons=100_000, p_time=p_time, p_wav=p_wav, seed=12
)
assert sc.identical(a.data, b.data)
c = tof.Source.from_distribution(
neutrons=100_000, p_time=p_time, p_wav=p_wav, seed=1
)
assert not sc.identical(a.data, c.data)

0 comments on commit c42a2f8

Please sign in to comment.