@@ -1949,6 +1949,9 @@ def _bary_sample_bcd(
19491949 metric ,
19501950 inner_solver ,
19511951 update_masses ,
1952+ warmstart_plan ,
1953+ warmstart_potentials ,
1954+ stopping_criterion ,
19521955 max_iter_bary ,
19531956 tol_bary ,
19541957 verbose ,
@@ -1975,6 +1978,12 @@ def _bary_sample_bcd(
19751978 Function to solve the inner OT problem
19761979 update_masses : bool
19771980 Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used.
1981+ warmstart_plan : bool
1982+ Use the previous plan as initialization for the inner solver. Set based on inner solver type in ot.bary_sample
1983+ warmstart_potentials : bool
1984+ Use the previous potentials as initialization for the inner solver. Set based on inner solver type in ot.bary_sample
1985+ stopping_criterion : str
1986+ Stopping criterion for the BCD algorithm. Can be "loss" or "bary".
19781987 max_iter_bary : int
19791988 Maximum number of iterations for the barycenter
19801989 tol_bary : float
@@ -1994,22 +2003,41 @@ def _bary_sample_bcd(
19942003 b = b_init
19952004 inv_b = 1.0 / b
19962005
1997- prev_loss = np .inf
2006+ prev_criterion = np .inf
19982007 n_samples = len (X_s )
19992008
20002009 if log :
2001- log_ = {"loss " : []}
2010+ log_ = {"stopping_criterion " : []}
20022011 else :
20032012 log_ = None
2013+
20042014 # Compute the barycenter using BCD
20052015 for it in range (max_iter_bary ):
20062016 # Solve the inner OT problem for each source distribution
2007- list_res = [inner_solver (X_s [k ], X , a_s [k ], b ) for k in range (n_samples )]
2017+ if it == 0 :
2018+ list_res = [
2019+ inner_solver (X_s [k ], X , a_s [k ], b , None , None ) for k in range (n_samples )
2020+ ]
2021+ elif warmstart_plan :
2022+ list_res = [
2023+ inner_solver (X_s [k ], X , a_s [k ], b , list_res [k ].plan , None )
2024+ for k in range (n_samples )
2025+ ]
2026+ elif warmstart_potentials :
2027+ list_res = [
2028+ inner_solver (X_s [k ], X , a_s [k ], b , None , list_res [k ].potentials )
2029+ for k in range (n_samples )
2030+ ]
2031+ else :
2032+ list_res = [
2033+ inner_solver (X_s [k ], X , a_s [k ], b , None , None ) for k in range (n_samples )
2034+ ]
20082035
20092036 # Update the estimated barycenter weights in unbalanced cases
20102037 if update_masses :
20112038 b = sum ([w_s [k ] * list_res [k ].plan .sum (axis = 0 ) for k in range (n_samples )])
20122039 inv_b = 1.0 / b
2040+
20132041 # Update the barycenter samples
20142042 if metric in ["sqeuclidean" , "euclidean" ]:
20152043 X_new = (
@@ -2019,30 +2047,40 @@ def _bary_sample_bcd(
20192047 else :
20202048 raise NotImplementedError ('Not implemented metric="{}"' .format (metric ))
20212049
2022- # compute loss
2023- new_loss = sum ([w_s [k ] * list_res [k ].value for k in range (n_samples )])
2050+ # compute criterion
2051+ if stopping_criterion == "loss" :
2052+ new_criterion = sum ([w_s [k ] * list_res [k ].value for k in range (n_samples )])
2053+ else : # stopping_criterion = "bary"
2054+ new_criterion = nx .norm (X_new - X , ord = 2 )
20242055
20252056 if verbose :
20262057 if it % 1 == 0 :
2027- print (f"BCD iteration { it } : loss = { new_loss :.4f} " )
2058+ print (
2059+ f"BCD iteration { it } : criterion { stopping_criterion } = { new_criterion :.4f} "
2060+ )
20282061
20292062 if log :
2030- log_ ["loss " ].append (new_loss )
2063+ log_ ["stopping_criterion " ].append (new_criterion )
20312064 # Check convergence
2032- if abs (new_loss - prev_loss ) / abs (prev_loss ) < tol_bary :
2065+ if abs (new_criterion - prev_criterion ) / abs (prev_criterion ) < tol_bary :
20332066 print (f"BCD converged in { it } iterations" )
20342067 break
20352068
20362069 X = X_new
2037- prev_loss = new_loss
2070+ prev_criterion = new_criterion
2071+
2072+ # compute loss values
20382073
2039- # compute value_linear
20402074 value_linear = sum ([w_s [k ] * list_res [k ].value_linear for k in range (n_samples )])
2075+ if stopping_criterion == "loss" :
2076+ value = new_criterion
2077+ else :
2078+ value = sum ([w_s [k ] * list_res [k ].value for k in range (n_samples )])
20412079 # update BaryResult
20422080 bary_res = BaryResult (
20432081 X = X_new ,
20442082 b = b ,
2045- value = new_loss ,
2083+ value = value ,
20462084 value_linear = value_linear ,
20472085 log = log_ ,
20482086 list_res = list_res ,
@@ -2070,6 +2108,8 @@ def bary_sample(
20702108 batch_size = None ,
20712109 method = None ,
20722110 n_threads = 1 ,
2111+ warmstart = False ,
2112+ stopping_criterion = "loss" ,
20732113 max_iter_bary = 1000 ,
20742114 max_iter = None ,
20752115 rank = 100 ,
@@ -2154,6 +2194,11 @@ def bary_sample(
21542194 large scale solver.
21552195 n_threads : int, optional
21562196 Number of OMP threads for exact OT solver, by default 1
2197+ warmstart : bool, optional
2198+ Use the previous OT or potentials as initialization for the next inner solver iteration, by default False.
2199+ stopping_criterion : str, optional
2200+ Stopping criterion for the outer loop of the BCD solver, by default 'loss'.
2201+ Either 'loss' to use the optimize objective or 'bary' for variations of the barycenter w.r.t the Frobenius norm.
21572202 max_iter_bary : int, optional
21582203 Maximum number of iteration for the BCD solver, by default 1000.
21592204 max_iter : int, optional
@@ -2398,6 +2443,13 @@ def bary_sample(
23982443 if method is not None and method .lower () in lst_method_lazy :
23992444 raise NotImplementedError ("Barycenter with Lazy tensors not implemented yet" )
24002445
2446+ if stopping_criterion not in ["loss" , "bary" ]:
2447+ raise ValueError (
2448+ "stopping_criterion must be either 'loss' or 'bary', got {}" .format (
2449+ stopping_criterion
2450+ )
2451+ )
2452+
24012453 n_samples = len (X_s )
24022454
24032455 if (
@@ -2449,7 +2501,28 @@ def bary_sample(
24492501 if b_init is None :
24502502 b_init = nx .ones ((n ,), type_as = X_s [0 ]) / n
24512503
2452- def inner_solver (X_a , X , a , b ):
2504+ if warmstart :
2505+ if reg is None : # exact OT
2506+ warmstart_plan = True
2507+ warmstart_potentials = False
2508+ else : # regularized OT
2509+ # unbalanced AND regularized OT
2510+ if (
2511+ not isinstance (reg_type , tuple )
2512+ and reg_type .lower () in ["kl" ]
2513+ and unbalanced_type .lower () == "kl"
2514+ ):
2515+ warmstart_plan = False
2516+ warmstart_potentials = True
2517+
2518+ else :
2519+ warmstart_plan = True
2520+ warmstart_potentials = False
2521+ else :
2522+ warmstart_plan = False
2523+ warmstart_potentials = False
2524+
2525+ def inner_solver (X_a , X , a , b , plan_init , potentials_init ):
24532526 return solve_sample (
24542527 X_a = X_a ,
24552528 X_b = X ,
@@ -2465,6 +2538,8 @@ def inner_solver(X_a, X, a, b):
24652538 n_threads = n_threads ,
24662539 max_iter = max_iter ,
24672540 tol = tol ,
2541+ plan_init = plan_init ,
2542+ potentials_init = potentials_init ,
24682543 verbose = False ,
24692544 )
24702545
@@ -2479,6 +2554,9 @@ def inner_solver(X_a, X, a, b):
24792554 metric ,
24802555 inner_solver ,
24812556 update_masses ,
2557+ warmstart_plan ,
2558+ warmstart_potentials ,
2559+ stopping_criterion ,
24822560 max_iter_bary ,
24832561 tol_bary ,
24842562 verbose ,
0 commit comments