Skip to content

Support for Sparse Dirichlet Priors (0 < α < 1) in GLM-HMM #448

@BalzaniEdoardo

Description

@BalzaniEdoardo

Background

The current analytical M-step implementation for GLM-HMM transition and initial probabilities requires Dirichlet prior parameters α >= 1. This limitation is documented in the code (see _analytical_m_step_log_initial_prob and _analytical_m_step_log_transition_prob).

Sparse Dirichlet priors with 0 < α < 1 are valuable for encouraging sparse state transitions and initial distributions, but they push solutions toward the boundary of the probability simplex where the log-space analytical formulation becomes problematic.

Proposed Solution

Implement support for sparse priors using proximal gradient descent as an alternative M-step update method.

Architecture

Make the M-step update methods for initial and transition probabilities configurable callables:

def run_m_step(
    X: Array,
    y: Array,
    log_posteriors: Array,
    log_joint_posterior: Array,
    glm_params: Tuple[Array, Array],
    is_new_session: Array,
    solver_run: Callable,
    dirichlet_prior_alphas_init_prob: Array | None = None,
    dirichlet_prior_alphas_transition: Array | None = None,
    # NEW PARAMETERS:
    initial_prob_update_fn: Callable | None = None,
    transition_prob_update_fn: Callable | None = None,
) -> Tuple[...]:
    """
    Perform the M-step of the EM algorithm for GLM-HMM.

    Parameters
    ----------
    ...
    initial_prob_update_fn : Callable, optional
        Function to update initial probabilities. Should have signature:
        `f(log_posteriors, is_new_session, dirichlet_prior_alphas) -> log_initial_prob`
        If None, uses analytical M-step (requires α >= 1).
    transition_prob_update_fn : Callable, optional
        Function to update transition probabilities. Should have signature:
        `f(log_joint_posterior, dirichlet_prior_alphas) -> log_transition_prob`
        If None, uses analytical M-step (requires α >= 1).
    """
    # Use analytical or provided update functions
    init_update_fn = initial_prob_update_fn or _analytical_m_step_log_initial_prob
    trans_update_fn = transition_prob_update_fn or _analytical_m_step_log_transition_prob

    log_initial_prob = init_update_fn(
        log_posteriors, is_new_session, dirichlet_prior_alphas_init_prob
    )
    log_transition_prob = trans_update_fn(
        log_joint_posterior, dirichlet_prior_alphas_transition
    )
    ...

GLMHMM Class Integration

The GLMHMM class will:

  1. Validate prior parameters during initialization/fitting
  2. Select appropriate update functions based on α values
  3. Create partial functions to pass to em_glm_hmm
class GLMHMM:
    def fit(self, X, y, ...):
        # Validate and select M-step methods
        init_update_fn, trans_update_fn = self._select_m_step_methods(
            self.dirichlet_prior_alphas_init,
            self.dirichlet_prior_alphas_transition
        )

        # Create partial of run_m_step with selected update methods
        m_step_partial = partial(
            run_m_step,
            initial_prob_update_fn=init_update_fn,
            transition_prob_update_fn=trans_update_fn,
        )

        # Pass to EM algorithm
        results = em_glm_hmm(
            ...,
            solver_run=m_step_partial,  # Matches current signature
            ...
        )

    def _select_m_step_methods(self, alphas_init, alphas_trans):
        """Select M-step update methods based on prior parameters."""
        # Check if any α < 1
        needs_numerical_init = (
            alphas_init is not None and jnp.any(alphas_init < 1)
        )
        needs_numerical_trans = (
            alphas_trans is not None and jnp.any(alphas_trans < 1)
        )

        # Select appropriate update functions
        init_fn = (
            _proximal_gradient_m_step_log_initial_prob
            if needs_numerical_init
            else _analytical_m_step_log_initial_prob
        )
        trans_fn = (
            _proximal_gradient_m_step_log_transition_prob
            if needs_numerical_trans
            else _analytical_m_step_log_transition_prob
        )

        return init_fn, trans_fn

