-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomers
Description
Add jaxley.optimize.features / jaxley.optimize.distributions or just both of these things to optimize.utils
Here would be an example of this:
# Multidimensional uniform distribution in jax
class Uniform:
"""Uniform distribution with sample and log_prob methods."""
def __init__(self, lower: float, upper: float) -> None:
self.lower = lower
self.upper = upper
def sample(self, key: jnp.ndarray, shape: Tuple[int, ...] = (1,)) -> jnp.ndarray:
"""Samples from the uniform distribution.
Args:
key: A JAX random key.
shape: Sample shape.
Returns:
Samples from the uniform distribution.
"""
return jax.random.uniform(key, shape=shape, minval=self.lower, maxval=self.upper)
def log_prob(self, x: jnp.ndarray) -> jnp.ndarray:
"""Computes the log probability of the uniform distribution.
Args:
x: The input to compute the log probability for.
Returns:
The log probability of the uniform distribution.
"""
in_bounds = (x >= self.lower) & (x <= self.upper)
return jnp.where(in_bounds, -jnp.log(self.upper - self.lower), -jnp.inf)
class ProductDistribution:
"""Product distribution of multiple distributions; p(x,y,z) = p(x)p(y)p(z)."""
def __init__(self, dists: Dict[str, Uniform]) -> None:
self.dists = dists
def sample(self, key: jnp.ndarray, shape: Tuple[int, ...] = (1,)) -> Dict[str, jnp.ndarray]:
split_keys = jax.random.split(key, len(self.dists))
split_keys = {k: v for k, v in zip(self.dists.keys(), split_keys)}
return jax.tree_util.tree_map(lambda k, d: d.sample(k, shape), split_keys, self.dists)
def log_prob(self, x: Dict[str, jnp.ndarray]) -> jnp.ndarray:
log_probs = jax.tree_util.tree_map(lambda x, d: d.log_prob(x), x, self.dists)
return jnp.sum(jnp.array(jax.tree_leaves(log_probs)), axis=0)
class BoxUniform(ProductDistribution):
"""Multi-dimensional uniform distribution.
Args:
bounds: A dictionary of parameter names and their bounds.
"""
def __init__(self, bounds: Dict[str, Tuple[float, float]]) -> None:
dists = {k: Uniform(*v) for k, v in bounds.items()}
super().__init__(dists)prior = BoxUniform(bounds)
key = jax.random.PRNGKey(23)
obs_params = prior.sample(key, (1,))Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomers