Skip to content

Commit 7b4dd2e

Browse files
committed
fix: Some rebase issues
1 parent ea15908 commit 7b4dd2e

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

python/nutpie/compiled_pyfunc.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,8 @@ def make_expand_func(seed1, seed2, chain):
8282
make_expand_func,
8383
self._variables,
8484
self.n_dim,
85-
self._make_initial_points,
86-
make_transform_adapter,
87-
make_adapter,
85+
init_point_func=self._make_initial_points,
86+
transform_adapter=make_adapter,
8887
)
8988

9089

src/pyfunc.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use pyo3::{
1717
Bound, Py, PyAny, PyErr, Python,
1818
};
1919
use rand::Rng;
20-
use rand_distr::{Distribution, StandardNormal, Uniform};
20+
use rand_distr::{Distribution, Uniform};
2121
use smallvec::SmallVec;
2222
use thiserror::Error;
2323

@@ -76,7 +76,7 @@ impl PyVariable {
7676
pub struct PyModel {
7777
make_logp_func: Arc<Py<PyAny>>,
7878
make_expand_func: Arc<Py<PyAny>>,
79-
init_point_func: Arc<Option<Py<PyAny>>>,
79+
init_point_func: Option<Arc<Py<PyAny>>>,
8080
variables: Arc<Vec<PyVariable>>,
8181
transform_adapter: Option<PyTransformAdapt>,
8282
ndim: usize,
@@ -85,7 +85,7 @@ pub struct PyModel {
8585
#[pymethods]
8686
impl PyModel {
8787
#[new]
88-
#[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, transform_adapter=None))]
88+
#[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, *, init_point_func=None, transform_adapter=None))]
8989
fn new<'py>(
9090
make_logp_func: Py<PyAny>,
9191
make_expand_func: Py<PyAny>,
@@ -97,7 +97,7 @@ impl PyModel {
9797
Self {
9898
make_logp_func: Arc::new(make_logp_func),
9999
make_expand_func: Arc::new(make_expand_func),
100-
init_point_func: Arc::new(init_point_func),
100+
init_point_func: init_point_func.map(|x| x.into()),
101101
variables: Arc::new(variables),
102102
ndim,
103103
transform_adapter: transform_adapter.map(PyTransformAdapt::new),

src/pymc.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use pyo3::{
1313
types::{PyAnyMethods, PyList},
1414
Bound, Py, PyAny, PyObject, PyResult, Python,
1515
};
16-
use rand::{distributions::Uniform, prelude::Distribution};
1716

1817
use thiserror::Error;
1918

@@ -232,7 +231,7 @@ pub(crate) struct PyMcModel {
232231
dim: usize,
233232
density: LogpFunc,
234233
expand: ExpandFunc,
235-
init_func: Py<PyAny>,
234+
init_func: Arc<Py<PyAny>>,
236235
var_sizes: Vec<usize>,
237236
var_names: Vec<String>,
238237
}
@@ -252,7 +251,7 @@ impl PyMcModel {
252251
dim,
253252
density,
254253
expand,
255-
init_func,
254+
init_func: init_func.into(),
256255
var_names: var_names.extract()?,
257256
var_sizes: var_sizes.extract()?,
258257
})

0 commit comments

Comments
 (0)