@@ -119,15 +119,15 @@ def _check_input_shape(self, input: Tensor):
119119class ConditionalDensityEstimator (ConditionalEstimator ):
120120 r"""Base class for density estimators.
121121
122- The density estimator class is a wrapper around neural networks that
123- allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$
124- pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`.
122+ The density estimator class is a wrapper around neural networks that allows to
123+ evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$ pairs. Here
124+ $\theta$ would be the `input` and $x$ would be the `condition`.
125125
126126 Note:
127127 We assume that the input to the density estimator is a tensor of shape
128- (batch_size, input_size ), where input_size is the dimensionality of the input.
129- The condition is a tensor of shape (batch_size, *condition_shape), where
130- condition_shape is the shape of the condition tensor.
128+ (sample_dim, batch_dim, *input_shape ), where input_shape is the dimensionality
129+ of the input. The condition is a tensor of shape (batch_size, *condition_shape),
130+ where condition_shape is the shape of the condition tensor.
131131
132132 """
133133
@@ -226,15 +226,15 @@ def sample_and_log_prob(
226226class ConditionalVectorFieldEstimator (ConditionalEstimator ):
227227 r"""Base class for vector field (e.g., score and ODE flow) estimators.
228228
229- The density estimator class is a wrapper around neural networks that
230- allows to evaluate the `vector_field`, and provide the `loss` of $\theta,x$
231- pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`.
229+ The density estimator class is a wrapper around neural networks that allows to
230+ evaluate the `vector_field`, and provide the `loss` of $\theta,x$ pairs. Here
231+ $\theta$ would be the `input` and $x$ would be the `condition`.
232232
233233 Note:
234234 We assume that the input to the density estimator is a tensor of shape
235- (batch_size, input_size ), where input_size is the dimensionality of the input.
236- The condition is a tensor of shape (batch_size , *condition_shape), where
237- condition_shape is the shape of the condition tensor.
235+ (sample_dim, batch_dim, *input_shape ), where input_shape is the dimensionality
236+ of the input. The condition is a tensor of shape (batch_dim , *condition_shape),
237+ where condition_shape is the shape of the condition tensor.
238238
239239 """
240240
0 commit comments