summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py231
1 files changed, 184 insertions, 47 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index ce15642..f01bb14 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -59,10 +59,12 @@ def test_convergence_warning(method):
with pytest.warns(UserWarning):
ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
with pytest.warns(UserWarning):
- ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True)
+ ot.sinkhorn2(a1, a2, M, 1, method=method,
+ stopThr=0, numItermax=1, warn=True)
with warnings.catch_warnings():
warnings.simplefilter("error")
- ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False)
+ ot.sinkhorn2(a1, a2, M, 1, method=method,
+ stopThr=0, numItermax=1, warn=False)
def test_not_implemented_method():
@@ -266,12 +268,16 @@ def test_sinkhorn_variants(nx):
ub, M_nx = nx.from_numpy(u, M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
Ges = nx.to_numpy(ot.sinkhorn(
ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
- G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -371,9 +377,12 @@ def test_sinkhorn_variants_multi_b(nx):
ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -399,9 +408,12 @@ def test_sinkhorn2_variants_multi_b(nx):
ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -419,12 +431,16 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
- G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
- Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
- Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn',
+ stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,)
- G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+ G_green, loggreen = ot.sinkhorn(
+ u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
@@ -446,7 +462,8 @@ def test_sinkhorn_variants_log_multib(verbose, warn):
M = ot.dist(x, x)
- G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn',
+ stopThr=1e-10, log=True)
Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True,
verbose=verbose, warn=warn)
Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True,
@@ -485,8 +502,10 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
else:
# wasserstein
- bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass_np = ot.bregman.barycenter(
+ A, M, reg, weights, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_nx, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -514,7 +533,8 @@ def test_free_support_sinkhorn_barycenter():
# Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
# term to 1, but this should be, in general, fine-tuned to the problem.
- X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1)
+ X = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations, measures_weights, X_init, reg=1)
# Verifies if calculated barycenter matches ground-truth
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
@@ -545,8 +565,10 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, method=method)
else:
# wasserstein
- bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True)
+ bary_wass_np = ot.bregman.barycenter(
+ A, M, reg, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(
+ A_nx, M_nx, reg, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -581,17 +603,20 @@ def test_barycenter_debiased(nx, method, verbose, warn):
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
- ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
+ ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, weights, method=method)
else:
bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method,
verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass, _ = ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, weights_nx, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5)
- ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False)
+ ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, log=True, verbose=False)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
@@ -616,7 +641,8 @@ def test_convergence_warning_barycenters(method):
weights = np.array([1 - alpha, alpha])
reg = 0.1
with pytest.warns(UserWarning):
- ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1)
+ ot.bregman.barycenter_debiased(
+ A, M, reg, weights, method=method, numItermax=1)
with pytest.warns(UserWarning):
ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1)
with pytest.warns(UserWarning):
@@ -648,7 +674,8 @@ def test_barycenter_stabilization(nx):
# wasserstein
reg = 1e-2
- bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
+ bar_np = ot.bregman.barycenter(
+ A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
bar_stable = nx.to_numpy(ot.bregman.barycenter(
A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized",
stopThr=1e-8, verbose=True
@@ -683,8 +710,10 @@ def test_wasserstein_bary_2d(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
- bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+ A, reg, method=method, verbose=True, log=True)
+ bary_wass = nx.to_numpy(
+ ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
@@ -713,10 +742,13 @@ def test_wasserstein_bary_2d_debiased(nx, method):
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
- ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+ ot.bregman.convolutional_barycenter2d_debiased(
+ A_nx, reg, method=method)
else:
- bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+ A, reg, method=method, verbose=True, log=True)
+ bary_wass = nx.to_numpy(
+ ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
@@ -750,7 +782,8 @@ def test_unmix(nx):
# wasserstein
reg = 1e-3
um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
- um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
+ um = nx.to_numpy(ot.bregman.unmix(
+ ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
@@ -781,10 +814,12 @@ def test_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, metric='euclidean'))
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
+ loss_emp_sinkhorn = nx.to_numpy(
+ ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
# check constraints
@@ -817,23 +852,27 @@ def test_lazy_empirical_sinkhorn(nx):
ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ f, g = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
- f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g, log_es = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(
+ X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
@@ -865,22 +904,27 @@ def test_empirical_sinkhorn_divergence(nx):
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t)
+ ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(
+ a, b, X_s, X_t, M, M_s, M_t)
- emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
+ emp_sinkhorn_div = nx.to_numpy(
+ ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
sinkhorn_div = nx.to_numpy(
ot.sinkhorn2(ab, bb, M_nx, 1)
- 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
- 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
)
- emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
+ emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(
+ X_s, X_t, 1, a=a, b=b)
# check constraints
- np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
- ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
+ ot.bregman.empirical_sinkhorn_divergence(
+ X_sb, X_tb, 1, a=ab, b=bb, log=True)
@pytest.mark.skipif(not torch, reason="No torch available")
@@ -902,7 +946,8 @@ def test_empirical_sinkhorn_divergence_gradient():
X_sb.requires_grad = True
X_tb.requires_grad = True
- emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(
+ X_sb, X_tb, 1, a=ab, b=bb)
emp_sinkhorn_div.backward()
@@ -931,7 +976,8 @@ def test_stabilized_vs_sinkhorn_multidim(nx):
ab, bb, M_nx = nx.from_numpy(a, b, M)
- G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
+ G_np, _ = ot.bregman.sinkhorn(
+ a, b, M, reg=epsilon, method="sinkhorn", log=True)
G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
@@ -996,7 +1042,8 @@ def test_screenkhorn(nx):
# sinkhorn
G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(
+ ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
# check marginals
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
@@ -1013,3 +1060,93 @@ def test_convolutional_barycenter_non_square(nx):
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(b, b_np)
+
+
+def test_sinkhorn_warmstart():
+ m, n = 10, 20
+ a = ot.unif(m)
+ b = ot.unif(n)
+
+ Xs = np.arange(m) * 1.0
+ Xt = np.arange(n) * 1.0
+ M = ot.dist(Xs.reshape(-1, 1), Xt.reshape(-1, 1))
+
+ # Generate warmstart from dual vectors of unregularized OT
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ pi_unif, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn", log=True, warmstart=None)
+ # Optimal plan with warmstart generated from unregularized OT
+ pi_sh, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart)
+ pi_sh_log, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart)
+ pi_sh_stab, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart)
+ pi_sh_sc, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_stab, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_sc, atol=1e-05)
+
+
+def test_empirical_sinkhorn_warmstart():
+ m, n = 10, 20
+ Xs = np.arange(m).reshape(-1, 1) * 1.0
+ Xt = np.arange(n).reshape(-1, 1) * 1.0
+ M = ot.dist(Xs, Xt)
+
+ # Generate warmstart from dual vectors of unregularized OT
+ a = ot.unif(m)
+ b = ot.unif(n)
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ f, g, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None)
+ pi_unif = np.exp(f[:, None] + g[None, :] - M / reg)
+ # Optimal plan with warmstart generated from unregularized OT
+ f, g, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart)
+ pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg)
+ pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence_warmstart():
+ m, n = 10, 20
+ Xs = np.arange(m).reshape(-1, 1) * 1.0
+ Xt = np.arange(n).reshape(-1, 1) * 1.0
+ M = ot.dist(Xs, Xt)
+
+ # Generate warmstart from dual vectors of unregularized OT
+ a = ot.unif(m)
+ b = ot.unif(n)
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None)
+ # Optimal plan with warmstart generated from unregularized OT
+ sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart)
+ sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05)
+ np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05)