Implementation Tasks

1. Implement Proximal Gradient M-step Functions

Create new functions in expectation_maximization.py:

  • _proximal_gradient_m_step_log_initial_prob(log_posteriors, is_new_session, dirichlet_prior_alphas) -> log_initial_prob
  • _proximal_gradient_m_step_log_transition_prob(log_joint_posterior, dirichlet_prior_alphas) -> log_transition_prob

Key requirements:

  • Handle probability simplex constraints using projection
  • Support 0 < α < 1 (and α >= 1 for consistency)
  • Work in log-space where possible for numerical stability
  • Use proximal optimization methods
  • Match the signature of analytical M-step functions

Objective function (for initial probabilities):

min_p  -sum(γ_i * log(p_i)) - sum((α_i - 1) * log(p_i))
s.t.   sum(p_i) = 1, p_i >= 0

Where γ_i are the expected counts (sum of posteriors at session starts).

Proximal operator for probability simplex:

  • Projection onto the simplex can be done efficiently
  • JAX implementation available or can be implemented

2. Modify run_m_step Function

  • Add optional initial_prob_update_fn and transition_prob_update_fn parameters
  • Default to analytical M-step if not provided (backward compatible)
  • Update docstring with callable signatures

3. Integrate into GLMHMM Class

  • Add _select_m_step_methods() method to choose update functions based on α
  • Modify fit() to validate α values and create appropriate partials
  • Add validation to check α > 0 for all prior parameters
  • Update GLMHMM docstring to note sparse prior support

4. Testing

  • Unit tests for proximal gradient M-step functions

    • Test with 0 < α < 1 (sparse priors)
    • Test with α >= 1 (should work but may be slower than analytical)
    • Verify simplex constraints are satisfied
    • Check convergence to optimal solution
  • Integration tests for GLMHMM class

    • Test automatic selection of M-step methods
    • Compare analytical vs numerical for α >= 1 (should give similar results)
    • Test sparse priors with synthetic data where ground truth is known
  • Gradient-free optimality verification (similar to approach discussed)

    • Use Nelder-Mead or similar to verify proximal gradient finds good solutions
    • Especially important for 0 < α < 1 where solution may be at simplex boundary

5. Documentation

  • Add example notebook demonstrating sparse priors
  • Document performance considerations (numerical slower than analytical)
  • Update API documentation for new parameters
  • Add theory/background section explaining when to use sparse priors

Technical Considerations

Numerical Stability

  • Work in log-space where possible
  • Use logsumexp for normalization
  • Handle edge cases where probabilities approach 0

Performance

  • Numerical optimization will be slower than analytical M-step
  • Consider caching/JIT compilation for the optimization routine
  • May want to set reasonable tolerance/max_iter for the optimizer

Simplex Projection

Can use efficient O(n log n) algorithm for projection onto probability simplex:

def project_onto_simplex(v):
    """Project vector v onto probability simplex."""
    # See: "Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application"
    # Duchi et al., 2008
    ...

Alternative Approaches Considered

  1. Clamping in log-space (rejected): Creates discontinuities and doesn't find true optimum
  2. Reparameterization (future consideration): Could use softmax reparameterization, but changes the prior interpretation
  3. EM with constrained optimization (current choice): Most principled approach

Success Criteria

  • GLMHMM supports 0 < α < 1 without errors or warnings
  • Proximal gradient finds solutions within 0.1 of global optimum (verified by gradient-free optimization)
  • All existing tests continue to pass
  • Performance: numerical M-step completes in reasonable time (< 10x analytical for typical problems)
  • Documentation includes clear guidance on when to use sparse priors

References

  • Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Chapter 13 (HMMs and EM)
  • Parikh, N., & Boyd, S. (2014). Proximal algorithms. Foundations and Trends in Optimization, 1(3), 127-239.
  • Duchi, J., et al. (2008). Efficient projections onto the ℓ1-ball for learning in high dimensions. ICML.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions