-
Notifications
You must be signed in to change notification settings - Fork 14
Description
Background
The current analytical M-step implementation for GLM-HMM transition and initial probabilities requires Dirichlet prior parameters α >= 1. This limitation is documented in the code (see _analytical_m_step_log_initial_prob and _analytical_m_step_log_transition_prob).
Sparse Dirichlet priors with 0 < α < 1 are valuable for encouraging sparse state transitions and initial distributions, but they push solutions toward the boundary of the probability simplex where the log-space analytical formulation becomes problematic.
Proposed Solution
Implement support for sparse priors using proximal gradient descent as an alternative M-step update method.
Architecture
Make the M-step update methods for initial and transition probabilities configurable callables:
def run_m_step(
X: Array,
y: Array,
log_posteriors: Array,
log_joint_posterior: Array,
glm_params: Tuple[Array, Array],
is_new_session: Array,
solver_run: Callable,
dirichlet_prior_alphas_init_prob: Array | None = None,
dirichlet_prior_alphas_transition: Array | None = None,
# NEW PARAMETERS:
initial_prob_update_fn: Callable | None = None,
transition_prob_update_fn: Callable | None = None,
) -> Tuple[...]:
"""
Perform the M-step of the EM algorithm for GLM-HMM.
Parameters
----------
...
initial_prob_update_fn : Callable, optional
Function to update initial probabilities. Should have signature:
`f(log_posteriors, is_new_session, dirichlet_prior_alphas) -> log_initial_prob`
If None, uses analytical M-step (requires α >= 1).
transition_prob_update_fn : Callable, optional
Function to update transition probabilities. Should have signature:
`f(log_joint_posterior, dirichlet_prior_alphas) -> log_transition_prob`
If None, uses analytical M-step (requires α >= 1).
"""
# Use analytical or provided update functions
init_update_fn = initial_prob_update_fn or _analytical_m_step_log_initial_prob
trans_update_fn = transition_prob_update_fn or _analytical_m_step_log_transition_prob
log_initial_prob = init_update_fn(
log_posteriors, is_new_session, dirichlet_prior_alphas_init_prob
)
log_transition_prob = trans_update_fn(
log_joint_posterior, dirichlet_prior_alphas_transition
)
...GLMHMM Class Integration
The GLMHMM class will:
- Validate prior parameters during initialization/fitting
- Select appropriate update functions based on α values
- Create partial functions to pass to
em_glm_hmm
class GLMHMM:
def fit(self, X, y, ...):
# Validate and select M-step methods
init_update_fn, trans_update_fn = self._select_m_step_methods(
self.dirichlet_prior_alphas_init,
self.dirichlet_prior_alphas_transition
)
# Create partial of run_m_step with selected update methods
m_step_partial = partial(
run_m_step,
initial_prob_update_fn=init_update_fn,
transition_prob_update_fn=trans_update_fn,
)
# Pass to EM algorithm
results = em_glm_hmm(
...,
solver_run=m_step_partial, # Matches current signature
...
)
def _select_m_step_methods(self, alphas_init, alphas_trans):
"""Select M-step update methods based on prior parameters."""
# Check if any α < 1
needs_numerical_init = (
alphas_init is not None and jnp.any(alphas_init < 1)
)
needs_numerical_trans = (
alphas_trans is not None and jnp.any(alphas_trans < 1)
)
# Select appropriate update functions
init_fn = (
_proximal_gradient_m_step_log_initial_prob
if needs_numerical_init
else _analytical_m_step_log_initial_prob
)
trans_fn = (
_proximal_gradient_m_step_log_transition_prob
if needs_numerical_trans
else _analytical_m_step_log_transition_prob
)
return init_fn, trans_fnImplementation Tasks
1. Implement Proximal Gradient M-step Functions
Create new functions in expectation_maximization.py:
-
_proximal_gradient_m_step_log_initial_prob(log_posteriors, is_new_session, dirichlet_prior_alphas) -> log_initial_prob -
_proximal_gradient_m_step_log_transition_prob(log_joint_posterior, dirichlet_prior_alphas) -> log_transition_prob
Key requirements:
- Handle probability simplex constraints using projection
- Support
0 < α < 1(andα >= 1for consistency) - Work in log-space where possible for numerical stability
- Use proximal optimization methods
- Match the signature of analytical M-step functions
Objective function (for initial probabilities):
min_p -sum(γ_i * log(p_i)) - sum((α_i - 1) * log(p_i))
s.t. sum(p_i) = 1, p_i >= 0
Where γ_i are the expected counts (sum of posteriors at session starts).
Proximal operator for probability simplex:
- Projection onto the simplex can be done efficiently
- JAX implementation available or can be implemented
2. Modify run_m_step Function
- Add optional
initial_prob_update_fnandtransition_prob_update_fnparameters - Default to analytical M-step if not provided (backward compatible)
- Update docstring with callable signatures
3. Integrate into GLMHMM Class
- Add
_select_m_step_methods()method to choose update functions based on α - Modify
fit()to validate α values and create appropriate partials - Add validation to check
α > 0for all prior parameters - Update GLMHMM docstring to note sparse prior support
4. Testing
-
Unit tests for proximal gradient M-step functions
- Test with
0 < α < 1(sparse priors) - Test with
α >= 1(should work but may be slower than analytical) - Verify simplex constraints are satisfied
- Check convergence to optimal solution
- Test with
-
Integration tests for GLMHMM class
- Test automatic selection of M-step methods
- Compare analytical vs numerical for
α >= 1(should give similar results) - Test sparse priors with synthetic data where ground truth is known
-
Gradient-free optimality verification (similar to approach discussed)
- Use Nelder-Mead or similar to verify proximal gradient finds good solutions
- Especially important for
0 < α < 1where solution may be at simplex boundary
5. Documentation
- Add example notebook demonstrating sparse priors
- Document performance considerations (numerical slower than analytical)
- Update API documentation for new parameters
- Add theory/background section explaining when to use sparse priors
Technical Considerations
Numerical Stability
- Work in log-space where possible
- Use
logsumexpfor normalization - Handle edge cases where probabilities approach 0
Performance
- Numerical optimization will be slower than analytical M-step
- Consider caching/JIT compilation for the optimization routine
- May want to set reasonable tolerance/max_iter for the optimizer
Simplex Projection
Can use efficient O(n log n) algorithm for projection onto probability simplex:
def project_onto_simplex(v):
"""Project vector v onto probability simplex."""
# See: "Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application"
# Duchi et al., 2008
...Alternative Approaches Considered
- Clamping in log-space (rejected): Creates discontinuities and doesn't find true optimum
- Reparameterization (future consideration): Could use softmax reparameterization, but changes the prior interpretation
- EM with constrained optimization (current choice): Most principled approach
Success Criteria
- GLMHMM supports
0 < α < 1without errors or warnings - Proximal gradient finds solutions within 0.1 of global optimum (verified by gradient-free optimization)
- All existing tests continue to pass
- Performance: numerical M-step completes in reasonable time (< 10x analytical for typical problems)
- Documentation includes clear guidance on when to use sparse priors
References
- Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Chapter 13 (HMMs and EM)
- Parikh, N., & Boyd, S. (2014). Proximal algorithms. Foundations and Trends in Optimization, 1(3), 127-239.
- Duchi, J., et al. (2008). Efficient projections onto the ℓ1-ball for learning in high dimensions. ICML.