summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-06-18 16:40:06 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-06-18 16:40:06 +0200
commit897982718a5fd81a9a591d80a7d50839399fc088 (patch)
tree48189493c09cda25ee19dfd0b7ef59c2f6819ba7
parent50bc90058940645a13e2f3e41129bdc97161dc63 (diff)
fix func names + add more tests
-rw-r--r--ot/__init__.py2
-rw-r--r--ot/bregman.py2
-rw-r--r--ot/unbalanced.py79
-rw-r--r--test/test_unbalanced.py79
4 files changed, 127 insertions, 35 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 361be02..acb05e6 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -25,7 +25,7 @@ from . import unbalanced
# OT functions
from .lp import emd, emd2
from .bregman import sinkhorn, sinkhorn2, barycenter
-from .unbalanced import sinkhorn_unbalanced
+from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
from .da import sinkhorn_lpl1_mm
# utils functions
diff --git a/ot/bregman.py b/ot/bregman.py
index 321712b..09716e6 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -241,7 +241,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
- b = b.reshape((-1, 1))
+ b = b[:, None]
return sink()
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index a30fc18..97e2576 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -73,8 +73,9 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
- >>> ot.sinkhorn2(a, b, M, 1, 1)
- array([0.26894142])
+ >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
References
@@ -91,28 +92,36 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
See Also
--------
- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
- ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
+ ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
if method.lower() == 'sinkhorn':
def sink():
- return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- else:
- warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
def sink():
- return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
return sink()
-def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
+ numItermax=1000, stopThr=1e-9, verbose=False,
+ log=False, **kwargs):
u"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
@@ -173,8 +182,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
>>> a=[.5, .10]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
- >>> ot.sinkhorn2(a, b, M, 1., 1.)
- array([ 0.26894142])
+ >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.)
+ array([0.31912866])
@@ -199,23 +208,31 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
def sink():
- return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- else:
- warnings.warn('Unknown method using classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
def sink():
- return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
- b = b[None, :]
+ b = b[:, None]
return sink()
-def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
@@ -273,10 +290,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
>>> a=[.5, .15]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
- >>> ot.sinkhorn(a, b, M, 1., 1.)
- array([[ 0.36552929, 0.13447071],
- [ 0.13447071, 0.36552929]])
-
+ >>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
+ array([[0.52761554, 0.22392482],
+ [0.10286295, 0.32257641]])
References
----------
@@ -303,8 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
if len(b) == 0:
b = np.ones(n_b, dtype=np.float64) / n_b
- assert n_a == len(a) and n_b == len(b)
- if b.ndim > 1:
+ if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0
@@ -315,8 +330,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((n_a, n_hists)) / n_a
+ u = np.ones((n_a, 1)) / n_a
v = np.ones((n_b, n_hists)) / n_b
+ a = a.reshape(n_a, 1)
else:
u = np.ones(n_a) / n_a
v = np.ones(n_b) / n_b
@@ -332,6 +348,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
cpt = 0
err = 1.
+
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
@@ -473,7 +490,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %s' % cpt)
u = uprev
v = vprev
break
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index b39e457..1395fe1 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -29,7 +29,8 @@ def test_unbalanced_convergence(method):
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
stopThr=1e-10, method=method,
log=True)
-
+ loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ method=method)
# check fixed point equations
fi = alpha / (alpha + epsilon)
v_final = (b / K.T.dot(log["u"])) ** fi
@@ -40,6 +41,44 @@ def test_unbalanced_convergence(method):
np.testing.assert_allclose(
v_final, log["v"], atol=1e-05)
+ # check if sinkhorn_unbalanced2 returns the correct loss
+ np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn"])
+def test_unbalanced_multiple_inputs(method):
+ # test generalized sinkhorn for unbalanced OT
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = rng.rand(n, 2)
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ alpha = 1.
+ K = np.exp(- M / epsilon)
+
+ loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ alpha=alpha,
+ stopThr=1e-10, method=method,
+ log=True)
+ # check fixed point equations
+ fi = alpha / (alpha + epsilon)
+ v_final = (b / K.T.dot(log["u"])) ** fi
+
+ u_final = (a[:, None] / K.dot(log["v"])) ** fi
+
+ np.testing.assert_allclose(
+ u_final, log["u"], atol=1e-05)
+ np.testing.assert_allclose(
+ v_final, log["v"], atol=1e-05)
+
+ assert len(loss) == b.shape[1]
+
def test_unbalanced_barycenter():
# test generalized sinkhorn for unbalanced OT barycenter
@@ -59,7 +98,6 @@ def test_unbalanced_barycenter():
q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
stopThr=1e-10,
log=True)
-
# check fixed point equations
fi = alpha / (alpha + epsilon)
v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
@@ -69,3 +107,40 @@ def test_unbalanced_barycenter():
u_final, log["u"], atol=1e-05)
np.testing.assert_allclose(
v_final, log["v"], atol=1e-05)
+
+
+def test_implemented_methods():
+ IMPLEMENTED_METHODS = ['sinkhorn']
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized',
+ 'sinkhorn_epsilon_scaling']
+ NOT_VALID_TOKENS = ['foo']
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 3
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n) * 1.5
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ alpha = 1.
+ for method in IMPLEMENTED_METHODS:
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ method=method)
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ method=method)
+ with pytest.warns(UserWarning, match='not implemented'):
+ for method in set(TO_BE_IMPLEMENTED_METHODS):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ method=method)
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ method=method)
+ with pytest.raises(ValueError):
+ for method in set(NOT_VALID_TOKENS):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ method=method)
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ method=method)