@@ -17,7 +17,7 @@ use pyo3::{
17
17
Bound , Py , PyAny , PyErr , Python ,
18
18
} ;
19
19
use rand:: Rng ;
20
- use rand_distr:: { Distribution , StandardNormal , Uniform } ;
20
+ use rand_distr:: { Distribution , Uniform } ;
21
21
use smallvec:: SmallVec ;
22
22
use thiserror:: Error ;
23
23
@@ -76,7 +76,7 @@ impl PyVariable {
76
76
pub struct PyModel {
77
77
make_logp_func : Arc < Py < PyAny > > ,
78
78
make_expand_func : Arc < Py < PyAny > > ,
79
- init_point_func : Arc < Option < Py < PyAny > > > ,
79
+ init_point_func : Option < Arc < Py < PyAny > > > ,
80
80
variables : Arc < Vec < PyVariable > > ,
81
81
transform_adapter : Option < PyTransformAdapt > ,
82
82
ndim : usize ,
@@ -85,7 +85,7 @@ pub struct PyModel {
85
85
#[ pymethods]
86
86
impl PyModel {
87
87
#[ new]
88
- #[ pyo3( signature = ( make_logp_func, make_expand_func, variables, ndim, transform_adapter=None ) ) ]
88
+ #[ pyo3( signature = ( make_logp_func, make_expand_func, variables, ndim, * , init_point_func= None , transform_adapter=None ) ) ]
89
89
fn new < ' py > (
90
90
make_logp_func : Py < PyAny > ,
91
91
make_expand_func : Py < PyAny > ,
@@ -97,7 +97,7 @@ impl PyModel {
97
97
Self {
98
98
make_logp_func : Arc :: new ( make_logp_func) ,
99
99
make_expand_func : Arc :: new ( make_expand_func) ,
100
- init_point_func : Arc :: new ( init_point_func) ,
100
+ init_point_func : init_point_func. map ( |x| x . into ( ) ) ,
101
101
variables : Arc :: new ( variables) ,
102
102
ndim,
103
103
transform_adapter : transform_adapter. map ( PyTransformAdapt :: new) ,
0 commit comments