@@ -442,36 +442,50 @@ def Tk0k1(k0, k1):
442442 return nx .sum (mat , axis = (0 , 1 ))
443443
444444
445- def solve_gmm_barycenter_fixed_point (
446- means ,
447- covs ,
445+ def gmm_barycenter_fixed_point (
448446 means_list ,
449447 covs_list ,
450- b_list ,
448+ w_list ,
449+ means_init ,
450+ covs_init ,
451451 weights ,
452- max_its = 300 ,
452+ w_bar = None ,
453+ iterations = 100 ,
453454 log = False ,
454455 barycentric_proj_method = "euclidean" ,
455456):
456457 r"""
457- Solves the GMM OT barycenter problem using the fixed point algorithm.
458+ Solves the Gaussian Mixture Model OT barycenter problem (defined in [69])
459+ using the fixed point algorithm (proposed in [74]). The
460+ weights of the barycenter are not optimized, and stay the same as the input
461+ `w_list` or are initialized to uniform.
462+
463+ The algorithm uses barycentric projections of GMM-OT plans, and these can be
464+ computed either through Bures Barycenters (slow but accurate,
465+ barycentric_proj_method='bures') or by convex combination (fast,
466+ barycentric_proj_method='euclidean', default).
467+
468+ This is a special case of the generic free-support barycenter solver
469+ `ot.lp.free_support_barycenter_generic_costs`.
458470
459471 Parameters
460472 ----------
461- means : array-like
462- Initial (n, d) GMM means.
463- covs : array-like
464- Initial (n, d, d) GMM covariances.
465473 means_list : list of array-like
466474 List of K (m_k, d) GMM means.
467475 covs_list : list of array-like
468476 List of K (m_k, d, d) GMM covariances.
469- b_list : list of array-like
477+ w_list : list of array-like
470478 List of K (m_k) arrays of weights.
479+ means_init : array-like
480+ Initial (n, d) GMM means.
481+ covs_init : array-like
482+ Initial (n, d, d) GMM covariances.
471483 weights : array-like
472484 Array (K,) of the barycentre coefficients.
473- max_its : int, optional
474- Maximum number of iterations (default is 300).
485+ w_bar : array-like, optional
486+ Initial weights (n) of the barycentre GMM. If None, initialized to uniform.
487+ iterations : int, optional
488+ Number of iterations (default is 100).
475489 log : bool, optional
476490 Whether to return the list of iterations (default is False).
477491 barycentric_proj_method : str, optional
@@ -485,30 +499,46 @@ def solve_gmm_barycenter_fixed_point(
485499 (n, d, d) barycentre GMM covariances.
486500 log_dict : dict, optional
487501 Dictionary containing the list of iterations if log is True.
502+
503+ References
504+ ----------
505+ .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
506+
507+ .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024)
508+
509+ See Also
510+ --------
511+ ot.lp.free_support_barycenter_generic_costs : Compute barycenter of measures for generic transport costs.
488512 """
489- nx = get_backend (means , covs [0 ], means_list [0 ], covs_list [0 ])
513+ nx = get_backend (
514+ means_init , covs_init , means_list [0 ], covs_list [0 ], w_list [0 ], weights
515+ )
490516 K = len (means_list )
491- n = means .shape [0 ]
492- d = means .shape [1 ]
493- means_its = [means .copy ()]
494- covs_its = [covs .copy ()]
495- a = nx .ones (n , type_as = means ) / n
517+ n = means_init .shape [0 ]
518+ d = means_init .shape [1 ]
519+ means_its = [nx .copy (means_init )]
520+ covs_its = [nx .copy (covs_init )]
521+ means , covs = means_init , covs_init
522+
523+ if w_bar is None :
524+ w_bar = nx .ones (n , type_as = means ) / n
496525
497- for _ in range (max_its ):
526+ for _ in range (iterations ):
498527 pi_list = [
499- gmm_ot_plan (means , means_list [k ], covs , covs_list [k ], a , b_list [k ])
528+ gmm_ot_plan (means , means_list [k ], covs , covs_list [k ], w_bar , w_list [k ])
500529 for k in range (K )
501530 ]
502531
532+ # filled in the euclidean case
503533 means_selection , covs_selection = None , None
534+
504535 # in the euclidean case, the selection of Gaussians from each K sources
505- # comes from a barycentric projection is a convex combination of the
506- # selected means and covariances, which can be computed without a
507- # for loop on i
536+ # comes from a barycentric projection: it is a convex combination of the
537+ # selected means and covariances, which can be computed without a
538+ # for loop on i = 0, ..., n -1
508539 if barycentric_proj_method == "euclidean" :
509540 means_selection = nx .zeros ((n , K , d ), type_as = means )
510541 covs_selection = nx .zeros ((n , K , d , d ), type_as = means )
511-
512542 for k in range (K ):
513543 means_selection [:, k , :] = n * pi_list [k ] @ means_list [k ]
514544 covs_selection [:, k , :, :] = (
@@ -519,24 +549,27 @@ def solve_gmm_barycenter_fixed_point(
519549 # selected components of the K GMMs. In the 'bures' barycentric
520550 # projection option, the selected components are also Bures barycentres.
521551 for i in range (n ):
522- # means_slice_i (K, d) is the selected means, each comes from a
552+ # means_selection_i (K, d) is the selected means, each comes from a
523553 # Gaussian barycentre along the disintegration of pi_k at i
524- # covs_slice_i (K, d, d) are the selected covariances
525- means_selection_i = []
526- covs_selection_i = []
554+ # covs_selection_i (K, d, d) are the selected covariances
555+ means_selection_i = None
556+ covs_selection_i = None
527557
528558 # use previous computation (convex combination)
529559 if barycentric_proj_method == "euclidean" :
530560 means_selection_i = means_selection [i ]
531561 covs_selection_i = covs_selection [i ]
532562
533- # compute Bures barycentre of the selected components
563+ # compute Bures barycentre of certain components to get the
564+ # selection at i
534565 elif barycentric_proj_method == "bures" :
535- w = (1 / a [i ]) * pi_list [k ][i , :]
566+ means_selection_i = nx .zeros ((K , d ), type_as = means )
567+ covs_selection_i = nx .zeros ((K , d , d ), type_as = means )
536568 for k in range (K ):
569+ w = (1 / w_bar [i ]) * pi_list [k ][i , :]
537570 m , C = bures_wasserstein_barycenter (means_list [k ], covs_list [k ], w )
538- means_selection_i . append ( m )
539- covs_selection_i . append ( C )
571+ means_selection_i [ k ] = m
572+ covs_selection_i [ k ] = C
540573
541574 else :
542575 raise ValueError ("Unknown barycentric_proj_method" )
@@ -546,8 +579,8 @@ def solve_gmm_barycenter_fixed_point(
546579 )
547580
548581 if log :
549- means_its .append (means .copy ())
550- covs_its .append (covs .copy ())
582+ means_its .append (nx .copy (means ))
583+ covs_its .append (nx .copy (covs ))
551584
552585 if log :
553586 return means , covs , {"means_its" : means_its , "covs_its" : covs_its }
0 commit comments