Skip to content

Commit f6a73b0

Browse files
authored
Merge pull request #664 from CUQI-DTU/multiple_input_bayesian_modeling
Multiple input bayesian modeling
2 parents 2071a1f + 00f6ec4 commit f6a73b0

File tree

9 files changed

+1012
-101
lines changed

9 files changed

+1012
-101
lines changed

cuqi/distribution/_distribution.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def geometry(self):
105105
f"Inconsistent distribution geometry attribute {self._geometry} and inferred "
106106
f"dimension from distribution variables {inferred_dim}."
107107
)
108-
108+
109109
# If Geometry dimension is None, update it with the inferred dimension
110110
if inferred_dim and self._geometry.par_dim is None:
111111
self.geometry = inferred_dim
@@ -117,7 +117,7 @@ def geometry(self):
117117
# We do not use self.name to potentially infer it from python stack.
118118
if self._name:
119119
self._geometry._variable_name = self._name
120-
120+
121121
return self._geometry
122122

123123
@geometry.setter
@@ -160,15 +160,15 @@ def logd(self, *args, **kwargs):
160160
f"{self.logd.__qualname__}: To evaluate the log density all conditioning variables and main"
161161
f" parameter must be specified. Conditioning variables are: {cond_vars}"
162162
)
163-
163+
164164
# Check if all conditioning variables are specified
165165
all_cond_vars_specified = all([key in kwargs for key in cond_vars])
166166
if not all_cond_vars_specified:
167167
raise ValueError(
168168
f"{self.logd.__qualname__}: To evaluate the log density all conditioning variables must be"
169169
f" specified. Conditioning variables are: {cond_vars}"
170170
)
171-
171+
172172
# Extract exactly the conditioning variables from kwargs
173173
cond_kwargs = {key: kwargs[key] for key in cond_vars}
174174

@@ -186,7 +186,7 @@ def logd(self, *args, **kwargs):
186186
# Not conditional distribution, simply evaluate log density directly
187187
else:
188188
return super().logd(*args, **kwargs)
189-
189+
190190
def _logd(self, *args):
191191
return self.logpdf(*args) # Currently all distributions implement logpdf so we simply call this method.
192192

@@ -216,7 +216,7 @@ def sample(self,N=1,*args,**kwargs):
216216
# Get samples from the distribution sample method
217217
s = self._sample(N,*args,**kwargs)
218218

219-
#Store samples in cuqi samples object if more than 1 sample
219+
# Store samples in cuqi samples object if more than 1 sample
220220
if N==1:
221221
if len(s) == 1 and isinstance(s,np.ndarray): #Extract single value from numpy array
222222
s = s.ravel()[0]
@@ -264,7 +264,7 @@ def _condition(self, *args, **kwargs):
264264
# Go through every mutable variable and assign value from kwargs if present
265265
for var_key in mutable_vars:
266266

267-
#If keyword directly specifies new value of variable we simply reassign
267+
# If keyword directly specifies new value of variable we simply reassign
268268
if var_key in kwargs:
269269
setattr(new_dist, var_key, kwargs.get(var_key))
270270
processed_kwargs.add(var_key)
@@ -291,9 +291,18 @@ def _condition(self, *args, **kwargs):
291291

292292
elif len(var_args)>0: #Some keywords found
293293
# Define new partial function with partially defined args
294-
func = partial(var_val, **var_args)
294+
if (
295+
hasattr(var_val, "_supports_partial_eval")
296+
and var_val._supports_partial_eval
297+
):
298+
func = var_val(**var_args)
299+
else:
300+
# If the callable does not support partial evaluation,
301+
# we use the partial function to set the variable
302+
func = partial(var_val, **var_args)
303+
295304
setattr(new_dist, var_key, func)
296-
305+
297306
# Store processed keywords
298307
processed_kwargs.update(var_args.keys())
299308

@@ -329,7 +338,7 @@ def __call__(self, *args, **kwargs) -> Union[Distribution, Likelihood, Evaluated
329338

330339
def get_conditioning_variables(self):
331340
"""Return the conditioning variables of this distribution (if any)."""
332-
341+
333342
# Get all mutable variables
334343
mutable_vars = self.get_mutable_variables()
335344

@@ -338,7 +347,7 @@ def get_conditioning_variables(self):
338347

339348
# Add any variables defined through callable functions
340349
cond_vars += get_indirect_variables(self)
341-
350+
342351
return cond_vars
343352

344353
def get_mutable_variables(self):
@@ -347,10 +356,10 @@ def get_mutable_variables(self):
347356
# If mutable variables are already cached, return them
348357
if hasattr(self, '_mutable_vars'):
349358
return self._mutable_vars
350-
359+
351360
# Define list of ignored attributes and properties
352361
ignore_vars = ['name', 'is_symmetric', 'geometry', 'dim']
353-
362+
354363
# Get public attributes
355364
attributes = get_writeable_attributes(self)
356365

@@ -396,7 +405,7 @@ def _parse_args_add_to_kwargs(self, cond_vars, *args, **kwargs):
396405
raise ValueError(f"{self._condition.__qualname__}: {ordered_keys[index]} passed as both argument and keyword argument.\nArguments follow the listed conditioning variable order: {self.get_conditioning_variables()}")
397406
kwargs[ordered_keys[index]] = arg
398407
return kwargs
399-
408+
400409
def _check_geometry_consistency(self):
401410
""" Checks that the geometry of the distribution is consistent by calling the geometry property. Should be called at the end of __init__ of subclasses. """
402411
self.geometry
@@ -411,4 +420,4 @@ def __repr__(self) -> str:
411420
def rv(self):
412421
""" Return a random variable object representing the distribution. """
413422
from cuqi.experimental.algebra import RandomVariable
414-
return RandomVariable(self)
423+
return RandomVariable(self)

cuqi/likelihood/_likelihood.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,4 +212,4 @@ def get_parameter_names(self):
212212
return get_non_default_args(self.logpdf_func)
213213

214214
def __repr__(self) -> str:
215-
return "CUQI {} function. Parameters {}.".format(self.__class__.__name__,self.get_parameter_names())
215+
return "CUQI {} function. Parameters {}.".format(self.__class__.__name__,self.get_parameter_names())

0 commit comments

Comments
 (0)