summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-06-18 16:40:06 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-06-18 16:40:06 +0200
commit897982718a5fd81a9a591d80a7d50839399fc088 (patch)
tree48189493c09cda25ee19dfd0b7ef59c2f6819ba7 /ot/unbalanced.py
parent50bc90058940645a13e2f3e41129bdc97161dc63 (diff)
fix func names + add more tests
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py79
1 files changed, 48 insertions, 31 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index a30fc18..97e2576 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -73,8 +73,9 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
- >>> ot.sinkhorn2(a, b, M, 1, 1)
- array([0.26894142])
+ >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
References
@@ -91,28 +92,36 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
See Also
--------
- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
- ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
+ ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
if method.lower() == 'sinkhorn':
def sink():
- return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- else:
- warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
def sink():
- return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
return sink()
-def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
+ numItermax=1000, stopThr=1e-9, verbose=False,
+ log=False, **kwargs):
u"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
@@ -173,8 +182,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
>>> a=[.5, .10]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
- >>> ot.sinkhorn2(a, b, M, 1., 1.)
- array([ 0.26894142])
+ >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.)
+ array([0.31912866])
@@ -199,23 +208,31 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
def sink():
- return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- else:
- warnings.warn('Unknown method using classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
def sink():
- return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
- b = b[None, :]
+ b = b[:, None]
return sink()
-def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
@@ -273,10 +290,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
>>> a=[.5, .15]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
- >>> ot.sinkhorn(a, b, M, 1., 1.)
- array([[ 0.36552929, 0.13447071],
- [ 0.13447071, 0.36552929]])
-
+ >>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
+ array([[0.52761554, 0.22392482],
+ [0.10286295, 0.32257641]])
References
----------
@@ -303,8 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
if len(b) == 0:
b = np.ones(n_b, dtype=np.float64) / n_b
- assert n_a == len(a) and n_b == len(b)
- if b.ndim > 1:
+ if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0
@@ -315,8 +330,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((n_a, n_hists)) / n_a
+ u = np.ones((n_a, 1)) / n_a
v = np.ones((n_b, n_hists)) / n_b
+ a = a.reshape(n_a, 1)
else:
u = np.ones(n_a) / n_a
v = np.ones(n_b) / n_b
@@ -332,6 +348,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
cpt = 0
err = 1.
+
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
@@ -473,7 +490,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %s' % cpt)
u = uprev
v = vprev
break