summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_1d_solver.py127
-rw-r--r--test/test_backend.py52
-rw-r--r--test/test_bregman.py315
-rw-r--r--test/test_coot.py359
-rw-r--r--test/test_da.py79
-rw-r--r--test/test_gaussian.py98
-rw-r--r--test/test_gromov.py637
-rw-r--r--test/test_optim.py63
-rw-r--r--test/test_ot.py59
-rwxr-xr-xtest/test_partial.py124
-rw-r--r--test/test_sliced.py200
-rw-r--r--test/test_solvers.py133
-rw-r--r--test/test_unbalanced.py61
-rw-r--r--test/test_utils.py40
14 files changed, 2191 insertions, 156 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 20f307a..21abd1d 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -218,3 +218,130 @@ def test_emd1d_device_tf():
nx.assert_same_dtype_device(xb, emd)
nx.assert_same_dtype_device(xb, emd2)
assert nx.dtype_device(emd)[1].startswith("GPU")
+
+
+def test_wasserstein_1d_circle():
+ # test binary_search_circle and wasserstein_circle give similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None]))
+
+ wass1 = ot.emd2(w_u, w_v, M1)
+
+ wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1)
+ w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1)
+
+ M2 = M1**2
+ wass2 = ot.emd2(w_u, w_v, M2)
+ wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2)
+ w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass1, wass1_bsc)
+ np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2)
+ np.testing.assert_allclose(wass2, wass2_bsc)
+ np.testing.assert_allclose(wass2, w2_circle)
+
+
+@pytest.skip_backend("tf")
+def test_wasserstein1d_circle_devices(nx):
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 1, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
+
+ w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1)
+ w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2)
+
+ nx.assert_same_dtype_device(xb, w1)
+ nx.assert_same_dtype_device(xb, w2_bsc)
+
+
+def test_wasserstein_1d_unif_circle():
+ # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle
+ n = 20
+ m = 50000
+
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ # w_u = rng.uniform(0., 1., n)
+ # w_u = w_u / w_u.sum()
+
+ w_u = ot.utils.unif(n)
+ w_v = ot.utils.unif(m)
+
+ M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None]))
+ wass2 = ot.emd2(w_u, w_v, M1**2)
+
+ wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15)
+ wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3)
+ np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3)
+
+
+def test_wasserstein1d_unif_circle_devices(nx):
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 1, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp)
+
+ w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub)
+
+ nx.assert_same_dtype_device(xb, w2)
+
+
+def test_binary_search_circle_log():
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True)
+ optimal_thetas = log["optimal_theta"]
+
+ assert optimal_thetas.shape[0] == 1
+
+
+def test_wasserstein_circle_bad_shape():
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n, 2)
+ v = rng.rand(m, 1)
+
+ with pytest.raises(ValueError):
+ _ = ot.wasserstein_circle(u, v, p=2)
+
+ with pytest.raises(ValueError):
+ _ = ot.wasserstein_circle(u, v, p=1)
diff --git a/test/test_backend.py b/test/test_backend.py
index 311c075..fd9a761 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -275,11 +275,27 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.sqrtm(M)
with pytest.raises(NotImplementedError):
+ nx.kl_div(M, M)
+ with pytest.raises(NotImplementedError):
nx.isfinite(M)
with pytest.raises(NotImplementedError):
nx.array_equal(M, M)
with pytest.raises(NotImplementedError):
nx.is_floating_point(M)
+ with pytest.raises(NotImplementedError):
+ nx.tile(M, (10, 1))
+ with pytest.raises(NotImplementedError):
+ nx.floor(M)
+ with pytest.raises(NotImplementedError):
+ nx.prod(M)
+ with pytest.raises(NotImplementedError):
+ nx.sort2(M)
+ with pytest.raises(NotImplementedError):
+ nx.qr(M)
+ with pytest.raises(NotImplementedError):
+ nx.atan2(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.transpose(M)
def test_func_backends(nx):
@@ -592,11 +608,47 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("matrix square root")
+ A = nx.kl_div(nx.abs(Mb), nx.abs(Mb) + 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("Kullback-Leibler divergence")
+
A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0)
A = nx.isfinite(A)
lst_b.append(nx.to_numpy(A))
lst_name.append("isfinite")
+ A = nx.tile(vb, (10, 1))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("tile")
+
+ A = nx.floor(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("floor")
+
+ A = nx.prod(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("prod")
+
+ A, B = nx.sort2(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("sort2 sort")
+ lst_b.append(nx.to_numpy(B))
+ lst_name.append("sort2 argsort")
+
+ A, B = nx.qr(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("QR Q")
+ lst_b.append(nx.to_numpy(B))
+ lst_name.append("QR R")
+
+ A = nx.atan2(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("atan2")
+
+ A = nx.transpose(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("transpose")
+
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6c37984..f01bb14 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -3,9 +3,11 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
+import warnings
from itertools import product
import numpy as np
@@ -57,7 +59,12 @@ def test_convergence_warning(method):
with pytest.warns(UserWarning):
ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
with pytest.warns(UserWarning):
- ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
+ ot.sinkhorn2(a1, a2, M, 1, method=method,
+ stopThr=0, numItermax=1, warn=True)
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ ot.sinkhorn2(a1, a2, M, 1, method=method,
+ stopThr=0, numItermax=1, warn=False)
def test_not_implemented_method():
@@ -261,12 +268,16 @@ def test_sinkhorn_variants(nx):
ub, M_nx = nx.from_numpy(u, M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
Ges = nx.to_numpy(ot.sinkhorn(
ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
- G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -366,9 +377,12 @@ def test_sinkhorn_variants_multi_b(nx):
ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -394,9 +408,12 @@ def test_sinkhorn2_variants_multi_b(nx):
ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -414,12 +431,16 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
- G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
- Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
- Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn',
+ stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,)
- G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+ G_green, loggreen = ot.sinkhorn(
+ u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
@@ -441,7 +462,8 @@ def test_sinkhorn_variants_log_multib(verbose, warn):
M = ot.dist(x, x)
- G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn',
+ stopThr=1e-10, log=True)
Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True,
verbose=verbose, warn=warn)
Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True,
@@ -480,8 +502,73 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
else:
# wasserstein
- bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass_np = ot.bregman.barycenter(
+ A, M, reg, weights, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
+
+ ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+
+
+def test_free_support_sinkhorn_barycenter():
+ measures_locations = [
+ np.array([-1.]).reshape((1, 1)), # First dirac support
+ np.array([1.]).reshape((1, 1)) # Second dirac support
+ ]
+
+ measures_weights = [
+ np.array([1.]), # First dirac sample weights
+ np.array([1.]) # Second dirac sample weights
+ ]
+
+ # Barycenter initialization
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter
+ bar_locations = np.array([0.]).reshape((1, 1))
+
+ # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
+ # term to 1, but this should be, in general, fine-tuned to the problem.
+ X = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations, measures_weights, X_init, reg=1)
+
+ # Verifies if calculated barycenter matches ground-truth
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter_assymetric_cost(nx, method, verbose, warn):
+ n_bins = 20 # nb bins
+
+ # Gaussian distributions
+ A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+
+ # creating matrix A containing all distributions
+ A = A[:, None]
+
+ # assymetric loss matrix + normalization
+ rng = np.random.RandomState(42)
+ M = rng.randn(n_bins, n_bins) ** 2
+ M /= M.max()
+
+ A_nx, M_nx = nx.from_numpy(A, M)
+ reg = 1e-2
+
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter(A_nx, M_nx, reg, method=method)
+ else:
+ # wasserstein
+ bary_wass_np = ot.bregman.barycenter(
+ A, M, reg, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(
+ A_nx, M_nx, reg, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -516,17 +603,20 @@ def test_barycenter_debiased(nx, method, verbose, warn):
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
- ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
+ ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, weights, method=method)
else:
bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method,
verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass, _ = ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, weights_nx, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5)
- ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False)
+ ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, log=True, verbose=False)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
@@ -551,7 +641,8 @@ def test_convergence_warning_barycenters(method):
weights = np.array([1 - alpha, alpha])
reg = 0.1
with pytest.warns(UserWarning):
- ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1)
+ ot.bregman.barycenter_debiased(
+ A, M, reg, weights, method=method, numItermax=1)
with pytest.warns(UserWarning):
ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1)
with pytest.warns(UserWarning):
@@ -583,7 +674,8 @@ def test_barycenter_stabilization(nx):
# wasserstein
reg = 1e-2
- bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
+ bar_np = ot.bregman.barycenter(
+ A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
bar_stable = nx.to_numpy(ot.bregman.barycenter(
A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized",
stopThr=1e-8, verbose=True
@@ -618,8 +710,10 @@ def test_wasserstein_bary_2d(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
- bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+ A, reg, method=method, verbose=True, log=True)
+ bary_wass = nx.to_numpy(
+ ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
@@ -648,10 +742,13 @@ def test_wasserstein_bary_2d_debiased(nx, method):
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
- ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+ ot.bregman.convolutional_barycenter2d_debiased(
+ A_nx, reg, method=method)
else:
- bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+ A, reg, method=method, verbose=True, log=True)
+ bary_wass = nx.to_numpy(
+ ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
@@ -685,7 +782,8 @@ def test_unmix(nx):
# wasserstein
reg = 1e-3
um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
- um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
+ um = nx.to_numpy(ot.bregman.unmix(
+ ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
@@ -716,10 +814,12 @@ def test_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, metric='euclidean'))
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
+ loss_emp_sinkhorn = nx.to_numpy(
+ ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
# check constraints
@@ -752,23 +852,27 @@ def test_lazy_empirical_sinkhorn(nx):
ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ f, g = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
- f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g, log_es = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(
+ X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
@@ -800,22 +904,57 @@ def test_empirical_sinkhorn_divergence(nx):
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t)
+ ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(
+ a, b, X_s, X_t, M, M_s, M_t)
- emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
+ emp_sinkhorn_div = nx.to_numpy(
+ ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
sinkhorn_div = nx.to_numpy(
ot.sinkhorn2(ab, bb, M_nx, 1)
- 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
- 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
)
- emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
+ emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(
+ X_s, X_t, 1, a=a, b=b)
# check constraints
- np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
- ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
+ ot.bregman.empirical_sinkhorn_divergence(
+ X_sb, X_tb, 1, a=ab, b=bb, log=True)
+
+
+@pytest.mark.skipif(not torch, reason="No torch available")
+def test_empirical_sinkhorn_divergence_gradient():
+ # Test sinkhorn divergence
+ n = 10
+ a = np.linspace(1, n, n)
+ a /= a.sum()
+ b = ot.unif(n)
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1))
+
+ nx = ot.backend.TorchBackend()
+
+ ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t)
+
+ ab.requires_grad = True
+ bb.requires_grad = True
+ X_sb.requires_grad = True
+ X_tb.requires_grad = True
+
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(
+ X_sb, X_tb, 1, a=ab, b=bb)
+
+ emp_sinkhorn_div.backward()
+
+ assert ab.grad is not None
+ assert bb.grad is not None
+ assert X_sb.grad is not None
+ assert X_tb.grad is not None
def test_stabilized_vs_sinkhorn_multidim(nx):
@@ -837,7 +976,8 @@ def test_stabilized_vs_sinkhorn_multidim(nx):
ab, bb, M_nx = nx.from_numpy(a, b, M)
- G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
+ G_np, _ = ot.bregman.sinkhorn(
+ a, b, M, reg=epsilon, method="sinkhorn", log=True)
G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
@@ -902,7 +1042,8 @@ def test_screenkhorn(nx):
# sinkhorn
G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(
+ ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
# check marginals
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
@@ -919,3 +1060,93 @@ def test_convolutional_barycenter_non_square(nx):
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(b, b_np)
+
+
+def test_sinkhorn_warmstart():
+ m, n = 10, 20
+ a = ot.unif(m)
+ b = ot.unif(n)
+
+ Xs = np.arange(m) * 1.0
+ Xt = np.arange(n) * 1.0
+ M = ot.dist(Xs.reshape(-1, 1), Xt.reshape(-1, 1))
+
+ # Generate warmstart from dual vectors of unregularized OT
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ pi_unif, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn", log=True, warmstart=None)
+ # Optimal plan with warmstart generated from unregularized OT
+ pi_sh, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart)
+ pi_sh_log, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart)
+ pi_sh_stab, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart)
+ pi_sh_sc, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_stab, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_sc, atol=1e-05)
+
+
+def test_empirical_sinkhorn_warmstart():
+ m, n = 10, 20
+ Xs = np.arange(m).reshape(-1, 1) * 1.0
+ Xt = np.arange(n).reshape(-1, 1) * 1.0
+ M = ot.dist(Xs, Xt)
+
+ # Generate warmstart from dual vectors of unregularized OT
+ a = ot.unif(m)
+ b = ot.unif(n)
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ f, g, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None)
+ pi_unif = np.exp(f[:, None] + g[None, :] - M / reg)
+ # Optimal plan with warmstart generated from unregularized OT
+ f, g, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart)
+ pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg)
+ pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence_warmstart():
+ m, n = 10, 20
+ Xs = np.arange(m).reshape(-1, 1) * 1.0
+ Xt = np.arange(n).reshape(-1, 1) * 1.0
+ M = ot.dist(Xs, Xt)
+
+ # Generate warmstart from dual vectors of unregularized OT
+ a = ot.unif(m)
+ b = ot.unif(n)
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None)
+ # Optimal plan with warmstart generated from unregularized OT
+ sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart)
+ sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05)
+ np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05)
diff --git a/test/test_coot.py b/test/test_coot.py
new file mode 100644
index 0000000..ef68a9b
--- /dev/null
+++ b/test/test_coot.py
@@ -0,0 +1,359 @@
+"""Tests for module COOT on OT """
+
+# Author: Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+import pytest
+
+
+@pytest.mark.parametrize("verbose", [False, True, 1, 0])
+def test_coot(nx, verbose):
+ n_samples = 60 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ # test couplings
+ pi_sample, pi_feature = coot(X=xs, Y=xt, verbose=verbose)
+ pi_sample_nx, pi_feature_nx = coot(X=xs_nx, Y=xt_nx, verbose=verbose)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples
+ id_feature = np.eye(2, 2) / 2
+
+ np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04)
+ np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test COOT distance
+
+ coot_np = coot2(X=xs, Y=xt, verbose=verbose)
+ coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, verbose=verbose))
+ np.testing.assert_allclose(coot_np, 0, atol=1e-08)
+ np.testing.assert_allclose(coot_nx, 0, atol=1e-08)
+
+
+def test_entropic_coot(nx):
+ n_samples = 60 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ epsilon = (1, 1e-1)
+ nits_ot = 2000
+
+ # test couplings
+ pi_sample, pi_feature = coot(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot)
+ pi_sample_nx, pi_feature_nx = coot(
+ X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test entropic COOT distance
+
+ coot_np = coot2(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot)
+ coot_nx = nx.to_numpy(
+ coot2(X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot))
+
+ np.testing.assert_allclose(coot_np, coot_nx, atol=1e-08)
+
+
+def test_coot_with_linear_terms(nx):
+ n_samples = 60 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ M_samp = np.ones((n_samples, n_samples))
+ np.fill_diagonal(np.fliplr(M_samp), 0)
+ M_feat = np.ones((2, 2))
+ np.fill_diagonal(M_feat, 0)
+ M_samp_nx, M_feat_nx = nx.from_numpy(M_samp), nx.from_numpy(M_feat)
+
+ alpha = (1, 2)
+
+ # test couplings
+ anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples
+ id_feature = np.eye(2, 2) / 2
+
+ pi_sample, pi_feature = coot(
+ X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat)
+ pi_sample_nx, pi_feature_nx = coot(
+ X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04)
+ np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test COOT distance
+
+ coot_np = coot2(X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat)
+ coot_nx = nx.to_numpy(
+ coot2(X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx))
+ np.testing.assert_allclose(coot_np, 0, atol=1e-08)
+ np.testing.assert_allclose(coot_nx, 0, atol=1e-08)
+
+
+def test_coot_raise_value_error(nx):
+ n_samples = 80 # nb samples
+
+ mu_s = np.array([2, 4])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=43)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ # raise value error of method sinkhorn
+ def coot_sh(method_sinkhorn):
+ return coot(X=xs, Y=xt, method_sinkhorn=method_sinkhorn)
+
+ def coot_sh_nx(method_sinkhorn):
+ return coot(X=xs_nx, Y=xt_nx, method_sinkhorn=method_sinkhorn)
+
+ np.testing.assert_raises(ValueError, coot_sh, "not_sinkhorn")
+ np.testing.assert_raises(ValueError, coot_sh_nx, "not_sinkhorn")
+
+ # raise value error for epsilon
+ def coot_eps(epsilon):
+ return coot(X=xs, Y=xt, epsilon=epsilon)
+
+ def coot_eps_nx(epsilon):
+ return coot(X=xs_nx, Y=xt_nx, epsilon=epsilon)
+
+ np.testing.assert_raises(ValueError, coot_eps, (1, 2, 3))
+ np.testing.assert_raises(ValueError, coot_eps_nx, [1, 2, 3, 4])
+
+ # raise value error for alpha
+ def coot_alpha(alpha):
+ return coot(X=xs, Y=xt, alpha=alpha)
+
+ def coot_alpha_nx(alpha):
+ return coot(X=xs_nx, Y=xt_nx, alpha=alpha)
+
+ np.testing.assert_raises(ValueError, coot_alpha, [1])
+ np.testing.assert_raises(ValueError, coot_alpha_nx, np.arange(4))
+
+
+def test_coot_warmstart(nx):
+ n_samples = 80 # nb samples
+
+ mu_s = np.array([2, 3])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=125)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ # initialize warmstart
+ init_pi_sample = np.random.rand(n_samples, n_samples)
+ init_pi_sample = init_pi_sample / np.sum(init_pi_sample)
+ init_pi_sample_nx = nx.from_numpy(init_pi_sample)
+
+ init_pi_feature = np.random.rand(2, 2)
+ init_pi_feature /= init_pi_feature / np.sum(init_pi_feature)
+ init_pi_feature_nx = nx.from_numpy(init_pi_feature)
+
+ init_duals_sample = (np.random.random(n_samples) * 2 - 1,
+ np.random.random(n_samples) * 2 - 1)
+ init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]),
+ nx.from_numpy(init_duals_sample[1]))
+
+ init_duals_feature = (np.random.random(2) * 2 - 1,
+ np.random.random(2) * 2 - 1)
+ init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]),
+ nx.from_numpy(init_duals_feature[1]))
+
+ warmstart = {
+ "pi_sample": init_pi_sample,
+ "pi_feature": init_pi_feature,
+ "duals_sample": init_duals_sample,
+ "duals_feature": init_duals_feature
+ }
+
+ warmstart_nx = {
+ "pi_sample": init_pi_sample_nx,
+ "pi_feature": init_pi_feature_nx,
+ "duals_sample": init_duals_sample_nx,
+ "duals_feature": init_duals_feature_nx
+ }
+
+ # test couplings
+ pi_sample, pi_feature = coot(X=xs, Y=xt, warmstart=warmstart)
+ pi_sample_nx, pi_feature_nx = coot(
+ X=xs_nx, Y=xt_nx, warmstart=warmstart_nx)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples
+ id_feature = np.eye(2, 2) / 2
+
+ np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04)
+ np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test COOT distance
+ coot_np = coot2(X=xs, Y=xt, warmstart=warmstart)
+ coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, warmstart=warmstart_nx))
+ np.testing.assert_allclose(coot_np, 0, atol=1e-08)
+ np.testing.assert_allclose(coot_nx, 0, atol=1e-08)
+
+
+def test_coot_log(nx):
+ n_samples = 90 # nb samples
+
+ mu_s = np.array([-2, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=43)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ pi_sample, pi_feature, log = coot(X=xs, Y=xt, log=True)
+ pi_sample_nx, pi_feature_nx, log_nx = coot(X=xs_nx, Y=xt_nx, log=True)
+
+ duals_sample, duals_feature = log["duals_sample"], log["duals_feature"]
+ assert len(duals_sample) == 2
+ assert len(duals_feature) == 2
+ assert len(duals_sample[0]) == n_samples
+ assert len(duals_sample[1]) == n_samples
+ assert len(duals_feature[0]) == 2
+ assert len(duals_feature[1]) == 2
+
+ duals_sample_nx = log_nx["duals_sample"]
+ assert len(duals_sample_nx) == 2
+ assert len(duals_sample_nx[0]) == n_samples
+ assert len(duals_sample_nx[1]) == n_samples
+
+ duals_feature_nx = log_nx["duals_feature"]
+ assert len(duals_feature_nx) == 2
+ assert len(duals_feature_nx[0]) == 2
+ assert len(duals_feature_nx[1]) == 2
+
+ list_coot = log["distances"]
+ assert len(list_coot) >= 1
+
+ list_coot_nx = log_nx["distances"]
+ assert len(list_coot_nx) >= 1
+
+ # test with coot distance
+ coot_np, log = coot2(X=xs, Y=xt, log=True)
+ coot_nx, log_nx = coot2(X=xs_nx, Y=xt_nx, log=True)
+
+ duals_sample, duals_feature = log["duals_sample"], log["duals_feature"]
+ assert len(duals_sample) == 2
+ assert len(duals_feature) == 2
+ assert len(duals_sample[0]) == n_samples
+ assert len(duals_sample[1]) == n_samples
+ assert len(duals_feature[0]) == 2
+ assert len(duals_feature[1]) == 2
+
+ duals_sample_nx = log_nx["duals_sample"]
+ assert len(duals_sample_nx) == 2
+ assert len(duals_sample_nx[0]) == n_samples
+ assert len(duals_sample_nx[1]) == n_samples
+
+ duals_feature_nx = log_nx["duals_feature"]
+ assert len(duals_feature_nx) == 2
+ assert len(duals_feature_nx[0]) == 2
+ assert len(duals_feature_nx[1]) == 2
+
+ list_coot = log["distances"]
+ assert len(list_coot) >= 1
+
+ list_coot_nx = log_nx["distances"]
+ assert len(list_coot_nx) >= 1
diff --git a/test/test_da.py b/test/test_da.py
index 4bf0ab1..c5f08d6 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -44,12 +44,32 @@ def test_class_jax_tf():
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
+@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport])
+def test_log_da(nx, class_to_test):
+
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
+ otda = class_to_test(log=True)
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(otda, "log_")
+
+
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
def test_sinkhorn_lpl1_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
"""
ns = 50
- nt = 100
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -402,8 +422,8 @@ def test_emd_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -557,30 +577,9 @@ def test_mapping_transport_class_specific_seed(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
-def test_linear_mapping(nx):
- ns = 150
- nt = 200
-
- Xs, ys = make_data_classif('3gauss', ns)
- Xt, yt = make_data_classif('3gauss2', nt)
-
- Xsb, Xtb = nx.from_numpy(Xs, Xt)
-
- A, b = ot.da.OT_mapping_linear(Xsb, Xtb)
-
- Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
-
- Ct = np.cov(Xt.T)
- Cst = np.cov(Xst.T)
-
- np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-
-
-@pytest.skip_backend("jax")
-@pytest.skip_backend("tf")
def test_linear_mapping_class(nx):
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -609,9 +608,9 @@ def test_jcpot_transport_class(nx):
"""test_jcpot_transport
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
Xs1, ys1 = make_data_classif('3gauss', ns1)
Xs2, ys2 = make_data_classif('3gauss', ns2)
@@ -681,9 +680,9 @@ def test_jcpot_barycenter(nx):
"""test_jcpot_barycenter
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
sigma = 0.1
np.random.seed(1985)
@@ -713,8 +712,8 @@ def test_jcpot_barycenter(nx):
def test_emd_laplace_class(nx):
"""test_emd_laplace_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
diff --git a/test/test_gaussian.py b/test/test_gaussian.py
new file mode 100644
index 0000000..be7a806
--- /dev/null
+++ b/test/test_gaussian.py
@@ -0,0 +1,98 @@
+"""Tests for module gaussian"""
+
+# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
+# Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+import numpy as np
+
+import pytest
+
+import ot
+from ot.datasets import make_data_classif
+
+
+def test_bures_wasserstein_mapping(nx):
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ ms = np.mean(Xs, axis=0)[None, :]
+ mt = np.mean(Xt, axis=0)[None, :]
+ Cs = np.cov(Xs.T)
+ Ct = np.cov(Xt.T)
+
+ Xsb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, ms, mt, Cs, Ct)
+
+ A_log, b_log, log = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True)
+ A, b = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=False)
+
+ Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
+ Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log)
+
+ Cst = np.cov(Xst.T)
+ Cst_log = np.cov(Xst_log.T)
+
+ np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+@pytest.mark.parametrize("bias", [True, False])
+def test_empirical_bures_wasserstein_mapping(nx, bias):
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ if not bias:
+ ms = np.mean(Xs, axis=0)[None, :]
+ mt = np.mean(Xt, axis=0)[None, :]
+
+ Xs = Xs - ms
+ Xt = Xt - mt
+
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
+
+ A, b, log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=True, bias=bias)
+ A_log, b_log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=False, bias=bias)
+
+ Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
+ Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log)
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+ Cst_log = np.cov(Xst_log.T)
+
+ np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+def test_bures_wasserstein_distance(nx):
+ ms, mt = np.array([0]), np.array([10])
+ Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32)
+ msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct)
+ Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True)
+ Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False)
+
+ np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
+
+
+@pytest.mark.parametrize("bias", [True, False])
+def test_empirical_bures_wasserstein_distance(nx, bias):
+ ns = 400
+ nt = 400
+
+ rng = np.random.RandomState(10)
+ Xs = rng.normal(0, 1, ns)[:, np.newaxis]
+ Xt = rng.normal(10 * bias, 1, nt)[:, np.newaxis]
+
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
+ Wb_log, log = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=True, bias=bias)
+ Wb = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=False, bias=bias)
+
+ np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 9c85b92..80b6df4 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -3,7 +3,7 @@
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License
@@ -11,18 +11,15 @@ import numpy as np
import ot
from ot.backend import NumpyBackend
from ot.backend import torch, tf
-
import pytest
def test_gromov(nx):
n_samples = 50 # nb samples
-
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
-
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
xt = xs[::-1].copy()
p = ot.unif(n_samples)
@@ -38,7 +35,7 @@ def test_gromov(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True)
- Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True))
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=G0b, verbose=True))
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -51,13 +48,13 @@ def test_gromov(nx):
np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
- gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
- gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=True)
gwb = nx.to_numpy(gwb)
- gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False)
gw_valb = nx.to_numpy(
- ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
)
G = log['T']
@@ -77,6 +74,49 @@ def test_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_gromov(nx):
+ n_samples = 30 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+
+ G, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)
+
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)
+
+
def test_gromov_dtype_device(nx):
# setup
n_samples = 50 # nb samples
@@ -104,7 +144,7 @@ def test_gromov_dtype_device(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -130,7 +170,7 @@ def test_gromov_device_tf():
with tf.device("/CPU:0"):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -138,7 +178,7 @@ def test_gromov_device_tf():
# Check that everything happens on the GPU
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
assert nx.dtype_device(Gb)[1].startswith("GPU")
@@ -174,6 +214,7 @@ def test_gromov2_gradients():
C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)
+ # Test with exact line-search
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
val.backward()
@@ -184,6 +225,60 @@ def test_gromov2_gradients():
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
+ # Test with armijo line-search
+ q1.grad = None
+ p1.grad = None
+ C11.grad = None
+ C12.grad = None
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
+def test_gw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss')
+
+ def f(G):
+ return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None)
+
+ def df(G):
+ return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
@@ -199,19 +294,21 @@ def test_entropic_gromov(nx):
p = ot.unif(n_samples)
q = ot.unif(n_samples)
-
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
C1 /= C1.max()
C2 /= C2.max()
- C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
- G = ot.gromov.entropic_gromov_wasserstein(
- C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+ G, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-2, verbose=True, log=True)
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
- C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None,
+ epsilon=1e-2, verbose=True, log=False
))
# check constraints
@@ -222,9 +319,11 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
- C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
+ C1, C2, p, q, 'kl_loss', symmetric=True, G0=None,
+ max_iter=10, epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
- C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
+ C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-2, log=True)
gwb = nx.to_numpy(gwb)
G = log['T']
@@ -241,6 +340,45 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_entropic_gromov(nx):
+ n_samples = 10 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-1, verbose=True, log=False)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None,
+ epsilon=1e-1, verbose=True, log=False
+ ))
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', symmetric=False, G0=None,
+ max_iter=10, epsilon=1e-1, log=False)
+ gwb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-1, log=False)
+ gwb = nx.to_numpy(gwb)
+
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
@@ -539,8 +677,8 @@ def test_fgw(nx):
Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
- G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True)
- Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -555,8 +693,8 @@ def test_fgw(nx):
np.testing.assert_allclose(
Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
- fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True)
- fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True)
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True)
fgwb = nx.to_numpy(fgwb)
G = log['T']
@@ -573,6 +711,82 @@ def test_fgw(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_fgw(nx):
+ n_samples = 50 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ # add features
+ F1 = np.random.uniform(low=0., high=10, size=(n_samples, 1))
+ F2 = F1[idx, :]
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ M = ot.dist(F1, F2)
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ # Tests with kl-loss:
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+
def test_fgw2_gradients():
n_samples = 20 # nb samples
@@ -617,6 +831,57 @@ def test_fgw2_gradients():
assert M1.shape == M1.grad.shape
+def test_fgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ ys = np.random.randn(xs.shape[0], 2)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+ yt = np.random.randn(xt.shape[0], 2)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+ alpha = 0.5
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss')
+
+ def f(G):
+ return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None)
+
+ def df(G):
+ return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+ np.testing.assert_allclose(res_armijo, Gb, atol=1e-06)
+
+
def test_fgw_barycenter(nx):
np.random.seed(42)
@@ -1186,3 +1451,327 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
# > Compare results with/without backend
total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05)
+
+
+def test_semirelaxed_gromov(nx):
+ np.random.seed(0)
+ # unbalanced proportions
+ list_n = [30, 15]
+ nt = 2
+ ns = np.sum(list_n)
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns), dtype=np.float64)
+ C2 = np.array([[0.8, 0.05],
+ [0.05, 1.]], dtype=np.float64)
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ p = ot.unif(ns, type_as=C1)
+ q0 = ot.unif(C2.shape[0], type_as=C1)
+ G0 = p[:, None] * q0[None, :]
+ # asymmetric
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0)
+ Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=False, log=True, G0=None)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+def test_semirelaxed_gromov2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ # semirelaxed solvers do not support gradients over masses yet.
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
+def test_srgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
+ Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, 'square_loss', armijo=False, symmetric=True, G0=None, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss')
+ ones_pb = nx.ones(pb.shape[0], type_as=pb)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1b, C2b, ones_pb, 0., 1., nx=None)
+ # feed the precomputed local optimum Gb to semirelaxed_cg
+ res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+
+
+def test_semirelaxed_fgw(nx):
+ np.random.seed(0)
+ list_n = [16, 8]
+ nt = 2
+ ns = 24
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns))
+ C2 = np.array([[0.7, 0.05],
+ [0.05, 0.9]])
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ F1 = np.zeros((ns, 1))
+ F1[:16] = np.random.normal(loc=0., scale=0.01, size=(16, 1))
+ F1[16:] = np.random.normal(loc=1., scale=0.01, size=(8, 1))
+ F2 = np.zeros((2, 1))
+ F2[1, :] = 1.
+ M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T)
+
+ p = ot.unif(ns)
+ q0 = ot.unif(C2.shape[0])
+ G0 = p[:, None] * q0[None, :]
+
+ # asymmetric
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+ G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+def test_semirelaxed_fgw2_gradients():
+ n_samples = 20 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ M = ot.dist(xs, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ # semirelaxed solvers do not support gradients over masses yet.
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
+def test_srfgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ ys = np.random.randn(xs.shape[0], 2)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+ yt = np.random.randn(xt.shape[0], 2)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+ alpha = 0.5
+ Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss')
+ ones_pb = nx.ones(pb.shape[0], type_as=pb)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None)
+ # feed the precomputed local optimum Gb to semirelaxed_cg
+ res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
diff --git a/test/test_optim.py b/test/test_optim.py
index 67e9d13..a43e704 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -120,31 +120,33 @@ def test_generalized_conditional_gradient(nx):
Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(Gb, G, atol=1e-12)
np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05)
np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05)
def test_solve_1d_linesearch_quad_funct():
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1), 0.5)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5), 0)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5), 1)
def test_line_search_armijo(nx):
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
- old_fval = -123
+ old_fval = -123.
xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk)
+ def f(x):
+ return 1.
# Should not throw an exception and return 0. for alpha
alpha, a, b = ot.optim.line_search_armijo(
- lambda x: 1, xkb, pkb, gfkb, old_fval
+ f, xkb, pkb, gfkb, old_fval
)
alpha_np, anp, bnp = ot.optim.line_search_armijo(
- lambda x: 1, xk, pk, gfk, old_fval
+ f, xk, pk, gfk, old_fval
)
assert a == anp
assert b == bnp
@@ -182,3 +184,50 @@ def test_line_search_armijo(nx):
old_fval = f(xk)
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
np.testing.assert_allclose(alpha, 0.1)
+
+
+def test_line_search_armijo_dtype_device(nx):
+ for tp in nx.__type_list__:
+ def f(x):
+ return nx.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = np.array([[[-5.0, -5.0]]])
+ pk = np.array([[[100.0, 100.0]]])
+ xkb, pkb = nx.from_numpy(xk, pk, type_as=tp)
+ gfkb = grad(xkb)
+ old_fval = f(xkb)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+ np.testing.assert_allclose(alpha, 0.1)
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where the direction is not far enough
+ pk = np.array([[[3.0, 3.0]]])
+ pkb = nx.from_numpy(pk, type_as=tp)
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0)
+ alpha = nx.to_numpy(alpha)
+ np.testing.assert_allclose(alpha, 1.0)
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where checking the wrong direction
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, -pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+
+ assert alpha <= 0
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where the point is not a vector
+ xkb = nx.from_numpy(np.array(-5.0), type_as=tp)
+ pkb = nx.from_numpy(np.array(100), type_as=tp)
+ gfkb = grad(xkb)
+ old_fval = f(xkb)
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+
+ np.testing.assert_allclose(alpha, 0.1)
+ nx.assert_same_dtype_device(old_fval, fval)
diff --git a/test/test_ot.py b/test/test_ot.py
index bf832f6..f2338ac 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ # test emd and emd2 for mass mismatch
+ a = ot.utils.unif(n_samples)
b = a.copy()
a[0] = 100
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
+ np.testing.assert_raises(AssertionError, ot.emd2, a, b, M)
def test_emd_backends(nx):
@@ -201,6 +204,22 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
+def test_omp_emd2():
+ # test emd2 and emd2 with openmp for simple identity
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ w = ot.emd2(u, u, M)
+ w2 = ot.emd2(u, u, M, numThreads=2)
+
+ np.testing.assert_allclose(w, w2)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
@@ -320,6 +339,46 @@ def test_free_support_barycenter_backends(nx):
np.testing.assert_allclose(X, nx.to_numpy(X2))
+def test_generalised_free_support_barycenter():
+ np.random.seed(42) # random inits
+ X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0
+ a = [np.array([1.]), np.array([1.])]
+
+ P = [np.eye(2), np.eye(2)]
+
+ Y_init = np.array([-12., 7.]).reshape((1, 2))
+
+ # obvious barycenter location between two 2D diracs
+ Y_true = np.array([0., .0]).reshape((1, 2))
+
+ # test without log and no init
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+ # test with log and init
+ Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+
+def test_generalised_free_support_barycenter_backends(nx):
+ np.random.seed(42)
+ X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ a = [np.array([1.]), np.array([1.])]
+ P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ Y_init = np.array([-12.]).reshape((1, 1))
+
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init)
+
+ X2 = nx.from_numpy(*X)
+ a2 = nx.from_numpy(*a)
+ P2 = nx.from_numpy(*P)
+ Y_init2 = nx.from_numpy(Y_init)
+
+ Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2)
+
+ np.testing.assert_allclose(Y, nx.to_numpy(Y2))
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]
diff --git a/test/test_partial.py b/test/test_partial.py
index 97c611b..86f9e62 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -8,6 +8,7 @@
import numpy as np
import scipy as sp
import ot
+from ot.backend import to_numpy, torch
import pytest
@@ -79,8 +80,10 @@ def test_partial_wasserstein_lagrange():
w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True)
+ w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True)
-def test_partial_wasserstein():
+
+def test_partial_wasserstein(nx):
n_samples = 20 # nb samples (gaussian)
n_noise = 20 # nb of samples (noise)
@@ -100,25 +103,20 @@ def test_partial_wasserstein():
m = 0.5
+ p, q, M = nx.from_numpy(p, q, M)
+
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
- w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
- log=True, verbose=True)
+ w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True)
# check constraints
- np.testing.assert_equal(
- w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
- np.testing.assert_equal(
- w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
# check transported mass
- np.testing.assert_allclose(
- np.sum(w0), m, atol=1e-04)
- np.testing.assert_allclose(
- np.sum(w), m, atol=1e-04)
+ np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04)
+ np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04)
w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False)
@@ -128,15 +126,91 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
# check constraints
- np.testing.assert_equal(
- G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
- np.testing.assert_allclose(
- np.sum(G), m, atol=1e-04)
+ np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04)
+
+ empty_array = nx.zeros(0, type_as=M)
+ w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None)
+
+ # check constraints
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
+
+ # check transported mass
+ np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04)
+
+ w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None)
+
+ # check constraints
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+
+ # check transported mass
+ np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04)
+
+
+def test_partial_wasserstein2_gradient():
+ if torch:
+ n_samples = 40
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+
+ M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
+
+ p = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
+ q = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
+
+ m = 0.5
+
+ w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
+
+ w.backward()
+
+ assert M.grad is not None
+ assert M.grad.shape == M.shape
+
+
+def test_entropic_partial_wasserstein_gradient():
+ if torch:
+ n_samples = 40
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+
+ M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
+
+ p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
+ q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
+
+ m = 0.5
+ reg = 1
+
+ _, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True)
+
+ log['partial_w_dist'].backward()
+
+ assert M.grad is not None
+ assert p.grad is not None
+ assert q.grad is not None
+ assert M.grad.shape == M.shape
+ assert p.grad.shape == p.shape
+ assert q.grad.shape == q.shape
def test_partial_gromov_wasserstein():
+ rng = np.random.RandomState(seed=42)
n_samples = 20 # nb samples
n_noise = 10 # nb of samples (noise)
@@ -149,11 +223,11 @@ def test_partial_gromov_wasserstein():
mu_t = np.array([0, 0, 0])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
- xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, rng)
+ xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0)
P = sp.linalg.sqrtm(cov_t)
- xt = np.random.randn(n_samples, 3).dot(P) + mu_t
- xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+ xt = rng.randn(n_samples, 3).dot(P) + mu_t
+ xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0)
xt2 = xs[::-1].copy()
C1 = ot.dist(xs, xs)
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 08ab4fb..f54c799 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -110,6 +110,20 @@ def test_max_sliced_different_dists():
assert res > 0.
+def test_sliced_same_proj():
+ n_projections = 10
+ seed = 12
+ rng = np.random.RandomState(0)
+ X = rng.randn(8, 2)
+ Y = rng.randn(8, 2)
+ cost1, log1 = ot.sliced_wasserstein_distance(X, Y, seed=seed, n_projections=n_projections, log=True)
+ P = get_random_projections(X.shape[1], n_projections=10, seed=seed)
+ cost2, log2 = ot.sliced_wasserstein_distance(X, Y, projections=P, log=True)
+
+ assert np.allclose(log1['projections'], log2['projections'])
+ assert np.isclose(cost1, cost2)
+
+
def test_sliced_backend(nx):
n = 100
@@ -252,3 +266,189 @@ def test_max_sliced_backend_device_tf():
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
+
+
+def test_projections_stiefel():
+ rng = np.random.RandomState(0)
+
+ n_projs = 500
+ x = np.random.randn(100, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs,
+ seed=rng, log=True)
+
+ P = log["projections"]
+ P_T = np.transpose(P, [0, 2, 1])
+ np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)]))
+
+
+def test_sliced_sphere_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_sliced_sphere_bad_shapes():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+ assert res > 0.
+
+
+def test_1d_sliced_sphere_equals_emd():
+ n = 100
+ m = 120
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ a = rng.uniform(0, 1, n)
+ a /= a.sum()
+
+ y = rng.randn(m, 2)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi)
+ u = ot.utils.unif(m)
+
+ res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2)
+ expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2)
+
+ res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1)
+ expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1)
+
+ np.testing.assert_almost_equal(res ** 2, expected)
+ np.testing.assert_almost_equal(res1, expected1, decimal=3)
+
+
+@pytest.skip_backend("tf")
+def test_sliced_sphere_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(2 * n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, yb = nx.from_numpy(x, y, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere(xb, yb)
+
+ nx.assert_same_dtype_device(xb, valb)
+
+
+def test_sliced_sphere_unif_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng)
+
+
+def test_sliced_sphere_unif_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_unif_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere_unif(xb)
+
+ nx.assert_same_dtype_device(xb, valb)
diff --git a/test/test_solvers.py b/test/test_solvers.py
new file mode 100644
index 0000000..b792aca
--- /dev/null
+++ b/test/test_solvers.py
@@ -0,0 +1,133 @@
+"""Tests for ot solvers"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+import itertools
+import numpy as np
+import pytest
+
+import ot
+
+
+lst_reg = [None, 1.0]
+lst_reg_type = ['KL', 'entropy', 'L2']
+lst_unbalanced = [None, 0.9]
+lst_unbalanced_type = ['KL', 'L2', 'TV']
+
+
+def assert_allclose_sol(sol1, sol2):
+
+ lst_attr = ['value', 'value_linear', 'plan',
+ 'potential_a', 'potential_b', 'marginal_a', 'marginal_b']
+
+ nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend()
+ nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend()
+
+ for attr in lst_attr:
+ try:
+ np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr)))
+ except NotImplementedError:
+ pass
+
+
+def test_solve(nx):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ # solve unif weights
+ sol0 = ot.solve(M)
+
+ print(sol0)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b)
+
+ # check some attributes
+ sol.potentials
+ sol.sparse_plan
+ sol.marginals
+ sol.status
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b)
+
+ assert_allclose_sol(sol, solb)
+
+ # test not implemented unbalanced and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence')
+
+ # test not implemented reg_type and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence')
+
+
+@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
+def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ try:
+
+ # solve unif weights
+ sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol, solb)
+ except NotImplementedError:
+ pass
+
+
+def test_solve_not_implemented(nx):
+
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+
+ M = ot.dist(x, y)
+
+ # test not implemented and check raise
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='cryptic divergence')
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence')
+
+ # pairs of incompatible divergences
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv')
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 02b3fc3..b76d738 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -5,6 +5,7 @@
#
# License: MIT License
+import itertools
import numpy as np
import ot
import pytest
@@ -289,32 +290,55 @@ def test_implemented_methods(nx):
method=method)
+@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2']))
+def test_lbfgsb_unbalanced(nx, reg_div, regm_div):
+
+ np.random.seed(42)
+
+ xs = np.random.randn(5, 2)
+ xt = np.random.randn(6, 2)
+
+ M = ot.dist(xs, xt)
+
+ a = ot.unif(5)
+ b = ot.unif(6)
+
+ G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ np.testing.assert_allclose(G, nx.to_numpy(Gb))
+
+
def test_mm_convergence(nx):
n = 100
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
rng = np.random.RandomState(75)
y = rng.randn(n, 2)
- a = ot.utils.unif(n)
- b = ot.utils.unif(n)
+ a_np = ot.utils.unif(n)
+ b_np = ot.utils.unif(n)
M = ot.dist(x, y)
M = M / M.max()
reg_m = 100
- a, b, M = nx.from_numpy(a, b, M)
+ a, b, M = nx.from_numpy(a_np, b_np, M)
G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
- verbose=True, log=True)
- loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
- a, b, M, reg_m, div='kl', verbose=True))
+ verbose=False, log=True)
+ loss_kl = nx.to_numpy(
+ ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True)
+ )
G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
verbose=False, log=True)
# check if the marginals come close to the true ones when large reg
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03)
# check if mm_unbalanced2 returns the correct loss
np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
@@ -324,15 +348,16 @@ def test_mm_convergence(nx):
a_np, b_np = np.array([]), np.array([])
a, b = nx.from_numpy(a_np, b_np)
- G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
- G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
- np.testing.assert_allclose(G_kl_null, G_kl)
- np.testing.assert_allclose(G_l2_null, G_l2)
+ G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False)
+ G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False)
+ np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl))
+ np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2))
# test when G0 is given
G0 = ot.emd(a, b, M)
+ G0_np = nx.to_numpy(G0)
reg_m = 10000
- G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
- G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
- np.testing.assert_allclose(G0, G_kl, atol=1e-05)
- np.testing.assert_allclose(G0, G_l2, atol=1e-05)
+ G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False)
+ G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False)
+ np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05)
+ np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05)
diff --git a/test/test_utils.py b/test/test_utils.py
index 3cfd295..31b12ef 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -143,6 +143,7 @@ def test_dist():
for metric in metrics_w:
print(metric)
ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, )))
+ ot.dist(x, x, metric=metric, p=3, w=None) # check that not having any weight does not cause issues
for metric in metrics:
print(metric)
ot.dist(x, x, metric=metric, p=3)
@@ -300,3 +301,42 @@ def test_BaseEstimator():
cl.set_params(bibi=10)
assert cl.first == 'spam again'
+
+
+def test_OTResult():
+
+ res = ot.utils.OTResult()
+
+ # test print
+ print(res)
+
+ # tets get citation
+ print(res.citation)
+
+ lst_attributes = ['a_to_b',
+ 'b_to_a',
+ 'lazy_plan',
+ 'marginal_a',
+ 'marginal_b',
+ 'marginals',
+ 'plan',
+ 'potential_a',
+ 'potential_b',
+ 'potentials',
+ 'sparse_plan',
+ 'status',
+ 'value',
+ 'value_linear']
+ for at in lst_attributes:
+ with pytest.raises(NotImplementedError):
+ getattr(res, at)
+
+
+def test_get_coordinate_circle():
+
+ u = np.random.rand(1, 100)
+ x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi))
+ x = np.concatenate([x1, y1]).T
+ x_p = ot.utils.get_coordinate_circle(x)
+
+ np.testing.assert_allclose(u[0], x_p)