diff options
author | Hicham Janati <hicham.janati@inria.fr> | 2019-06-18 16:40:06 +0200 |
---|---|---|
committer | Hicham Janati <hicham.janati@inria.fr> | 2019-06-18 16:40:06 +0200 |
commit | 897982718a5fd81a9a591d80a7d50839399fc088 (patch) | |
tree | 48189493c09cda25ee19dfd0b7ef59c2f6819ba7 | |
parent | 50bc90058940645a13e2f3e41129bdc97161dc63 (diff) |
fix func names + add more tests
-rw-r--r-- | ot/__init__.py | 2 | ||||
-rw-r--r-- | ot/bregman.py | 2 | ||||
-rw-r--r-- | ot/unbalanced.py | 79 | ||||
-rw-r--r-- | test/test_unbalanced.py | 79 |
4 files changed, 127 insertions, 35 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index 361be02..acb05e6 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -25,7 +25,7 @@ from . import unbalanced # OT functions from .lp import emd, emd2 from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced +from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced from .da import sinkhorn_lpl1_mm # utils functions diff --git a/ot/bregman.py b/ot/bregman.py index 321712b..09716e6 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -241,7 +241,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: - b = b.reshape((-1, 1)) + b = b[:, None] return sink() diff --git a/ot/unbalanced.py b/ot/unbalanced.py index a30fc18..97e2576 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -73,8 +73,9 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] - >>> ot.sinkhorn2(a, b, M, 1, 1) - array([0.26894142]) + >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) References @@ -91,28 +92,36 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, See Also -------- - ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] - ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) - else: - warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method. Using classic Sinkhorn Knopp') return sink() -def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', + numItermax=1000, stopThr=1e-9, verbose=False, + log=False, **kwargs): u""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -173,8 +182,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, >>> a=[.5, .10] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.sinkhorn2(a, b, M, 1., 1.) - array([ 0.26894142]) + >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) + array([0.31912866]) @@ -199,23 +208,31 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) - else: - warnings.warn('Unknown method using classic Sinkhorn Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method. Using classic Sinkhorn Knopp') b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: - b = b[None, :] + b = b[:, None] return sink() -def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -273,10 +290,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, >>> a=[.5, .15] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.sinkhorn(a, b, M, 1., 1.) - array([[ 0.36552929, 0.13447071], - [ 0.13447071, 0.36552929]]) - + >>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) + array([[0.52761554, 0.22392482], + [0.10286295, 0.32257641]]) References ---------- @@ -303,8 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, if len(b) == 0: b = np.ones(n_b, dtype=np.float64) / n_b - assert n_a == len(a) and n_b == len(b) - if b.ndim > 1: + if len(b.shape) > 1: n_hists = b.shape[1] else: n_hists = 0 @@ -315,8 +330,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((n_a, n_hists)) / n_a + u = np.ones((n_a, 1)) / n_a v = np.ones((n_b, n_hists)) / n_b + a = a.reshape(n_a, 1) else: u = np.ones(n_a) / n_a v = np.ones(n_b) / n_b @@ -332,6 +348,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, cpt = 0 err = 1. + while (err > stopThr and cpt < numItermax): uprev = u vprev = v @@ -473,7 +490,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %s' % cpt) u = uprev v = vprev break diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index b39e457..1395fe1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -29,7 +29,8 @@ def test_unbalanced_convergence(method): G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, stopThr=1e-10, method=method, log=True) - + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) # check fixed point equations fi = alpha / (alpha + epsilon) v_final = (b / K.T.dot(log["u"])) ** fi @@ -40,6 +41,44 @@ def test_unbalanced_convergence(method): np.testing.assert_allclose( v_final, log["v"], atol=1e-05) + # check if sinkhorn_unbalanced2 returns the correct loss + np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) + + +@pytest.mark.parametrize("method", ["sinkhorn"]) +def test_unbalanced_multiple_inputs(method): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + K = np.exp(- M / epsilon) + + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + alpha=alpha, + stopThr=1e-10, method=method, + log=True) + # check fixed point equations + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + + u_final = (a[:, None] / K.dot(log["v"])) ** fi + + np.testing.assert_allclose( + u_final, log["u"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["v"], atol=1e-05) + + assert len(loss) == b.shape[1] + def test_unbalanced_barycenter(): # test generalized sinkhorn for unbalanced OT barycenter @@ -59,7 +98,6 @@ def test_unbalanced_barycenter(): q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, stopThr=1e-10, log=True) - # check fixed point equations fi = alpha / (alpha + epsilon) v_final = (q[:, None] / K.T.dot(log["u"])) ** fi @@ -69,3 +107,40 @@ def test_unbalanced_barycenter(): u_final, log["u"], atol=1e-05) np.testing.assert_allclose( v_final, log["v"], atol=1e-05) + + +def test_implemented_methods(): + IMPLEMENTED_METHODS = ['sinkhorn'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized', + 'sinkhorn_epsilon_scaling'] + NOT_VALID_TOKENS = ['foo'] + # test generalized sinkhorn for unbalanced OT barycenter + n = 3 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + for method in IMPLEMENTED_METHODS: + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) + with pytest.warns(UserWarning, match='not implemented'): + for method in set(TO_BE_IMPLEMENTED_METHODS): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) + with pytest.raises(ValueError): + for method in set(NOT_VALID_TOKENS): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) |