Skip to content

Commit 3b38e05

Browse files
authored
Merge pull request #692 from CUQI-DTU/enable_FD_JointDistribution
enable FD for JointDistribution
2 parents f08eb2a + a2eaf13 commit 3b38e05

File tree

4 files changed

+324
-14
lines changed

4 files changed

+324
-14
lines changed

cuqi/density/_density.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,15 @@ def __call__(self, *args, **kwargs):
143143
def enable_FD(self, epsilon=1e-8):
144144
""" Enable finite difference approximation for logd gradient. Note
145145
that if enabled, the FD approximation will be used even if the
146-
_gradient method is implemented. """
146+
_gradient method is implemented.
147+
148+
Parameters
149+
----------
150+
epsilon : float
151+
152+
Spacing (step size) to use for finite difference approximation for logd
153+
gradient for each variable. Default is 1e-8.
154+
"""
147155
self._FD_enabled = True
148156
self._FD_epsilon = epsilon
149157

cuqi/distribution/_joint_distribution.py

Lines changed: 96 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def __init__(self, *densities: [Density, cuqi.experimental.algebra.RandomVariabl
8484
cond_vars = self._get_conditioning_variables()
8585
if len(cond_vars) > 0:
8686
raise ValueError(f"Every density parameter must have a distribution (prior). Missing prior for {cond_vars}.")
87+
# Initialize finite difference gradient approximation settings
88+
self.disable_FD()
8789

8890
# --------- Public properties ---------
8991
@property
@@ -96,6 +98,38 @@ def geometry(self) -> List[Geometry]:
9698
""" Returns the geometries of the joint distribution. """
9799
return [dist.geometry for dist in self._distributions]
98100

101+
@property
102+
def FD_enabled(self):
103+
""" Returns a dictionary of keys and booleans indicating for each
104+
parameter name (key) if finite difference approximation of the logd
105+
gradient is enabled. """
106+
par_names = self.get_parameter_names()
107+
FD_enabled = {
108+
par_name: self.FD_epsilon[par_name] is not None for par_name in par_names
109+
}
110+
return FD_enabled
111+
112+
@property
113+
def FD_epsilon(self):
114+
""" Returns a dictionary indicating for each parameter name the
115+
spacing for the finite difference approximation of the logd gradient."""
116+
return self._FD_epsilon
117+
118+
@FD_epsilon.setter
119+
def FD_epsilon(self, value):
120+
""" Set the spacing for the finite difference approximation of the
121+
logd gradient as a dictionary. The keys are the parameter names.
122+
The value for each key is either None (no FD approximation) or a float
123+
representing the FD step size.
124+
"""
125+
par_names = self.get_parameter_names()
126+
if value is None:
127+
self._FD_epsilon = {par_name: None for par_name in par_names}
128+
else:
129+
if set(value.keys()) != set(par_names):
130+
raise ValueError("Keys of FD_epsilon must match the parameter names of the distribution "+f" {par_names}")
131+
self._FD_epsilon = value
132+
99133
# --------- Public methods ---------
100134
def logd(self, *args, **kwargs):
101135
""" Evaluate the un-normalized log density function. """
@@ -136,6 +170,33 @@ def _condition(self, *args, **kwargs): # Public through __call__
136170
# Can reduce to Posterior, Likelihood or Distribution.
137171
return new_joint._reduce_to_single_density()
138172

173+
def enable_FD(self, epsilon=None):
174+
""" Enable finite difference approximation for logd gradient. Note
175+
that if enabled, the FD approximation will be used even if the
176+
_gradient method is implemented. By default, all parameters
177+
will have FD enabled with a step size of 1e-8.
178+
179+
Parameters
180+
----------
181+
epsilon : dict, *optional*
182+
183+
Dictionary indicating the spacing (step size) to use for finite
184+
difference approximation for logd gradient for each variable.
185+
186+
Keys are variable names.
187+
Values are either a float to enable FD with the given value as the FD
188+
step size, or None to disable FD for that variable. Default is 1e-8 for
189+
all variables.
190+
"""
191+
if epsilon is None:
192+
epsilon = {par_name: 1e-8 for par_name in self.get_parameter_names()}
193+
self.FD_epsilon = epsilon
194+
195+
def disable_FD(self):
196+
""" Disable finite difference approximation for logd gradient. """
197+
par_names = self.get_parameter_names()
198+
self.FD_epsilon = {par_name: None for par_name in par_names}
199+
139200
def get_parameter_names(self) -> List[str]:
140201
""" Returns the parameter names of the joint distribution. """
141202
return [dist.name for dist in self._distributions]
@@ -202,34 +263,58 @@ def _reduce_to_single_density(self):
202263
# Count number of distributions and likelihoods
203264
n_dist = len(self._distributions)
204265
n_likelihood = len(self._likelihoods)
266+
reduced_FD_epsilon = {par_name:self.FD_epsilon[par_name] for par_name in self.get_parameter_names()}
267+
self.enable_FD(epsilon=reduced_FD_epsilon)
205268

206269
# Cant reduce if there are multiple distributions or likelihoods
207270
if n_dist > 1:
208271
return self
209272

273+
# If only evaluated densities left return joint to ensure logd method is available
274+
if n_dist == 0 and n_likelihood == 0:
275+
return self
276+
277+
# Extract the parameter name of the distribution
278+
if n_dist == 1:
279+
par_name = self._distributions[0].name
280+
elif n_likelihood == 1:
281+
par_name = self._likelihoods[0].name
282+
else:
283+
par_name = None
284+
210285
# If exactly one distribution and multiple likelihoods reduce
211286
if n_dist == 1 and n_likelihood > 1:
212-
return MultipleLikelihoodPosterior(*self._densities)
213-
287+
reduced_distribution = MultipleLikelihoodPosterior(*self._densities)
288+
reduced_FD_epsilon = {par_name:self.FD_epsilon[par_name]}
289+
214290
# If exactly one distribution and one likelihood its a Posterior
215291
if n_dist == 1 and n_likelihood == 1:
216292
# Ensure parameter names match, otherwise return the joint distribution
217293
if set(self._likelihoods[0].get_parameter_names()) != set(self._distributions[0].get_parameter_names()):
218294
return self
219-
return self._add_constants_to_density(Posterior(self._likelihoods[0], self._distributions[0]))
295+
reduced_distribution = Posterior(self._likelihoods[0], self._distributions[0])
296+
reduced_distribution = self._add_constants_to_density(reduced_distribution)
297+
reduced_FD_epsilon = self.FD_epsilon[par_name]
220298

221299
# If exactly one distribution and no likelihoods its a Distribution
222300
if n_dist == 1 and n_likelihood == 0:
223-
return self._add_constants_to_density(self._distributions[0])
224-
301+
# Intentionally skip enabling FD here. If the user wants FD, they
302+
# can enable it for this particular distribution before forming
303+
# the joint distribution.
304+
return self._add_constants_to_density(self._distributions[0])
305+
225306
# If no distributions and exactly one likelihood its a Likelihood
226307
if n_likelihood == 1 and n_dist == 0:
227-
return self._likelihoods[0]
308+
# This case seems to not happen in practice, but we include it for
309+
# completeness.
310+
reduced_distribution = self._likelihoods[0]
311+
reduced_FD_epsilon = self.FD_epsilon[par_name]
312+
313+
if self.FD_enabled[par_name]:
314+
reduced_distribution.enable_FD(epsilon=reduced_FD_epsilon)
315+
316+
return reduced_distribution
228317

229-
# If only evaluated densities left return joint to ensure logd method is available
230-
if n_dist == 0 and n_likelihood == 0:
231-
return self
232-
233318
def _add_constants_to_density(self, density: Density):
234319
""" Add the constants (evaluated densities) to a single density. Used when reducing to single density. """
235320

@@ -274,7 +359,7 @@ def __repr__(self):
274359
if len(cond_vars) > 0:
275360
msg += f"|{cond_vars}"
276361
msg += ")"
277-
362+
278363
msg += "\n"
279364
msg += " Densities: \n"
280365

tests/test_joint_distribution.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,4 +484,176 @@ def test_joint_distribution_with_multiple_inputs_model_has_correct_parameter_nam
484484

485485
assert joint_dist(x_dist=x_val, y_dist=y_val, data_dist=np.array([2,2,3])).likelihood.get_parameter_names() == ['z_dist']
486486
assert joint_dist(x_dist=x_val, z_dist=z_val, data_dist=np.array([2,2,3])).likelihood.get_parameter_names() == ['y_dist']
487-
assert joint_dist(y_dist=y_val, z_dist=z_val, data_dist=np.array([2,2,3])).likelihood.get_parameter_names() == ['x_dist']
487+
assert joint_dist(y_dist=y_val, z_dist=z_val, data_dist=np.array([2,2,3])).likelihood.get_parameter_names() == ['x_dist']
488+
489+
490+
def test_FD_enabled_is_set_correctly():
491+
""" Test that FD_enabled property is set correctly in JointDistribution """
492+
493+
# Create a joint distribution with two distributions
494+
d1 = cuqi.distribution.Normal(0, 1, name="x")
495+
d2 = cuqi.distribution.Gamma(lambda x: x**2, 1, name="y")
496+
J = cuqi.distribution.JointDistribution(d1, d2)
497+
498+
# Initially FD should be disabled for both
499+
assert J.FD_enabled == {"x": False, "y": False}
500+
501+
# Enable FD for x
502+
J.enable_FD(epsilon={"x": 1e-6, "y": None})
503+
assert J.FD_enabled == {"x": True, "y": False}
504+
assert J.FD_epsilon == {"x": 1e-6, "y": None}
505+
506+
# Enable FD for y as well
507+
J.enable_FD(epsilon={"x": 1e-6, "y": 1e-5})
508+
assert J.FD_enabled == {"x": True, "y": True}
509+
assert J.FD_epsilon == {"x": 1e-6, "y": 1e-5}
510+
511+
# Disable FD for x
512+
J.enable_FD(epsilon={"x": None, "y": 1e-5})
513+
assert J.FD_enabled == {"x": False, "y": True}
514+
assert J.FD_epsilon == {"x": None, "y": 1e-5}
515+
516+
# Disable FD for all
517+
J.disable_FD()
518+
assert J.FD_enabled == {"x": False, "y": False}
519+
assert J.FD_epsilon == {"x": None, "y": None}
520+
521+
# Enable FD and reduce to single density
522+
J.enable_FD() # Enable FD for all
523+
J_given_x = J(x=0)
524+
J_given_y = J(y=1)
525+
526+
# Check types and FD_enabled status of J_given_x
527+
assert isinstance(J_given_x, cuqi.distribution.Gamma)
528+
assert not J_given_x.FD_enabled # intentionally disabled for single remaining
529+
# distribution
530+
assert J_given_x.FD_epsilon == None
531+
532+
# Check types and FD_enabled status of J_given_y
533+
assert isinstance(J_given_y, cuqi.distribution.Posterior)
534+
assert J_given_y.FD_enabled
535+
assert J_given_y.FD_epsilon == 1e-8 # Default epsilon for remaining density
536+
537+
# Catch error if epsilon keys do not match parameter names
538+
with pytest.raises(ValueError, match=r"Keys of FD_epsilon must match"):
539+
J.enable_FD(epsilon={"x": 1e-6}) # Missing "y" key
540+
541+
def test_FD_enabled_is_set_correctly_for_stacked_joint_distribution():
542+
""" Test that FD_enabled property is set correctly in JointDistribution """
543+
544+
# Create a joint distribution with two distributions
545+
x = cuqi.distribution.Normal(0, 1, name="x")
546+
y = cuqi.distribution.Uniform(1, 2, name="y")
547+
J = cuqi.distribution._StackedJointDistribution(x, y)
548+
J.enable_FD(epsilon={"x": 1e-6, "y": None})
549+
550+
assert J.FD_enabled == {"x": True, "y": False}
551+
assert J.FD_epsilon == {"x": 1e-6, "y": None}
552+
553+
# Reduce to single density (substitute y)
554+
J_given_y = J(y=1.5)
555+
assert isinstance(J_given_y, cuqi.distribution.Normal)
556+
assert J_given_y.FD_enabled == False # Intentionally disabled for
557+
# single remaining
558+
# distribution
559+
assert J_given_y.FD_epsilon is None
560+
561+
# Reduce to single density (substitute x)
562+
J_given_x = J(x=0)
563+
assert isinstance(J_given_x, cuqi.distribution.Uniform)
564+
assert J_given_x.FD_enabled == False
565+
assert J_given_x.FD_epsilon is None
566+
567+
568+
569+
@pytest.mark.parametrize(
570+
"densities,kwargs,fd_epsilon,expected_type,expected_fd_enabled",
571+
[
572+
# Case 0: Single Distribution, FD enabled
573+
(
574+
[cuqi.distribution.Normal(np.zeros(3), 1, name="x")],
575+
{},
576+
{"x": 1e-5},
577+
cuqi.distribution.Normal,
578+
False, # Intentionally disabled for single remaining distribution
579+
),
580+
# Case 1: Single Distribution, FD disabled
581+
(
582+
[cuqi.distribution.Normal(np.zeros(3), 1, name="x")],
583+
{},
584+
{"x": None},
585+
cuqi.distribution.Normal,
586+
False,
587+
),
588+
# Case 2: Distribution + Data distribution, substitute y
589+
(
590+
[
591+
cuqi.distribution.Normal(np.zeros(3), 1, name="x"),
592+
cuqi.distribution.Gaussian(lambda x: x**2, np.ones(3), name="y"),
593+
],
594+
{"y": np.ones(3)},
595+
{"x": 1e-6, "y": 1e-7},
596+
cuqi.distribution.Posterior,
597+
True,
598+
),
599+
# Case 3: Distribution + data distribution, substitute x
600+
(
601+
[
602+
cuqi.distribution.Normal(np.zeros(3), 1, name="x"),
603+
cuqi.distribution.Gaussian(lambda x: x**2, np.ones(3), name="y"),
604+
],
605+
{"x": np.ones(3)},
606+
{"x": 1e-5, "y": 1e-6},
607+
cuqi.distribution.Distribution,
608+
False, # Intentionally disabled for single remaining distribution
609+
),
610+
# Case 4: Multiple data distributions + prior (MultipleLikelihoodPosterior)
611+
(
612+
[
613+
cuqi.distribution.Normal(np.zeros(3), 1, name="x"),
614+
cuqi.distribution.Gaussian(lambda x: x, np.ones(3), name="y1"),
615+
cuqi.distribution.Gaussian(lambda x: x + 1, np.ones(3), name="y2"),
616+
],
617+
{"y1": np.ones(3), "y2": np.ones(3)},
618+
{"x": 1e-5, "y1": 1e-6, "y2": 1e-7},
619+
cuqi.distribution.MultipleLikelihoodPosterior,
620+
{"x": True},
621+
),
622+
# Case 5: Distribution, substitute x
623+
(
624+
[cuqi.distribution.Normal(np.zeros(3), 1, name="x")],
625+
{"x": np.ones(3)},
626+
{"x": 1e-8},
627+
cuqi.distribution.JointDistribution,
628+
{},
629+
),
630+
],
631+
)
632+
def test_fd_enabled_of_joint_distribution_after_substitution_is_correct(
633+
densities, kwargs, fd_epsilon, expected_type, expected_fd_enabled
634+
):
635+
""" Test that FD_enabled and FD_epsilon properties are set correctly in JointDistribution even after substitution."""
636+
joint = cuqi.distribution.JointDistribution(*densities)
637+
joint.enable_FD(epsilon=fd_epsilon)
638+
639+
# Assert FD_epsilon is set correctly
640+
assert joint.FD_epsilon == fd_epsilon
641+
642+
# Substitute parameters (if any), which reduces the joint distribution
643+
reduced = joint(**kwargs)
644+
645+
# Assert the type and FD_enabled status of the reduced distribution
646+
assert isinstance(reduced, expected_type)
647+
assert reduced.FD_enabled == expected_fd_enabled
648+
649+
# Assert FD_epsilon is set correctly in the reduced distribution
650+
if expected_fd_enabled is not False:
651+
fd_epsilon_reduced = {
652+
k: v for k, v in fd_epsilon.items() if k not in kwargs.keys()
653+
}
654+
if len(fd_epsilon_reduced) == 1 and not isinstance(
655+
reduced, cuqi.distribution.MultipleLikelihoodPosterior
656+
):
657+
# Single value instead of dict in this case
658+
fd_epsilon_reduced = list(fd_epsilon_reduced.values())[0]
659+
assert reduced.FD_epsilon == fd_epsilon_reduced

0 commit comments

Comments
 (0)