summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-09-09 14:55:04 +0200
committerRémi Flamary <remi.flamary@gmail.com>2019-09-09 14:55:04 +0200
commitb2a7afb848a78570d01f35f9b239be8838520edc (patch)
treefc243208d24f5488d5ce06298b2ebb39b76be9bb /test
parentc698e0aa20d28e36d25f87082855a490283f3c88 (diff)
parentf251b4d080a577c2cee890ca43d8ec3658332021 (diff)
merge new unbalanced
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py97
-rw-r--r--test/test_da.py65
-rw-r--r--test/test_unbalanced.py163
3 files changed, 274 insertions, 51 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 7f4972c..f70df10 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -7,6 +7,7 @@
import numpy as np
import ot
+import pytest
def test_sinkhorn():
@@ -71,13 +72,11 @@ def test_sinkhorn_variants():
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
- np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)
@@ -96,18 +95,17 @@ def test_sinkhorn_variants_log():
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
- Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True)
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
- np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)
-def test_bary():
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_barycenter(method):
n_bins = 100 # nb bins
@@ -126,14 +124,42 @@ def test_bary():
weights = np.array([1 - alpha, alpha])
# wasserstein
- reg = 1e-3
- bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ reg = 1e-2
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
np.testing.assert_allclose(1, np.sum(bary_wass))
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+def test_barycenter_stabilization():
+
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # wasserstein
+ reg = 1e-2
+ bar_stable = ot.bregman.barycenter(A, M, reg, weights,
+ method="sinkhorn_stabilized",
+ stopThr=1e-8)
+ bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
+ stopThr=1e-8)
+ np.testing.assert_allclose(bar, bar_stable)
+
+
def test_wasserstein_bary_2d():
size = 100 # size of a square image
@@ -254,3 +280,60 @@ def test_empirical_sinkhorn_divergence():
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
np.testing.assert_allclose(
emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+
+
+def test_stabilized_vs_sinkhorn_multidim():
+ # test if stable version matches sinkhorn
+ # for multidimensional inputs
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ log=True)
+ G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2)
+
+
+def test_implemented_methods():
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling']
+ NOT_VALID_TOKENS = ['foo']
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 3
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n)
+ A = rng.rand(n, 2)
+ M = ot.dist(x, x)
+ epsilon = 1.
+
+ for method in IMPLEMENTED_METHODS:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ with pytest.raises(ValueError):
+ for method in set(NOT_VALID_TOKENS):
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ for method in ONLY_1D_methods:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ with pytest.raises(ValueError):
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
diff --git a/test/test_da.py b/test/test_da.py
index f7f3a9d..2a5e50e 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -245,6 +245,71 @@ def test_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
+def test_unbalanced_sinkhorn_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.UnbalancedSinkhornTransport()
+
+ # test its computed
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornTransport()
+ otda_unsup.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check everything runs well with log=True
+ otda = ot.da.SinkhornTransport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
def test_emd_transport_class():
"""test_sinkhorn_transport
"""
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 1395fe1..ca1efba 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -7,9 +7,12 @@
import numpy as np
import ot
import pytest
+from ot.unbalanced import barycenter_unbalanced
+from scipy.special import logsumexp
-@pytest.mark.parametrize("method", ["sinkhorn"])
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_convergence(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -23,29 +26,35 @@ def test_unbalanced_convergence(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
- G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10, method=method,
+ G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method=method,
log=True)
- loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
- u_final = (a / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)
+ logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
# check if sinkhorn_unbalanced2 returns the correct loss
np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
-@pytest.mark.parametrize("method", ["sinkhorn"])
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_multiple_inputs(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -59,28 +68,59 @@ def test_unbalanced_multiple_inputs(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
- alpha=alpha,
- stopThr=1e-10, method=method,
+ reg_m=reg_m,
+ method=method,
log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
-
- u_final = (a[:, None] / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
assert len(loss) == b.shape[1]
-def test_unbalanced_barycenter():
+def test_stabilized_vs_sinkhorn():
+ # test if stable version matches sinkhorn
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ reg_m = 1.
+ G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ reg_m=reg_m,
+ log=True)
+ G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2, atol=1e-5)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_barycenter(method):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -92,27 +132,56 @@ def test_unbalanced_barycenter():
A = A * np.array([1, 2])[None, :]
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
- q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10,
- log=True)
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method, log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
- u_final = (A / K.dot(log["v"])) ** fi
+ fi = reg_m / (reg_m + epsilon)
+ logA = np.log(A + 1e-16)
+ logq = np.log(q + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logq - logKtu)
+ u_final = fi * (logA - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
+
+
+def test_barycenter_stabilized_vs_sinkhorn():
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ A = rng.rand(n, 2)
+
+ # make dists unbalanced
+ A = A * np.array([1, 4])[None, :]
+ M = ot.dist(x, x)
+ epsilon = 0.5
+ reg_m = 10
+
+ qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
+ reg_m=reg_m, log=True,
+ tau=100,
+ method="sinkhorn_stabilized",
+ )
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method="sinkhorn",
+ log=True)
+
+ np.testing.assert_allclose(
+ q, qstable, atol=1e-05)
def test_implemented_methods():
- IMPLEMENTED_METHODS = ['sinkhorn']
- TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized',
- 'sinkhorn_epsilon_scaling']
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
NOT_VALID_TOKENS = ['foo']
# test generalized sinkhorn for unbalanced OT barycenter
n = 3
@@ -123,24 +192,30 @@ def test_implemented_methods():
# make dists unbalanced
b = ot.utils.unif(n) * 1.5
-
+ A = rng.rand(n, 2)
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
+ reg_m = 1.
for method in IMPLEMENTED_METHODS:
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
with pytest.warns(UserWarning, match='not implemented'):
for method in set(TO_BE_IMPLEMENTED_METHODS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
with pytest.raises(ValueError):
for method in set(NOT_VALID_TOKENS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)