Skip to content

Commit 9e25e80

Browse files
handle masses in unbalanced cases
1 parent a93c60c commit 9e25e80

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

ot/solvers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,6 +1948,7 @@ def _bary_sample_bcd(
19481948
w_s,
19491949
metric,
19501950
inner_solver,
1951+
update_masses,
19511952
max_iter_bary,
19521953
tol_bary,
19531954
verbose,
@@ -1972,6 +1973,8 @@ def _bary_sample_bcd(
19721973
Metric to use for the cost matrix, by default "sqeuclidean"
19731974
inner_solver : callable
19741975
Function to solve the inner OT problem
1976+
update_masses : bool
1977+
Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used.
19751978
max_iter_bary : int
19761979
Maximum number of iterations for the barycenter
19771980
tol_bary : float
@@ -2003,6 +2006,10 @@ def _bary_sample_bcd(
20032006
# Solve the inner OT problem for each source distribution
20042007
list_res = [inner_solver(X_s[k], X, a_s[k], b) for k in range(n_samples)]
20052008

2009+
# Update the estimated barycenter weights in unbalanced cases
2010+
if update_masses:
2011+
b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)])
2012+
inv_b = 1.0 / b
20062013
# Update the barycenter samples
20072014
if metric in ["sqeuclidean", "euclidean"]:
20082015
X_new = (
@@ -2461,6 +2468,8 @@ def inner_solver(X_a, X, a, b):
24612468
verbose=False,
24622469
)
24632470

2471+
# compute the barycenter using BCD
2472+
update_masses = unbalanced is not None
24642473
res = _bary_sample_bcd(
24652474
X_s,
24662475
X_init,
@@ -2469,6 +2478,7 @@ def inner_solver(X_a, X, a, b):
24692478
w_s,
24702479
metric,
24712480
inner_solver,
2481+
update_masses,
24722482
max_iter_bary,
24732483
tol_bary,
24742484
verbose,

0 commit comments

Comments
 (0)