|
5 | 5 | import operator |
6 | 6 | import cuqi |
7 | 7 | from cuqi.distribution import Distribution |
8 | | -from copy import copy |
9 | | - |
| 8 | +from copy import copy, deepcopy |
10 | 9 |
|
11 | 10 | class RandomVariable: |
12 | 11 | """ Random variable defined by a distribution with the option to apply algebraic operations on it. |
@@ -210,7 +209,7 @@ def distributions(self) -> set: |
210 | 209 | def parameter_names(self) -> str: |
211 | 210 | """ Name of the parameter that the random variable can be evaluated at. """ |
212 | 211 | self._inject_name_into_distribution() |
213 | | - return [distribution.name for distribution in self.distributions] # Consider renaming .name to .par_name for distributions |
| 212 | + return [distribution._name for distribution in self.distributions] # Consider renaming .name to .par_name for distributions |
214 | 213 |
|
215 | 214 | @property |
216 | 215 | def dim(self): |
@@ -239,21 +238,89 @@ def expression(self): |
239 | 238 | def is_transformed(self): |
240 | 239 | """ Returns True if the random variable is transformed. """ |
241 | 240 | return not isinstance(self.tree, VariableNode) |
242 | | - |
| 241 | + |
| 242 | + @property |
| 243 | + def is_cond(self): |
| 244 | + """ Returns True if the random variable is a conditional random variable. """ |
| 245 | + return any(dist.is_cond for dist in self.distributions) |
| 246 | + |
| 247 | + def condition(self, *args, **kwargs): |
| 248 | + """Condition the random variable on a given value. Only one of either positional or keyword arguments can be passed. |
| 249 | + |
| 250 | + Parameters |
| 251 | + ---------- |
| 252 | + *args : Any |
| 253 | + Positional arguments to condition the random variable on. The order of the arguments must match the order of the parameter names. |
| 254 | +
|
| 255 | + **kwargs : Any |
| 256 | + Keyword arguments to condition the random variable on. The keys must match the parameter names. |
| 257 | + |
| 258 | + """ |
| 259 | + |
| 260 | + # Before conditioning, capture repr to ensure all variable names are injected |
| 261 | + self.__repr__() |
| 262 | + |
| 263 | + if args and kwargs: |
| 264 | + raise ValueError("Cannot pass both positional and keyword arguments to RandomVariable") |
| 265 | + |
| 266 | + if args: |
| 267 | + kwargs = self._parse_args_add_to_kwargs(args, kwargs) |
| 268 | + |
| 269 | + # Create a deep copy of the random variable to ensure the original tree is not modified |
| 270 | + new_variable = self._make_copy(deep=True) |
| 271 | + |
| 272 | + for kwargs_name in list(kwargs.keys()): |
| 273 | + value = kwargs.pop(kwargs_name) |
| 274 | + |
| 275 | + # Condition the tree turning the variable into a constant |
| 276 | + if kwargs_name in self.parameter_names: |
| 277 | + new_variable._tree = new_variable.tree.condition(**{kwargs_name: value}) |
| 278 | + |
| 279 | + # Condition the random variable on both the distribution parameter name and distribution conditioning variables |
| 280 | + for dist in self.distributions: |
| 281 | + if kwargs_name == dist.name: |
| 282 | + new_variable._remove_distribution(dist.name) |
| 283 | + elif kwargs_name in dist.get_conditioning_variables(): |
| 284 | + new_variable._replace_distribution(dist.name, dist(**{kwargs_name: value})) |
| 285 | + |
| 286 | + # Check if any kwargs are left unprocessed |
| 287 | + if kwargs: |
| 288 | + raise ValueError(f"Conditioning variables {list(kwargs.keys())} not found in the random variable {self}") |
| 289 | + |
| 290 | + return new_variable |
| 291 | + |
243 | 292 | @property |
244 | 293 | def _non_default_args(self) -> List[str]: |
245 | 294 | """List of non-default arguments to distribution. This is used to return the correct |
246 | 295 | arguments when evaluating the random variable. |
247 | 296 | """ |
248 | 297 | return self.parameter_names |
249 | 298 |
|
| 299 | + def _replace_distribution(self, name, new_distribution): |
| 300 | + """ Replace distribution with a given name with a new distribution in the same position of the ordered set. """ |
| 301 | + for dist in self.distributions: |
| 302 | + if dist._name == name: |
| 303 | + self._distributions.replace(dist, new_distribution) |
| 304 | + break |
| 305 | + |
| 306 | + def _remove_distribution(self, name): |
| 307 | + """ Remove distribution with a given name from the set of distributions. """ |
| 308 | + for dist in self.distributions: |
| 309 | + if dist._name == name: |
| 310 | + self._distributions.remove(dist) |
| 311 | + break |
| 312 | + |
250 | 313 | def _inject_name_into_distribution(self, name=None): |
251 | 314 | if len(self._distributions) == 1: |
252 | 315 | dist = next(iter(self._distributions)) |
| 316 | + |
| 317 | + if dist._is_copy: |
| 318 | + dist = dist._original_density |
| 319 | + |
253 | 320 | if dist._name is None: |
254 | 321 | if name is None: |
255 | 322 | name = self.name |
256 | | - dist._name = name |
| 323 | + dist.name = name # Inject using setter |
257 | 324 |
|
258 | 325 | def _parse_args_add_to_kwargs(self, args, kwargs) -> dict: |
259 | 326 | """ Parse args and add to kwargs if any. Arguments follow self.parameter_names order. """ |
@@ -293,8 +360,12 @@ def _is_copy(self): |
293 | 360 | """ Returns True if this is a copy of another random variable, e.g. by conditioning. """ |
294 | 361 | return hasattr(self, '_original_variable') and self._original_variable is not None |
295 | 362 |
|
296 | | - def _make_copy(self): |
297 | | - """ Returns a shallow copy of the density keeping a pointer to the original. """ |
| 363 | + def _make_copy(self, deep=False) -> 'RandomVariable': |
| 364 | + """ Returns a copy of the density keeping a pointer to the original. """ |
| 365 | + if deep: |
| 366 | + new_variable = deepcopy(self) |
| 367 | + new_variable._original_variable = self |
| 368 | + return new_variable |
298 | 369 | new_variable = copy(self) |
299 | 370 | new_variable._distributions = copy(self.distributions) |
300 | 371 | new_variable._tree = copy(self._tree) |
|
0 commit comments