summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-09-09 14:55:04 +0200
committerRémi Flamary <remi.flamary@gmail.com>2019-09-09 14:55:04 +0200
commitb2a7afb848a78570d01f35f9b239be8838520edc (patch)
treefc243208d24f5488d5ce06298b2ebb39b76be9bb /ot
parentc698e0aa20d28e36d25f87082855a490283f3c88 (diff)
parentf251b4d080a577c2cee890ca43d8ec3658332021 (diff)
merge new unbalanced
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py34
-rw-r--r--ot/bregman.py604
-rw-r--r--ot/da.py121
-rw-r--r--ot/datasets.py32
-rw-r--r--ot/dr.py57
-rw-r--r--ot/gromov.py360
-rw-r--r--ot/optim.py32
-rw-r--r--ot/plot.py6
-rw-r--r--ot/stochastic.py199
-rw-r--r--ot/unbalanced.py808
-rw-r--r--ot/utils.py56
11 files changed, 1536 insertions, 773 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 571f235..89c7936 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,7 +1,7 @@
"""
-This is the main module of the POT toolbox. It provides easy access to
-a number of sub-modules and functions described below.
+This is the main module of the POT toolbox. It provides easy access to
+a number of sub-modules and functions described below.
.. note::
@@ -14,27 +14,27 @@ a number of sub-modules and functions described below.
- :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
- :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT
problems.
- - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
+ - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
Wasserstein problems.
- - :any:`ot.optim` contains generic solvers OT based optimization problems
+ - :any:`ot.optim` contains generic solvers OT based optimization problems
- :any:`ot.da` contains classes and function related to Monge mapping
estimation and Domain Adaptation (DA).
- :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers
- - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
+ - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
Discriminant Analysis.
- - :any:`ot.utils` contains utility functions such as distance computation and
- timing.
+ - :any:`ot.utils` contains utility functions such as distance computation and
+ timing.
- :any:`ot.datasets` contains toy dataset generation functions.
- :any:`ot.plot` contains visualization functions
- :any:`ot.stochastic` contains stochastic solvers for regularized OT.
- :any:`ot.unbalanced` contains solvers for regularized unbalanced OT.
.. warning::
- The list of automatically imported sub-modules is as follows:
+ The list of automatically imported sub-modules is as follows:
:py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
:py:mod:`ot.utils`, :py:mod:`ot.datasets`,
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
- :py:mod:`ot.stochastic`
+ :py:mod:`ot.stochastic`
The following sub-modules are not imported due to additional dependencies:
@@ -65,17 +65,17 @@ from . import unbalanced
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
-from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
+from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
from .da import sinkhorn_lpl1_mm
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "1.0.0"
+__version__ = "0.6.0"
-__all__ = ["emd", "emd2", 'emd_1d','emd2_1d', 'wasserstein_1d',
- "sinkhorn", "sinkhorn2", 'barycenter',
- 'sinkhorn_lpl1_mm',
- 'sinkhorn_unbalanced', "barycenter_unbalanced",
- 'dist', 'unif', 'tic', 'toc', 'toq',
- "utils", 'datasets', 'bregman', 'lp', 'gromov', 'da', 'optim']
+__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets',
+ 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d',
+ 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
+ 'sinkhorn_unbalanced', 'barycenter_unbalanced',
+ 'sinkhorn_unbalanced2']
diff --git a/ot/bregman.py b/ot/bregman.py
index 13dfa3b..2cd832b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,10 +7,12 @@ Bregman projections for regularized OT
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Hicham Janati <hicham.janati@inria.fr>
#
# License: MIT License
import numpy as np
+import warnings
from .utils import unif, dist
@@ -31,21 +33,21 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -64,7 +66,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -103,30 +105,23 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
"""
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'greenkhorn':
- def sink():
- return greenkhorn(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log)
+ return greenkhorn(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log)
elif method.lower() == 'sinkhorn_stabilized':
- def sink():
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- def sink():
- return sinkhorn_epsilon_scaling(
- a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_epsilon_scaling(a, b, M, reg,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- print('Warning : unknown method using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
-
- return sink()
+ raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -146,21 +141,21 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -176,11 +171,10 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
+ W : (n_hists) ndarray or float
+ Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -219,31 +213,23 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
"""
-
+ b = np.asarray(b, dtype=np.float64)
+ if len(b.shape) < 2:
+ b = b[:, None]
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- def sink():
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- def sink():
- return sinkhorn_epsilon_scaling(
- a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- print('Warning : unknown method using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp(a, b, M, reg, **kwargs)
-
- b = np.asarray(b, dtype=np.float64)
- if len(b.shape) < 2:
- b = b[:, None]
-
- return sink()
+ raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
@@ -263,21 +249,21 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -290,10 +276,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -333,25 +318,25 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
- Nini = len(a)
- Nfin = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
if len(b.shape) > 1:
- nbb = b.shape[1]
+ n_hists = b.shape[1]
else:
- nbb = 0
+ n_hists = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
- if nbb:
- u = np.ones((Nini, nbb)) / Nini
- v = np.ones((Nfin, nbb)) / Nfin
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
else:
- u = np.ones(Nini) / Nini
- v = np.ones(Nfin) / Nfin
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
# print(reg)
@@ -386,13 +371,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ if n_hists:
+ np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
+ err = np.linalg.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
@@ -406,7 +390,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['u'] = u
log['v'] = v
- if nbb: # return only loss
+ if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
@@ -421,7 +405,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
+ log=False):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -445,20 +430,20 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -469,10 +454,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -512,16 +496,16 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- n = a.shape[0]
- m = b.shape[0]
+ dim_a = a.shape[0]
+ dim_b = b.shape[0]
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty_like(M)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- u = np.full(n, 1. / n)
- v = np.full(m, 1. / m)
+ u = np.full(dim_a, 1. / dim_a)
+ v = np.full(dim_b, 1. / dim_b)
G = u[:, np.newaxis] * K * v[np.newaxis, :]
viol = G.sum(1) - a
@@ -575,7 +559,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
- warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
+ warmstart=None, verbose=False, print_period=20,
+ log=False, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization
@@ -591,9 +576,10 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in [2]_ but with the log stabilization
@@ -602,11 +588,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (dim_b,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -623,10 +609,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -638,7 +623,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.bregman.sinkhorn_stabilized(a,b,M,1)
+ >>> ot.bregman.sinkhorn_stabilized(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
@@ -671,14 +656,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# test if multiple target
if len(b.shape) > 1:
- nbb = b.shape[1]
+ n_hists = b.shape[1]
a = a[:, np.newaxis]
else:
- nbb = 0
+ n_hists = 0
# init data
- na = len(a)
- nb = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
cpt = 0
if log:
@@ -687,24 +672,25 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(na), np.zeros(nb)
+ alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
else:
alpha, beta = warmstart
- if nbb:
- u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
else:
- u, v = np.ones(na) / na, np.ones(nb) / nb
+ u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1))
- - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((dim_a, 1))
+ - beta.reshape((1, dim_b))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb)))
- / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
+ return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
+ / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b))))
# print(np.min(K))
@@ -724,26 +710,29 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if np.abs(u).max() > tau or np.abs(v).max() > tau:
- if nbb:
+ if n_hists:
alpha, beta = alpha + reg * \
np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
else:
alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
- if nbb:
- u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
+ if n_hists:
+ u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b
else:
- u, v = np.ones(na) / na, np.ones(nb) / nb
+ u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
K = get_K(alpha, beta)
if cpt % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ if n_hists:
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
+ err = np.linalg.norm((np.sum(transp, axis=0) - b))
if log:
log['err'].append(err)
@@ -769,33 +758,39 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
if log:
- log['logu'] = alpha / reg + np.log(u)
- log['logv'] = beta / reg + np.log(v)
+ if n_hists:
+ alpha = alpha[:, None]
+ beta = beta[:, None]
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ log['logu'] = logu
+ log['logv'] = logv
log['alpha'] = alpha + reg * np.log(u)
log['beta'] = beta + reg * np.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
- if nbb:
- res = np.zeros((nbb))
- for i in range(nbb):
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
- if nbb:
- res = np.zeros((nbb))
- for i in range(nbb):
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
return res
else:
return get_Gamma(alpha, beta, u, v)
-def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100,
- tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
+def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
+ numInnerItermax=100, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=10,
+ log=False, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
@@ -812,9 +807,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in [2]_ but with the log stabilization
@@ -823,19 +819,17 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (dim_b,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
thershold for max value in u or v for log scaling
- tau : float
- thershold for max value in u or v for log scaling
- warmstart : tible of vectors
+ warmstart : tuple of vectors
if given then sarting values for alpha an beta log scalings
numItermax : int, optional
Max number of iterations
@@ -850,10 +844,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -894,8 +887,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
- na = len(a)
- nb = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
# nrelative umerical precision with 64 bits
numItermin = 35
@@ -908,14 +901,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(na), np.zeros(nb)
+ alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
else:
alpha, beta = warmstart
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1))
- - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((dim_a, 1))
+ - beta.reshape((1, dim_b))) / reg)
# print(np.min(K))
def get_reg(n): # exponential decreasing
@@ -928,8 +921,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
regi = get_reg(cpt)
- G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, warmstart=(
- alpha, beta), verbose=False, print_period=20, tau=tau, log=True)
+ G, logi = sinkhorn_stabilized(a, b, M, regi,
+ numItermax=numInnerItermax, stopThr=1e-9,
+ warmstart=(alpha, beta), verbose=False,
+ print_period=20, tau=tau, log=True)
alpha = logi['alpha']
beta = logi['beta']
@@ -987,8 +982,8 @@ def projC(gamma, q):
return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
-def barycenter(A, M, reg, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
+def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False, **kwargs):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
@@ -1006,13 +1001,15 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Parameters
----------
- A : np.ndarray (d,n)
- n training distributions a_i of size d
- M : np.ndarray (d,d)
- loss matrix for OT
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
reg : float
- Regularization term >0
- weights : np.ndarray (n,)
+ Regularization term > 0
+ method : str (optional)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
+ weights : ndarray, shape (n_hists,)
Weights of each histogram a_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1026,7 +1023,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : (dim,) ndarray
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1037,7 +1034,69 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ """
+
+ if method.lower() == 'sinkhorn':
+ return barycenter_sinkhorn(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return barycenter_stabilized(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
@@ -1083,7 +1142,138 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
-def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
+def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ with stabilization.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+ """
+
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ u = A / (Kv + 1e-16)
+ Ktu = K.T.dot(u)
+ q = geometricBar(weights, Ktu)
+ Q = q[:, None]
+ v = Q / (Ktu + 1e-16)
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u * Kv - A).max()
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg`" +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-9, stabThr=1e-30, verbose=False,
+ log=False):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
@@ -1102,16 +1292,16 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Parameters
----------
- A : np.ndarray (n,w,h)
- n distributions (2D images) of size w x h
+ A : ndarray, shape (n_hists, width, height)
+ n distributions (2D images) of size width x height
reg : float
Regularization term >0
- weights : np.ndarray (n,)
+ weights : ndarray, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
@@ -1119,15 +1309,13 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
log : bool, optional
record log if True
-
Returns
-------
- a : (w,h) ndarray
+ a : ndarray, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
-
References
----------
@@ -1207,9 +1395,12 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
where :
- :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}` is an observed distribution, :math:`\mathbf{h}_0` is aprior on unmixing
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT data fitting
- - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix for regularization
+ - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
+ - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
+ - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
+ - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting
+ - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization
- :math:`\\alpha`weight data fitting and regularization
The optimization problem is solved suing the algorithm described in [4]
@@ -1217,16 +1408,16 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (d)
- observed distribution
- D : np.ndarray (d,n)
+ a : ndarray, shape (dim_a)
+ observed distribution (histogram, sums to 1)
+ D : ndarray, shape (dim_a, n_atoms)
dictionary matrix
- M : np.ndarray (d,d)
+ M : ndarray, shape (dim_a, dim_a)
loss matrix
- M0 : np.ndarray (n,n)
+ M0 : ndarray, shape (n_atoms, dim_prior)
loss matrix
- h0 : np.ndarray (n,)
- prior on h
+ h0 : ndarray, shape (n_atoms,)
+ prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
reg0 : float
@@ -1245,7 +1436,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ h : ndarray, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1301,7 +1492,9 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
return np.sum(K0, axis=1)
-def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9, verbose=False,
+ log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -1318,22 +1511,22 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
\gamma\geq 0
where :
- - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1347,7 +1540,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1355,11 +1548,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Examples
--------
- >>> n_s = 2
- >>> n_t = 2
+ >>> n_samples_a = 2
+ >>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01, 2.26989344e-05],
[2.26989344e-05, 4.99977301e-01]])
@@ -1408,22 +1601,22 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
\gamma\geq 0
where :
- - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1437,7 +1630,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1445,11 +1638,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Examples
--------
- >>> n_s = 2
- >>> n_t = 2
+ >>> n_samples_a = 2
+ >>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
array([4.53978687e-05])
@@ -1516,22 +1709,22 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
\gamma_b\geq 0
where :
- - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt))
+ - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b))
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1542,29 +1735,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
-
- >>> n_s = 2
- >>> n_t = 4
+ >>> n_samples_a = 2
+ >>> n_samples_b = 4
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
array([1.499...])
References
----------
-
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
'''
if log:
diff --git a/ot/da.py b/ot/da.py
index 83f9027..108a38d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -6,6 +6,7 @@ Domain adaptation with optimal transport
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+# Nathalie Gayraud <nat.gayraud@gmail.com>
#
# License: MIT License
@@ -16,6 +17,7 @@ from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, BaseEstimator
+from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
@@ -1793,3 +1795,122 @@ class MappingTransport(BaseEstimator):
transp_Xs = K.dot(self.mapping_)
return transp_Xs
+
+
+class UnbalancedSinkhornTransport(BaseTransport):
+
+ """Domain Adapatation unbalanced OT method based on sinkhorn algorithm
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_m : float, optional (default=0.1)
+ Mass regularization parameter
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ tol : float, optional (default=10e-9)
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ log : bool, optional (default=False)
+ Controls the logs of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+ """
+
+ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
+ max_iter=10, tol=1e-9, verbose=False, log=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10):
+
+ self.reg_e = reg_e
+ self.reg_m = reg_m
+ self.method = method
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.norm = norm
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = sinkhorn_unbalanced(
+ a=self.mu_s, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, reg_m=self.reg_m, method=self.method,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
diff --git a/ot/datasets.py b/ot/datasets.py
index e76e75d..ba0cfd9 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -17,7 +17,6 @@ def make_1D_gauss(n, m, s):
Parameters
----------
-
n : int
number of bins in the histogram
m : float
@@ -25,12 +24,10 @@ def make_1D_gauss(n, m, s):
s : float
standard deviaton of the gaussian distribution
-
Returns
-------
- h : np.array (n,)
- 1D histogram for a gaussian distribution
-
+ h : ndarray (n,)
+ 1D histogram for a gaussian distribution
"""
x = np.arange(n, dtype=np.float64)
h = np.exp(-(x - m)**2 / (2 * s**2))
@@ -44,16 +41,15 @@ def get_1D_gauss(n, m, sigma):
def make_2D_samples_gauss(n, m, sigma, random_state=None):
- """return n samples drawn from 2D gaussian N(m,sigma)
+ """Return n samples drawn from 2D gaussian N(m,sigma)
Parameters
----------
-
n : int
number of samples to make
- m : np.array (2,)
+ m : ndarray, shape (2,)
mean value of the gaussian distribution
- sigma : np.array (2,2)
+ sigma : ndarray, shape (2, 2)
covariance matrix of the gaussian distribution
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
@@ -63,9 +59,8 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None):
Returns
-------
- X : np.array (n,2)
- n samples drawn from N(m,sigma)
-
+ X : ndarray, shape (n, 2)
+ n samples drawn from N(m, sigma).
"""
generator = check_random_state(random_state)
@@ -86,11 +81,10 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
- """ dataset generation for classification problems
+ """Dataset generation for classification problems
Parameters
----------
-
dataset : str
type of classification problem (see code)
n : int
@@ -105,13 +99,11 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
Returns
-------
- X : np.array (n,d)
- n observation of size d
- y : np.array (n,)
- labels of the samples
-
+ X : ndarray, shape (n, d)
+ n observation of size d
+ y : ndarray, shape (n,)
+ labels of the samples.
"""
-
generator = check_random_state(random_state)
if dataset.lower() == '3gauss':
diff --git a/ot/dr.py b/ot/dr.py
index d2bf6e2..680dabf 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -49,30 +49,25 @@ def split_classes(X, y):
def fda(X, y, p=2, reg=1e-16):
- """
- Fisher Discriminant Analysis
-
+ """Fisher Discriminant Analysis
Parameters
----------
- X : numpy.ndarray (n,d)
- Training samples
- y : np.ndarray (n,)
- labels for training samples
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
p : int, optional
- size of dimensionnality reduction
+ Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (ridge regularization)
-
Returns
-------
- P : (d x p) ndarray
+ P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
- proj : fun
+ proj : callable
projection function including mean centering
-
-
"""
mx = np.mean(X)
@@ -130,37 +125,33 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
Parameters
----------
- X : numpy.ndarray (n,d)
- Training samples
- y : np.ndarray (n,)
- labels for training samples
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
p : int, optional
- size of dimensionnality reduction
+ Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (entropic regularization)
- solver : str, optional
- None for steepest decsent or 'TrustRegions' for trust regions algorithm
- else shoudl be a pymanopt.solvers
- P0 : numpy.ndarray (d,p)
- Initial starting point for projection
+ solver : None | str, optional
+ None for steepest descent or 'TrustRegions' for trust regions algorithm
+ else should be a pymanopt.solvers
+ P0 : ndarray, shape (d, p)
+ Initial starting point for projection.
verbose : int, optional
- Print information along iterations
-
-
+ Print information along iterations.
Returns
-------
- P : (d x p) ndarray
+ P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
- proj : fun
- projection function including mean centering
-
+ proj : callable
+ Projection function including mean centering.
References
----------
-
- .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
-
+ .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+ Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
mx = np.mean(X)
diff --git a/ot/gromov.py b/ot/gromov.py
index 3a7e24c..699ae4c 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -1,9 +1,6 @@
-
# -*- coding: utf-8 -*-
"""
Gromov-Wasserstein transport method
-
-
"""
# Author: Erwan Vautier <erwan.vautier@gmail.com>
@@ -22,7 +19,7 @@ from .optim import cg
def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
- """ Return loss matrices and tensors for Gromov-Wasserstein fast computation
+ """Return loss matrices and tensors for Gromov-Wasserstein fast computation
Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
function as the loss function of Gromow-Wasserstein discrepancy.
@@ -51,23 +48,21 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
+ Metric costfr matrix in the target space
T : ndarray, shape (ns, nt)
- Coupling between source and target spaces
+ Coupling between source and target spaces
p : ndarray, shape (ns,)
-
Returns
-------
-
constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
+ Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ h2(C) matrix in Eq. (6)
References
----------
@@ -114,25 +109,23 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
def tensor_product(constC, hC1, hC2, T):
- """ Return the tensor for Gromov-Wasserstein fast computation
+ """Return the tensor for Gromov-Wasserstein fast computation
The tensor is computed as described in Proposition 1 Eq. (6) in [12].
Parameters
----------
constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
+ Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
-
+ h2(C) matrix in Eq. (6)
Returns
-------
-
tens : ndarray, shape (ns, nt)
- \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+ \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
References
----------
@@ -148,26 +141,25 @@ def tensor_product(constC, hC1, hC2, T):
def gwloss(constC, hC1, hC2, T):
- """ Return the Loss for Gromov-Wasserstein
+ """Return the Loss for Gromov-Wasserstein
The loss is computed as described in Proposition 1 Eq. (6) in [12].
Parameters
----------
constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
+ Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ h2(C) matrix in Eq. (6)
T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ Current value of transport matrix T
Returns
-------
-
loss : float
- Gromov Wasserstein loss
+ Gromov Wasserstein loss
References
----------
@@ -183,24 +175,23 @@ def gwloss(constC, hC1, hC2, T):
def gwggrad(constC, hC1, hC2, T):
- """ Return the gradient for Gromov-Wasserstein
+ """Return the gradient for Gromov-Wasserstein
The gradient is computed as described in Proposition 2 in [12].
Parameters
----------
constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
+ Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ h2(C) matrix in Eq. (6)
T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ Current value of transport matrix T
Returns
-------
-
grad : ndarray, shape (ns, nt)
Gromov Wasserstein gradient
@@ -222,19 +213,19 @@ def update_square_loss(p, lambdas, T, Cs):
Parameters
----------
- p : ndarray, shape (N,)
- masses in the targeted barycenter
+ p : ndarray, shape (N,)
+ Masses in the targeted barycenter.
lambdas : list of float
- list of the S spaces' weights
- T : list of S np.ndarray(ns,N)
- the S Ts couplings calculated at each iteration
+ List of the S spaces' weights.
+ T : list of S np.ndarray of shape (ns,N)
+ The S Ts couplings calculated at each iteration.
Cs : list of S ndarray, shape(ns,ns)
- Metric cost matrices
+ Metric cost matrices.
Returns
----------
- C : ndarray, shape (nt,nt)
- updated C matrix
+ C : ndarray, shape (nt, nt)
+ Updated C matrix.
"""
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
for s in range(len(T))])
@@ -251,12 +242,12 @@ def update_kl_loss(p, lambdas, T, Cs):
Parameters
----------
p : ndarray, shape (N,)
- weights in the targeted barycenter
+ Weights in the targeted barycenter.
lambdas : list of the S spaces' weights
- T : list of S np.ndarray(ns,N)
- the S Ts couplings calculated at each iteration
+ T : list of S np.ndarray of shape (ns,N)
+ The S Ts couplings calculated at each iteration.
Cs : list of S ndarray, shape(ns,ns)
- Metric cost matrices
+ Metric cost matrices.
Returns
----------
@@ -290,14 +281,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
- distribution in the source space
- q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
max_iter : int, optional
@@ -317,10 +308,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
Returns
-------
T : ndarray, shape (ns, nt)
- coupling between the two spaces that minimizes :
+ Doupling between the two spaces that minimizes:
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
log : dict
- convergence information and loss
+ Convergence information and loss.
References
----------
@@ -374,18 +365,18 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Parameters
----------
- M : ndarray, shape (ns, nt)
- Metric cost matrix between features across domains
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
C1 : ndarray, shape (ns, ns)
- Metric cost matrix representative of the structure in the source space
+ Metric cost matrix representative of the structure in the source space
C2 : ndarray, shape (nt, nt)
- Metric cost matrix representative of the structure in the target space
- p : ndarray, shape (ns,)
- distribution in the source space
- q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string,optional
- loss function used for the solver
+ Metric cost matrix representative of the structure in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str, optional
+ Loss function used for the solver
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -402,11 +393,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
+ gamma : ndarray, shape (ns, nt)
+ Optimal transportation matrix for the given parameters.
log : dict
- log dictionary return only if log==True in parameters
-
+ Log dictionary return only if log==True in parameters.
References
----------
@@ -414,7 +404,6 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
and Courty Nicolas "Optimal Transport for structured data with
application on graphs", International Conference on Machine Learning
(ICML). 2019.
-
"""
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -457,18 +446,18 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
Parameters
----------
- M : ndarray, shape (ns, nt)
- Metric cost matrix between features across domains
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
C1 : ndarray, shape (ns, ns)
- Metric cost matrix respresentative of the structure in the source space
+ Metric cost matrix respresentative of the structure in the source space.
C2 : ndarray, shape (nt, nt)
- Metric cost matrix espresentative of the structure in the target space
+ Metric cost matrix espresentative of the structure in the target space.
p : ndarray, shape (ns,)
- distribution in the source space
+ Distribution in the source space.
q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string,optional
- loss function used for the solver
+ Distribution in the target space.
+ loss_fun : str, optional
+ Loss function used for the solver.
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -476,19 +465,19 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
+ Record log if True.
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the steps of the line-search is found via an armijo research.
+ Else closed form is used. If there is convergence issues use False.
**kwargs : dict
- parameters can be directly pased to the ot.optim.cg solver
+ Parameters can be directly pased to the ot.optim.cg solver.
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
+ gamma : ndarray, shape (ns, nt)
+ Optimal transportation matrix for the given parameters.
log : dict
- log dictionary return only if log==True in parameters
+ Log dictionary return only if log==True in parameters.
References
----------
@@ -537,16 +526,15 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric cost matrix in the target space
- p : ndarray, shape (ns,)
- distribution in the source space
+ Metric cost matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space.
q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string
+ Distribution in the target space.
+ loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
-
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -558,6 +546,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
+
Returns
-------
gw_dist : float
@@ -624,25 +613,25 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
+ Metric costfr matrix in the target space
p : ndarray, shape (ns,)
- distribution in the source space
+ Distribution in the source space
q : ndarray, shape (nt,)
- distribution in the target space
+ Distribution in the target space
loss_fun : string
- loss function used for the solver either 'square_loss' or 'kl_loss'
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
max_iter : int, optional
- Max number of iterations
+ Max number of iterations
tol : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
+ Record log if True.
Returns
-------
@@ -725,15 +714,15 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
+ Metric costfr matrix in the target space
p : ndarray, shape (ns,)
- distribution in the source space
+ Distribution in the source space
q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string
- loss function used for the solver either 'square_loss' or 'kl_loss'
+ Distribution in the target space
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
max_iter : int, optional
@@ -743,7 +732,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
+ Record log if True.
Returns
-------
@@ -757,7 +746,6 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
International Conference on Machine Learning (ICML). 2016.
"""
-
gw, logv = entropic_gromov_wasserstein(
C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
@@ -789,19 +777,21 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Parameters
----------
- N : Integer
- Size of the targeted barycenter
- Cs : list of S np.ndarray(ns,ns)
- Metric cost matrices
- ps : list of S np.ndarray(ns,)
- sample weights in the S spaces
- p : ndarray, shape(N,)
- weights in the targeted barycenter
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray of shape (ns,ns)
+ Metric cost matrices
+ ps : list of S np.ndarray of shape (ns,)
+ Sample weights in the S spaces
+ p : ndarray, shape(N,)
+ Weights in the targeted barycenter
lambdas : list of float
- list of the S spaces' weights
- loss_fun : tensor-matrix multiplication function based on specific loss function
- update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
- with the S Ts couplings calculated at each iteration
+ List of the S spaces' weights.
+ loss_fun : callable
+ Tensor-matrix multiplication function based on specific loss function.
+ update : callable
+ function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
epsilon : float
Regularization term >0
max_iter : int, optional
@@ -809,11 +799,11 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
tol : float, optional
Stop threshol on error (>0)
verbose : bool, optional
- Print information along iterations
+ Print information along iterations.
log : bool, optional
- record log if True
- init_C : bool, ndarray, shape(N,N)
- random initial value for the C matrix provided by user
+ Record log if True.
+ init_C : bool | ndarray, shape (N, N)
+ Random initial value for the C matrix provided by user.
Returns
-------
@@ -825,7 +815,6 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
-
"""
S = len(Cs)
@@ -835,6 +824,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
+ # XXX use random state
xalea = np.random.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
@@ -846,7 +836,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
error = []
- while(err > tol and cpt < max_iter):
+ while (err > tol) and (cpt < max_iter):
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
@@ -890,7 +880,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
.. math::
C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
-
Where :
- Cs : metric cost matrix
@@ -898,29 +887,29 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
Parameters
----------
- N : Integer
- Size of the targeted barycenter
- Cs : list of S np.ndarray(ns,ns)
- Metric cost matrices
- ps : list of S np.ndarray(ns,)
- sample weights in the S spaces
- p : ndarray, shape(N,)
- weights in the targeted barycenter
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray of shape (ns, ns)
+ Metric cost matrices
+ ps : list of S np.ndarray of shape (ns,)
+ Sample weights in the S spaces
+ p : ndarray, shape (N,)
+ Weights in the targeted barycenter
lambdas : list of float
- list of the S spaces' weights
+ List of the S spaces' weights
loss_fun : tensor-matrix multiplication function based on specific loss function
update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
with the S Ts couplings calculated at each iteration
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (>0).
verbose : bool, optional
- Print information along iterations
+ Print information along iterations.
log : bool, optional
- record log if True
- init_C : bool, ndarray, shape(N,N)
- random initial value for the C matrix provided by user
+ Record log if True.
+ init_C : bool | ndarray, shape(N,N)
+ Random initial value for the C matrix provided by user.
Returns
-------
@@ -934,7 +923,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
International Conference on Machine Learning (ICML). 2016.
"""
-
S = len(Cs)
Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
@@ -942,6 +930,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
+ # XXX : should use a random state and not use the global seed
xalea = np.random.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
@@ -987,8 +976,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
verbose=False, log=False, init_C=None, init_X=None):
- """
- Compute the fgw barycenter as presented eq (5) in [24].
+ """Compute the fgw barycenter as presented eq (5) in [24].
Parameters
----------
@@ -997,30 +985,32 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Ys: list of ndarray, each element has shape (ns,d)
Features of all samples
Cs : list of ndarray, each element has shape (ns,ns)
- Structure matrices of all samples
+ Structure matrices of all samples
ps : list of ndarray, each element has shape (ns,)
- masses of all samples
+ Masses of all samples.
lambdas : list of float
- list of the S spaces' weights
+ List of the S spaces' weights
alpha : float
- Alpha parameter for the fgw distance
- fixed_structure : bool
- Wether to fix the structure of the barycenter during the updates
- fixed_features : bool
- Wether to fix the feature of the barycenter during the updates
- init_C : ndarray, shape (N,N), optional
- initialization for the barycenters' structure matrix. If not set random init
- init_X : ndarray, shape (N,d), optional
- initialization for the barycenters' features. If not set random init
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Whether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Whether to fix the feature of the barycenter during the updates
+ init_C : ndarray, shape (N,N), optional
+ Initialization for the barycenters' structure matrix. If not set
+ a random init is used.
+ init_X : ndarray, shape (N,d), optional
+ Initialization for the barycenters' features. If not set a
+ random init is used.
Returns
-------
- X : ndarray, shape (N,d)
+ X : ndarray, shape (N, d)
Barycenters' features
- C : ndarray, shape (N,N)
+ C : ndarray, shape (N, N)
Barycenters' structure matrix
- log_: dictionary
- Only returned when log=True
+ log_: dict
+ Only returned when log=True. It contains the keys:
T : list of (N,ns) transport matrices
Ms : all distance matrices between the feature of the barycenter and the
other features dist(X,Ys) shape (N,ns)
@@ -1032,7 +1022,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
-
S = len(Cs)
d = Ys[0].shape[1] # dimension on the node features
if p is None:
@@ -1095,7 +1084,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
@@ -1114,6 +1104,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
print('{:5d}|{:8e}|'.format(cpt, err_feature))
cpt += 1
+
if log:
log_['T'] = T # from target to Ys
log_['p'] = p
@@ -1126,25 +1117,25 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
def update_sructure_matrix(p, lambdas, T, Cs):
- """
- Updates C according to the L2 Loss kernel with the S Ts couplings
- calculated at each iteration
+ """Updates C according to the L2 Loss kernel with the S Ts couplings.
+
+ It is calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
- masses in the targeted barycenter
+ p : ndarray, shape (N,)
+ Masses in the targeted barycenter.
lambdas : list of float
- list of the S spaces' weights
- T : list of S np.ndarray(ns,N)
- the S Ts couplings calculated at each iteration
- Cs : list of S ndarray, shape(ns,ns)
- Metric cost matrices
+ List of the S spaces' weights.
+ T : list of S ndarray of shape (ns, N)
+ The S Ts couplings calculated at each iteration.
+ Cs : list of S ndarray, shape (ns, ns)
+ Metric cost matrices.
Returns
- ----------
- C : ndarray, shape (nt,nt)
- updated C matrix
+ -------
+ C : ndarray, shape (nt, nt)
+ Updated C matrix.
"""
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
ppt = np.outer(p, p)
@@ -1153,24 +1144,26 @@ def update_sructure_matrix(p, lambdas, T, Cs):
def update_feature_matrix(lambdas, Ys, Ts, p):
- """
- Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24]
- calculated at each iteration
+ """Updates the feature with respect to the S Ts couplings.
+
+
+ See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
+ in [24] calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
- masses in the targeted barycenter
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
lambdas : list of float
- list of the S spaces' weights
+ List of the S spaces' weights
Ts : list of S np.ndarray(ns,N)
the S Ts couplings calculated at each iteration
Ys : list of S ndarray, shape(d,ns)
- The features
+ The features.
Returns
- ----------
- X : ndarray, shape (d,N)
+ -------
+ X : ndarray, shape (d, N)
References
----------
@@ -1179,9 +1172,8 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
+ p = np.array(1. / p).reshape(-1,)
- p = np.diag(np.array(1 / p).reshape(-1,))
-
- tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))])
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))])
return tmpsum
diff --git a/ot/optim.py b/ot/optim.py
index f94aceb..0abd9e9 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -26,14 +26,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
Parameters
----------
-
- f : function
+ f : callable
loss function
- xk : np.ndarray
+ xk : ndarray
initial position
- pk : np.ndarray
+ pk : ndarray
descent direction
- gfk : np.ndarray
+ gfk : ndarray
gradient of f at xk
old_fval : float
loss value at xk
@@ -161,15 +160,15 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
- G0 : np.ndarray (ns,nt), optional
+ G0 : ndarray, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
@@ -299,17 +298,17 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarrayv (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg1 : float
Entropic Regularization term >0
reg2 : float
Second Regularization term >0
- G0 : np.ndarray (ns,nt), optional
+ G0 : ndarray, shape (ns, nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
@@ -326,15 +325,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
-
References
----------
-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
@@ -422,13 +419,12 @@ def solve_1d_linesearch_quad(a, b, c):
Parameters
----------
a,b,c : float
- The coefficients of the quadratic function
+ The coefficients of the quadratic function
Returns
-------
x : float
The optimal value which leads to the minimal cost
-
"""
f0 = c
df0 = b
diff --git a/ot/plot.py b/ot/plot.py
index a409d4a..f403e98 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -26,11 +26,11 @@ def plot1D_mat(a, b, M, title=''):
Parameters
----------
- a : np.array, shape (na,)
+ a : ndarray, shape (na,)
Source distribution
- b : np.array, shape (nb,)
+ b : ndarray, shape (nb,)
Target distribution
- M : np.array, shape (na,nb)
+ M : ndarray, shape (na, nb)
Matrix to plot
"""
na, nb = M.shape
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 5754968..13ed9cc 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -38,22 +38,20 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
Parameters
----------
-
- b : np.ndarray(nt,)
- target measure
- M : np.ndarray(ns, nt)
- cost matrix
- reg : float nu
- Regularization term > 0
- v : np.ndarray(nt,)
- dual variable
- i : number int
- picked number i
+ b : ndarray, shape (nt,)
+ Target measure.
+ M : ndarray, shape (ns, nt)
+ Cost matrix.
+ reg : float
+ Regularization term > 0.
+ v : ndarray, shape (nt,)
+ Dual variable.
+ i : int
+ Picked number i.
Returns
-------
-
- coordinate gradient : np.ndarray(nt,)
+ coordinate gradient : ndarray, shape (nt,)
Examples
--------
@@ -78,14 +76,11 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
References
----------
-
[Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
-
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
'''
-
r = M[i, :] - beta
exp_beta = np.exp(-r / reg) * b
khi = exp_beta / (np.sum(exp_beta))
@@ -121,24 +116,23 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
Parameters
----------
- a : np.ndarray(ns,),
- source measure
- b : np.ndarray(nt,),
- target measure
- M : np.ndarray(ns, nt),
- cost matrix
- reg : float number,
+ a : ndarray, shape (ns,),
+ Source measure.
+ b : ndarray, shape (nt,),
+ Target measure.
+ M : ndarray, shape (ns, nt),
+ Cost matrix.
+ reg : float
Regularization term > 0
- numItermax : int number
- number of iteration
- lr : float number
- learning rate
+ numItermax : int
+ Number of iteration.
+ lr : float
+ Learning rate.
Returns
-------
-
- v : np.ndarray(nt,)
- dual variable
+ v : ndarray, shape (nt,)
+ Dual variable.
Examples
--------
@@ -213,23 +207,20 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
Parameters
----------
-
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
- numItermax : int number
- number of iteration
- lr : float number
- learning rate
-
+ numItermax : int
+ Number of iteration.
+ lr : float
+ Learning rate.
Returns
-------
-
- ave_v : np.ndarray(nt,)
+ ave_v : ndarray, shape (nt,)
dual variable
Examples
@@ -256,9 +247,9 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
----------
[Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
'''
if lr is None:
@@ -298,21 +289,19 @@ def c_transform_entropic(b, M, reg, beta):
Parameters
----------
-
- b : np.ndarray(nt,)
- target measure
- M : np.ndarray(ns, nt)
- cost matrix
+ b : ndarray, shape (nt,)
+ Target measure
+ M : ndarray, shape (ns, nt)
+ Cost matrix
reg : float
- regularization term > 0
- v : np.ndarray(nt,)
- dual variable
+ Regularization term > 0
+ v : ndarray, shape (nt,)
+ Dual variable.
Returns
-------
-
- u : np.ndarray(ns,)
- dual variable
+ u : ndarray, shape (ns,)
+ Dual variable.
Examples
--------
@@ -338,9 +327,9 @@ def c_transform_entropic(b, M, reg, beta):
----------
[Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
'''
n_source = np.shape(M)[0]
@@ -382,31 +371,30 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
Parameters
----------
- a : np.ndarray(ns,)
+ a : ndarray, shape (ns,)
source measure
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
methode : str
used method (SAG or ASGD)
- numItermax : int number
+ numItermax : int
number of iteration
- lr : float number
+ lr : float
learning rate
- n_source : int number
+ n_source : int
size of the source measure
- n_target : int number
+ n_target : int
size of the target measure
log : bool, optional
record log if True
Returns
-------
-
- pi : np.ndarray(ns, nt)
+ pi : ndarray, shape (ns, nt)
transportation matrix
log : dict
log dictionary return only if log==True in parameters
@@ -495,30 +483,28 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
Parameters
----------
-
- a : np.ndarray(ns,)
+ a : ndarray, shape (ns,)
source measure
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
- alpha : np.ndarray(ns,)
+ alpha : ndarray, shape (ns,)
dual variable
- beta : np.ndarray(nt,)
+ beta : ndarray, shape (nt,)
dual variable
- batch_size : int number
+ batch_size : int
size of the batch
- batch_alpha : np.ndarray(bs,)
+ batch_alpha : ndarray, shape (bs,)
batch of index of alpha
- batch_beta : np.ndarray(bs,)
+ batch_beta : ndarray, shape (bs,)
batch of index of beta
Returns
-------
-
- grad : np.ndarray(ns,)
+ grad : ndarray, shape (ns,)
partial grad F
Examples
@@ -591,28 +577,26 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
Parameters
----------
-
- a : np.ndarray(ns,)
+ a : ndarray, shape (ns,)
source measure
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
- batch_size : int number
+ batch_size : int
size of the batch
- numItermax : int number
+ numItermax : int
number of iteration
- lr : float number
+ lr : float
learning rate
Returns
-------
-
- alpha : np.ndarray(ns,)
+ alpha : ndarray, shape (ns,)
dual variable
- beta : np.ndarray(nt,)
+ beta : ndarray, shape (nt,)
dual variable
Examples
@@ -648,10 +632,9 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
References
----------
-
[Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ International Conference on Learning Representation (2018),
+ arXiv preprint arxiv:1711.02283.
'''
n_source = np.shape(M)[0]
@@ -696,28 +679,26 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
Parameters
----------
-
- a : np.ndarray(ns,)
+ a : ndarray, shape (ns,)
source measure
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
- batch_size : int number
+ batch_size : int
size of the batch
- numItermax : int number
+ numItermax : int
number of iteration
- lr : float number
+ lr : float
learning rate
log : bool, optional
record log if True
Returns
-------
-
- pi : np.ndarray(ns, nt)
+ pi : ndarray, shape (ns, nt)
transportation matrix
log : dict
log dictionary return only if log==True in parameters
@@ -757,8 +738,8 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
----------
[Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ International Conference on Learning Representation (2018),
+ arXiv preprint arxiv:1711.02283.
'''
opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 50ec03c..d516dfc 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -9,51 +9,56 @@ Regularized Unbalanced OT
from __future__ import division
import warnings
import numpy as np
+from scipy.special import logsumexp
+
# from .utils import unif, dist
-def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
- Solve the unbalanced entropic regularization optimal transport problem and return the loss
+ Solve the unbalanced entropic regularization optimal transport problem
+ and return the OT plan
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns, nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -62,10 +67,16 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -82,83 +93,96 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
- ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_stabilized_unbalanced:
+ Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
+ Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
+ raise ValueError("Unknown method '%s'." % method)
- return sink()
-
-def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
- numItermax=1000, stopThr=1e-9, verbose=False,
+def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
+ numItermax=1000, stopThr=1e-6, verbose=False,
log=False, **kwargs):
r"""
- Solve the entropic regularization unbalanced optimal transport problem and return the loss
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -171,10 +195,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
log : dict
- log dictionary return only if log==True in parameters
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -191,64 +215,70 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
-
- if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
- warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
- else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
-
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
b = b[:, None]
-
- return sink()
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method %s.' % method)
-def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
@@ -256,21 +286,21 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -279,11 +309,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -298,9 +333,13 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
References
----------
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
@@ -313,12 +352,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- n_a, n_b = M.shape
+ dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(n_a, dtype=np.float64) / n_a
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
if len(b) == 0:
- b = np.ones(n_b, dtype=np.float64) / n_b
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -331,21 +370,19 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((n_a, 1)) / n_a
- v = np.ones((n_b, n_hists)) / n_b
- a = a.reshape(n_a, 1)
+ u = np.ones((dim_a, 1)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
else:
- u = np.ones(n_a) / n_a
- v = np.ones(n_b) / n_b
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
- # print(reg)
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- # print(np.min(K))
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
@@ -364,15 +401,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %s' % cpt)
u = uprev
v = vprev
break
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -380,10 +418,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
+ cpt += 1
+
if log:
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
@@ -400,9 +439,224 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
return u[:, None] * K * v[None, :]
-def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
+def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False,
+ **kwargs):
+ r"""
+ Solve the entropic regularization unbalanced optimal transport
+ problem and return the loss
+
+ The function solves the following optimization problem using log-domain
+ stabilization as proposed in [10]:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.],[1., 0.]]
+ >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
+
+ References
+ ----------
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
+ if len(b) == 0:
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
+ else:
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim_a)
+ beta = np.zeros(dim_b)
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+
+ if n_hists:
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ if n_hists:
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ else:
+ alpha = alpha + reg * np.log(np.max(u))
+ beta = beta + reg * np.log(np.max(v))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ u = uprev
+ v = vprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
+ 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if n_hists:
+ logu = alpha[:, None] / reg + np.log(u)
+ logv = beta[:, None] / reg + np.log(v)
+ else:
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ if log:
+ log['logu'] = logu
+ log['logv'] = logv
+ if n_hists: # return only loss
+ res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
+ logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
+ res = np.exp(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+ ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ if log:
+ return ot_matrix, log
+ else:
+ return ot_matrix
+
+
+def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
The function solves the following optimization problem:
@@ -411,28 +665,35 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- - alpha is the marginal relaxation hyperparameter
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of
+ matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
Parameters
----------
- A : np.ndarray (d,n)
- n training distributions a_i of size d
- M : np.ndarray (d,d)
- loss matrix for OT
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m : float
Marginal relaxation term > 0
- weights : np.ndarray (n,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ tau : float
+ Stabilization threshold for log domain absorption.
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -441,7 +702,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : (dim,) ndarray
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -450,12 +711,165 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
+ G. (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
"""
- p, n_hists = A.shape
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ fi = reg_m / (reg_m + reg)
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ q = (Ktu ** (1 - fi)) * f_beta
+ q = q.dot(weights) ** (1 / (1 - fi))
+ Q = q[:, None]
+ v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(q - qprev).max() / max(abs(q).max(),
+ abs(qprev).max(), 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+
+ """
+ dim, n_hists = A.shape
if weights is None:
weights = np.ones(n_hists) / n_hists
else:
@@ -466,10 +880,10 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
K = np.exp(- M / reg)
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
- v = np.ones((p, n_hists)) / p
- u = np.ones((p, 1)) / p
+ v = np.ones((dim, n_hists)) / dim
+ u = np.ones((dim, 1)) / dim
cpt = 0
err = 1.
@@ -498,8 +912,11 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
- np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -511,8 +928,95 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
cpt += 1
if log:
log['niter'] = cpt
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
return q, log
else:
return q
+
+
+def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False, **kwargs):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+ """
+
+ if method.lower() == 'sinkhorn':
+ return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return barycenter_unbalanced(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
diff --git a/ot/utils.py b/ot/utils.py
index 5707d9b..b71458b 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -111,12 +111,12 @@ def dist(x1, x2=None, metric='sqeuclidean'):
Parameters
----------
- x1 : np.array (n1,d)
+ x1 : ndarray, shape (n1,d)
matrix with n1 samples of size d
- x2 : np.array (n2,d), optional
+ x2 : array, shape (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
- metric : str, fun, optional
- name of the metric to be computed (full list in the doc of scipy), If a string,
+ metric : str | callable, optional
+ Name of the metric to be computed (full list in the doc of scipy), If a string,
the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
@@ -138,26 +138,21 @@ def dist(x1, x2=None, metric='sqeuclidean'):
def dist0(n, method='lin_square'):
- """Compute standard cost matrices of size (n,n) for OT problems
+ """Compute standard cost matrices of size (n, n) for OT problems
Parameters
----------
-
n : int
- size of the cost matrix
+ Size of the cost matrix.
method : str, optional
Type of loss matrix chosen from:
* 'lin_square' : linear sampling between 0 and n-1, quadratic loss
-
Returns
-------
-
- M : np.array (n1,n2)
- distance matrix computed with given metric
-
-
+ M : ndarray, shape (n1,n2)
+ Distance matrix computed with given metric.
"""
res = 0
if method == 'lin_square':
@@ -169,33 +164,34 @@ def dist0(n, method='lin_square'):
def cost_normalization(C, norm=None):
""" Apply normalization to the loss matrix
-
Parameters
----------
- C : np.array (n1, n2)
+ C : ndarray, shape (n1, n2)
The cost matrix to normalize.
norm : str
- type of normalization from 'median','max','log','loglog'. Any other
- value do not normalize.
-
+ Type of normalization from 'median', 'max', 'log', 'loglog'. Any
+ other value do not normalize.
Returns
-------
-
- C : np.array (n1, n2)
+ C : ndarray, shape (n1, n2)
The input cost matrix normalized according to given norm.
-
"""
- if norm == "median":
+ if norm is None:
+ pass
+ elif norm == "median":
C /= float(np.median(C))
elif norm == "max":
C /= float(np.max(C))
elif norm == "log":
C = np.log(1 + C)
elif norm == "loglog":
- C = np.log(1 + np.log(1 + C))
-
+ C = np.log1p(np.log1p(C))
+ else:
+ raise ValueError('Norm %s is not a valid option.\n'
+ 'Valid options are:\n'
+ 'median, max, log, loglog' % norm)
return C
@@ -261,6 +257,7 @@ def check_params(**kwargs):
def check_random_state(seed):
"""Turn seed into a np.random.RandomState instance
+
Parameters
----------
seed : None | int | instance of RandomState
@@ -280,7 +277,6 @@ def check_random_state(seed):
class deprecated(object):
-
"""Decorator to mark a function or class as deprecated.
deprecated class from scikit-learn package
@@ -296,8 +292,8 @@ class deprecated(object):
Parameters
----------
- extra : string
- to be added to the deprecation messages
+ extra : str
+ To be added to the deprecation messages.
"""
# Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
@@ -378,9 +374,9 @@ def _is_deprecated(func):
class BaseEstimator(object):
-
"""Base class for most objects in POT
- adapted from sklearn BaseEstimator class
+
+ Code adapted from sklearn BaseEstimator class
Notes
-----
@@ -422,7 +418,7 @@ class BaseEstimator(object):
Parameters
----------
- deep : boolean, optional
+ deep : bool, optional
If True, will return the parameters for this estimator and
contained subobjects that are estimators.