Skip to content

Commit 46c4638

Browse files
update free support
1 parent 9e25e80 commit 46c4638

File tree

4 files changed

+136
-20
lines changed

4 files changed

+136
-20
lines changed

ot/solvers.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,9 @@ def _bary_sample_bcd(
19491949
metric,
19501950
inner_solver,
19511951
update_masses,
1952+
warmstart_plan,
1953+
warmstart_potentials,
1954+
stopping_criterion,
19521955
max_iter_bary,
19531956
tol_bary,
19541957
verbose,
@@ -1975,6 +1978,12 @@ def _bary_sample_bcd(
19751978
Function to solve the inner OT problem
19761979
update_masses : bool
19771980
Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used.
1981+
warmstart_plan : bool
1982+
Use the previous plan as initialization for the inner solver. Set based on inner solver type in ot.bary_sample
1983+
warmstart_potentials : bool
1984+
Use the previous potentials as initialization for the inner solver. Set based on inner solver type in ot.bary_sample
1985+
stopping_criterion : str
1986+
Stopping criterion for the BCD algorithm. Can be "loss" or "bary".
19781987
max_iter_bary : int
19791988
Maximum number of iterations for the barycenter
19801989
tol_bary : float
@@ -1994,22 +2003,41 @@ def _bary_sample_bcd(
19942003
b = b_init
19952004
inv_b = 1.0 / b
19962005

1997-
prev_loss = np.inf
2006+
prev_criterion = np.inf
19982007
n_samples = len(X_s)
19992008

20002009
if log:
2001-
log_ = {"loss": []}
2010+
log_ = {"stopping_criterion": []}
20022011
else:
20032012
log_ = None
2013+
20042014
# Compute the barycenter using BCD
20052015
for it in range(max_iter_bary):
20062016
# Solve the inner OT problem for each source distribution
2007-
list_res = [inner_solver(X_s[k], X, a_s[k], b) for k in range(n_samples)]
2017+
if it == 0:
2018+
list_res = [
2019+
inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples)
2020+
]
2021+
elif warmstart_plan:
2022+
list_res = [
2023+
inner_solver(X_s[k], X, a_s[k], b, list_res[k].plan, None)
2024+
for k in range(n_samples)
2025+
]
2026+
elif warmstart_potentials:
2027+
list_res = [
2028+
inner_solver(X_s[k], X, a_s[k], b, None, list_res[k].potentials)
2029+
for k in range(n_samples)
2030+
]
2031+
else:
2032+
list_res = [
2033+
inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples)
2034+
]
20082035

20092036
# Update the estimated barycenter weights in unbalanced cases
20102037
if update_masses:
20112038
b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)])
20122039
inv_b = 1.0 / b
2040+
20132041
# Update the barycenter samples
20142042
if metric in ["sqeuclidean", "euclidean"]:
20152043
X_new = (
@@ -2019,30 +2047,40 @@ def _bary_sample_bcd(
20192047
else:
20202048
raise NotImplementedError('Not implemented metric="{}"'.format(metric))
20212049

2022-
# compute loss
2023-
new_loss = sum([w_s[k] * list_res[k].value for k in range(n_samples)])
2050+
# compute criterion
2051+
if stopping_criterion == "loss":
2052+
new_criterion = sum([w_s[k] * list_res[k].value for k in range(n_samples)])
2053+
else: # stopping_criterion = "bary"
2054+
new_criterion = nx.norm(X_new - X, ord=2)
20242055

20252056
if verbose:
20262057
if it % 1 == 0:
2027-
print(f"BCD iteration {it}: loss = {new_loss:.4f}")
2058+
print(
2059+
f"BCD iteration {it}: criterion {stopping_criterion} = {new_criterion:.4f}"
2060+
)
20282061

20292062
if log:
2030-
log_["loss"].append(new_loss)
2063+
log_["stopping_criterion"].append(new_criterion)
20312064
# Check convergence
2032-
if abs(new_loss - prev_loss) / abs(prev_loss) < tol_bary:
2065+
if abs(new_criterion - prev_criterion) / abs(prev_criterion) < tol_bary:
20332066
print(f"BCD converged in {it} iterations")
20342067
break
20352068

20362069
X = X_new
2037-
prev_loss = new_loss
2070+
prev_criterion = new_criterion
2071+
2072+
# compute loss values
20382073

2039-
# compute value_linear
20402074
value_linear = sum([w_s[k] * list_res[k].value_linear for k in range(n_samples)])
2075+
if stopping_criterion == "loss":
2076+
value = new_criterion
2077+
else:
2078+
value = sum([w_s[k] * list_res[k].value for k in range(n_samples)])
20412079
# update BaryResult
20422080
bary_res = BaryResult(
20432081
X=X_new,
20442082
b=b,
2045-
value=new_loss,
2083+
value=value,
20462084
value_linear=value_linear,
20472085
log=log_,
20482086
list_res=list_res,
@@ -2070,6 +2108,8 @@ def bary_sample(
20702108
batch_size=None,
20712109
method=None,
20722110
n_threads=1,
2111+
warmstart=False,
2112+
stopping_criterion="loss",
20732113
max_iter_bary=1000,
20742114
max_iter=None,
20752115
rank=100,
@@ -2154,6 +2194,11 @@ def bary_sample(
21542194
large scale solver.
21552195
n_threads : int, optional
21562196
Number of OMP threads for exact OT solver, by default 1
2197+
warmstart : bool, optional
2198+
Use the previous OT or potentials as initialization for the next inner solver iteration, by default False.
2199+
stopping_criterion : str, optional
2200+
Stopping criterion for the outer loop of the BCD solver, by default 'loss'.
2201+
Either 'loss' to use the optimize objective or 'bary' for variations of the barycenter w.r.t the Frobenius norm.
21572202
max_iter_bary : int, optional
21582203
Maximum number of iteration for the BCD solver, by default 1000.
21592204
max_iter : int, optional
@@ -2398,6 +2443,13 @@ def bary_sample(
23982443
if method is not None and method.lower() in lst_method_lazy:
23992444
raise NotImplementedError("Barycenter with Lazy tensors not implemented yet")
24002445

2446+
if stopping_criterion not in ["loss", "bary"]:
2447+
raise ValueError(
2448+
"stopping_criterion must be either 'loss' or 'bary', got {}".format(
2449+
stopping_criterion
2450+
)
2451+
)
2452+
24012453
n_samples = len(X_s)
24022454

24032455
if (
@@ -2449,7 +2501,28 @@ def bary_sample(
24492501
if b_init is None:
24502502
b_init = nx.ones((n,), type_as=X_s[0]) / n
24512503

2452-
def inner_solver(X_a, X, a, b):
2504+
if warmstart:
2505+
if reg is None: # exact OT
2506+
warmstart_plan = True
2507+
warmstart_potentials = False
2508+
else: # regularized OT
2509+
# unbalanced AND regularized OT
2510+
if (
2511+
not isinstance(reg_type, tuple)
2512+
and reg_type.lower() in ["kl"]
2513+
and unbalanced_type.lower() == "kl"
2514+
):
2515+
warmstart_plan = False
2516+
warmstart_potentials = True
2517+
2518+
else:
2519+
warmstart_plan = True
2520+
warmstart_potentials = False
2521+
else:
2522+
warmstart_plan = False
2523+
warmstart_potentials = False
2524+
2525+
def inner_solver(X_a, X, a, b, plan_init, potentials_init):
24532526
return solve_sample(
24542527
X_a=X_a,
24552528
X_b=X,
@@ -2465,6 +2538,8 @@ def inner_solver(X_a, X, a, b):
24652538
n_threads=n_threads,
24662539
max_iter=max_iter,
24672540
tol=tol,
2541+
plan_init=plan_init,
2542+
potentials_init=potentials_init,
24682543
verbose=False,
24692544
)
24702545

@@ -2479,6 +2554,9 @@ def inner_solver(X_a, X, a, b):
24792554
metric,
24802555
inner_solver,
24812556
update_masses,
2557+
warmstart_plan,
2558+
warmstart_potentials,
2559+
stopping_criterion,
24822560
max_iter_bary,
24832561
tol_bary,
24842562
verbose,

ot/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,8 @@ class BaryResult:
13341334
Dictionary containing potential information about the solver.
13351335
list_res: list of OTResult
13361336
List of results for the individual OT matching.
1337+
status : int or str
1338+
Status of the solver.
13371339
13381340
Attributes
13391341
----------
@@ -1357,6 +1359,8 @@ class BaryResult:
13571359
Dictionary containing potential information about the solver.
13581360
list_res: list of OTResult
13591361
List of results for the individual OT matching.
1362+
status : int or str
1363+
Status of the solver.
13601364
backend : Backend
13611365
Backend used to compute the results.
13621366
"""
@@ -1371,6 +1375,7 @@ def __init__(
13711375
value_quad=None,
13721376
log=None,
13731377
list_res=None,
1378+
status=None,
13741379
backend=None,
13751380
):
13761381
self._X = X
@@ -1381,6 +1386,7 @@ def __init__(
13811386
self._value_quad = value_quad
13821387
self._log = log
13831388
self._list_res = list_res
1389+
self._status = status
13841390
self._backend = backend if backend is not None else NumpyBackend()
13851391

13861392
def __repr__(self):

test/test_solvers.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -741,12 +741,16 @@ def assert_allclose_bary_sol(sol1, sol2):
741741
@pytest.skip_backend("jax", reason="test very slow with jax backend")
742742
@pytest.skip_backend("tf", reason="test very slow with tf backend")
743743
@pytest.mark.parametrize(
744-
"reg,reg_type,unbalanced,unbalanced_type",
745-
itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type),
744+
"reg,reg_type,unbalanced,unbalanced_type,warmstart",
745+
itertools.product(
746+
lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False]
747+
),
746748
)
747-
def test_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type):
749+
def test_bary_sample_free_support(
750+
nx, reg, reg_type, unbalanced, unbalanced_type, warmstart
751+
):
748752
# test bary_sample when is_Lazy = False
749-
rng = np.random.RandomState(0)
753+
rng = np.random.RandomState()
750754

751755
K = 3 # number of distributions
752756
ns = rng.randint(10, 20, K) # number of samples within each distribution
@@ -781,7 +785,8 @@ def df(G):
781785
reg_type=reg_type,
782786
unbalanced=unbalanced,
783787
unbalanced_type=unbalanced_type,
784-
max_iter_bary=4,
788+
warmstart=warmstart,
789+
max_iter_bary=3,
785790
tol_bary=1e-3,
786791
verbose=True,
787792
)
@@ -798,7 +803,8 @@ def df(G):
798803
reg_type=reg_type,
799804
unbalanced=unbalanced,
800805
unbalanced_type=unbalanced_type,
801-
max_iter_bary=4,
806+
warmstart=warmstart,
807+
max_iter_bary=3,
802808
tol_bary=1e-3,
803809
verbose=True,
804810
)
@@ -831,7 +837,8 @@ def df(G):
831837
reg_type=reg_type,
832838
unbalanced=unbalanced,
833839
unbalanced_type=unbalanced_type,
834-
max_iter_bary=4,
840+
warmstart=warmstart,
841+
max_iter_bary=3,
835842
tol_bary=1e-3,
836843
verbose=True,
837844
)

test/test_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def test_OTResult():
456456
# test print
457457
print(res)
458458

459-
# tets get citation
459+
# test get citation
460460
print(res.citation)
461461

462462
lst_attributes = [
@@ -486,6 +486,31 @@ def test_OTResult():
486486
getattr(res, at)
487487

488488

489+
def test_BaryResult():
490+
res = ot.utils.BaryResult()
491+
492+
# test print
493+
print(res)
494+
495+
# test get citation
496+
print(res.citation)
497+
498+
lst_attributes = [
499+
"X",
500+
"C",
501+
"b",
502+
"value",
503+
"value_linear",
504+
"value_quad",
505+
"list_res",
506+
"status",
507+
"log",
508+
]
509+
for at in lst_attributes:
510+
print(at)
511+
assert getattr(res, at) is None
512+
513+
489514
def test_get_coordinate_circle():
490515
rng = np.random.RandomState(42)
491516
u = rng.rand(1, 100)

0 commit comments

Comments
 (0)