Skip to content

Commit 4d0949e

Browse files
authored
Merge pull request #600 from CUQI-DTU/add_conditioning_RV_578
Add support for conditioning on random variables
2 parents 4ef35e3 + 3d3bdb4 commit 4d0949e

File tree

6 files changed

+187
-16
lines changed

6 files changed

+187
-16
lines changed

cuqi/experimental/algebra/_ast.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def __call__(self, **kwargs):
5151
"""Evaluate node at a given parameter value. This will traverse the sub-tree originated at this node and evaluate it given the recorded operations."""
5252
pass
5353

54+
@abstractmethod
55+
def condition(self, **kwargs):
56+
""" Conditions the tree by replacing any VariableNode with a ValueNode if the variable is in the kwargs dictionary. """
57+
pass
58+
5459
@abstractmethod
5560
def __repr__(self):
5661
"""String representation of the node. Used for printing the AST."""
@@ -129,6 +134,9 @@ class UnaryNode(Node, ABC):
129134
def __init__(self, child: Node):
130135
self.child = child
131136

137+
def condition(self, **kwargs):
138+
return self.__class__(self.child.condition(**kwargs))
139+
132140

133141
class BinaryNode(Node, ABC):
134142
"""Base class for all binary nodes in the abstract syntax tree.
@@ -155,6 +163,9 @@ def __init__(self, left: Node, right: Node):
155163
self.left = left
156164
self.right = right
157165

166+
def condition(self, **kwargs):
167+
return self.__class__(self.left.condition(**kwargs), self.right.condition(**kwargs))
168+
158169
def __repr__(self):
159170
return f"{self.left} {self.op_symbol} {self.right}"
160171

@@ -205,6 +216,11 @@ def __call__(self, **kwargs):
205216
)
206217
return kwargs[self.name]
207218

219+
def condition(self, **kwargs):
220+
if self.name in kwargs:
221+
return ValueNode(kwargs[self.name])
222+
return self
223+
208224
def __repr__(self):
209225
return self.name
210226

@@ -226,6 +242,9 @@ def __call__(self, **kwargs):
226242
"""Returns the value of the node."""
227243
return self.value
228244

245+
def condition(self, **kwargs):
246+
return self
247+
229248
def __repr__(self):
230249
return str(self.value)
231250

cuqi/experimental/algebra/_orderedset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ def add(self, item):
1919
"""
2020
self.dict[item] = None
2121

22+
def remove(self, item):
23+
"""Remove an item from the set.
24+
25+
If the item is not in the set, it raises a KeyError.
26+
"""
27+
del self.dict[item]
28+
2229
def __contains__(self, item):
2330
"""Check if an item is in the set.
2431
@@ -47,6 +54,18 @@ def extend(self, other):
4754
for item in other:
4855
self.add(item)
4956

57+
def replace(self, old_item, new_item):
58+
"""Replace old_item with new_item at the same position, preserving order."""
59+
if old_item not in self.dict:
60+
raise KeyError(f"{old_item} not in set")
61+
62+
items = list(self.dict.keys()) # Preserve order
63+
index = items.index(old_item) # Find position
64+
items[index] = new_item # Replace at the same position
65+
66+
# Reconstruct the ordered set with the new item in place
67+
self.dict = dict.fromkeys(items)
68+
5069
def __or__(self, other):
5170
"""Return a new set that is the union of this set and another set.
5271
@@ -57,3 +76,7 @@ def __or__(self, other):
5776
new_set = _OrderedSet(self.dict.keys())
5877
new_set.extend(other)
5978
return new_set
79+
80+
def __repr__(self):
81+
"""Return a string representation of the set."""
82+
return "_OrderedSet({})".format(list(self.dict.keys()))

cuqi/experimental/algebra/_randomvariable.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import operator
66
import cuqi
77
from cuqi.distribution import Distribution
8-
from copy import copy
9-
8+
from copy import copy, deepcopy
109

