diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_1d_solver.py | 127 | ||||
-rw-r--r-- | test/test_backend.py | 52 | ||||
-rw-r--r-- | test/test_bregman.py | 315 | ||||
-rw-r--r-- | test/test_coot.py | 359 | ||||
-rw-r--r-- | test/test_da.py | 79 | ||||
-rw-r--r-- | test/test_gaussian.py | 98 | ||||
-rw-r--r-- | test/test_gromov.py | 637 | ||||
-rw-r--r-- | test/test_optim.py | 63 | ||||
-rw-r--r-- | test/test_ot.py | 59 | ||||
-rwxr-xr-x | test/test_partial.py | 124 | ||||
-rw-r--r-- | test/test_sliced.py | 200 | ||||
-rw-r--r-- | test/test_solvers.py | 133 | ||||
-rw-r--r-- | test/test_unbalanced.py | 61 | ||||
-rw-r--r-- | test/test_utils.py | 40 |
14 files changed, 2191 insertions, 156 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 20f307a..21abd1d 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -218,3 +218,130 @@ def test_emd1d_device_tf(): nx.assert_same_dtype_device(xb, emd) nx.assert_same_dtype_device(xb, emd2) assert nx.dtype_device(emd)[1].startswith("GPU") + + +def test_wasserstein_1d_circle(): + # test binary_search_circle and wasserstein_circle give similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + + wass1 = ot.emd2(w_u, w_v, M1) + + wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) + w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + + M2 = M1**2 + wass2 = ot.emd2(w_u, w_v, M2) + wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass1, wass1_bsc) + np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) + np.testing.assert_allclose(wass2, wass2_bsc) + np.testing.assert_allclose(wass2, w2_circle) + + +@pytest.skip_backend("tf") +def test_wasserstein1d_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, w1) + nx.assert_same_dtype_device(xb, w2_bsc) + + +def test_wasserstein_1d_unif_circle(): + # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle + n = 20 + m = 50000 + + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + # w_u = rng.uniform(0., 1., n) + # w_u = w_u / w_u.sum() + + w_u = ot.utils.unif(n) + w_v = ot.utils.unif(m) + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + wass2 = ot.emd2(w_u, w_v, M1**2) + + wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) + wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) + + # check loss is similar + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3) + + +def test_wasserstein1d_unif_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) + + w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) + + nx.assert_same_dtype_device(xb, w2) + + +def test_binary_search_circle_log(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) + optimal_thetas = log["optimal_theta"] + + assert optimal_thetas.shape[0] == 1 + + +def test_wasserstein_circle_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=2) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=1) diff --git a/test/test_backend.py b/test/test_backend.py index 311c075..fd9a761 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -275,11 +275,27 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.sqrtm(M) with pytest.raises(NotImplementedError): + nx.kl_div(M, M) + with pytest.raises(NotImplementedError): nx.isfinite(M) with pytest.raises(NotImplementedError): nx.array_equal(M, M) with pytest.raises(NotImplementedError): nx.is_floating_point(M) + with pytest.raises(NotImplementedError): + nx.tile(M, (10, 1)) + with pytest.raises(NotImplementedError): + nx.floor(M) + with pytest.raises(NotImplementedError): + nx.prod(M) + with pytest.raises(NotImplementedError): + nx.sort2(M) + with pytest.raises(NotImplementedError): + nx.qr(M) + with pytest.raises(NotImplementedError): + nx.atan2(v, v) + with pytest.raises(NotImplementedError): + nx.transpose(M) def test_func_backends(nx): @@ -592,11 +608,47 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("matrix square root") + A = nx.kl_div(nx.abs(Mb), nx.abs(Mb) + 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append("Kullback-Leibler divergence") + A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0) A = nx.isfinite(A) lst_b.append(nx.to_numpy(A)) lst_name.append("isfinite") + A = nx.tile(vb, (10, 1)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("tile") + + A = nx.floor(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("floor") + + A = nx.prod(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("prod") + + A, B = nx.sort2(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("sort2 sort") + lst_b.append(nx.to_numpy(B)) + lst_name.append("sort2 argsort") + + A, B = nx.qr(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("QR Q") + lst_b.append(nx.to_numpy(B)) + lst_name.append("QR R") + + A = nx.atan2(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("atan2") + + A = nx.transpose(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("transpose") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_bregman.py b/test/test_bregman.py index 6c37984..f01bb14 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -3,9 +3,11 @@ # Author: Remi Flamary <remi.flamary@unice.fr> # Kilian Fatras <kilian.fatras@irisa.fr> # Quang Huy Tran <quang-huy.tran@univ-ubs.fr> +# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> # # License: MIT License +import warnings from itertools import product import numpy as np @@ -57,7 +59,12 @@ def test_convergence_warning(method): with pytest.warns(UserWarning): ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) with pytest.warns(UserWarning): - ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + ot.sinkhorn2(a1, a2, M, 1, method=method, + stopThr=0, numItermax=1, warn=True) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ot.sinkhorn2(a1, a2, M, 1, method=method, + stopThr=0, numItermax=1, warn=False) def test_not_implemented_method(): @@ -261,12 +268,16 @@ def test_sinkhorn_variants(nx): ub, M_nx = nx.from_numpy(u, M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -366,9 +377,12 @@ def test_sinkhorn_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -394,9 +408,12 @@ def test_sinkhorn2_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -414,12 +431,16 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) - Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', + stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) - G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) + G_green, loggreen = ot.sinkhorn( + u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) @@ -441,7 +462,8 @@ def test_sinkhorn_variants_log_multib(verbose, warn): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', + stopThr=1e-10, log=True) Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, verbose=verbose, warn=warn) Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, @@ -480,8 +502,73 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) else: # wasserstein - bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass_np = ot.bregman.barycenter( + A, M, reg, weights, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter( + A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + +def test_free_support_sinkhorn_barycenter(): + measures_locations = [ + np.array([-1.]).reshape((1, 1)), # First dirac support + np.array([1.]).reshape((1, 1)) # Second dirac support + ] + + measures_weights = [ + np.array([1.]), # First dirac sample weights + np.array([1.]) # Second dirac sample weights + ] + + # Barycenter initialization + X_init = np.array([-12.]).reshape((1, 1)) + + # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter + bar_locations = np.array([0.]).reshape((1, 1)) + + # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization + # term to 1, but this should be, in general, fine-tuned to the problem. + X = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations, measures_weights, X_init, reg=1) + + # Verifies if calculated barycenter matches ground-truth + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_assymetric_cost(nx, method, verbose, warn): + n_bins = 20 # nb bins + + # Gaussian distributions + A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + + # creating matrix A containing all distributions + A = A[:, None] + + # assymetric loss matrix + normalization + rng = np.random.RandomState(42) + M = rng.randn(n_bins, n_bins) ** 2 + M /= M.max() + + A_nx, M_nx = nx.from_numpy(A, M) + reg = 1e-2 + + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter( + A, M, reg, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter( + A_nx, M_nx, reg, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -516,17 +603,20 @@ def test_barycenter_debiased(nx, method, verbose, warn): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) + ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, weights, method=method) else: bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass, _ = ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) + ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, log=True, verbose=False) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) @@ -551,7 +641,8 @@ def test_convergence_warning_barycenters(method): weights = np.array([1 - alpha, alpha]) reg = 0.1 with pytest.warns(UserWarning): - ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + ot.bregman.barycenter_debiased( + A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): @@ -583,7 +674,8 @@ def test_barycenter_stabilization(nx): # wasserstein reg = 1e-2 - bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) + bar_np = ot.bregman.barycenter( + A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) bar_stable = nx.to_numpy(ot.bregman.barycenter( A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", stopThr=1e-8, verbose=True @@ -618,8 +710,10 @@ def test_wasserstein_bary_2d(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -648,10 +742,13 @@ def test_wasserstein_bary_2d_debiased(nx, method): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + ot.bregman.convolutional_barycenter2d_debiased( + A_nx, reg, method=method) else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -685,7 +782,8 @@ def test_unmix(nx): # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix( + ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) @@ -716,10 +814,12 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_emp_sinkhorn = nx.to_numpy( + ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints @@ -752,23 +852,27 @@ def test_lazy_empirical_sinkhorn(nx): ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + f, g = ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) - f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g, log_es = ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2( + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) @@ -800,22 +904,57 @@ def test_empirical_sinkhorn_divergence(nx): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy( + a, b, X_s, X_t, M, M_s, M_t) - emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + emp_sinkhorn_div = nx.to_numpy( + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) - emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) + emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence( + X_s, X_t, 1, a=a, b=b) # check constraints - np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) + np.testing.assert_allclose( + emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn - ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) + ot.bregman.empirical_sinkhorn_divergence( + X_sb, X_tb, 1, a=ab, b=bb, log=True) + + +@pytest.mark.skipif(not torch, reason="No torch available") +def test_empirical_sinkhorn_divergence_gradient(): + # Test sinkhorn divergence + n = 10 + a = np.linspace(1, n, n) + a /= a.sum() + b = ot.unif(n) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1)) + + nx = ot.backend.TorchBackend() + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + ab.requires_grad = True + bb.requires_grad = True + X_sb.requires_grad = True + X_tb.requires_grad = True + + emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence( + X_sb, X_tb, 1, a=ab, b=bb) + + emp_sinkhorn_div.backward() + + assert ab.grad is not None + assert bb.grad is not None + assert X_sb.grad is not None + assert X_tb.grad is not None def test_stabilized_vs_sinkhorn_multidim(nx): @@ -837,7 +976,8 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab, bb, M_nx = nx.from_numpy(a, b, M) - G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G_np, _ = ot.bregman.sinkhorn( + a, b, M, reg=epsilon, method="sinkhorn", log=True) G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True) @@ -902,7 +1042,8 @@ def test_screenkhorn(nx): # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn( + ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) @@ -919,3 +1060,93 @@ def test_convolutional_barycenter_non_square(nx): np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(b, b_np) + + +def test_sinkhorn_warmstart(): + m, n = 10, 20 + a = ot.unif(m) + b = ot.unif(n) + + Xs = np.arange(m) * 1.0 + Xt = np.arange(n) * 1.0 + M = ot.dist(Xs.reshape(-1, 1), Xt.reshape(-1, 1)) + + # Generate warmstart from dual vectors of unregularized OT + _, log = ot.lp.emd(a, b, M, log=True) + warmstart = (log["u"], log["v"]) + + reg = 1 + + # Optimal plan with uniform warmstart + pi_unif, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn", log=True, warmstart=None) + # Optimal plan with warmstart generated from unregularized OT + pi_sh, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart) + pi_sh_log, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart) + pi_sh_stab, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart) + pi_sh_sc, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart) + + np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_sh_stab, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_sh_sc, atol=1e-05) + + +def test_empirical_sinkhorn_warmstart(): + m, n = 10, 20 + Xs = np.arange(m).reshape(-1, 1) * 1.0 + Xt = np.arange(n).reshape(-1, 1) * 1.0 + M = ot.dist(Xs, Xt) + + # Generate warmstart from dual vectors of unregularized OT + a = ot.unif(m) + b = ot.unif(n) + _, log = ot.lp.emd(a, b, M, log=True) + warmstart = (log["u"], log["v"]) + + reg = 1 + + # Optimal plan with uniform warmstart + f, g, _ = ot.bregman.empirical_sinkhorn( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) + pi_unif = np.exp(f[:, None] + g[None, :] - M / reg) + # Optimal plan with warmstart generated from unregularized OT + f, g, _ = ot.bregman.empirical_sinkhorn( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) + pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg) + pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn( + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) + + np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05) + + +def test_empirical_sinkhorn_divergence_warmstart(): + m, n = 10, 20 + Xs = np.arange(m).reshape(-1, 1) * 1.0 + Xt = np.arange(n).reshape(-1, 1) * 1.0 + M = ot.dist(Xs, Xt) + + # Generate warmstart from dual vectors of unregularized OT + a = ot.unif(m) + b = ot.unif(n) + _, log = ot.lp.emd(a, b, M, log=True) + warmstart = (log["u"], log["v"]) + + reg = 1 + + # Optimal plan with uniform warmstart + sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) + # Optimal plan with warmstart generated from unregularized OT + sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) + sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) + + np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05) + np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05) diff --git a/test/test_coot.py b/test/test_coot.py new file mode 100644 index 0000000..ef68a9b --- /dev/null +++ b/test/test_coot.py @@ -0,0 +1,359 @@ +"""Tests for module COOT on OT """ + +# Author: Quang Huy Tran <quang-huy.tran@univ-ubs.fr> +# +# License: MIT License + +import numpy as np +import ot +from ot.coot import co_optimal_transport as coot +from ot.coot import co_optimal_transport2 as coot2 +import pytest + + +@pytest.mark.parametrize("verbose", [False, True, 1, 0]) +def test_coot(nx, verbose): + n_samples = 60 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + # test couplings + pi_sample, pi_feature = coot(X=xs, Y=xt, verbose=verbose) + pi_sample_nx, pi_feature_nx = coot(X=xs_nx, Y=xt_nx, verbose=verbose) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04) + np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test COOT distance + + coot_np = coot2(X=xs, Y=xt, verbose=verbose) + coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, verbose=verbose)) + np.testing.assert_allclose(coot_np, 0, atol=1e-08) + np.testing.assert_allclose(coot_nx, 0, atol=1e-08) + + +def test_entropic_coot(nx): + n_samples = 60 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + epsilon = (1, 1e-1) + nits_ot = 2000 + + # test couplings + pi_sample, pi_feature = coot(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot) + pi_sample_nx, pi_feature_nx = coot( + X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-04) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test entropic COOT distance + + coot_np = coot2(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot) + coot_nx = nx.to_numpy( + coot2(X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot)) + + np.testing.assert_allclose(coot_np, coot_nx, atol=1e-08) + + +def test_coot_with_linear_terms(nx): + n_samples = 60 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + M_feat = np.ones((2, 2)) + np.fill_diagonal(M_feat, 0) + M_samp_nx, M_feat_nx = nx.from_numpy(M_samp), nx.from_numpy(M_feat) + + alpha = (1, 2) + + # test couplings + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + pi_sample, pi_feature = coot( + X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat) + pi_sample_nx, pi_feature_nx = coot( + X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04) + np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test COOT distance + + coot_np = coot2(X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat) + coot_nx = nx.to_numpy( + coot2(X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx)) + np.testing.assert_allclose(coot_np, 0, atol=1e-08) + np.testing.assert_allclose(coot_nx, 0, atol=1e-08) + + +def test_coot_raise_value_error(nx): + n_samples = 80 # nb samples + + mu_s = np.array([2, 4]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=43) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + # raise value error of method sinkhorn + def coot_sh(method_sinkhorn): + return coot(X=xs, Y=xt, method_sinkhorn=method_sinkhorn) + + def coot_sh_nx(method_sinkhorn): + return coot(X=xs_nx, Y=xt_nx, method_sinkhorn=method_sinkhorn) + + np.testing.assert_raises(ValueError, coot_sh, "not_sinkhorn") + np.testing.assert_raises(ValueError, coot_sh_nx, "not_sinkhorn") + + # raise value error for epsilon + def coot_eps(epsilon): + return coot(X=xs, Y=xt, epsilon=epsilon) + + def coot_eps_nx(epsilon): + return coot(X=xs_nx, Y=xt_nx, epsilon=epsilon) + + np.testing.assert_raises(ValueError, coot_eps, (1, 2, 3)) + np.testing.assert_raises(ValueError, coot_eps_nx, [1, 2, 3, 4]) + + # raise value error for alpha + def coot_alpha(alpha): + return coot(X=xs, Y=xt, alpha=alpha) + + def coot_alpha_nx(alpha): + return coot(X=xs_nx, Y=xt_nx, alpha=alpha) + + np.testing.assert_raises(ValueError, coot_alpha, [1]) + np.testing.assert_raises(ValueError, coot_alpha_nx, np.arange(4)) + + +def test_coot_warmstart(nx): + n_samples = 80 # nb samples + + mu_s = np.array([2, 3]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=125) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + # initialize warmstart + init_pi_sample = np.random.rand(n_samples, n_samples) + init_pi_sample = init_pi_sample / np.sum(init_pi_sample) + init_pi_sample_nx = nx.from_numpy(init_pi_sample) + + init_pi_feature = np.random.rand(2, 2) + init_pi_feature /= init_pi_feature / np.sum(init_pi_feature) + init_pi_feature_nx = nx.from_numpy(init_pi_feature) + + init_duals_sample = (np.random.random(n_samples) * 2 - 1, + np.random.random(n_samples) * 2 - 1) + init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]), + nx.from_numpy(init_duals_sample[1])) + + init_duals_feature = (np.random.random(2) * 2 - 1, + np.random.random(2) * 2 - 1) + init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]), + nx.from_numpy(init_duals_feature[1])) + + warmstart = { + "pi_sample": init_pi_sample, + "pi_feature": init_pi_feature, + "duals_sample": init_duals_sample, + "duals_feature": init_duals_feature + } + + warmstart_nx = { + "pi_sample": init_pi_sample_nx, + "pi_feature": init_pi_feature_nx, + "duals_sample": init_duals_sample_nx, + "duals_feature": init_duals_feature_nx + } + + # test couplings + pi_sample, pi_feature = coot(X=xs, Y=xt, warmstart=warmstart) + pi_sample_nx, pi_feature_nx = coot( + X=xs_nx, Y=xt_nx, warmstart=warmstart_nx) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04) + np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test COOT distance + coot_np = coot2(X=xs, Y=xt, warmstart=warmstart) + coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, warmstart=warmstart_nx)) + np.testing.assert_allclose(coot_np, 0, atol=1e-08) + np.testing.assert_allclose(coot_nx, 0, atol=1e-08) + + +def test_coot_log(nx): + n_samples = 90 # nb samples + + mu_s = np.array([-2, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=43) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + pi_sample, pi_feature, log = coot(X=xs, Y=xt, log=True) + pi_sample_nx, pi_feature_nx, log_nx = coot(X=xs_nx, Y=xt_nx, log=True) + + duals_sample, duals_feature = log["duals_sample"], log["duals_feature"] + assert len(duals_sample) == 2 + assert len(duals_feature) == 2 + assert len(duals_sample[0]) == n_samples + assert len(duals_sample[1]) == n_samples + assert len(duals_feature[0]) == 2 + assert len(duals_feature[1]) == 2 + + duals_sample_nx = log_nx["duals_sample"] + assert len(duals_sample_nx) == 2 + assert len(duals_sample_nx[0]) == n_samples + assert len(duals_sample_nx[1]) == n_samples + + duals_feature_nx = log_nx["duals_feature"] + assert len(duals_feature_nx) == 2 + assert len(duals_feature_nx[0]) == 2 + assert len(duals_feature_nx[1]) == 2 + + list_coot = log["distances"] + assert len(list_coot) >= 1 + + list_coot_nx = log_nx["distances"] + assert len(list_coot_nx) >= 1 + + # test with coot distance + coot_np, log = coot2(X=xs, Y=xt, log=True) + coot_nx, log_nx = coot2(X=xs_nx, Y=xt_nx, log=True) + + duals_sample, duals_feature = log["duals_sample"], log["duals_feature"] + assert len(duals_sample) == 2 + assert len(duals_feature) == 2 + assert len(duals_sample[0]) == n_samples + assert len(duals_sample[1]) == n_samples + assert len(duals_feature[0]) == 2 + assert len(duals_feature[1]) == 2 + + duals_sample_nx = log_nx["duals_sample"] + assert len(duals_sample_nx) == 2 + assert len(duals_sample_nx[0]) == n_samples + assert len(duals_sample_nx[1]) == n_samples + + duals_feature_nx = log_nx["duals_feature"] + assert len(duals_feature_nx) == 2 + assert len(duals_feature_nx[0]) == 2 + assert len(duals_feature_nx[1]) == 2 + + list_coot = log["distances"] + assert len(list_coot) >= 1 + + list_coot_nx = log_nx["distances"] + assert len(list_coot_nx) >= 1 diff --git a/test/test_da.py b/test/test_da.py index 4bf0ab1..c5f08d6 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -44,12 +44,32 @@ def test_class_jax_tf(): @pytest.skip_backend("jax") @pytest.skip_backend("tf") +@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport]) +def test_log_da(nx, class_to_test): + + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + + otda = class_to_test(log=True) + + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + assert hasattr(otda, "log_") + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx): """ ns = 50 - nt = 100 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -402,8 +422,8 @@ def test_emd_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -557,30 +577,9 @@ def test_mapping_transport_class_specific_seed(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") -def test_linear_mapping(nx): - ns = 150 - nt = 200 - - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) - - Xsb, Xtb = nx.from_numpy(Xs, Xt) - - A, b = ot.da.OT_mapping_linear(Xsb, Xtb) - - Xst = nx.to_numpy(nx.dot(Xsb, A) + b) - - Ct = np.cov(Xt.T) - Cst = np.cov(Xst.T) - - np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) - - -@pytest.skip_backend("jax") -@pytest.skip_backend("tf") def test_linear_mapping_class(nx): - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -609,9 +608,9 @@ def test_jcpot_transport_class(nx): """test_jcpot_transport """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 Xs1, ys1 = make_data_classif('3gauss', ns1) Xs2, ys2 = make_data_classif('3gauss', ns2) @@ -681,9 +680,9 @@ def test_jcpot_barycenter(nx): """test_jcpot_barycenter """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 sigma = 0.1 np.random.seed(1985) @@ -713,8 +712,8 @@ def test_jcpot_barycenter(nx): def test_emd_laplace_class(nx): """test_emd_laplace_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) diff --git a/test/test_gaussian.py b/test/test_gaussian.py new file mode 100644 index 0000000..be7a806 --- /dev/null +++ b/test/test_gaussian.py @@ -0,0 +1,98 @@ +"""Tests for module gaussian""" + +# Author: Theo Gnassounou <theo.gnassounou@inria.fr> +# Remi Flamary <remi.flamary@polytehnique.edu> +# +# License: MIT License + +import numpy as np + +import pytest + +import ot +from ot.datasets import make_data_classif + + +def test_bures_wasserstein_mapping(nx): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + Cs = np.cov(Xs.T) + Ct = np.cov(Xt.T) + + Xsb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, ms, mt, Cs, Ct) + + A_log, b_log, log = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) + A, b = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=False) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) + Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log) + + Cst = np.cov(Xst.T) + Cst_log = np.cov(Xst_log.T) + + np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_empirical_bures_wasserstein_mapping(nx, bias): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + if not bias: + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + + Xs = Xs - ms + Xt = Xt - mt + + Xsb, Xtb = nx.from_numpy(Xs, Xt) + + A, b, log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=True, bias=bias) + A_log, b_log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=False, bias=bias) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) + Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log) + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + Cst_log = np.cov(Xst_log.T) + + np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +def test_bures_wasserstein_distance(nx): + ms, mt = np.array([0]), np.array([10]) + Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32) + msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct) + Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True) + Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False) + + np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_empirical_bures_wasserstein_distance(nx, bias): + ns = 400 + nt = 400 + + rng = np.random.RandomState(10) + Xs = rng.normal(0, 1, ns)[:, np.newaxis] + Xt = rng.normal(10 * bias, 1, nt)[:, np.newaxis] + + Xsb, Xtb = nx.from_numpy(Xs, Xt) + Wb_log, log = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=True, bias=bias) + Wb = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=False, bias=bias) + + np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) diff --git a/test/test_gromov.py b/test/test_gromov.py index 9c85b92..80b6df4 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -3,7 +3,7 @@ # Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License
@@ -11,18 +11,15 @@ import numpy as np import ot
from ot.backend import NumpyBackend
from ot.backend import torch, tf
-
import pytest
def test_gromov(nx):
n_samples = 50 # nb samples
-
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
-
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
xt = xs[::-1].copy()
p = ot.unif(n_samples)
@@ -38,7 +35,7 @@ def test_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True)
- Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True))
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=G0b, verbose=True))
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -51,13 +48,13 @@ def test_gromov(nx): np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
- gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
- gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=True)
gwb = nx.to_numpy(gwb)
- gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False)
gw_valb = nx.to_numpy(
- ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
)
G = log['T']
@@ -77,6 +74,49 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_gromov(nx):
+ n_samples = 30 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+
+ G, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)
+
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)
+
+
def test_gromov_dtype_device(nx):
# setup
n_samples = 50 # nb samples
@@ -104,7 +144,7 @@ def test_gromov_dtype_device(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -130,7 +170,7 @@ def test_gromov_device_tf(): with tf.device("/CPU:0"):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -138,7 +178,7 @@ def test_gromov_device_tf(): # Check that everything happens on the GPU
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
assert nx.dtype_device(Gb)[1].startswith("GPU")
@@ -174,6 +214,7 @@ def test_gromov2_gradients(): C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)
+ # Test with exact line-search
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
val.backward()
@@ -184,6 +225,60 @@ def test_gromov2_gradients(): assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
+ # Test with armijo line-search
+ q1.grad = None
+ p1.grad = None
+ C11.grad = None
+ C12.grad = None
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
+def test_gw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss')
+
+ def f(G):
+ return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None)
+
+ def df(G):
+ return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
@@ -199,19 +294,21 @@ def test_entropic_gromov(nx): p = ot.unif(n_samples)
q = ot.unif(n_samples)
-
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
C1 /= C1.max()
C2 /= C2.max()
- C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
- G = ot.gromov.entropic_gromov_wasserstein(
- C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+ G, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-2, verbose=True, log=True)
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
- C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None,
+ epsilon=1e-2, verbose=True, log=False
))
# check constraints
@@ -222,9 +319,11 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
- C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
+ C1, C2, p, q, 'kl_loss', symmetric=True, G0=None,
+ max_iter=10, epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
- C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
+ C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-2, log=True)
gwb = nx.to_numpy(gwb)
G = log['T']
@@ -241,6 +340,45 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_entropic_gromov(nx):
+ n_samples = 10 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-1, verbose=True, log=False)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None,
+ epsilon=1e-1, verbose=True, log=False
+ ))
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', symmetric=False, G0=None,
+ max_iter=10, epsilon=1e-1, log=False)
+ gwb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-1, log=False)
+ gwb = nx.to_numpy(gwb)
+
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
@@ -539,8 +677,8 @@ def test_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
- G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True)
- Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -555,8 +693,8 @@ def test_fgw(nx): np.testing.assert_allclose(
Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
- fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True)
- fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True)
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True)
fgwb = nx.to_numpy(fgwb)
G = log['T']
@@ -573,6 +711,82 @@ def test_fgw(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_fgw(nx):
+ n_samples = 50 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ # add features
+ F1 = np.random.uniform(low=0., high=10, size=(n_samples, 1))
+ F2 = F1[idx, :]
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ M = ot.dist(F1, F2)
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ # Tests with kl-loss:
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+
def test_fgw2_gradients():
n_samples = 20 # nb samples
@@ -617,6 +831,57 @@ def test_fgw2_gradients(): assert M1.shape == M1.grad.shape
+def test_fgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ ys = np.random.randn(xs.shape[0], 2)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+ yt = np.random.randn(xt.shape[0], 2)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+ alpha = 0.5
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss')
+
+ def f(G):
+ return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None)
+
+ def df(G):
+ return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+ np.testing.assert_allclose(res_armijo, Gb, atol=1e-06)
+
+
def test_fgw_barycenter(nx):
np.random.seed(42)
@@ -1186,3 +1451,327 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): # > Compare results with/without backend
total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05)
+
+
+def test_semirelaxed_gromov(nx):
+ np.random.seed(0)
+ # unbalanced proportions
+ list_n = [30, 15]
+ nt = 2
+ ns = np.sum(list_n)
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns), dtype=np.float64)
+ C2 = np.array([[0.8, 0.05],
+ [0.05, 1.]], dtype=np.float64)
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ p = ot.unif(ns, type_as=C1)
+ q0 = ot.unif(C2.shape[0], type_as=C1)
+ G0 = p[:, None] * q0[None, :]
+ # asymmetric
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0)
+ Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=False, log=True, G0=None)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+def test_semirelaxed_gromov2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ # semirelaxed solvers do not support gradients over masses yet.
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
+def test_srgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
+ Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, 'square_loss', armijo=False, symmetric=True, G0=None, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss')
+ ones_pb = nx.ones(pb.shape[0], type_as=pb)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1b, C2b, ones_pb, 0., 1., nx=None)
+ # feed the precomputed local optimum Gb to semirelaxed_cg
+ res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+
+
+def test_semirelaxed_fgw(nx):
+ np.random.seed(0)
+ list_n = [16, 8]
+ nt = 2
+ ns = 24
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns))
+ C2 = np.array([[0.7, 0.05],
+ [0.05, 0.9]])
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ F1 = np.zeros((ns, 1))
+ F1[:16] = np.random.normal(loc=0., scale=0.01, size=(16, 1))
+ F1[16:] = np.random.normal(loc=1., scale=0.01, size=(8, 1))
+ F2 = np.zeros((2, 1))
+ F2[1, :] = 1.
+ M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T)
+
+ p = ot.unif(ns)
+ q0 = ot.unif(C2.shape[0])
+ G0 = p[:, None] * q0[None, :]
+
+ # asymmetric
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+ G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+def test_semirelaxed_fgw2_gradients():
+ n_samples = 20 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ M = ot.dist(xs, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ # semirelaxed solvers do not support gradients over masses yet.
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
+def test_srfgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ ys = np.random.randn(xs.shape[0], 2)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+ yt = np.random.randn(xt.shape[0], 2)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+ alpha = 0.5
+ Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss')
+ ones_pb = nx.ones(pb.shape[0], type_as=pb)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None)
+ # feed the precomputed local optimum Gb to semirelaxed_cg
+ res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
diff --git a/test/test_optim.py b/test/test_optim.py index 67e9d13..a43e704 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -120,31 +120,33 @@ def test_generalized_conditional_gradient(nx): Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(Gb, G, atol=1e-12) np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05) np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05) def test_solve_1d_linesearch_quad_funct(): - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1), 0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5), 0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5), 1) def test_line_search_armijo(nx): xk = np.array([[0.25, 0.25], [0.25, 0.25]]) pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) - old_fval = -123 + old_fval = -123. xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk) + def f(x): + return 1. # Should not throw an exception and return 0. for alpha alpha, a, b = ot.optim.line_search_armijo( - lambda x: 1, xkb, pkb, gfkb, old_fval + f, xkb, pkb, gfkb, old_fval ) alpha_np, anp, bnp = ot.optim.line_search_armijo( - lambda x: 1, xk, pk, gfk, old_fval + f, xk, pk, gfk, old_fval ) assert a == anp assert b == bnp @@ -182,3 +184,50 @@ def test_line_search_armijo(nx): old_fval = f(xk) alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) np.testing.assert_allclose(alpha, 0.1) + + +def test_line_search_armijo_dtype_device(nx): + for tp in nx.__type_list__: + def f(x): + return nx.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = np.array([[[-5.0, -5.0]]]) + pk = np.array([[[100.0, 100.0]]]) + xkb, pkb = nx.from_numpy(xk, pk, type_as=tp) + gfkb = grad(xkb) + old_fval = f(xkb) + + # chech the case where the optimum is on the direction + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + np.testing.assert_allclose(alpha, 0.1) + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where the direction is not far enough + pk = np.array([[[3.0, 3.0]]]) + pkb = nx.from_numpy(pk, type_as=tp) + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0) + alpha = nx.to_numpy(alpha) + np.testing.assert_allclose(alpha, 1.0) + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where checking the wrong direction + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, -pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + + assert alpha <= 0 + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where the point is not a vector + xkb = nx.from_numpy(np.array(-5.0), type_as=tp) + pkb = nx.from_numpy(np.array(100), type_as=tp) + gfkb = grad(xkb) + old_fval = f(xkb) + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + + np.testing.assert_allclose(alpha, 0.1) + nx.assert_same_dtype_device(old_fval, fval) diff --git a/test/test_ot.py b/test/test_ot.py index bf832f6..f2338ac 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + # test emd and emd2 for mass mismatch + a = ot.utils.unif(n_samples) b = a.copy() a[0] = 100 np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + np.testing.assert_raises(AssertionError, ot.emd2, a, b, M) def test_emd_backends(nx): @@ -201,6 +204,22 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) +def test_omp_emd2(): + # test emd2 and emd2 with openmp for simple identity + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + w = ot.emd2(u, u, M) + w2 = ot.emd2(u, u, M, numThreads=2) + + np.testing.assert_allclose(w, w2) + + def test_emd_empty(): # test emd and emd2 for simple identity n = 100 @@ -320,6 +339,46 @@ def test_free_support_barycenter_backends(nx): np.testing.assert_allclose(X, nx.to_numpy(X2)) +def test_generalised_free_support_barycenter(): + np.random.seed(42) # random inits + X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0 + a = [np.array([1.]), np.array([1.])] + + P = [np.eye(2), np.eye(2)] + + Y_init = np.array([-12., 7.]).reshape((1, 2)) + + # obvious barycenter location between two 2D diracs + Y_true = np.array([0., .0]).reshape((1, 2)) + + # test without log and no init + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + # test with log and init + Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + +def test_generalised_free_support_barycenter_backends(nx): + np.random.seed(42) + X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + a = [np.array([1.]), np.array([1.])] + P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + Y_init = np.array([-12.]).reshape((1, 1)) + + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init) + + X2 = nx.from_numpy(*X) + a2 = nx.from_numpy(*a) + P2 = nx.from_numpy(*P) + Y_init2 = nx.from_numpy(Y_init) + + Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2) + + np.testing.assert_allclose(Y, nx.to_numpy(Y2)) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] diff --git a/test/test_partial.py b/test/test_partial.py index 97c611b..86f9e62 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -8,6 +8,7 @@ import numpy as np import scipy as sp import ot +from ot.backend import to_numpy, torch import pytest @@ -79,8 +80,10 @@ def test_partial_wasserstein_lagrange(): w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True) + w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True) -def test_partial_wasserstein(): + +def test_partial_wasserstein(nx): n_samples = 20 # nb samples (gaussian) n_noise = 20 # nb of samples (noise) @@ -100,25 +103,20 @@ def test_partial_wasserstein(): m = 0.5 + p, q, M = nx.from_numpy(p, q, M) + w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True) - w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, - log=True, verbose=True) + w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True) # check constraints - np.testing.assert_equal( - w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_equal( - w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) # check transported mass - np.testing.assert_allclose( - np.sum(w0), m, atol=1e-04) - np.testing.assert_allclose( - np.sum(w), m, atol=1e-04) + np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04) + np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04) w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True) w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False) @@ -128,15 +126,91 @@ def test_partial_wasserstein(): np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1) # check constraints - np.testing.assert_equal( - G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_allclose( - np.sum(G), m, atol=1e-04) + np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04) + + empty_array = nx.zeros(0, type_as=M) + w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None) + + # check constraints + np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q)) + + # check transported mass + np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04) + + w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None) + + # check constraints + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + + # check transported mass + np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04) + + +def test_partial_wasserstein2_gradient(): + if torch: + n_samples = 40 + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + + M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64) + + p = torch.tensor(ot.unif(n_samples), dtype=torch.float64) + q = torch.tensor(ot.unif(n_samples), dtype=torch.float64) + + m = 0.5 + + w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True) + + w.backward() + + assert M.grad is not None + assert M.grad.shape == M.shape + + +def test_entropic_partial_wasserstein_gradient(): + if torch: + n_samples = 40 + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + + M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64) + + p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64) + q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64) + + m = 0.5 + reg = 1 + + _, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True) + + log['partial_w_dist'].backward() + + assert M.grad is not None + assert p.grad is not None + assert q.grad is not None + assert M.grad.shape == M.shape + assert p.grad.shape == p.shape + assert q.grad.shape == q.shape def test_partial_gromov_wasserstein(): + rng = np.random.RandomState(seed=42) n_samples = 20 # nb samples n_noise = 10 # nb of samples (noise) @@ -149,11 +223,11 @@ def test_partial_gromov_wasserstein(): mu_t = np.array([0, 0, 0]) cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) - xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, rng) + xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0) P = sp.linalg.sqrtm(cov_t) - xt = np.random.randn(n_samples, 3).dot(P) + mu_t - xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0) + xt = rng.randn(n_samples, 3).dot(P) + mu_t + xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0) xt2 = xs[::-1].copy() C1 = ot.dist(xs, xs) diff --git a/test/test_sliced.py b/test/test_sliced.py index 08ab4fb..f54c799 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -110,6 +110,20 @@ def test_max_sliced_different_dists(): assert res > 0. +def test_sliced_same_proj(): + n_projections = 10 + seed = 12 + rng = np.random.RandomState(0) + X = rng.randn(8, 2) + Y = rng.randn(8, 2) + cost1, log1 = ot.sliced_wasserstein_distance(X, Y, seed=seed, n_projections=n_projections, log=True) + P = get_random_projections(X.shape[1], n_projections=10, seed=seed) + cost2, log2 = ot.sliced_wasserstein_distance(X, Y, projections=P, log=True) + + assert np.allclose(log1['projections'], log2['projections']) + assert np.isclose(cost1, cost2) + + def test_sliced_backend(nx): n = 100 @@ -252,3 +266,189 @@ def test_max_sliced_backend_device_tf(): valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") + + +def test_projections_stiefel(): + rng = np.random.RandomState(0) + + n_projs = 500 + x = np.random.randn(100, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs, + seed=rng, log=True) + + P = log["projections"] + P_T = np.transpose(P, [0, 2, 1]) + np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)])) + + +def test_sliced_sphere_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_sphere_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_sphere_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + a = rng.uniform(0, 1, n) + a /= a.sum() + + y = rng.randn(m, 2) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi) + u = ot.utils.unif(m) + + res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2) + expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2) + + res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1) + expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1) + + np.testing.assert_almost_equal(res ** 2, expected) + np.testing.assert_almost_equal(res1, expected1, decimal=3) + + +@pytest.skip_backend("tf") +def test_sliced_sphere_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(2 * n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, yb = nx.from_numpy(x, y, type_as=tp) + + valb = ot.sliced_wasserstein_sphere(xb, yb) + + nx.assert_same_dtype_device(xb, valb) + + +def test_sliced_sphere_unif_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng) + + +def test_sliced_sphere_unif_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_unif_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + + valb = ot.sliced_wasserstein_sphere_unif(xb) + + nx.assert_same_dtype_device(xb, valb) diff --git a/test/test_solvers.py b/test/test_solvers.py new file mode 100644 index 0000000..b792aca --- /dev/null +++ b/test/test_solvers.py @@ -0,0 +1,133 @@ +"""Tests for ot solvers""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + + +import itertools +import numpy as np +import pytest + +import ot + + +lst_reg = [None, 1.0] +lst_reg_type = ['KL', 'entropy', 'L2'] +lst_unbalanced = [None, 0.9] +lst_unbalanced_type = ['KL', 'L2', 'TV'] + + +def assert_allclose_sol(sol1, sol2): + + lst_attr = ['value', 'value_linear', 'plan', + 'potential_a', 'potential_b', 'marginal_a', 'marginal_b'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +def test_solve(nx): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = ot.dist(x, y) + + # solve unif weights + sol0 = ot.solve(M) + + print(sol0) + + # solve signe weights + sol = ot.solve(M, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + ab, bb, Mb = nx.from_numpy(a, b, M) + solb = ot.solve(M, a, b) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence') + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = ot.dist(x, y) + + try: + + # solve unif weights + sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + # solve signe weights + sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + assert_allclose_sol(sol0, sol) + + # solve in backend + ab, bb, Mb = nx.from_numpy(a, b, M) + solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + assert_allclose_sol(sol, solb) + except NotImplementedError: + pass + + +def test_solve_not_implemented(nx): + + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + + M = ot.dist(x, y) + + # test not implemented and check raise + with pytest.raises(NotImplementedError): + ot.solve(M, reg=1.0, reg_type='cryptic divergence') + with pytest.raises(NotImplementedError): + ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence') + + # pairs of incompatible divergences + with pytest.raises(NotImplementedError): + ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 02b3fc3..b76d738 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -5,6 +5,7 @@ # # License: MIT License +import itertools import numpy as np import ot import pytest @@ -289,32 +290,55 @@ def test_implemented_methods(nx): method=method) +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_unbalanced(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + np.testing.assert_allclose(G, nx.to_numpy(Gb)) + + def test_mm_convergence(nx): n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) rng = np.random.RandomState(75) y = rng.randn(n, 2) - a = ot.utils.unif(n) - b = ot.utils.unif(n) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) M = ot.dist(x, y) M = M / M.max() reg_m = 100 - a, b, M = nx.from_numpy(a, b, M) + a, b, M = nx.from_numpy(a_np, b_np, M) G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', - verbose=True, log=True) - loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2( - a, b, M, reg_m, div='kl', verbose=True)) + verbose=False, log=True) + loss_kl = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True) + ) G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False, log=True) # check if the marginals come close to the true ones when large reg - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03) # check if mm_unbalanced2 returns the correct loss np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl, @@ -324,15 +348,16 @@ def test_mm_convergence(nx): a_np, b_np = np.array([]), np.array([]) a, b = nx.from_numpy(a_np, b_np) - G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl') - G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2') - np.testing.assert_allclose(G_kl_null, G_kl) - np.testing.assert_allclose(G_l2_null, G_l2) + G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False) + G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False) + np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl)) + np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2)) # test when G0 is given G0 = ot.emd(a, b, M) + G0_np = nx.to_numpy(G0) reg_m = 10000 - G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0) - G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0) - np.testing.assert_allclose(G0, G_kl, atol=1e-05) - np.testing.assert_allclose(G0, G_l2, atol=1e-05) + G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False) + G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False) + np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05) + np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05) diff --git a/test/test_utils.py b/test/test_utils.py index 3cfd295..31b12ef 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -143,6 +143,7 @@ def test_dist(): for metric in metrics_w: print(metric) ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, ))) + ot.dist(x, x, metric=metric, p=3, w=None) # check that not having any weight does not cause issues for metric in metrics: print(metric) ot.dist(x, x, metric=metric, p=3) @@ -300,3 +301,42 @@ def test_BaseEstimator(): cl.set_params(bibi=10) assert cl.first == 'spam again' + + +def test_OTResult(): + + res = ot.utils.OTResult() + + # test print + print(res) + + # tets get citation + print(res.citation) + + lst_attributes = ['a_to_b', + 'b_to_a', + 'lazy_plan', + 'marginal_a', + 'marginal_b', + 'marginals', + 'plan', + 'potential_a', + 'potential_b', + 'potentials', + 'sparse_plan', + 'status', + 'value', + 'value_linear'] + for at in lst_attributes: + with pytest.raises(NotImplementedError): + getattr(res, at) + + +def test_get_coordinate_circle(): + + u = np.random.rand(1, 100) + x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi)) + x = np.concatenate([x1, y1]).T + x_p = ot.utils.get_coordinate_circle(x) + + np.testing.assert_allclose(u[0], x_p) |