Skip to content

Commit 6325eb7

Browse files
authored
Merge pull request #597 from CUQI-DTU/add_577_random_variable_class
Add random variable class
2 parents 6f80078 + 7062b88 commit 6325eb7

File tree

9 files changed

+857
-49
lines changed

9 files changed

+857
-49
lines changed

cuqi/distribution/_joint_distribution.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from cuqi.distribution import Distribution, Posterior
66
from cuqi.likelihood import Likelihood
77
from cuqi.geometry import Geometry, _DefaultGeometry1D
8+
import cuqi
89
import numpy as np # for splitting array. Can avoid.
910

1011
class JointDistribution:
@@ -13,9 +14,11 @@ class JointDistribution:
1314
1415
Parameters
1516
----------
16-
densities : Density
17+
densities : RandomVariable or Density
1718
The densities to include in the joint distribution.
18-
Each density is passed as comma-separated arguments.
19+
Each density is passed as comma-separated arguments,
20+
and can be either a :class:'Density' such as :class:'Distribution'
21+
or :class:`RandomVariable`.
1922
2023
Notes
2124
-----
@@ -59,7 +62,16 @@ class JointDistribution:
5962
posterior = joint(y=y_obs)
6063
6164
"""
62-
def __init__(self, *densities: Density):
65+
def __init__(self, *densities: [Density, cuqi.experimental.algebra.RandomVariable]):
66+
""" Create a joint distribution from the given densities. """
67+
68+
# Check if all RandomVariables are simple (not-transformed)
69+
for density in densities:
70+
if isinstance(density, cuqi.experimental.algebra.RandomVariable) and density.is_transformed:
71+
raise ValueError(f"To be used in {self.__class__.__name__}, all RandomVariables must be untransformed.")
72+
73+
# Convert potential random variables to their underlying distribution
74+
densities = [density.distribution if isinstance(density, cuqi.experimental.algebra.RandomVariable) else density for density in densities]
6375

6476
# Ensure all densities have unique names
6577
names = [density.name for density in densities]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from ._ast import VariableNode
1+
from ._ast import VariableNode, Node
2+
from ._randomvariable import RandomVariable

cuqi/experimental/algebra/_ast.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ def __repr__(self):
5656
"""String representation of the node. Used for printing the AST."""
5757
pass
5858

59+
def get_variables(self, variables=None):
60+
"""Returns a set with the names of all variables in the sub-tree originated at this node."""
61+
if variables is None:
62+
variables = set()
63+
if isinstance(self, VariableNode):
64+
variables.add(self.name)
65+
if hasattr(self, "child"):
66+
self.child.get_variables(variables)
67+
if hasattr(self, "left"):
68+
self.left.get_variables(variables)
69+
if hasattr(self, "right"):
70+
self.right.get_variables(variables)
71+
return variables
72+
5973
def __add__(self, other):
6074
return AddNode(self, convert_to_node(other))
6175

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
class _OrderedSet:
2+
"""A set (i.e. unique elements) that keeps its elements in the order they were added.
3+
4+
This is a minimal implementation of an ordered set, using a dictionary for storage.
5+
"""
6+
7+
def __init__(self, iterable=None):
8+
"""Initialize the OrderedSet.
9+
10+
If an iterable is provided, add all its elements to the set.
11+
"""
12+
self.dict = dict.fromkeys(iterable if iterable else [])
13+
14+
def add(self, item):
15+
"""Add an item to the set.
16+
17+
If the item is already in the set, it does nothing.
18+
Otherwise, the item is stored as a key in the dictionary, with None as its value.
19+
"""
20+
self.dict[item] = None
21+
22+
def __contains__(self, item):
23+
"""Check if an item is in the set.
24+
25+
This is equivalent to checking if the item is a key in the dictionary.
26+
"""
27+
return item in self.dict
28+
29+
def __iter__(self):
30+
"""Return an iterator over the set.
31+
32+
This iterates over the keys in the dictionary.
33+
"""
34+
return iter(self.dict)
35+
36+
def __len__(self):
37+
"""Return the number of items in the set."""
38+
return len(self.dict)
39+
40+
def extend(self, other):
41+
"""Extend the set with the items in another set.
42+
43+
Raises a TypeError if the other object is not an _OrderedSet.
44+
"""
45+
if not isinstance(other, _OrderedSet):
46+
raise TypeError("unsupported operand type(s) for extend: '_OrderedSet' and '{}'".format(type(other).__name__))
47+
for item in other:
48+
self.add(item)
49+
50+
def __or__(self, other):
51+
"""Return a new set that is the union of this set and another set.
52+
53+
Raises a TypeError if the other object is not an _OrderedSet.
54+
"""
55+
if not isinstance(other, _OrderedSet):
56+
raise TypeError("unsupported operand type(s) for |: '_OrderedSet' and '{}'".format(type(other).__name__))
57+
new_set = _OrderedSet(self.dict.keys())
58+
new_set.extend(other)
59+
return new_set

0 commit comments

Comments
 (0)