1110
class RandomVariable:
1211
""" Random variable defined by a distribution with the option to apply algebraic operations on it.
@@ -210,7 +209,7 @@ def distributions(self) -> set:
210209
def parameter_names(self) -> str:
211210
""" Name of the parameter that the random variable can be evaluated at. """
212211
self._inject_name_into_distribution()
213-
return [distribution.name for distribution in self.distributions] # Consider renaming .name to .par_name for distributions
212+
return [distribution._name for distribution in self.distributions] # Consider renaming .name to .par_name for distributions
214213

215214
@property
216215
def dim(self):
@@ -239,21 +238,89 @@ def expression(self):
239238
def is_transformed(self):
240239
""" Returns True if the random variable is transformed. """
241240
return not isinstance(self.tree, VariableNode)
242-
241+
242+
@property
243+
def is_cond(self):
244+
""" Returns True if the random variable is a conditional random variable. """
245+
return any(dist.is_cond for dist in self.distributions)
246+
247+
def condition(self, *args, **kwargs):
248+
"""Condition the random variable on a given value. Only one of either positional or keyword arguments can be passed.
249+
250+
Parameters
251+
----------
252+
*args : Any
253+
Positional arguments to condition the random variable on. The order of the arguments must match the order of the parameter names.
254+
255+
**kwargs : Any
256+
Keyword arguments to condition the random variable on. The keys must match the parameter names.
257+
258+
"""
259+
260+
# Before conditioning, capture repr to ensure all variable names are injected
261+
self.__repr__()
262+
263+
if args and kwargs:
264+
raise ValueError("Cannot pass both positional and keyword arguments to RandomVariable")
265+
266+
if args:
267+
kwargs = self._parse_args_add_to_kwargs(args, kwargs)
268+
269+
# Create a deep copy of the random variable to ensure the original tree is not modified
270+
new_variable = self._make_copy(deep=True)
271+
272+
for kwargs_name in list(kwargs.keys()):
273+
value = kwargs.pop(kwargs_name)
274+
275+
# Condition the tree turning the variable into a constant
276+
if kwargs_name in self.parameter_names:
277+
new_variable._tree = new_variable.tree.condition(**{kwargs_name: value})
278+
279+
# Condition the random variable on both the distribution parameter name and distribution conditioning variables
280+
for dist in self.distributions:
281+
if kwargs_name == dist.name:
282+
new_variable._remove_distribution(dist.name)
283+
elif kwargs_name in dist.get_conditioning_variables():
284+
new_variable._replace_distribution(dist.name, dist(**{kwargs_name: value}))
285+
286+
# Check if any kwargs are left unprocessed
287+
if kwargs:
288+
raise ValueError(f"Conditioning variables {list(kwargs.keys())} not found in the random variable {self}")
289+
290+
return new_variable
291+
243292
@property
244293
def _non_default_args(self) -> List[str]:
245294
"""List of non-default arguments to distribution. This is used to return the correct
246295
arguments when evaluating the random variable.
247296
"""
248297
return self.parameter_names
249298

299+
def _replace_distribution(self, name, new_distribution):
300+
""" Replace distribution with a given name with a new distribution in the same position of the ordered set. """
301+
for dist in self.distributions:
302+
if dist._name == name:
303+
self._distributions.replace(dist, new_distribution)
304+
break
305+
306+
def _remove_distribution(self, name):
307+
""" Remove distribution with a given name from the set of distributions. """
308+
for dist in self.distributions:
309+
if dist._name == name:
310+
self._distributions.remove(dist)
311+
break
312+
250313
def _inject_name_into_distribution(self, name=None):
251314
if len(self._distributions) == 1:
252315
dist = next(iter(self._distributions))
316+
317+
if dist._is_copy:
318+
dist = dist._original_density
319+
253320
if dist._name is None:
254321
if name is None:
255322
name = self.name
256-
dist._name = name
323+
dist.name = name # Inject using setter
257324

258325
def _parse_args_add_to_kwargs(self, args, kwargs) -> dict:
259326
""" Parse args and add to kwargs if any. Arguments follow self.parameter_names order. """
@@ -293,8 +360,12 @@ def _is_copy(self):
293360
""" Returns True if this is a copy of another random variable, e.g. by conditioning. """
294361
return hasattr(self, '_original_variable') and self._original_variable is not None
295362

296-
def _make_copy(self):
297-
""" Returns a shallow copy of the density keeping a pointer to the original. """
363+
def _make_copy(self, deep=False) -> 'RandomVariable':
364+
""" Returns a copy of the density keeping a pointer to the original. """
365+
if deep:
366+
new_variable = deepcopy(self)
367+
new_variable._original_variable = self
368+
return new_variable
298369
new_variable = copy(self)
299370
new_variable._distributions = copy(self.distributions)
300371
new_variable._tree = copy(self._tree)

cuqi/likelihood/_likelihood.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def name(self):
4343
def name(self, value):
4444
self.distribution.name = value
4545

46+
@property
47+
def _name(self):
48+
return self.distribution._name
49+
50+
@_name.setter
51+
def _name(self, value):
52+
self.distribution._name = value
53+
4654
@property
4755
def FD_enabled(self):
4856
""" Return FD_enabled of the likelihood from the underlying distribution """

demos/howtos/algebra.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,41 @@
8383
BP.UQ(exact={"x": info.exactSolution})
8484

8585
# %%
86+
# Conditioning on random variables (example 1)
87+
from cuqi.distribution import Gaussian
88+
from cuqi.experimental.algebra import RandomVariable
89+
90+
x = RandomVariable(Gaussian(0, lambda s: s))
91+
y = RandomVariable(Gaussian(0, lambda d: d))
92+
93+
z = x+y
94+
95+
z.condition(x=1)
96+
97+
# %%
98+
# Or conditioning on the variables s, or d
99+
z.condition(s=1)
100+
101+
# %%
102+
# Conditioning on random variables (example 2)
103+
from cuqi.testproblem import Deconvolution1D
104+
from cuqi.distribution import Gaussian, Gamma, GMRF
105+
from cuqi.experimental.algebra import RandomVariable
106+
from cuqi.problem import BayesianProblem
107+
import numpy as np
108+
109+
# Forward model
110+
A, y_obs, info = Deconvolution1D(dim=4).get_components()
111+
112+
# Bayesian Problem (defined using Random Variables)
113+
d = RandomVariable(Gamma(1, 1e-4))
114+
s = RandomVariable(Gamma(1, 1e-4))
115+
x = RandomVariable(GMRF(np.zeros(A.domain_dim), d))
116+
y = RandomVariable(Gaussian(A @ x, 1/s))
117+
118+
119+
z = x+y
120+
121+
z.condition(x=np.zeros(A.domain_dim))
122+
123+
# %%

tests/zexperimental/test_randomvariable.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def test_randomvariable_returns_correct_parameter_name():
7272
z = RandomVariable(cuqi.distribution.Gaussian(0, 1))
7373
assert cuqi.utilities.get_non_default_args(z) == ["z"]
7474

75-
@pytest.mark.xfail(reason="Conditional random variables are not yet implemented")
7675
@pytest.mark.parametrize("operations", [
7776
lambda x: x+1,
7877
lambda x: x**2,
@@ -258,7 +257,6 @@ def recursive_return_rv(rv, recursions):
258257

259258
recursive_return_rv(h, 10)
260259

261-
@pytest.mark.xfail(reason="Conditional random variables are not yet implemented")
262260
def test_rv_name_consistency():
263261

264262
x = RandomVariable(cuqi.distribution.Gaussian(geometry=1))
@@ -306,13 +304,27 @@ def test_ensure_that_RV_evaluation_requires_all_parameters():
306304
with pytest.raises(ValueError, match=r"Expected arguments \['x', 'y'\], got arguments \{'x': 1\}"):
307305
z(x=1)
308306

309-
@pytest.mark.xfail(reason="Conditional random variables are not yet implemented")
310307
def test_RV_sets_name_of_internal_conditional_density_if_par_name_not_set_and_does_not_set_original_density():
311-
Z_s = cuqi.distribution.Gaussian(0, lambda s: s)
312-
Z = Z_s(s=3)
313-
z = RandomVariable(Z)
308+
# Case 1
309+
z = RandomVariable(cuqi.distribution.Gaussian(0, lambda s: s))
314310

315311
assert z.name == 'z'
316-
assert z.distribution.par_name == "z"
317-
assert Z_s.par_name is None # Should not be set for the original density.
318-
assert Z.par_name is None # Should not be set for the conditioned density.
312+
assert z.distribution.name == "z"
313+
314+
# Case 2 (conditioned density. Should not be able to set name here)
315+
z = RandomVariable(cuqi.distribution.Gaussian(0, lambda s: s)(s=3))
316+
317+
assert z.name == 'z'
318+
assert z.distribution.name == "z"
319+
320+
def test_RV_condition_maintains_parameter_name_order():
321+
322+
x = RandomVariable(cuqi.distribution.Gaussian(0, lambda s: s))
323+
y = RandomVariable(cuqi.distribution.Gaussian(0, lambda d: d))
324+
325+
z = x+y
326+
327+
assert z.parameter_names == ['x', 'y']
328+
assert z.condition(s=1).parameter_names == ['x', 'y']
329+
assert z.condition(d=1).parameter_names == ['x', 'y']
330+
assert z.condition(d=1, s=1).parameter_names == ['x', 'y']

0 commit comments

Comments
 (0)