Skip to content

Add optimization utils / voltage trace features #707

@jnsbck

Description

@jnsbck

Add jaxley.optimize.features / jaxley.optimize.distributions or just both of these things to optimize.utils

Here would be an example of this:

# Multidimensional uniform distribution in jax
class Uniform:
    """Uniform distribution with sample and log_prob methods."""
    def __init__(self, lower: float, upper: float) -> None:
        self.lower = lower
        self.upper = upper

    def sample(self, key: jnp.ndarray, shape: Tuple[int, ...] = (1,)) -> jnp.ndarray:
        """Samples from the uniform distribution.

        Args:
            key: A JAX random key.
            shape: Sample shape.

        Returns:
            Samples from the uniform distribution.
        """
        return jax.random.uniform(key, shape=shape, minval=self.lower, maxval=self.upper)

    def log_prob(self, x: jnp.ndarray) -> jnp.ndarray:
        """Computes the log probability of the uniform distribution.

        Args:
            x: The input to compute the log probability for.

        Returns:
            The log probability of the uniform distribution.
        """
        in_bounds = (x >= self.lower) & (x <= self.upper)
        return jnp.where(in_bounds, -jnp.log(self.upper - self.lower), -jnp.inf)

class ProductDistribution:
    """Product distribution of multiple distributions; p(x,y,z) = p(x)p(y)p(z)."""
    def __init__(self, dists: Dict[str, Uniform]) -> None:
        self.dists = dists

    def sample(self, key: jnp.ndarray, shape: Tuple[int, ...] = (1,)) -> Dict[str, jnp.ndarray]:
        split_keys = jax.random.split(key, len(self.dists))
        split_keys = {k: v for k, v in zip(self.dists.keys(), split_keys)}
        return jax.tree_util.tree_map(lambda k, d: d.sample(k, shape), split_keys, self.dists)

    def log_prob(self, x: Dict[str, jnp.ndarray]) -> jnp.ndarray:
        log_probs = jax.tree_util.tree_map(lambda x, d: d.log_prob(x), x, self.dists)
        return jnp.sum(jnp.array(jax.tree_leaves(log_probs)), axis=0)

class BoxUniform(ProductDistribution):
    """Multi-dimensional uniform distribution.
    
    Args:
        bounds: A dictionary of parameter names and their bounds.
    """
    def __init__(self, bounds: Dict[str, Tuple[float, float]]) -> None:
        dists = {k: Uniform(*v) for k, v in bounds.items()}
        super().__init__(dists)
prior = BoxUniform(bounds)
key = jax.random.PRNGKey(23)
obs_params = prior.sample(key, (1,))

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions