@@ -105,7 +105,7 @@ def geometry(self):
105105 f"Inconsistent distribution geometry attribute { self ._geometry } and inferred "
106106 f"dimension from distribution variables { inferred_dim } ."
107107 )
108-
108+
109109 # If Geometry dimension is None, update it with the inferred dimension
110110 if inferred_dim and self ._geometry .par_dim is None :
111111 self .geometry = inferred_dim
@@ -117,7 +117,7 @@ def geometry(self):
117117 # We do not use self.name to potentially infer it from python stack.
118118 if self ._name :
119119 self ._geometry ._variable_name = self ._name
120-
120+
121121 return self ._geometry
122122
123123 @geometry .setter
@@ -160,15 +160,15 @@ def logd(self, *args, **kwargs):
160160 f"{ self .logd .__qualname__ } : To evaluate the log density all conditioning variables and main"
161161 f" parameter must be specified. Conditioning variables are: { cond_vars } "
162162 )
163-
163+
164164 # Check if all conditioning variables are specified
165165 all_cond_vars_specified = all ([key in kwargs for key in cond_vars ])
166166 if not all_cond_vars_specified :
167167 raise ValueError (
168168 f"{ self .logd .__qualname__ } : To evaluate the log density all conditioning variables must be"
169169 f" specified. Conditioning variables are: { cond_vars } "
170170 )
171-
171+
172172 # Extract exactly the conditioning variables from kwargs
173173 cond_kwargs = {key : kwargs [key ] for key in cond_vars }
174174
@@ -186,7 +186,7 @@ def logd(self, *args, **kwargs):
186186 # Not conditional distribution, simply evaluate log density directly
187187 else :
188188 return super ().logd (* args , ** kwargs )
189-
189+
190190 def _logd (self , * args ):
191191 return self .logpdf (* args ) # Currently all distributions implement logpdf so we simply call this method.
192192
@@ -216,7 +216,7 @@ def sample(self,N=1,*args,**kwargs):
216216 # Get samples from the distribution sample method
217217 s = self ._sample (N ,* args ,** kwargs )
218218
219- #Store samples in cuqi samples object if more than 1 sample
219+ # Store samples in cuqi samples object if more than 1 sample
220220 if N == 1 :
221221 if len (s ) == 1 and isinstance (s ,np .ndarray ): #Extract single value from numpy array
222222 s = s .ravel ()[0 ]
@@ -264,7 +264,7 @@ def _condition(self, *args, **kwargs):
264264 # Go through every mutable variable and assign value from kwargs if present
265265 for var_key in mutable_vars :
266266
267- #If keyword directly specifies new value of variable we simply reassign
267+ # If keyword directly specifies new value of variable we simply reassign
268268 if var_key in kwargs :
269269 setattr (new_dist , var_key , kwargs .get (var_key ))
270270 processed_kwargs .add (var_key )
@@ -291,9 +291,18 @@ def _condition(self, *args, **kwargs):
291291
292292 elif len (var_args )> 0 : #Some keywords found
293293 # Define new partial function with partially defined args
294- func = partial (var_val , ** var_args )
294+ if (
295+ hasattr (var_val , "_supports_partial_eval" )
296+ and var_val ._supports_partial_eval
297+ ):
298+ func = var_val (** var_args )
299+ else :
300+ # If the callable does not support partial evaluation,
301+ # we use the partial function to set the variable
302+ func = partial (var_val , ** var_args )
303+
295304 setattr (new_dist , var_key , func )
296-
305+
297306 # Store processed keywords
298307 processed_kwargs .update (var_args .keys ())
299308
@@ -329,7 +338,7 @@ def __call__(self, *args, **kwargs) -> Union[Distribution, Likelihood, Evaluated
329338
330339 def get_conditioning_variables (self ):
331340 """Return the conditioning variables of this distribution (if any)."""
332-
341+
333342 # Get all mutable variables
334343 mutable_vars = self .get_mutable_variables ()
335344
@@ -338,7 +347,7 @@ def get_conditioning_variables(self):
338347
339348 # Add any variables defined through callable functions
340349 cond_vars += get_indirect_variables (self )
341-
350+
342351 return cond_vars
343352
344353 def get_mutable_variables (self ):
@@ -347,10 +356,10 @@ def get_mutable_variables(self):
347356 # If mutable variables are already cached, return them
348357 if hasattr (self , '_mutable_vars' ):
349358 return self ._mutable_vars
350-
359+
351360 # Define list of ignored attributes and properties
352361 ignore_vars = ['name' , 'is_symmetric' , 'geometry' , 'dim' ]
353-
362+
354363 # Get public attributes
355364 attributes = get_writeable_attributes (self )
356365
@@ -396,7 +405,7 @@ def _parse_args_add_to_kwargs(self, cond_vars, *args, **kwargs):
396405 raise ValueError (f"{ self ._condition .__qualname__ } : { ordered_keys [index ]} passed as both argument and keyword argument.\n Arguments follow the listed conditioning variable order: { self .get_conditioning_variables ()} " )
397406 kwargs [ordered_keys [index ]] = arg
398407 return kwargs
399-
408+
400409 def _check_geometry_consistency (self ):
401410 """ Checks that the geometry of the distribution is consistent by calling the geometry property. Should be called at the end of __init__ of subclasses. """
402411 self .geometry
@@ -411,4 +420,4 @@ def __repr__(self) -> str:
411420 def rv (self ):
412421 """ Return a random variable object representing the distribution. """
413422 from cuqi .experimental .algebra import RandomVariable
414- return RandomVariable (self )
423+ return RandomVariable (self )
0 commit comments