From a313e21f223af16cf21d3b7dd01bd0c6345d574c Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Thu, 23 Feb 2023 15:31:20 +0100 Subject: [MRG] Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (#437) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow warmstart in sinkhorn and sinkhorn_log * Added argument for warmstart of dual vectors in Sinkhorn-based methods in * Add the number of the PR * [WIP] CO-Optimal Transport * Revert "[WIP] CO-Optimal Transport" This reverts commit f3d36b2705013409ac69b346585e311bc25fcfb7. * reformat with PEP8 * Fix W291 trailing whitespace error in pep8 test * Rearange position of warmstart argument and edit its description --------- Co-authored-by: RĂ©mi Flamary --- test/test_bregman.py | 231 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 184 insertions(+), 47 deletions(-) (limited to 'test/test_bregman.py') diff --git a/test/test_bregman.py b/test/test_bregman.py index ce15642..f01bb14 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -59,10 +59,12 @@ def test_convergence_warning(method): with pytest.warns(UserWarning): ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) with pytest.warns(UserWarning): - ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True) + ot.sinkhorn2(a1, a2, M, 1, method=method, + stopThr=0, numItermax=1, warn=True) with warnings.catch_warnings(): warnings.simplefilter("error") - ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False) + ot.sinkhorn2(a1, a2, M, 1, method=method, + stopThr=0, numItermax=1, warn=False) def test_not_implemented_method(): @@ -266,12 +268,16 @@ def test_sinkhorn_variants(nx): ub, M_nx = nx.from_numpy(u, M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -371,9 +377,12 @@ def test_sinkhorn_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -399,9 +408,12 @@ def test_sinkhorn2_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -419,12 +431,16 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) - Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', + stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) - G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) + G_green, loggreen = ot.sinkhorn( + u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) @@ -446,7 +462,8 @@ def test_sinkhorn_variants_log_multib(verbose, warn): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', + stopThr=1e-10, log=True) Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, verbose=verbose, warn=warn) Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, @@ -485,8 +502,10 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) else: # wasserstein - bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass_np = ot.bregman.barycenter( + A, M, reg, weights, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter( + A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -514,7 +533,8 @@ def test_free_support_sinkhorn_barycenter(): # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization # term to 1, but this should be, in general, fine-tuned to the problem. - X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) + X = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations, measures_weights, X_init, reg=1) # Verifies if calculated barycenter matches ground-truth np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) @@ -545,8 +565,10 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, method=method) else: # wasserstein - bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) + bary_wass_np = ot.bregman.barycenter( + A, M, reg, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter( + A_nx, M_nx, reg, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -581,17 +603,20 @@ def test_barycenter_debiased(nx, method, verbose, warn): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) + ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, weights, method=method) else: bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass, _ = ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) + ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, log=True, verbose=False) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) @@ -616,7 +641,8 @@ def test_convergence_warning_barycenters(method): weights = np.array([1 - alpha, alpha]) reg = 0.1 with pytest.warns(UserWarning): - ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + ot.bregman.barycenter_debiased( + A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): @@ -648,7 +674,8 @@ def test_barycenter_stabilization(nx): # wasserstein reg = 1e-2 - bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) + bar_np = ot.bregman.barycenter( + A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) bar_stable = nx.to_numpy(ot.bregman.barycenter( A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", stopThr=1e-8, verbose=True @@ -683,8 +710,10 @@ def test_wasserstein_bary_2d(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -713,10 +742,13 @@ def test_wasserstein_bary_2d_debiased(nx, method): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + ot.bregman.convolutional_barycenter2d_debiased( + A_nx, reg, method=method) else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -750,7 +782,8 @@ def test_unmix(nx): # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix( + ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) @@ -781,10 +814,12 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_emp_sinkhorn = nx.to_numpy( + ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints @@ -817,23 +852,27 @@ def test_lazy_empirical_sinkhorn(nx): ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + f, g = ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) - f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g, log_es = ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2( + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) @@ -865,22 +904,27 @@ def test_empirical_sinkhorn_divergence(nx): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy( + a, b, X_s, X_t, M, M_s, M_t) - emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + emp_sinkhorn_div = nx.to_numpy( + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) - emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) + emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence( + X_s, X_t, 1, a=a, b=b) # check constraints - np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) + np.testing.assert_allclose( + emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn - ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) + ot.bregman.empirical_sinkhorn_divergence( + X_sb, X_tb, 1, a=ab, b=bb, log=True) @pytest.mark.skipif(not torch, reason="No torch available") @@ -902,7 +946,8 @@ def test_empirical_sinkhorn_divergence_gradient(): X_sb.requires_grad = True X_tb.requires_grad = True - emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb) + emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence( + X_sb, X_tb, 1, a=ab, b=bb) emp_sinkhorn_div.backward() @@ -931,7 +976,8 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab, bb, M_nx = nx.from_numpy(a, b, M) - G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G_np, _ = ot.bregman.sinkhorn( + a, b, M, reg=epsilon, method="sinkhorn", log=True) G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True) @@ -996,7 +1042,8 @@ def test_screenkhorn(nx): # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn( + ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) @@ -1013,3 +1060,93 @@ def test_convolutional_barycenter_non_square(nx): np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(b, b_np) + + +def test_sinkhorn_warmstart(): + m, n = 10, 20 + a = ot.unif(m) + b = ot.unif(n) + + Xs = np.arange(m) * 1.0 + Xt = np.arange(n) * 1.0 + M = ot.dist(Xs.reshape(-1, 1), Xt.reshape(-1, 1)) + + # Generate warmstart from dual vectors of unregularized OT + _, log = ot.lp.emd(a, b, M, log=True) + warmstart = (log["u"], log["v"]) + + reg = 1 + + # Optimal plan with uniform warmstart + pi_unif, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn", log=True, warmstart=None) + # Optimal plan with warmstart generated from unregularized OT + pi_sh, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart) + pi_sh_log, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart) + pi_sh_stab, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart) + pi_sh_sc, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart) + + np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_sh_stab, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_sh_sc, atol=1e-05) + + +def test_empirical_sinkhorn_warmstart(): + m, n = 10, 20 + Xs = np.arange(m).reshape(-1, 1) * 1.0 + Xt = np.arange(n).reshape(-1, 1) * 1.0 + M = ot.dist(Xs, Xt) + + # Generate warmstart from dual vectors of unregularized OT + a = ot.unif(m) + b = ot.unif(n) + _, log = ot.lp.emd(a, b, M, log=True) + warmstart = (log["u"], log["v"]) + + reg = 1 + + # Optimal plan with uniform warmstart + f, g, _ = ot.bregman.empirical_sinkhorn( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) + pi_unif = np.exp(f[:, None] + g[None, :] - M / reg) + # Optimal plan with warmstart generated from unregularized OT + f, g, _ = ot.bregman.empirical_sinkhorn( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) + pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg) + pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn( + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) + + np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05) + + +def test_empirical_sinkhorn_divergence_warmstart(): + m, n = 10, 20 + Xs = np.arange(m).reshape(-1, 1) * 1.0 + Xt = np.arange(n).reshape(-1, 1) * 1.0 + M = ot.dist(Xs, Xt) + + # Generate warmstart from dual vectors of unregularized OT + a = ot.unif(m) + b = ot.unif(n) + _, log = ot.lp.emd(a, b, M, log=True) + warmstart = (log["u"], log["v"]) + + reg = 1 + + # Optimal plan with uniform warmstart + sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) + # Optimal plan with warmstart generated from unregularized OT + sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) + sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) + + np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05) + np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05) -- cgit v1.2.3