From 7b4dd2e5d13c3dcb3afc9376bc85343385635fbb Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 14 Nov 2024 18:07:49 +0100 Subject: [PATCH] fix: Some rebase issues --- python/nutpie/compiled_pyfunc.py | 5 ++--- src/pyfunc.rs | 8 ++++---- src/pymc.rs | 5 ++--- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 7298534..d0f9732 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -82,9 +82,8 @@ def make_expand_func(seed1, seed2, chain): make_expand_func, self._variables, self.n_dim, - self._make_initial_points, - make_transform_adapter, - make_adapter, + init_point_func=self._make_initial_points, + transform_adapter=make_adapter, ) diff --git a/src/pyfunc.rs b/src/pyfunc.rs index 37a25d6..a6220a2 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -17,7 +17,7 @@ use pyo3::{ Bound, Py, PyAny, PyErr, Python, }; use rand::Rng; -use rand_distr::{Distribution, StandardNormal, Uniform}; +use rand_distr::{Distribution, Uniform}; use smallvec::SmallVec; use thiserror::Error; @@ -76,7 +76,7 @@ impl PyVariable { pub struct PyModel { make_logp_func: Arc>, make_expand_func: Arc>, - init_point_func: Arc>>, + init_point_func: Option>>, variables: Arc>, transform_adapter: Option, ndim: usize, @@ -85,7 +85,7 @@ pub struct PyModel { #[pymethods] impl PyModel { #[new] - #[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, transform_adapter=None))] + #[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, *, init_point_func=None, transform_adapter=None))] fn new<'py>( make_logp_func: Py, make_expand_func: Py, @@ -97,7 +97,7 @@ impl PyModel { Self { make_logp_func: Arc::new(make_logp_func), make_expand_func: Arc::new(make_expand_func), - init_point_func: Arc::new(init_point_func), + init_point_func: init_point_func.map(|x| x.into()), variables: Arc::new(variables), ndim, transform_adapter: transform_adapter.map(PyTransformAdapt::new), diff --git a/src/pymc.rs b/src/pymc.rs index 39dfc76..526e18b 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -13,7 +13,6 @@ use pyo3::{ types::{PyAnyMethods, PyList}, Bound, Py, PyAny, PyObject, PyResult, Python, }; -use rand::{distributions::Uniform, prelude::Distribution}; use thiserror::Error; @@ -232,7 +231,7 @@ pub(crate) struct PyMcModel { dim: usize, density: LogpFunc, expand: ExpandFunc, - init_func: Py, + init_func: Arc>, var_sizes: Vec, var_names: Vec, } @@ -252,7 +251,7 @@ impl PyMcModel { dim, density, expand, - init_func, + init_func: init_func.into(), var_names: var_names.extract()?, var_sizes: var_sizes.extract()?, })