summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py81
-rw-r--r--ot/bregman.py1787
-rw-r--r--ot/da.py1916
-rw-r--r--ot/datasets.py164
-rw-r--r--ot/dr.py200
-rw-r--r--ot/externals/__init__.py0
-rw-r--r--ot/externals/funcsigs.py817
-rw-r--r--ot/gpu/__init__.py41
-rw-r--r--ot/gpu/bregman.py194
-rw-r--r--ot/gpu/da.py144
-rw-r--r--ot/gpu/utils.py101
-rw-r--r--ot/gromov.py1179
-rw-r--r--ot/lp/EMD.h35
-rw-r--r--ot/lp/EMD_wrapper.cpp107
-rw-r--r--ot/lp/__init__.py618
-rw-r--r--ot/lp/core.h103
-rw-r--r--ot/lp/cvx.py147
-rw-r--r--ot/lp/emd_wrap.pyx187
-rw-r--r--ot/lp/full_bipartitegraph.h215
-rw-r--r--ot/lp/network_simplex_simple.h1553
-rw-r--r--ot/optim.py440
-rw-r--r--ot/plot.py91
-rw-r--r--ot/smooth.py600
-rw-r--r--ot/stochastic.py755
-rw-r--r--ot/unbalanced.py1022
-rw-r--r--ot/utils.py498
26 files changed, 12995 insertions, 0 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
new file mode 100644
index 0000000..89c7936
--- /dev/null
+++ b/ot/__init__.py
@@ -0,0 +1,81 @@
+"""
+
+This is the main module of the POT toolbox. It provides easy access to
+a number of sub-modules and functions described below.
+
+.. note::
+
+
+ Here is a list of the submodules and short description of what they contain.
+
+ - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
+ - :any:`ot.bregman` contains OT solvers for the entropic OT problems using
+ Bregman projections.
+ - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
+ - :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT
+ problems.
+ - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
+ Wasserstein problems.
+ - :any:`ot.optim` contains generic solvers OT based optimization problems
+ - :any:`ot.da` contains classes and function related to Monge mapping
+ estimation and Domain Adaptation (DA).
+ - :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers
+ - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
+ Discriminant Analysis.
+ - :any:`ot.utils` contains utility functions such as distance computation and
+ timing.
+ - :any:`ot.datasets` contains toy dataset generation functions.
+ - :any:`ot.plot` contains visualization functions
+ - :any:`ot.stochastic` contains stochastic solvers for regularized OT.
+ - :any:`ot.unbalanced` contains solvers for regularized unbalanced OT.
+
+.. warning::
+ The list of automatically imported sub-modules is as follows:
+ :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
+ :py:mod:`ot.utils`, :py:mod:`ot.datasets`,
+ :py:mod:`ot.gromov`, :py:mod:`ot.smooth`
+ :py:mod:`ot.stochastic`
+
+ The following sub-modules are not imported due to additional dependencies:
+
+ - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
+ - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU.
+ - :any:`ot.plot` : depends on :code:`matplotlib`
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+# All submodules and packages
+from . import lp
+from . import bregman
+from . import optim
+from . import utils
+from . import datasets
+from . import da
+from . import gromov
+from . import smooth
+from . import stochastic
+from . import unbalanced
+
+# OT functions
+from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
+from .bregman import sinkhorn, sinkhorn2, barycenter
+from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
+from .da import sinkhorn_lpl1_mm
+
+# utils functions
+from .utils import dist, unif, tic, toc, toq
+
+__version__ = "0.6.0"
+
+__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets',
+ 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d',
+ 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
+ 'sinkhorn_unbalanced', 'barycenter_unbalanced',
+ 'sinkhorn_unbalanced2']
diff --git a/ot/bregman.py b/ot/bregman.py
new file mode 100644
index 0000000..2cd832b
--- /dev/null
+++ b/ot/bregman.py
@@ -0,0 +1,1787 @@
+# -*- coding: utf-8 -*-
+"""
+Bregman projections for regularized OT
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+# Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import warnings
+from .utils import unif, dist
+
+
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the entropic regularization optimal transport problem and return the OT matrix
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (histograms, both sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (dim_a,)
+ samples weights in the source domain
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : ndarray, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ method : str
+ method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : ndarray, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+
+ """
+
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'greenkhorn':
+ return greenkhorn(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() == 'sinkhorn_epsilon_scaling':
+ return sinkhorn_epsilon_scaling(a, b, M, reg,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the entropic regularization optimal transport problem and return the loss
+
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (histograms, both sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (dim_a,)
+ samples weights in the source domain
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : ndarray, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ W : (n_hists) ndarray or float
+ Optimal transportation loss for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn2(a, b, M, 1)
+ array([0.26894142])
+
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+ [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
+ ot.bregman.greenkhorn : Greenkhorn [21]
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+
+ """
+ b = np.asarray(b, dtype=np.float64)
+ if len(b.shape) < 2:
+ b = b[:, None]
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_epsilon_scaling':
+ return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the entropic regularization optimal transport problem and return the OT matrix
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (histograms, both sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (dim_a,)
+ samples weights in the source domain
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : ndarray, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gamma : ndarray, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ # init data
+ dim_a = len(a)
+ dim_b = len(b)
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ else:
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
+
+ # print(reg)
+
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ # print(np.min(K))
+ tmp2 = np.empty(b.shape, dtype=M.dtype)
+
+ Kp = (1 / a).reshape(-1, 1) * K
+ cpt = 0
+ err = 1
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ KtransposeU = np.dot(K.T, u)
+ v = np.divide(b, KtransposeU)
+ u = 1. / np.dot(Kp, v)
+
+ if (np.any(KtransposeU == 0)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ print('Warning: numerical errors at iteration', cpt)
+ u = uprev
+ v = vprev
+ break
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ if n_hists:
+ np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
+ else:
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ np.einsum('i,ij,j->j', u, K, v, out=tmp2)
+ err = np.linalg.norm(tmp2 - b) # violation of marginal
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+ if log:
+ log['u'] = u
+ log['v'] = v
+
+ if n_hists: # return only loss
+ res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+
+ if log:
+ return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log
+ else:
+ return u.reshape((-1, 1)) * K * v.reshape((1, -1))
+
+
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
+ log=False):
+ r"""
+ Solve the entropic regularization optimal transport problem and return the OT matrix
+
+ The algorithm used is based on the paper
+
+ Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
+ by Jason Altschuler, Jonathan Weed, Philippe Rigollet
+ appeared at NIPS 2017
+
+ which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (histograms, both sum to 1)
+
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (dim_a,)
+ samples weights in the source domain
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : ndarray, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gamma : ndarray, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.bregman.greenkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ dim_a = a.shape[0]
+ dim_b = b.shape[0]
+
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty_like(M)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ u = np.full(dim_a, 1. / dim_a)
+ v = np.full(dim_b, 1. / dim_b)
+ G = u[:, np.newaxis] * K * v[np.newaxis, :]
+
+ viol = G.sum(1) - a
+ viol_2 = G.sum(0) - b
+ stopThr_val = 1
+
+ if log:
+ log = dict()
+ log['u'] = u
+ log['v'] = v
+
+ for i in range(numItermax):
+ i_1 = np.argmax(np.abs(viol))
+ i_2 = np.argmax(np.abs(viol_2))
+ m_viol_1 = np.abs(viol[i_1])
+ m_viol_2 = np.abs(viol_2[i_2])
+ stopThr_val = np.maximum(m_viol_1, m_viol_2)
+
+ if m_viol_1 > m_viol_2:
+ old_u = u[i_1]
+ u[i_1] = a[i_1] / (K[i_1, :].dot(v))
+ G[i_1, :] = u[i_1] * K[i_1, :] * v
+
+ viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
+ viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v)
+
+ else:
+ old_v = v[i_2]
+ v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
+ G[:, i_2] = u * K[:, i_2] * v[i_2]
+ #aviol = (G@one_m - a)
+ #aviol_2 = (G.T@one_n - b)
+ viol += (-old_v + v[i_2]) * K[:, i_2] * u
+ viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
+
+ #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
+
+ if stopThr_val <= stopThr:
+ break
+ else:
+ print('Warning: Algorithm did not converge')
+
+ if log:
+ log['u'] = u
+ log['v'] = v
+
+ if log:
+ return G, log
+ else:
+ return G
+
+
+def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=20,
+ log=False, **kwargs):
+ r"""
+ Solve the entropic regularization OT problem with log stabilization
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (histograms, both sum to 1)
+
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in [2]_ but with the log stabilization
+ proposed in [10]_ an defined in [9]_ (Algo 3.1) .
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (dim_a,)
+ samples weights in the source domain
+ b : ndarray, shape (dim_b,)
+ samples in the target domain
+ M : ndarray, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ tau : float
+ thershold for max value in u or v for log scaling
+ warmstart : tible of vectors
+ if given then sarting values for alpha an beta log scalings
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gamma : ndarray, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.bregman.sinkhorn_stabilized(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ # test if multiple target
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ a = a[:, np.newaxis]
+ else:
+ n_hists = 0
+
+ # init data
+ dim_a = len(a)
+ dim_b = len(b)
+
+ cpt = 0
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if warmstart is None:
+ alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ else:
+ alpha, beta = warmstart
+
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ else:
+ u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+
+ def get_K(alpha, beta):
+ """log space computation"""
+ return np.exp(-(M - alpha.reshape((dim_a, 1))
+ - beta.reshape((1, dim_b))) / reg)
+
+ def get_Gamma(alpha, beta, u, v):
+ """log space gamma computation"""
+ return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
+ / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b))))
+
+ # print(np.min(K))
+
+ K = get_K(alpha, beta)
+ transp = K
+ loop = 1
+ cpt = 0
+ err = 1
+ while loop:
+
+ uprev = u
+ vprev = v
+
+ # sinkhorn update
+ v = b / (np.dot(K.T, u) + 1e-16)
+ u = a / (np.dot(K, v) + 1e-16)
+
+ # remove numerical problems and store them in K
+ if np.abs(u).max() > tau or np.abs(v).max() > tau:
+ if n_hists:
+ alpha, beta = alpha + reg * \
+ np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
+ else:
+ alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
+ if n_hists:
+ u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b
+ else:
+ u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+ K = get_K(alpha, beta)
+
+ if cpt % print_period == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ if n_hists:
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
+ else:
+ transp = get_Gamma(alpha, beta, u, v)
+ err = np.linalg.norm((np.sum(transp, axis=0) - b))
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % (print_period * 20) == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ if err <= stopThr:
+ loop = False
+
+ if cpt >= numItermax:
+ loop = False
+
+ if np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ print('Warning: numerical errors at iteration', cpt)
+ u = uprev
+ v = vprev
+ break
+
+ cpt = cpt + 1
+
+ if log:
+ if n_hists:
+ alpha = alpha[:, None]
+ beta = beta[:, None]
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ log['logu'] = logu
+ log['logv'] = logv
+ log['alpha'] = alpha + reg * np.log(u)
+ log['beta'] = beta + reg * np.log(v)
+ log['warmstart'] = (log['alpha'], log['beta'])
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
+ res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ return res, log
+
+ else:
+ return get_Gamma(alpha, beta, u, v), log
+ else:
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
+ res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ return res
+ else:
+ return get_Gamma(alpha, beta, u, v)
+
+
+def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
+ numInnerItermax=100, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=10,
+ log=False, **kwargs):
+ r"""
+ Solve the entropic regularization optimal transport problem with log
+ stabilization and epsilon scaling.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (histograms, both sum to 1)
+
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in [2]_ but with the log stabilization
+ proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (dim_a,)
+ samples weights in the source domain
+ b : ndarray, shape (dim_b,)
+ samples in the target domain
+ M : ndarray, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ tau : float
+ thershold for max value in u or v for log scaling
+ warmstart : tuple of vectors
+ if given then sarting values for alpha an beta log scalings
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterationsin the inner slog stabilized sinkhorn
+ epsilon0 : int, optional
+ first epsilon regularization value (then exponential decrease to reg)
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gamma : ndarray, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ # init data
+ dim_a = len(a)
+ dim_b = len(b)
+
+ # nrelative umerical precision with 64 bits
+ numItermin = 35
+ numItermax = max(numItermin, numItermax) # ensure that last velue is exact
+
+ cpt = 0
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if warmstart is None:
+ alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ else:
+ alpha, beta = warmstart
+
+ def get_K(alpha, beta):
+ """log space computation"""
+ return np.exp(-(M - alpha.reshape((dim_a, 1))
+ - beta.reshape((1, dim_b))) / reg)
+
+ # print(np.min(K))
+ def get_reg(n): # exponential decreasing
+ return (epsilon0 - reg) * np.exp(-n) + reg
+
+ loop = 1
+ cpt = 0
+ err = 1
+ while loop:
+
+ regi = get_reg(cpt)
+
+ G, logi = sinkhorn_stabilized(a, b, M, regi,
+ numItermax=numInnerItermax, stopThr=1e-9,
+ warmstart=(alpha, beta), verbose=False,
+ print_period=20, tau=tau, log=True)
+
+ alpha = logi['alpha']
+ beta = logi['beta']
+
+ if cpt >= numItermax:
+ loop = False
+
+ if cpt % (print_period) == 0: # spsion nearly converged
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ transp = G
+ err = np.linalg.norm(
+ (np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % (print_period * 10) == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ if err <= stopThr and cpt > numItermin:
+ loop = False
+
+ cpt = cpt + 1
+ # print('err=',err,' cpt=',cpt)
+ if log:
+ log['alpha'] = alpha
+ log['beta'] = beta
+ log['warmstart'] = (log['alpha'], log['beta'])
+ return G, log
+ else:
+ return G
+
+
+def geometricBar(weights, alldistribT):
+ """return the weighted geometric mean of distributions"""
+ assert(len(weights) == alldistribT.shape[1])
+ return np.exp(np.dot(np.log(alldistribT), weights.T))
+
+
+def geometricMean(alldistribT):
+ """return the geometric mean of distributions"""
+ return np.exp(np.mean(np.log(alldistribT), axis=1))
+
+
+def projR(gamma, p):
+ """return the KL projection on the row constrints """
+ return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T
+
+
+def projC(gamma, q):
+ """return the KL projection on the column constrints """
+ return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
+
+
+def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False, **kwargs):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ method : str (optional)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+ """
+
+ if method.lower() == 'sinkhorn':
+ return barycenter_sinkhorn(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return barycenter_stabilized(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+ """
+
+ if weights is None:
+ weights = np.ones(A.shape[1]) / A.shape[1]
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ # M = M/np.median(M) # suggested by G. Peyre
+ K = np.exp(-M / reg)
+
+ cpt = 0
+ err = 1
+
+ UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T)
+ u = (geometricMean(UKv) / UKv.T).T
+
+ while (err > stopThr and cpt < numItermax):
+ cpt = cpt + 1
+ UKv = u * np.dot(K, np.divide(A, np.dot(K, u)))
+ u = (u.T * geometricBar(weights, UKv)).T / UKv
+
+ if cpt % 10 == 1:
+ err = np.sum(np.std(UKv, axis=1))
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ if log:
+ log['niter'] = cpt
+ return geometricBar(weights, UKv), log
+ else:
+ return geometricBar(weights, UKv)
+
+
+def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ with stabilization.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+ """
+
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ u = A / (Kv + 1e-16)
+ Ktu = K.T.dot(u)
+ q = geometricBar(weights, Ktu)
+ Q = q[:, None]
+ v = Q / (Ktu + 1e-16)
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u * Kv - A).max()
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg`" +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-9, stabThr=1e-30, verbose=False,
+ log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
+ - reg is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (n_hists, width, height)
+ n distributions (2D images) of size width x height
+ reg : float
+ Regularization term >0
+ weights : ndarray, shape (n_hists,)
+ Weights of each image on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ a : ndarray, shape (width, height)
+ 2D Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ References
+ ----------
+
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
+ Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
+ ACM Transactions on Graphics (TOG), 34(4), 66
+
+
+ """
+
+ if weights is None:
+ weights = np.ones(A.shape[0]) / A.shape[0]
+ else:
+ assert(len(weights) == A.shape[0])
+
+ if log:
+ log = {'err': []}
+
+ b = np.zeros_like(A[0, :, :])
+ U = np.ones_like(A)
+ KV = np.ones_like(A)
+
+ cpt = 0
+ err = 1
+
+ # build the convolution operator
+ t = np.linspace(0, 1, A.shape[1])
+ [Y, X] = np.meshgrid(t, t)
+ xi1 = np.exp(-(X - Y)**2 / reg)
+
+ def K(x):
+ return np.dot(np.dot(xi1, x), xi1)
+
+ while (err > stopThr and cpt < numItermax):
+
+ bold = b
+ cpt = cpt + 1
+
+ b = np.zeros_like(A[0, :, :])
+ for r in range(A.shape[0]):
+ KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
+ b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
+ b = np.exp(b)
+ for r in range(A.shape[0]):
+ U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
+
+ if cpt % 10 == 1:
+ err = np.sum(np.abs(bold - b))
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ if log:
+ log['niter'] = cpt
+ log['U'] = U
+ return b, log
+ else:
+ return b
+
+
+def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
+ stopThr=1e-3, verbose=False, log=False):
+ r"""
+ Compute the unmixing of an observation with a given dictionary using Wasserstein distance
+
+ The function solve the following optimization problem:
+
+ .. math::
+ \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h})
+
+
+ where :
+
+ - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
+ - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
+ - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
+ - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
+ - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting
+ - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization
+ - :math:`\\alpha`weight data fitting and regularization
+
+ The optimization problem is solved suing the algorithm described in [4]
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (dim_a)
+ observed distribution (histogram, sums to 1)
+ D : ndarray, shape (dim_a, n_atoms)
+ dictionary matrix
+ M : ndarray, shape (dim_a, dim_a)
+ loss matrix
+ M0 : ndarray, shape (n_atoms, dim_prior)
+ loss matrix
+ h0 : ndarray, shape (n_atoms,)
+ prior on the estimated unmixing h
+ reg : float
+ Regularization term >0 (Wasserstein data fitting)
+ reg0 : float
+ Regularization term >0 (Wasserstein reg with h0)
+ alpha : float
+ How much should we trust the prior ([0,1])
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ h : ndarray, shape (n_atoms,)
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ References
+ ----------
+
+ .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.
+
+ """
+
+ # M = M/np.median(M)
+ K = np.exp(-M / reg)
+
+ # M0 = M0/np.median(M0)
+ K0 = np.exp(-M0 / reg0)
+ old = h0
+
+ err = 1
+ cpt = 0
+ # log = {'niter':0, 'all_err':[]}
+ if log:
+ log = {'err': []}
+
+ while (err > stopThr and cpt < numItermax):
+ K = projC(K, a)
+ K0 = projC(K0, h0)
+ new = np.sum(K0, axis=1)
+ # we recombine the current selection from dictionnary
+ inv_new = np.dot(D, new)
+ other = np.sum(K, axis=1)
+ # geometric interpolation
+ delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new))
+ K = projR(K, delta)
+ K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0)
+
+ err = np.linalg.norm(np.sum(K0, axis=1) - old)
+ old = new
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt = cpt + 1
+
+ if log:
+ log['niter'] = cpt
+ return np.sum(K0, axis=1), log
+ else:
+ return np.sum(K0, axis=1)
+
+
+def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9, verbose=False,
+ log=False, **kwargs):
+ r'''
+ Solve the entropic regularization optimal transport problem and return the
+ OT matrix from empirical data
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`a` and :math:`b` are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : ndarray, shape (n_samples_a,)
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_samples_a = 2
+ >>> n_samples_b = 2
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
+ >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
+ array([[4.99977301e-01, 2.26989344e-05],
+ [2.26989344e-05, 4.99977301e-01]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ '''
+
+ if a is None:
+ a = unif(np.shape(X_s)[0])
+ if b is None:
+ b = unif(np.shape(X_t)[0])
+
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
+ return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
+ return pi
+
+
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ r'''
+ Solve the entropic regularization optimal transport problem from empirical
+ data and return the OT loss
+
+
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`a` and :math:`b` are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : ndarray, shape (n_samples_a,)
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_samples_a = 2
+ >>> n_samples_b = 2
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
+ >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
+ array([4.53978687e-05])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ '''
+
+ if a is None:
+ a = unif(np.shape(X_s)[0])
+ if b is None:
+ b = unif(np.shape(X_t)[0])
+
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_loss
+
+
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ r'''
+ Compute the sinkhorn divergence loss from empirical data
+
+ The function solves the following optimization problems and return the
+ sinkhorn divergence :math:`S`:
+
+ .. math::
+
+ W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a)
+
+ W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
+
+ S &= W - 1/2 * (W_a + W_b)
+
+ .. math::
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+
+ \gamma_a 1 = a
+
+ \gamma_a^T 1= a
+
+ \gamma_a\geq 0
+
+ \gamma_b 1 = b
+
+ \gamma_b^T 1= b
+
+ \gamma_b\geq 0
+ where :
+
+ - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b))
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`a` and :math:`b` are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : ndarray, shape (n_samples_a,)
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+ >>> n_samples_a = 2
+ >>> n_samples_b = 4
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
+ >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
+ array([1.499...])
+
+
+ References
+ ----------
+ .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
+ '''
+ if log:
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+
+ log = {}
+ log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
+ log['sinkhorn_loss_a'] = sinkhorn_loss_a
+ log['sinkhorn_loss_b'] = sinkhorn_loss_b
+ log['log_sinkhorn_ab'] = log_ab
+ log['log_sinkhorn_a'] = log_a
+ log['log_sinkhorn_b'] = log_b
+
+ return max(0, sinkhorn_div), log
+
+ else:
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ return max(0, sinkhorn_div)
diff --git a/ot/da.py b/ot/da.py
new file mode 100644
index 0000000..108a38d
--- /dev/null
+++ b/ot/da.py
@@ -0,0 +1,1916 @@
+# -*- coding: utf-8 -*-
+"""
+Domain adaptation with optimal transport
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+# Nathalie Gayraud <nat.gayraud@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+import scipy.linalg as linalg
+
+from .bregman import sinkhorn
+from .lp import emd
+from .utils import unif, dist, kernel, cost_normalization
+from .utils import check_params, BaseEstimator
+from .unbalanced import sinkhorn_unbalanced
+from .optim import cg
+from .optim import gcg
+
+
+def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
+ numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
+ log=False):
+ """
+ Solve the entropic regularization optimal transport problem with nonconvex
+ group lasso regularization
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
+ + \eta \Omega_g(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e
+ (\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega_g` is the group lasso regularization term
+ :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
+ where :math:`\mathcal{I}_c` are the index of samples from class c
+ in the source domain.
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the generalized conditional
+ gradient as proposed in [5]_ [7]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ labels_a : np.ndarray (ns,)
+ labels of samples in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term for entropic regularization >0
+ eta : float, optional
+ Regularization term for group lasso regularization >0
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations (inner sinkhorn solver)
+ stopInnerThr : float, optional
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+ p = 0.5
+ epsilon = 1e-3
+
+ indices_labels = []
+ classes = np.unique(labels_a)
+ for c in classes:
+ idxc, = np.where(labels_a == c)
+ indices_labels.append(idxc)
+
+ W = np.zeros(M.shape)
+
+ for cpt in range(numItermax):
+ Mreg = M + eta * W
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr)
+ # the transport has been computed. Check if classes are really
+ # separated
+ W = np.ones(M.shape)
+ for (i, c) in enumerate(classes):
+ majs = np.sum(transp[indices_labels[i]], axis=0)
+ majs = p * ((majs + epsilon)**(p - 1))
+ W[indices_labels[i]] = majs
+
+ return transp
+
+
+def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
+ numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
+ log=False):
+ """
+ Solve the entropic regularization optimal transport problem with group
+ lasso regularization
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+
+ \eta \Omega_g(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega_e` is the entropic regularization term
+ :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega_g` is the group lasso regulaization term
+ :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2`
+ where :math:`\mathcal{I}_c` are the index of samples from class
+ c in the source domain.
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the generalised conditional
+ gradient as proposed in [5]_ [7]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ labels_a : np.ndarray (ns,)
+ labels of samples in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term for entropic regularization >0
+ eta : float, optional
+ Regularization term for group lasso regularization >0
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations (inner sinkhorn solver)
+ stopInnerThr : float, optional
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence and
+ applications. arXiv preprint arXiv:1510.06567.
+
+ See Also
+ --------
+ ot.optim.gcg : Generalized conditional gradient for OT problems
+
+ """
+ lstlab = np.unique(labels_a)
+
+ def f(G):
+ res = 0
+ for i in range(G.shape[1]):
+ for lab in lstlab:
+ temp = G[labels_a == lab, i]
+ res += np.linalg.norm(temp)
+ return res
+
+ def df(G):
+ W = np.zeros(G.shape)
+ for i in range(G.shape[1]):
+ for lab in lstlab:
+ temp = G[labels_a == lab, i]
+ n = np.linalg.norm(temp)
+ if n:
+ W[labels_a == lab, i] = temp / n
+ return W
+
+ return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
+ numInnerItermax=numInnerItermax, stopThr=stopInnerThr,
+ verbose=verbose, log=log)
+
+
+def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
+ verbose2=False, numItermax=100, numInnerItermax=10,
+ stopInnerThr=1e-6, stopThr=1e-5, log=False,
+ **kwargs):
+ """Joint OT and linear mapping estimation as proposed in [8]
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F +
+ \mu<\gamma,M>_F + \eta \|L -I\|^2_F
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) squared euclidean cost matrix between samples in
+ Xs and Xt (scaled by ns)
+ - :math:`L` is a dxd linear operator that approximates the barycentric
+ mapping
+ - :math:`I` is the identity matrix (neutral linear mapping)
+ - a and b are uniform source and target weights
+
+ The problem consist in solving jointly an optimal transport matrix
+ :math:`\gamma` and a linear mapping that fits the barycentric mapping
+ :math:`n_s\gamma X_t`.
+
+ One can also estimate a mapping with constant bias (see supplementary
+ material of [8]) using the bias optional argument.
+
+ The algorithm used for solving the problem is the block coordinate
+ descent that alternates between updates of G (using conditionnal gradient)
+ and the update of L using a classical least square solver.
+
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ mu : float,optional
+ Weight for the linear OT loss (>0)
+ eta : float, optional
+ Regularization term for the linear mapping L (>0)
+ bias : bool,optional
+ Estimate linear mapping with constant bias
+ numItermax : int, optional
+ Max number of BCD iterations
+ stopThr : float, optional
+ Stop threshold on relative loss decrease (>0)
+ numInnerItermax : int, optional
+ Max number of iterations (inner CG solver)
+ stopInnerThr : float, optional
+ Stop threshold on error (inner CG solver) (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ L : (d x d) ndarray
+ Linear mapping matrix (d+1 x d if bias)
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1]
+
+ if bias:
+ xs1 = np.hstack((xs, np.ones((ns, 1))))
+ xstxs = xs1.T.dot(xs1)
+ Id = np.eye(d + 1)
+ Id[-1] = 0
+ I0 = Id[:, :-1]
+
+ def sel(x):
+ return x[:-1, :]
+ else:
+ xs1 = xs
+ xstxs = xs1.T.dot(xs1)
+ Id = np.eye(d)
+ I0 = Id
+
+ def sel(x):
+ return x
+
+ if log:
+ log = {'err': []}
+
+ a, b = unif(ns), unif(nt)
+ M = dist(xs, xt) * ns
+ G = emd(a, b, M)
+
+ vloss = []
+
+ def loss(L, G):
+ """Compute full loss"""
+ return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
+
+ def solve_L(G):
+ """ solve L problem with fixed G (least square)"""
+ xst = ns * G.dot(xt)
+ return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0)
+
+ def solve_G(L, G0):
+ """Update G with CG algorithm"""
+ xsi = xs1.dot(L)
+
+ def f(G):
+ return np.sum((xsi - ns * G.dot(xt))**2)
+
+ def df(G):
+ return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+ G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
+ numItermax=numInnerItermax, stopThr=stopInnerThr)
+ return G
+
+ L = solve_L(G)
+
+ vloss.append(loss(L, G))
+
+ if verbose:
+ print('{:5s}|{:12s}|{:8s}'.format(
+ 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0))
+
+ # init loop
+ if numItermax > 0:
+ loop = 1
+ else:
+ loop = 0
+ it = 0
+
+ while loop:
+
+ it += 1
+
+ # update G
+ G = solve_G(L, G)
+
+ # update L
+ L = solve_L(G)
+
+ vloss.append(loss(L, G))
+
+ if it >= numItermax:
+ loop = 0
+
+ if abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) < stopThr:
+ loop = 0
+
+ if verbose:
+ if it % 20 == 0:
+ print('{:5s}|{:12s}|{:8s}'.format(
+ 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}'.format(
+ it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2])))
+ if log:
+ log['loss'] = vloss
+ return G, L, log
+ else:
+ return G, L
+
+
+def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
+ sigma=1, bias=False, verbose=False, verbose2=False,
+ numItermax=100, numInnerItermax=10,
+ stopInnerThr=1e-6, stopThr=1e-5, log=False,
+ **kwargs):
+ """Joint OT and nonlinear mapping estimation with kernels as proposed in [8]
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -
+ n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) squared euclidean cost matrix between samples in
+ Xs and Xt (scaled by ns)
+ - :math:`L` is a ns x d linear operator on a kernel matrix that
+ approximates the barycentric mapping
+ - a and b are uniform source and target weights
+
+ The problem consist in solving jointly an optimal transport matrix
+ :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping
+ :math:`n_s\gamma X_t`.
+
+ One can also estimate a mapping with constant bias (see supplementary
+ material of [8]) using the bias optional argument.
+
+ The algorithm used for solving the problem is the block coordinate
+ descent that alternates between updates of G (using conditionnal gradient)
+ and the update of L using a classical kernel least square solver.
+
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ mu : float,optional
+ Weight for the linear OT loss (>0)
+ eta : float, optional
+ Regularization term for the linear mapping L (>0)
+ kerneltype : str,optional
+ kernel used by calling function ot.utils.kernel (gaussian by default)
+ sigma : float, optional
+ Gaussian kernel bandwidth.
+ bias : bool,optional
+ Estimate linear mapping with constant bias
+ verbose : bool, optional
+ Print information along iterations
+ verbose2 : bool, optional
+ Print information along iterations
+ numItermax : int, optional
+ Max number of BCD iterations
+ numInnerItermax : int, optional
+ Max number of iterations (inner CG solver)
+ stopInnerThr : float, optional
+ Stop threshold on error (inner CG solver) (>0)
+ stopThr : float, optional
+ Stop threshold on relative loss decrease (>0)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ L : (ns x d) ndarray
+ Nonlinear mapping matrix (ns+1 x d if bias)
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ ns, nt = xs.shape[0], xt.shape[0]
+
+ K = kernel(xs, xs, method=kerneltype, sigma=sigma)
+ if bias:
+ K1 = np.hstack((K, np.ones((ns, 1))))
+ Id = np.eye(ns + 1)
+ Id[-1] = 0
+ Kp = np.eye(ns + 1)
+ Kp[:ns, :ns] = K
+
+ # ls regu
+ # K0 = K1.T.dot(K1)+eta*I
+ # Kreg=I
+
+ # RKHS regul
+ K0 = K1.T.dot(K1) + eta * Kp
+ Kreg = Kp
+
+ else:
+ K1 = K
+ Id = np.eye(ns)
+
+ # ls regul
+ # K0 = K1.T.dot(K1)+eta*I
+ # Kreg=I
+
+ # proper kernel ridge
+ K0 = K + eta * Id
+ Kreg = K
+
+ if log:
+ log = {'err': []}
+
+ a, b = unif(ns), unif(nt)
+ M = dist(xs, xt) * ns
+ G = emd(a, b, M)
+
+ vloss = []
+
+ def loss(L, G):
+ """Compute full loss"""
+ return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
+
+ def solve_L_nobias(G):
+ """ solve L problem with fixed G (least square)"""
+ xst = ns * G.dot(xt)
+ return np.linalg.solve(K0, xst)
+
+ def solve_L_bias(G):
+ """ solve L problem with fixed G (least square)"""
+ xst = ns * G.dot(xt)
+ return np.linalg.solve(K0, K1.T.dot(xst))
+
+ def solve_G(L, G0):
+ """Update G with CG algorithm"""
+ xsi = K1.dot(L)
+
+ def f(G):
+ return np.sum((xsi - ns * G.dot(xt))**2)
+
+ def df(G):
+ return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+ G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
+ numItermax=numInnerItermax, stopThr=stopInnerThr)
+ return G
+
+ if bias:
+ solve_L = solve_L_bias
+ else:
+ solve_L = solve_L_nobias
+
+ L = solve_L(G)
+
+ vloss.append(loss(L, G))
+
+ if verbose:
+ print('{:5s}|{:12s}|{:8s}'.format(
+ 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0))
+
+ # init loop
+ if numItermax > 0:
+ loop = 1
+ else:
+ loop = 0
+ it = 0
+
+ while loop:
+
+ it += 1
+
+ # update G
+ G = solve_G(L, G)
+
+ # update L
+ L = solve_L(G)
+
+ vloss.append(loss(L, G))
+
+ if it >= numItermax:
+ loop = 0
+
+ if abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) < stopThr:
+ loop = 0
+
+ if verbose:
+ if it % 20 == 0:
+ print('{:5s}|{:12s}|{:8s}'.format(
+ 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}'.format(
+ it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2])))
+ if log:
+ log['loss'] = vloss
+ return G, L, log
+ else:
+ return G, L
+
+
+def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
+ """ return OT linear operator between samples
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark
+ 2.29 in [15].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the diagonals of convariances (>0)
+ ws : np.ndarray (ns,1), optional
+ weights for the source samples
+ wt : np.ndarray (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d x d) ndarray
+ Linear operator
+ b : (1 x d) ndarray
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+
+ """
+
+ d = xs.shape[1]
+
+ if bias:
+ mxs = xs.mean(0, keepdims=True)
+ mxt = xt.mean(0, keepdims=True)
+
+ xs = xs - mxs
+ xt = xt - mxt
+ else:
+ mxs = np.zeros((1, d))
+ mxt = np.zeros((1, d))
+
+ if ws is None:
+ ws = np.ones((xs.shape[0], 1)) / xs.shape[0]
+
+ if wt is None:
+ wt = np.ones((xt.shape[0], 1)) / xt.shape[0]
+
+ Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d)
+ Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d)
+
+ Cs12 = linalg.sqrtm(Cs)
+ Cs_12 = linalg.inv(Cs12)
+
+ M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
+
+ A = Cs_12.dot(M0.dot(Cs_12))
+
+ b = mxt - mxs.dot(A)
+
+ if log:
+ log = {}
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ log['Cs12'] = Cs12
+ log['Cs_12'] = Cs_12
+ return A, b, log
+ else:
+ return A, b
+
+
+def distribution_estimation_uniform(X):
+ """estimates a uniform distribution from an array of samples X
+
+ Parameters
+ ----------
+ X : array-like, shape (n_samples, n_features)
+ The array of samples
+
+ Returns
+ -------
+ mu : array-like, shape (n_samples,)
+ The uniform distribution estimated from X
+ """
+
+ return unif(X.shape[0])
+
+
+class BaseTransport(BaseEstimator):
+
+ """Base class for OTDA objects
+
+ Notes
+ -----
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
+
+ fit method should:
+ - estimate a cost matrix and store it in a `cost_` attribute
+ - estimate a coupling matrix and store it in a `coupling_`
+ attribute
+ - estimate distributions from source and target data and store them in
+ mu_s and mu_t attributes
+ - store Xs and Xt in attributes to be used later on in transform and
+ inverse_transform methods
+
+ transform method should always get as input a Xs parameter
+ inverse_transform method should always get as input a Xt parameter
+ """
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ # pairwise distance
+ self.cost_ = dist(Xs, Xt, metric=self.metric)
+ self.cost_ = cost_normalization(self.cost_, self.norm)
+
+ if (ys is not None) and (yt is not None):
+
+ if self.limit_max != np.infty:
+ self.limit_max = self.limit_max * np.max(self.cost_)
+
+ # assumes labeled source samples occupy the first rows
+ # and labeled target samples occupy the first columns
+ classes = [c for c in np.unique(ys) if c != -1]
+ for c in classes:
+ idx_s = np.where((ys != c) & (ys != -1))
+ idx_t = np.where(yt == c)
+
+ # all the coefficients corresponding to a source sample
+ # and a target sample :
+ # with different labels get a infinite
+ for j in idx_t[0]:
+ self.cost_[idx_s[0], j] = self.limit_max
+
+ # distribution estimation
+ self.mu_s = self.distribution_estimation(Xs)
+ self.mu_t = self.distribution_estimation(Xt)
+
+ # store arrays of samples
+ self.xs_ = Xs
+ self.xt_ = Xt
+
+ return self
+
+ def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt) and transports source samples Xs onto target
+ ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The source samples samples.
+ """
+
+ return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt)
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The transport source samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ if np.array_equal(self.xs_, Xs):
+
+ # perform standard barycentric mapping
+ transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ # compute transported samples
+ transp_Xs = np.dot(transp, self.xt_)
+ else:
+ # perform out of sample mapping
+ indices = np.arange(Xs.shape[0])
+ batch_ind = [
+ indices[i:i + batch_size]
+ for i in range(0, len(indices), batch_size)]
+
+ transp_Xs = []
+ for bi in batch_ind:
+
+ # get the nearest neighbor in the source domain
+ D0 = dist(Xs[bi], self.xs_)
+ idx = np.argmin(D0, axis=1)
+
+ # transport the source samples
+ transp = self.coupling_ / np.sum(
+ self.coupling_, 1)[:, None]
+ transp[~ np.isfinite(transp)] = 0
+ transp_Xs_ = np.dot(transp, self.xt_)
+
+ # define the transported points
+ transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :]
+
+ transp_Xs.append(transp_Xs_)
+
+ transp_Xs = np.concatenate(transp_Xs, axis=0)
+
+ return transp_Xs
+
+ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
+ batch_size=128):
+ """Transports target samples Xt onto target samples Xs
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xt : array-like, shape (n_source_samples, n_features)
+ The transported target samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xt=Xt):
+
+ if np.array_equal(self.xt_, Xt):
+
+ # perform standard barycentric mapping
+ transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None]
+
+ # set nans to 0
+ transp_[~ np.isfinite(transp_)] = 0
+
+ # compute transported samples
+ transp_Xt = np.dot(transp_, self.xs_)
+ else:
+ # perform out of sample mapping
+ indices = np.arange(Xt.shape[0])
+ batch_ind = [
+ indices[i:i + batch_size]
+ for i in range(0, len(indices), batch_size)]
+
+ transp_Xt = []
+ for bi in batch_ind:
+
+ D0 = dist(Xt[bi], self.xt_)
+ idx = np.argmin(D0, axis=1)
+
+ # transport the target samples
+ transp_ = self.coupling_.T / np.sum(
+ self.coupling_, 0)[:, None]
+ transp_[~ np.isfinite(transp_)] = 0
+ transp_Xt_ = np.dot(transp_, self.xs_)
+
+ # define the transported points
+ transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :]
+
+ transp_Xt.append(transp_Xt_)
+
+ transp_Xt = np.concatenate(transp_Xt, axis=0)
+
+ return transp_Xt
+
+
+class LinearTransport(BaseTransport):
+ """ OT linear operator between empirical distributions
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in
+ remark 2.29 in [15].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ reg : float,optional
+ regularization added to the daigonals of convariances (>0)
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+ References
+ ----------
+
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ """
+
+ def __init__(self, reg=1e-8, bias=True, log=False,
+ distribution_estimation=distribution_estimation_uniform):
+
+ self.bias = bias
+ self.log = log
+ self.reg = reg
+ self.distribution_estimation = distribution_estimation
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ self.mu_s = self.distribution_estimation(Xs)
+ self.mu_t = self.distribution_estimation(Xt)
+
+ # coupling estimation
+ returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
+ ws=self.mu_s.reshape((-1, 1)),
+ wt=self.mu_t.reshape((-1, 1)),
+ bias=self.bias, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.A_, self.B_, self.log_ = returned_
+ else:
+ self.A_, self.B_, = returned_
+ self.log_ = dict()
+
+ # re compute inverse mapping
+ self.A1_ = linalg.inv(self.A_)
+ self.B1_ = -self.B_.dot(self.A1_)
+
+ return self
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The transport source samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ transp_Xs = Xs.dot(self.A_) + self.B_
+
+ return transp_Xs
+
+ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
+ batch_size=128):
+ """Transports target samples Xt onto target samples Xs
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xt : array-like, shape (n_source_samples, n_features)
+ The transported target samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xt=Xt):
+
+ transp_Xt = Xt.dot(self.A1_) + self.B1_
+
+ return transp_Xt
+
+
+class SinkhornTransport(BaseTransport):
+
+ """Domain Adapatation OT method based on Sinkhorn Algorithm
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ max_iter : int, float, optional (default=1000)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ tol : float, optional (default=10e-9)
+ The precision required to stop the optimization algorithm.
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ log : int, optional (default=False)
+ Controls the logs of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (defaul=np.infty)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an cost defined
+ by this variable
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS)
+ 26, 2013
+ """
+
+ def __init__(self, reg_e=1., max_iter=1000,
+ tol=10e-9, verbose=False, log=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=np.infty):
+
+ self.reg_e = reg_e
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.norm = norm
+ self.limit_max = limit_max
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
+
+ # coupling estimation
+ returned_ = sinkhorn(
+ a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
+
+
+class EMDTransport(BaseTransport):
+
+ """Domain Adapatation OT method based on Earth Mover's Distance
+
+ Parameters
+ ----------
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ log : int, optional (default=False)
+ Controls the logs of the optimization algorithm
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+ max_iter : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+
+ References
+ ----------
+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ """
+
+ def __init__(self, metric="sqeuclidean", norm=None, log=False,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10,
+ max_iter=100000):
+
+ self.metric = metric
+ self.norm = norm
+ self.log = log
+ self.limit_max = limit_max
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.max_iter = max_iter
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ super(EMDTransport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = emd(
+ a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter,
+ log=self.log)
+
+ # coupling estimation
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+ return self
+
+
+class SinkhornLpl1Transport(BaseTransport):
+
+ """Domain Adapatation OT method based on sinkhorn algorithm +
+ LpL1 class regularization.
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_cl : float, optional (default=0.1)
+ Class regularization parameter
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ max_inner_iter : int, float, optional (default=200)
+ The number of iteration in the inner loop
+ log : bool, optional (default=False)
+ Controls the logs of the optimization algorithm
+ tol : float, optional (default=10e-9)
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (defaul=np.infty)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit a cost defined by
+ limit_max.
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+
+ References
+ ----------
+
+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
+
+ """
+
+ def __init__(self, reg_e=1., reg_cl=0.1,
+ max_iter=10, max_inner_iter=200, log=False,
+ tol=10e-9, verbose=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=np.infty):
+
+ self.reg_e = reg_e
+ self.reg_cl = reg_cl
+ self.max_iter = max_iter
+ self.max_inner_iter = max_inner_iter
+ self.tol = tol
+ self.log = log
+ self.verbose = verbose
+ self.metric = metric
+ self.norm = norm
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt, ys=ys):
+
+ super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = sinkhorn_lpl1_mm(
+ a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+ return self
+
+
+class SinkhornL1l2Transport(BaseTransport):
+
+ """Domain Adapatation OT method based on sinkhorn algorithm +
+ l1l2 class regularization.
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_cl : float, optional (default=0.1)
+ Class regularization parameter
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ max_inner_iter : int, float, optional (default=200)
+ The number of iteration in the inner loop
+ tol : float, optional (default=10e-9)
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ log : bool, optional (default=False)
+ Controls the logs of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
+
+ """
+
+ def __init__(self, reg_e=1., reg_cl=0.1,
+ max_iter=10, max_inner_iter=200,
+ tol=10e-9, verbose=False, log=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10):
+
+ self.reg_e = reg_e
+ self.reg_cl = reg_cl
+ self.max_iter = max_iter
+ self.max_inner_iter = max_inner_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.norm = norm
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt, ys=ys):
+
+ super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = sinkhorn_l1l2_gl(
+ a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
+
+
+class MappingTransport(BaseEstimator):
+
+ """MappingTransport: DA methods that aims at jointly estimating a optimal
+ transport coupling and the associated mapping
+
+ Parameters
+ ----------
+ mu : float, optional (default=1)
+ Weight for the linear OT loss (>0)
+ eta : float, optional (default=0.001)
+ Regularization term for the linear mapping L (>0)
+ bias : bool, optional (default=False)
+ Estimate linear mapping with constant bias
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ kernel : string, optional (default="linear")
+ The kernel to use either linear or gaussian
+ sigma : float, optional (default=1)
+ The gaussian kernel parameter
+ max_iter : int, optional (default=100)
+ Max number of BCD iterations
+ tol : float, optional (default=1e-5)
+ Stop threshold on relative loss decrease (>0)
+ max_inner_iter : int, optional (default=10)
+ Max number of iterations (inner CG solver)
+ inner_tol : float, optional (default=1e-6)
+ Stop threshold on error (inner CG solver) (>0)
+ log : bool, optional (default=False)
+ record log if True
+ verbose : bool, optional (default=False)
+ Print information along iterations
+ verbose2 : bool, optional (default=False)
+ Print information along iterations
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ mapping_ : array-like, shape (n_features (+ 1), n_features)
+ (if bias) for kernel == linear
+ The associated mapping
+ array-like, shape (n_source_samples (+ 1), n_features)
+ (if bias) for kernel == gaussian
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+
+ """
+
+ def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean",
+ norm=None, kernel="linear", sigma=1, max_iter=100, tol=1e-5,
+ max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False,
+ verbose2=False):
+
+ self.metric = metric
+ self.norm = norm
+ self.mu = mu
+ self.eta = eta
+ self.bias = bias
+ self.kernel = kernel
+ self.sigma = sigma
+ self.max_iter = max_iter
+ self.tol = tol
+ self.max_inner_iter = max_inner_iter
+ self.inner_tol = inner_tol
+ self.log = log
+ self.verbose = verbose
+ self.verbose2 = verbose2
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Builds an optimal coupling and estimates the associated mapping
+ from source and target sets of samples (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ self.xs_ = Xs
+ self.xt_ = Xt
+
+ if self.kernel == "linear":
+ returned_ = joint_OT_mapping_linear(
+ Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
+ verbose=self.verbose, verbose2=self.verbose2,
+ numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter, stopThr=self.tol,
+ stopInnerThr=self.inner_tol, log=self.log)
+
+ elif self.kernel == "gaussian":
+ returned_ = joint_OT_mapping_kernel(
+ Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
+ sigma=self.sigma, verbose=self.verbose,
+ verbose2=self.verbose, numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter,
+ stopInnerThr=self.inner_tol, stopThr=self.tol,
+ log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.mapping_, self.log_ = returned_
+ else:
+ self.coupling_, self.mapping_ = returned_
+ self.log_ = dict()
+
+ return self
+
+ def transform(self, Xs):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The transport source samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ if np.array_equal(self.xs_, Xs):
+ # perform standard barycentric mapping
+ transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ # compute transported samples
+ transp_Xs = np.dot(transp, self.xt_)
+ else:
+ if self.kernel == "gaussian":
+ K = kernel(Xs, self.xs_, method=self.kernel,
+ sigma=self.sigma)
+ elif self.kernel == "linear":
+ K = Xs
+ if self.bias:
+ K = np.hstack((K, np.ones((Xs.shape[0], 1))))
+ transp_Xs = K.dot(self.mapping_)
+
+ return transp_Xs
+
+
+class UnbalancedSinkhornTransport(BaseTransport):
+
+ """Domain Adapatation unbalanced OT method based on sinkhorn algorithm
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_m : float, optional (default=0.1)
+ Mass regularization parameter
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ tol : float, optional (default=10e-9)
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ log : bool, optional (default=False)
+ Controls the logs of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+ """
+
+ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
+ max_iter=10, tol=1e-9, verbose=False, log=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10):
+
+ self.reg_e = reg_e
+ self.reg_m = reg_m
+ self.method = method
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.norm = norm
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = sinkhorn_unbalanced(
+ a=self.mu_s, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, reg_m=self.reg_m, method=self.method,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
diff --git a/ot/datasets.py b/ot/datasets.py
new file mode 100644
index 0000000..ba0cfd9
--- /dev/null
+++ b/ot/datasets.py
@@ -0,0 +1,164 @@
+"""
+Simple example datasets for OT
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import scipy as sp
+from .utils import check_random_state, deprecated
+
+
+def make_1D_gauss(n, m, s):
+ """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
+
+ Parameters
+ ----------
+ n : int
+ number of bins in the histogram
+ m : float
+ mean value of the gaussian distribution
+ s : float
+ standard deviaton of the gaussian distribution
+
+ Returns
+ -------
+ h : ndarray (n,)
+ 1D histogram for a gaussian distribution
+ """
+ x = np.arange(n, dtype=np.float64)
+ h = np.exp(-(x - m)**2 / (2 * s**2))
+ return h / h.sum()
+
+
+@deprecated()
+def get_1D_gauss(n, m, sigma):
+ """ Deprecated see make_1D_gauss """
+ return make_1D_gauss(n, m, sigma)
+
+
+def make_2D_samples_gauss(n, m, sigma, random_state=None):
+ """Return n samples drawn from 2D gaussian N(m,sigma)
+
+ Parameters
+ ----------
+ n : int
+ number of samples to make
+ m : ndarray, shape (2,)
+ mean value of the gaussian distribution
+ sigma : ndarray, shape (2, 2)
+ covariance matrix of the gaussian distribution
+ random_state : int, RandomState instance or None, optional (default=None)
+ If int, random_state is the seed used by the random number generator;
+ If RandomState instance, random_state is the random number generator;
+ If None, the random number generator is the RandomState instance used
+ by `np.random`.
+
+ Returns
+ -------
+ X : ndarray, shape (n, 2)
+ n samples drawn from N(m, sigma).
+ """
+
+ generator = check_random_state(random_state)
+ if np.isscalar(sigma):
+ sigma = np.array([sigma, ])
+ if len(sigma) > 1:
+ P = sp.linalg.sqrtm(sigma)
+ res = generator.randn(n, 2).dot(P) + m
+ else:
+ res = generator.randn(n, 2) * np.sqrt(sigma) + m
+ return res
+
+
+@deprecated()
+def get_2D_samples_gauss(n, m, sigma, random_state=None):
+ """ Deprecated see make_2D_samples_gauss """
+ return make_2D_samples_gauss(n, m, sigma, random_state=None)
+
+
+def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
+ """Dataset generation for classification problems
+
+ Parameters
+ ----------
+ dataset : str
+ type of classification problem (see code)
+ n : int
+ number of training samples
+ nz : float
+ noise level (>0)
+ random_state : int, RandomState instance or None, optional (default=None)
+ If int, random_state is the seed used by the random number generator;
+ If RandomState instance, random_state is the random number generator;
+ If None, the random number generator is the RandomState instance used
+ by `np.random`.
+
+ Returns
+ -------
+ X : ndarray, shape (n, d)
+ n observation of size d
+ y : ndarray, shape (n,)
+ labels of the samples.
+ """
+ generator = check_random_state(random_state)
+
+ if dataset.lower() == '3gauss':
+ y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ x = np.zeros((n, 2))
+ # class 1
+ x[y == 1, 0] = -1.
+ x[y == 1, 1] = -1.
+ x[y == 2, 0] = -1.
+ x[y == 2, 1] = 1.
+ x[y == 3, 0] = 1.
+ x[y == 3, 1] = 0
+
+ x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2)
+ x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)
+
+ elif dataset.lower() == '3gauss2':
+ y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ x = np.zeros((n, 2))
+ y[y == 4] = 3
+ # class 1
+ x[y == 1, 0] = -2.
+ x[y == 1, 1] = -2.
+ x[y == 2, 0] = -2.
+ x[y == 2, 1] = 2.
+ x[y == 3, 0] = 2.
+ x[y == 3, 1] = 0
+
+ x[y != 3, :] += nz * generator.randn(sum(y != 3), 2)
+ x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)
+
+ elif dataset.lower() == 'gaussrot':
+ rot = np.array(
+ [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]])
+ m1 = np.array([-1, 1])
+ m2 = np.array([1, -1])
+ y = np.floor((np.arange(n) * 1.0 / n * 2)) + 1
+ n1 = np.sum(y == 1)
+ n2 = np.sum(y == 2)
+ x = np.zeros((n, 2))
+
+ x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator)
+ x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator)
+
+ x = x.dot(rot)
+
+ else:
+ x = np.array(0)
+ y = np.array(0)
+ print("unknown dataset")
+
+ return x, y.astype(int)
+
+
+@deprecated()
+def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
+ """ Deprecated see make_data_classif """
+ return make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs)
diff --git a/ot/dr.py b/ot/dr.py
new file mode 100644
index 0000000..680dabf
--- /dev/null
+++ b/ot/dr.py
@@ -0,0 +1,200 @@
+# -*- coding: utf-8 -*-
+"""
+Dimension reduction with optimal transport
+
+
+.. warning::
+ Note that by default the module is not import in :mod:`ot`. In order to
+ use it you need to explicitely import :mod:`ot.dr`
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+from scipy import linalg
+import autograd.numpy as np
+from pymanopt.manifolds import Stiefel
+from pymanopt import Problem
+from pymanopt.solvers import SteepestDescent, TrustRegions
+
+
+def dist(x1, x2):
+ """ Compute squared euclidean distance between samples (autograd)
+ """
+ x1p2 = np.sum(np.square(x1), 1)
+ x2p2 = np.sum(np.square(x2), 1)
+ return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T)
+
+
+def sinkhorn(w1, w2, M, reg, k):
+ """Sinkhorn algorithm with fixed number of iteration (autograd)
+ """
+ K = np.exp(-M / reg)
+ ui = np.ones((M.shape[0],))
+ vi = np.ones((M.shape[1],))
+ for i in range(k):
+ vi = w2 / (np.dot(K.T, ui))
+ ui = w1 / (np.dot(K, vi))
+ G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
+ return G
+
+
+def split_classes(X, y):
+ """split samples in X by classes in y
+ """
+ lstsclass = np.unique(y)
+ return [X[y == i, :].astype(np.float32) for i in lstsclass]
+
+
+def fda(X, y, p=2, reg=1e-16):
+ """Fisher Discriminant Analysis
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
+ p : int, optional
+ Size of dimensionnality reduction.
+ reg : float, optional
+ Regularization term >0 (ridge regularization)
+
+ Returns
+ -------
+ P : ndarray, shape (d, p)
+ Optimal transportation matrix for the given parameters
+ proj : callable
+ projection function including mean centering
+ """
+
+ mx = np.mean(X)
+ X -= mx.reshape((1, -1))
+
+ # data split between classes
+ d = X.shape[1]
+ xc = split_classes(X, y)
+ nc = len(xc)
+
+ p = min(nc - 1, p)
+
+ Cw = 0
+ for x in xc:
+ Cw += np.cov(x, rowvar=False)
+ Cw /= nc
+
+ mxc = np.zeros((d, nc))
+
+ for i in range(nc):
+ mxc[:, i] = np.mean(xc[i])
+
+ mx0 = np.mean(mxc, 1)
+ Cb = 0
+ for i in range(nc):
+ Cb += (mxc[:, i] - mx0).reshape((-1, 1)) * \
+ (mxc[:, i] - mx0).reshape((1, -1))
+
+ w, V = linalg.eig(Cb, Cw + reg * np.eye(d))
+
+ idx = np.argsort(w.real)
+
+ Popt = V[:, idx[-p:]]
+
+ def proj(X):
+ return (X - mx.reshape((1, -1))).dot(Popt)
+
+ return Popt, proj
+
+
+def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
+ """
+ Wasserstein Discriminant Analysis [11]_
+
+ The function solves the following optimization problem:
+
+ .. math::
+ P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
+
+ where :
+
+ - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
+ - :math:`W` is entropic regularized Wasserstein distances
+ - :math:`X^i` are samples in the dataset corresponding to class i
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
+ p : int, optional
+ Size of dimensionnality reduction.
+ reg : float, optional
+ Regularization term >0 (entropic regularization)
+ solver : None | str, optional
+ None for steepest descent or 'TrustRegions' for trust regions algorithm
+ else should be a pymanopt.solvers
+ P0 : ndarray, shape (d, p)
+ Initial starting point for projection.
+ verbose : int, optional
+ Print information along iterations.
+
+ Returns
+ -------
+ P : ndarray, shape (d, p)
+ Optimal transportation matrix for the given parameters
+ proj : callable
+ Projection function including mean centering.
+
+ References
+ ----------
+ .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+ Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
+ """ # noqa
+
+ mx = np.mean(X)
+ X -= mx.reshape((1, -1))
+
+ # data split between classes
+ d = X.shape[1]
+ xc = split_classes(X, y)
+ # compute uniform weighs
+ wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc]
+
+ def cost(P):
+ # wda loss
+ loss_b = 0
+ loss_w = 0
+
+ for i, xi in enumerate(xc):
+ xi = np.dot(xi, P)
+ for j, xj in enumerate(xc[i:]):
+ xj = np.dot(xj, P)
+ M = dist(xi, xj)
+ G = sinkhorn(wc[i], wc[j + i], M, reg, k)
+ if j == 0:
+ loss_w += np.sum(G * M)
+ else:
+ loss_b += np.sum(G * M)
+
+ # loss inversed because minimization
+ return loss_w / loss_b
+
+ # declare manifold and problem
+ manifold = Stiefel(d, p)
+ problem = Problem(manifold=manifold, cost=cost)
+
+ # declare solver and solve
+ if solver is None:
+ solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
+ elif solver in ['tr', 'TrustRegions']:
+ solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
+
+ Popt = solver.solve(problem, x=P0)
+
+ def proj(X):
+ return (X - mx.reshape((1, -1))).dot(Popt)
+
+ return Popt, proj
diff --git a/ot/externals/__init__.py b/ot/externals/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/ot/externals/__init__.py
diff --git a/ot/externals/funcsigs.py b/ot/externals/funcsigs.py
new file mode 100644
index 0000000..106bde7
--- /dev/null
+++ b/ot/externals/funcsigs.py
@@ -0,0 +1,817 @@
+# Copyright 2001-2013 Python Software Foundation; All Rights Reserved
+"""Function signature objects for callables
+
+Back port of Python 3.3's function signature tools from the inspect module,
+modified to be compatible with Python 2.7 and 3.2+.
+"""
+from __future__ import absolute_import, division, print_function
+import itertools
+import functools
+import re
+import types
+
+from collections import OrderedDict
+
+__version__ = "0.4"
+
+__all__ = ['BoundArguments', 'Parameter', 'Signature', 'signature']
+
+
+_WrapperDescriptor = type(type.__call__)
+_MethodWrapper = type(all.__call__)
+
+_NonUserDefinedCallables = (_WrapperDescriptor,
+ _MethodWrapper,
+ types.BuiltinFunctionType)
+
+
+def formatannotation(annotation, base_module=None):
+ if isinstance(annotation, type):
+ if annotation.__module__ in ('builtins', '__builtin__', base_module):
+ return annotation.__name__
+ return annotation.__module__ + '.' + annotation.__name__
+ return repr(annotation)
+
+
+def _get_user_defined_method(cls, method_name, *nested):
+ try:
+ if cls is type:
+ return
+ meth = getattr(cls, method_name)
+ for name in nested:
+ meth = getattr(meth, name, meth)
+ except AttributeError:
+ return
+ else:
+ if not isinstance(meth, _NonUserDefinedCallables):
+ # Once '__signature__' will be added to 'C'-level
+ # callables, this check won't be necessary
+ return meth
+
+
+def signature(obj):
+ '''Get a signature object for the passed callable.'''
+
+ if not callable(obj):
+ raise TypeError('{0!r} is not a callable object'.format(obj))
+
+ if isinstance(obj, types.MethodType):
+ sig = signature(obj.__func__)
+ if obj.__self__ is None:
+ # Unbound method: the first parameter becomes positional-only
+ if sig.parameters:
+ first = sig.parameters.values()[0].replace(
+ kind=_POSITIONAL_ONLY)
+ return sig.replace(
+ parameters=(first,) + tuple(sig.parameters.values())[1:])
+ else:
+ return sig
+ else:
+ # In this case we skip the first parameter of the underlying
+ # function (usually `self` or `cls`).
+ return sig.replace(parameters=tuple(sig.parameters.values())[1:])
+
+ try:
+ sig = obj.__signature__
+ except AttributeError:
+ pass
+ else:
+ if sig is not None:
+ return sig
+
+ try:
+ # Was this function wrapped by a decorator?
+ wrapped = obj.__wrapped__
+ except AttributeError:
+ pass
+ else:
+ return signature(wrapped)
+
+ if isinstance(obj, types.FunctionType):
+ return Signature.from_function(obj)
+
+ if isinstance(obj, functools.partial):
+ sig = signature(obj.func)
+
+ new_params = OrderedDict(sig.parameters.items())
+
+ partial_args = obj.args or ()
+ partial_keywords = obj.keywords or {}
+ try:
+ ba = sig.bind_partial(*partial_args, **partial_keywords)
+ except TypeError:
+ msg = 'partial object {0!r} has incorrect arguments'.format(obj)
+ raise ValueError(msg)
+
+ for arg_name, arg_value in ba.arguments.items():
+ param = new_params[arg_name]
+ if arg_name in partial_keywords:
+ # We set a new default value, because the following code
+ # is correct:
+ #
+ # >>> def foo(a): print(a)
+ # >>> print(partial(partial(foo, a=10), a=20)())
+ # 20
+ # >>> print(partial(partial(foo, a=10), a=20)(a=30))
+ # 30
+ #
+ # So, with 'partial' objects, passing a keyword argument is
+ # like setting a new default value for the corresponding
+ # parameter
+ #
+ # We also mark this parameter with '_partial_kwarg'
+ # flag. Later, in '_bind', the 'default' value of this
+ # parameter will be added to 'kwargs', to simulate
+ # the 'functools.partial' real call.
+ new_params[arg_name] = param.replace(default=arg_value,
+ _partial_kwarg=True)
+
+ elif (param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL)
+ and not param._partial_kwarg):
+ new_params.pop(arg_name)
+
+ return sig.replace(parameters=new_params.values())
+
+ sig = None
+ if isinstance(obj, type):
+ # obj is a class or a metaclass
+
+ # First, let's see if it has an overloaded __call__ defined
+ # in its metaclass
+ call = _get_user_defined_method(type(obj), '__call__')
+ if call is not None:
+ sig = signature(call)
+ else:
+ # Now we check if the 'obj' class has a '__new__' method
+ new = _get_user_defined_method(obj, '__new__')
+ if new is not None:
+ sig = signature(new)
+ else:
+ # Finally, we should have at least __init__ implemented
+ init = _get_user_defined_method(obj, '__init__')
+ if init is not None:
+ sig = signature(init)
+ elif not isinstance(obj, _NonUserDefinedCallables):
+ # An object with __call__
+ # We also check that the 'obj' is not an instance of
+ # _WrapperDescriptor or _MethodWrapper to avoid
+ # infinite recursion (and even potential segfault)
+ call = _get_user_defined_method(type(obj), '__call__', 'im_func')
+ if call is not None:
+ sig = signature(call)
+
+ if sig is not None:
+ # For classes and objects we skip the first parameter of their
+ # __call__, __new__, or __init__ methods
+ return sig.replace(parameters=tuple(sig.parameters.values())[1:])
+
+ if isinstance(obj, types.BuiltinFunctionType):
+ # Raise a nicer error message for builtins
+ msg = 'no signature found for builtin function {0!r}'.format(obj)
+ raise ValueError(msg)
+
+ raise ValueError(
+ 'callable {0!r} is not supported by signature'.format(obj))
+
+
+class _void(object):
+ '''A private marker - used in Parameter & Signature'''
+
+
+class _empty(object):
+ pass
+
+
+class _ParameterKind(int):
+ def __new__(self, *args, **kwargs):
+ obj = int.__new__(self, *args)
+ obj._name = kwargs['name']
+ return obj
+
+ def __str__(self):
+ return self._name
+
+ def __repr__(self):
+ return '<_ParameterKind: {0!r}>'.format(self._name)
+
+
+_POSITIONAL_ONLY = _ParameterKind(0, name='POSITIONAL_ONLY')
+_POSITIONAL_OR_KEYWORD = _ParameterKind(1, name='POSITIONAL_OR_KEYWORD')
+_VAR_POSITIONAL = _ParameterKind(2, name='VAR_POSITIONAL')
+_KEYWORD_ONLY = _ParameterKind(3, name='KEYWORD_ONLY')
+_VAR_KEYWORD = _ParameterKind(4, name='VAR_KEYWORD')
+
+
+class Parameter(object):
+ '''Represents a parameter in a function signature.
+
+ Has the following public attributes:
+
+ * name : str
+ The name of the parameter as a string.
+ * default : object
+ The default value for the parameter if specified. If the
+ parameter has no default value, this attribute is not set.
+ * annotation
+ The annotation for the parameter if specified. If the
+ parameter has no annotation, this attribute is not set.
+ * kind : str
+ Describes how argument values are bound to the parameter.
+ Possible values: `Parameter.POSITIONAL_ONLY`,
+ `Parameter.POSITIONAL_OR_KEYWORD`, `Parameter.VAR_POSITIONAL`,
+ `Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`.
+ '''
+
+ __slots__ = ('_name', '_kind', '_default', '_annotation', '_partial_kwarg')
+
+ POSITIONAL_ONLY = _POSITIONAL_ONLY
+ POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD
+ VAR_POSITIONAL = _VAR_POSITIONAL
+ KEYWORD_ONLY = _KEYWORD_ONLY
+ VAR_KEYWORD = _VAR_KEYWORD
+
+ empty = _empty
+
+ def __init__(self, name, kind, default=_empty, annotation=_empty,
+ _partial_kwarg=False):
+
+ if kind not in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD,
+ _VAR_POSITIONAL, _KEYWORD_ONLY, _VAR_KEYWORD):
+ raise ValueError("invalid value for 'Parameter.kind' attribute")
+ self._kind = kind
+
+ if default is not _empty:
+ if kind in (_VAR_POSITIONAL, _VAR_KEYWORD):
+ msg = '{0} parameters cannot have default values'.format(kind)
+ raise ValueError(msg)
+ self._default = default
+ self._annotation = annotation
+
+ if name is None:
+ if kind != _POSITIONAL_ONLY:
+ raise ValueError("None is not a valid name for a "
+ "non-positional-only parameter")
+ self._name = name
+ else:
+ name = str(name)
+ if kind != _POSITIONAL_ONLY and not re.match(
+ r'[a-z_]\w*$', name, re.I):
+ msg = '{0!r} is not a valid parameter name'.format(name)
+ raise ValueError(msg)
+ self._name = name
+
+ self._partial_kwarg = _partial_kwarg
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def default(self):
+ return self._default
+
+ @property
+ def annotation(self):
+ return self._annotation
+
+ @property
+ def kind(self):
+ return self._kind
+
+ def replace(self, name=_void, kind=_void, annotation=_void,
+ default=_void, _partial_kwarg=_void):
+ '''Creates a customized copy of the Parameter.'''
+
+ if name is _void:
+ name = self._name
+
+ if kind is _void:
+ kind = self._kind
+
+ if annotation is _void:
+ annotation = self._annotation
+
+ if default is _void:
+ default = self._default
+
+ if _partial_kwarg is _void:
+ _partial_kwarg = self._partial_kwarg
+
+ return type(self)(name, kind, default=default, annotation=annotation,
+ _partial_kwarg=_partial_kwarg)
+
+ def __str__(self):
+ kind = self.kind
+
+ formatted = self._name
+ if kind == _POSITIONAL_ONLY:
+ if formatted is None:
+ formatted = ''
+ formatted = '<{0}>'.format(formatted)
+
+ # Add annotation and default value
+ if self._annotation is not _empty:
+ formatted = '{0}:{1}'.format(formatted,
+ formatannotation(self._annotation))
+
+ if self._default is not _empty:
+ formatted = '{0}={1}'.format(formatted, repr(self._default))
+
+ if kind == _VAR_POSITIONAL:
+ formatted = '*' + formatted
+ elif kind == _VAR_KEYWORD:
+ formatted = '**' + formatted
+
+ return formatted
+
+ def __repr__(self):
+ return '<{0} at {1:#x} {2!r}>'.format(self.__class__.__name__,
+ id(self), self.name)
+
+ def __hash__(self):
+ msg = "unhashable type: '{0}'".format(self.__class__.__name__)
+ raise TypeError(msg)
+
+ def __eq__(self, other):
+ return (issubclass(other.__class__, Parameter)
+ and self._name == other._name
+ and self._kind == other._kind
+ and self._default == other._default
+ and self._annotation == other._annotation)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class BoundArguments(object):
+ '''Result of `Signature.bind` call. Holds the mapping of arguments
+ to the function's parameters.
+
+ Has the following public attributes:
+
+ * arguments : OrderedDict
+ An ordered mutable mapping of parameters' names to arguments' values.
+ Does not contain arguments' default values.
+ * signature : Signature
+ The Signature object that created this instance.
+ * args : tuple
+ Tuple of positional arguments values.
+ * kwargs : dict
+ Dict of keyword arguments values.
+ '''
+
+ def __init__(self, signature, arguments):
+ self.arguments = arguments
+ self._signature = signature
+
+ @property
+ def signature(self):
+ return self._signature
+
+ @property
+ def args(self):
+ args = []
+ for param_name, param in self._signature.parameters.items():
+ if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY)
+ or param._partial_kwarg):
+ # Keyword arguments mapped by 'functools.partial'
+ # (Parameter._partial_kwarg is True) are mapped
+ # in 'BoundArguments.kwargs', along with VAR_KEYWORD &
+ # KEYWORD_ONLY
+ break
+
+ try:
+ arg = self.arguments[param_name]
+ except KeyError:
+ # We're done here. Other arguments
+ # will be mapped in 'BoundArguments.kwargs'
+ break
+ else:
+ if param.kind == _VAR_POSITIONAL:
+ # *args
+ args.extend(arg)
+ else:
+ # plain argument
+ args.append(arg)
+
+ return tuple(args)
+
+ @property
+ def kwargs(self):
+ kwargs = {}
+ kwargs_started = False
+ for param_name, param in self._signature.parameters.items():
+ if not kwargs_started:
+ if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY)
+ or param._partial_kwarg):
+ kwargs_started = True
+ else:
+ if param_name not in self.arguments:
+ kwargs_started = True
+ continue
+
+ if not kwargs_started:
+ continue
+
+ try:
+ arg = self.arguments[param_name]
+ except KeyError:
+ pass
+ else:
+ if param.kind == _VAR_KEYWORD:
+ # **kwargs
+ kwargs.update(arg)
+ else:
+ # plain keyword argument
+ kwargs[param_name] = arg
+
+ return kwargs
+
+ def __hash__(self):
+ msg = "unhashable type: '{0}'".format(self.__class__.__name__)
+ raise TypeError(msg)
+
+ def __eq__(self, other):
+ return (issubclass(other.__class__, BoundArguments)
+ and self.signature == other.signature
+ and self.arguments == other.arguments)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class Signature(object):
+ '''A Signature object represents the overall signature of a function.
+ It stores a Parameter object for each parameter accepted by the
+ function, as well as information specific to the function itself.
+
+ A Signature object has the following public attributes and methods:
+
+ * parameters : OrderedDict
+ An ordered mapping of parameters' names to the corresponding
+ Parameter objects (keyword-only arguments are in the same order
+ as listed in `code.co_varnames`).
+ * return_annotation : object
+ The annotation for the return type of the function if specified.
+ If the function has no annotation for its return type, this
+ attribute is not set.
+ * bind(*args, **kwargs) -> BoundArguments
+ Creates a mapping from positional and keyword arguments to
+ parameters.
+ * bind_partial(*args, **kwargs) -> BoundArguments
+ Creates a partial mapping from positional and keyword arguments
+ to parameters (simulating 'functools.partial' behavior.)
+ '''
+
+ __slots__ = ('_return_annotation', '_parameters')
+
+ _parameter_cls = Parameter
+ _bound_arguments_cls = BoundArguments
+
+ empty = _empty
+
+ def __init__(self, parameters=None, return_annotation=_empty,
+ __validate_parameters__=True):
+ '''Constructs Signature from the given list of Parameter
+ objects and 'return_annotation'. All arguments are optional.
+ '''
+
+ if parameters is None:
+ params = OrderedDict()
+ else:
+ if __validate_parameters__:
+ params = OrderedDict()
+ top_kind = _POSITIONAL_ONLY
+
+ for idx, param in enumerate(parameters):
+ kind = param.kind
+ if kind < top_kind:
+ msg = 'wrong parameter order: {0} before {1}'
+ msg = msg.format(top_kind, param.kind)
+ raise ValueError(msg)
+ else:
+ top_kind = kind
+
+ name = param.name
+ if name is None:
+ name = str(idx)
+ param = param.replace(name=name)
+
+ if name in params:
+ msg = 'duplicate parameter name: {0!r}'.format(name)
+ raise ValueError(msg)
+ params[name] = param
+ else:
+ params = OrderedDict(((param.name, param)
+ for param in parameters))
+
+ self._parameters = params
+ self._return_annotation = return_annotation
+
+ @classmethod
+ def from_function(cls, func):
+ '''Constructs Signature for the given python function'''
+
+ if not isinstance(func, types.FunctionType):
+ raise TypeError('{0!r} is not a Python function'.format(func))
+
+ Parameter = cls._parameter_cls
+
+ # Parameter information.
+ func_code = func.__code__
+ pos_count = func_code.co_argcount
+ arg_names = func_code.co_varnames
+ positional = tuple(arg_names[:pos_count])
+ keyword_only_count = getattr(func_code, 'co_kwonlyargcount', 0)
+ keyword_only = arg_names[pos_count:(pos_count + keyword_only_count)]
+ annotations = getattr(func, '__annotations__', {})
+ defaults = func.__defaults__
+ kwdefaults = getattr(func, '__kwdefaults__', None)
+
+ if defaults:
+ pos_default_count = len(defaults)
+ else:
+ pos_default_count = 0
+
+ parameters = []
+
+ # Non-keyword-only parameters w/o defaults.
+ non_default_count = pos_count - pos_default_count
+ for name in positional[:non_default_count]:
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_POSITIONAL_OR_KEYWORD))
+
+ # ... w/ defaults.
+ for offset, name in enumerate(positional[non_default_count:]):
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_POSITIONAL_OR_KEYWORD,
+ default=defaults[offset]))
+
+ # *args
+ if func_code.co_flags & 0x04:
+ name = arg_names[pos_count + keyword_only_count]
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_VAR_POSITIONAL))
+
+ # Keyword-only parameters.
+ for name in keyword_only:
+ default = _empty
+ if kwdefaults is not None:
+ default = kwdefaults.get(name, _empty)
+
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_KEYWORD_ONLY,
+ default=default))
+ # **kwargs
+ if func_code.co_flags & 0x08:
+ index = pos_count + keyword_only_count
+ if func_code.co_flags & 0x04:
+ index += 1
+
+ name = arg_names[index]
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_VAR_KEYWORD))
+
+ return cls(parameters,
+ return_annotation=annotations.get('return', _empty),
+ __validate_parameters__=False)
+
+ @property
+ def parameters(self):
+ try:
+ return types.MappingProxyType(self._parameters)
+ except AttributeError:
+ return OrderedDict(self._parameters.items())
+
+ @property
+ def return_annotation(self):
+ return self._return_annotation
+
+ def replace(self, parameters=_void, return_annotation=_void):
+ '''Creates a customized copy of the Signature.
+ Pass 'parameters' and/or 'return_annotation' arguments
+ to override them in the new copy.
+ '''
+
+ if parameters is _void:
+ parameters = self.parameters.values()
+
+ if return_annotation is _void:
+ return_annotation = self._return_annotation
+
+ return type(self)(parameters,
+ return_annotation=return_annotation)
+
+ def __hash__(self):
+ msg = "unhashable type: '{0}'".format(self.__class__.__name__)
+ raise TypeError(msg)
+
+ def __eq__(self, other):
+ if (not issubclass(type(other), Signature)
+ or self.return_annotation != other.return_annotation
+ or len(self.parameters) != len(other.parameters)):
+ return False
+
+ other_positions = dict((param, idx)
+ for idx, param in enumerate(other.parameters.keys()))
+
+ for idx, (param_name, param) in enumerate(self.parameters.items()):
+ if param.kind == _KEYWORD_ONLY:
+ try:
+ other_param = other.parameters[param_name]
+ except KeyError:
+ return False
+ else:
+ if param != other_param:
+ return False
+ else:
+ try:
+ other_idx = other_positions[param_name]
+ except KeyError:
+ return False
+ else:
+ if (idx != other_idx
+ or param != other.parameters[param_name]):
+ return False
+
+ return True
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _bind(self, args, kwargs, partial=False):
+ '''Private method. Don't use directly.'''
+
+ arguments = OrderedDict()
+
+ parameters = iter(self.parameters.values())
+ parameters_ex = ()
+ arg_vals = iter(args)
+
+ if partial:
+ # Support for binding arguments to 'functools.partial' objects.
+ # See 'functools.partial' case in 'signature()' implementation
+ # for details.
+ for param_name, param in self.parameters.items():
+ if (param._partial_kwarg and param_name not in kwargs):
+ # Simulating 'functools.partial' behavior
+ kwargs[param_name] = param.default
+
+ while True:
+ # Let's iterate through the positional arguments and corresponding
+ # parameters
+ try:
+ arg_val = next(arg_vals)
+ except StopIteration:
+ # No more positional arguments
+ try:
+ param = next(parameters)
+ except StopIteration:
+ # No more parameters. That's it. Just need to check that
+ # we have no `kwargs` after this while loop
+ break
+ else:
+ if param.kind == _VAR_POSITIONAL:
+ # That's OK, just empty *args. Let's start parsing
+ # kwargs
+ break
+ elif param.name in kwargs:
+ if param.kind == _POSITIONAL_ONLY:
+ msg = '{arg!r} parameter is positional only, ' \
+ 'but was passed as a keyword'
+ msg = msg.format(arg=param.name)
+ raise TypeError(msg)
+ parameters_ex = (param,)
+ break
+ elif (param.kind == _VAR_KEYWORD
+ or param.default is not _empty):
+ # That's fine too - we have a default value for this
+ # parameter. So, lets start parsing `kwargs`, starting
+ # with the current parameter
+ parameters_ex = (param,)
+ break
+ else:
+ if partial:
+ parameters_ex = (param,)
+ break
+ else:
+ msg = '{arg!r} parameter lacking default value'
+ msg = msg.format(arg=param.name)
+ raise TypeError(msg)
+ else:
+ # We have a positional argument to process
+ try:
+ param = next(parameters)
+ except StopIteration:
+ raise TypeError('too many positional arguments')
+ else:
+ if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY):
+ # Looks like we have no parameter for this positional
+ # argument
+ raise TypeError('too many positional arguments')
+
+ if param.kind == _VAR_POSITIONAL:
+ # We have an '*args'-like argument, let's fill it with
+ # all positional arguments we have left and move on to
+ # the next phase
+ values = [arg_val]
+ values.extend(arg_vals)
+ arguments[param.name] = tuple(values)
+ break
+
+ if param.name in kwargs:
+ raise TypeError('multiple values for argument '
+ '{arg!r}'.format(arg=param.name))
+
+ arguments[param.name] = arg_val
+
+ # Now, we iterate through the remaining parameters to process
+ # keyword arguments
+ kwargs_param = None
+ for param in itertools.chain(parameters_ex, parameters):
+ if param.kind == _POSITIONAL_ONLY:
+ # This should never happen in case of a properly built
+ # Signature object (but let's have this check here
+ # to ensure correct behaviour just in case)
+ raise TypeError('{arg!r} parameter is positional only, '
+ 'but was passed as a keyword'.
+ format(arg=param.name))
+
+ if param.kind == _VAR_KEYWORD:
+ # Memorize that we have a '**kwargs'-like parameter
+ kwargs_param = param
+ continue
+
+ param_name = param.name
+ try:
+ arg_val = kwargs.pop(param_name)
+ except KeyError:
+ # We have no value for this parameter. It's fine though,
+ # if it has a default value, or it is an '*args'-like
+ # parameter, left alone by the processing of positional
+ # arguments.
+ if (not partial and param.kind != _VAR_POSITIONAL
+ and param.default is _empty):
+ raise TypeError('{arg!r} parameter lacking default value'.
+ format(arg=param_name))
+
+ else:
+ arguments[param_name] = arg_val
+
+ if kwargs:
+ if kwargs_param is not None:
+ # Process our '**kwargs'-like parameter
+ arguments[kwargs_param.name] = kwargs
+ else:
+ raise TypeError('too many keyword arguments')
+
+ return self._bound_arguments_cls(self, arguments)
+
+ def bind(self, *args, **kwargs):
+ '''Get a BoundArguments object, that maps the passed `args`
+ and `kwargs` to the function's signature. Raises `TypeError`
+ if the passed arguments can not be bound.
+ '''
+ return self._bind(args, kwargs)
+
+ def bind_partial(self, *args, **kwargs):
+ '''Get a BoundArguments object, that partially maps the
+ passed `args` and `kwargs` to the function's signature.
+ Raises `TypeError` if the passed arguments can not be bound.
+ '''
+ return self._bind(args, kwargs, partial=True)
+
+ def __str__(self):
+ result = []
+ render_kw_only_separator = True
+ for idx, param in enumerate(self.parameters.values()):
+ formatted = str(param)
+
+ kind = param.kind
+ if kind == _VAR_POSITIONAL:
+ # OK, we have an '*args'-like parameter, so we won't need
+ # a '*' to separate keyword-only arguments
+ render_kw_only_separator = False
+ elif kind == _KEYWORD_ONLY and render_kw_only_separator:
+ # We have a keyword-only parameter to render and we haven't
+ # rendered an '*args'-like parameter before, so add a '*'
+ # separator to the parameters list ("foo(arg1, *, arg2)" case)
+ result.append('*')
+ # This condition should be only triggered once, so
+ # reset the flag
+ render_kw_only_separator = False
+
+ result.append(formatted)
+
+ rendered = '({0})'.format(', '.join(result))
+
+ if self.return_annotation is not _empty:
+ anno = formatannotation(self.return_annotation)
+ rendered += ' -> {0}'.format(anno)
+
+ return rendered
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
new file mode 100644
index 0000000..1ab95bb
--- /dev/null
+++ b/ot/gpu/__init__.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+"""
+
+This module provides GPU implementation for several OT solvers and utility
+functions. The GPU backend in handled by `cupy
+<https://cupy.chainer.org/>`_.
+
+.. warning::
+ Note that by default the module is not import in :mod:`ot`. In order to
+ use it you need to explicitely import :mod:`ot.gpu` .
+
+By default, the functions in this module accept and return numpy arrays
+in order to proide drop-in replacement for the other POT function but
+the transfer between CPU en GPU comes with a significant overhead.
+
+In order to get the best performances, we recommend to give only cupy
+arrays to the functions and desactivate the conversion to numpy of the
+result of the function with parameter ``to_numpy=False``.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
+from . import bregman
+from . import da
+from .bregman import sinkhorn
+from .da import sinkhorn_lpl1_mm
+
+from . import utils
+from .utils import dist, to_gpu, to_np
+
+
+
+
+
+__all__ = ["utils", "dist", "sinkhorn",
+ "sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np']
+
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
new file mode 100644
index 0000000..2e2df83
--- /dev/null
+++ b/ot/gpu/bregman.py
@@ -0,0 +1,194 @@
+# -*- coding: utf-8 -*-
+"""
+Bregman projections for regularized OT with GPU
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
+import cupy as np # np used for matrix computation
+import cupy as cp # cp used for cupy specific operations
+from . import utils
+
+
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, to_numpy=True, **kwargs):
+ """
+ Solve the entropic regularization optimal transport on GPU
+
+ If the input matrix are in numpy format, they will be uploaded to the
+ GPU first which can incur significant time overhead.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ to_numpy : boolean, optional (default True)
+ If true convert back the GPU array result to numpy format.
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = cp.asarray(a)
+ b = cp.asarray(b)
+ M = cp.asarray(M)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0],)) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],)) / M.shape[1]
+
+ # init data
+ Nini = len(a)
+ Nfin = len(b)
+
+ if len(b.shape) > 1:
+ nbb = b.shape[1]
+ else:
+ nbb = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if nbb:
+ u = np.ones((Nini, nbb)) / Nini
+ v = np.ones((Nfin, nbb)) / Nfin
+ else:
+ u = np.ones(Nini) / Nini
+ v = np.ones(Nfin) / Nfin
+
+ # print(reg)
+
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ # print(np.min(K))
+ tmp2 = np.empty(b.shape, dtype=M.dtype)
+
+ Kp = (1 / a).reshape(-1, 1) * K
+ cpt = 0
+ err = 1
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ KtransposeU = np.dot(K.T, u)
+ v = np.divide(b, KtransposeU)
+ u = 1. / np.dot(Kp, v)
+
+ if (np.any(KtransposeU == 0) or
+ np.any(np.isnan(u)) or np.any(np.isnan(v)) or
+ np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ print('Warning: numerical errors at iteration', cpt)
+ u = uprev
+ v = vprev
+ break
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ if nbb:
+ err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
+ np.sum((v - vprev)**2) / np.sum((v)**2)
+ else:
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ tmp2 = np.sum(u[:, None] * K * v[None, :], 0)
+ #tmp2=np.einsum('i,ij,j->j', u, K, v)
+ err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+ if log:
+ log['u'] = u
+ log['v'] = v
+
+ if nbb: # return only loss
+ #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory)
+ res = np.empty(nbb)
+ for i in range(nbb):
+ res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i])
+ if to_numpy:
+ res = utils.to_np(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+ res = u.reshape((-1, 1)) * K * v.reshape((1, -1))
+ if to_numpy:
+ res = utils.to_np(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+
+# define sinkhorn as sinkhorn_knopp
+sinkhorn = sinkhorn_knopp
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
new file mode 100644
index 0000000..4a98038
--- /dev/null
+++ b/ot/gpu/da.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+"""
+Domain adaptation with optimal transport with GPU implementation
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
+
+import cupy as np # np used for matrix computation
+import cupy as cp # cp used for cupy specific operations
+import numpy as npp
+from . import utils
+
+from .bregman import sinkhorn
+
+
+def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
+ numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
+ log=False, to_numpy=True):
+ """
+ Solve the entropic regularization optimal transport problem with nonconvex
+ group lasso regularization on GPU
+
+ If the input matrix are in numpy format, they will be uploaded to the
+ GPU first which can incur significant time overhead.
+
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
+ + \eta \Omega_g(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega_e` is the entropic regularization term
+ :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega_g` is the group lasso regulaization term
+ :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
+ where :math:`\mathcal{I}_c` are the index of samples from class c
+ in the source domain.
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the generalised conditional
+ gradient as proposed in [5]_ [7]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ labels_a : np.ndarray (ns,)
+ labels of samples in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term for entropic regularization >0
+ eta : float, optional
+ Regularization term for group lasso regularization >0
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations (inner sinkhorn solver)
+ stopInnerThr : float, optional
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ to_numpy : boolean, optional (default True)
+ If true convert back the GPU array result to numpy format.
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M)
+
+ p = 0.5
+ epsilon = 1e-3
+
+ indices_labels = []
+ labels_a2 = cp.asnumpy(labels_a)
+ classes = npp.unique(labels_a2)
+ for c in classes:
+ idxc, = utils.to_gpu(npp.where(labels_a2 == c))
+ indices_labels.append(idxc)
+
+ W = np.zeros(M.shape)
+
+ for cpt in range(numItermax):
+ Mreg = M + eta * W
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr, to_numpy=False)
+ # the transport has been computed. Check if classes are really
+ # separated
+ W = np.ones(M.shape)
+ for (i, c) in enumerate(classes):
+
+ majs = np.sum(transp[indices_labels[i]], axis=0)
+ majs = p * ((majs + epsilon)**(p - 1))
+ W[indices_labels[i]] = majs
+
+ if to_numpy:
+ return utils.to_np(transp)
+ else:
+ return transp
diff --git a/ot/gpu/utils.py b/ot/gpu/utils.py
new file mode 100644
index 0000000..41e168a
--- /dev/null
+++ b/ot/gpu/utils.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+"""
+Utility functions for GPU
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
+import cupy as np # np used for matrix computation
+import cupy as cp # cp used for cupy specific operations
+
+
+def euclidean_distances(a, b, squared=False, to_numpy=True):
+ """
+ Compute the pairwise euclidean distance between matrices a and b.
+
+ If the input matrix are in numpy format, they will be uploaded to the
+ GPU first which can incur significant time overhead.
+
+ Parameters
+ ----------
+ a : np.ndarray (n, f)
+ first matrix
+ b : np.ndarray (m, f)
+ second matrix
+ to_numpy : boolean, optional (default True)
+ If true convert back the GPU array result to numpy format.
+ squared : boolean, optional (default False)
+ if True, return squared euclidean distance matrix
+
+ Returns
+ -------
+ c : (n x m) np.ndarray or cupy.ndarray
+ pairwise euclidean distance distance matrix
+ """
+
+ a, b = to_gpu(a, b)
+
+ a2 = np.sum(np.square(a), 1)
+ b2 = np.sum(np.square(b), 1)
+
+ c = -2 * np.dot(a, b.T)
+ c += a2[:, None]
+ c += b2[None, :]
+
+ if not squared:
+ np.sqrt(c, out=c)
+ if to_numpy:
+ return to_np(c)
+ else:
+ return c
+
+
+def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True):
+ """Compute distance between samples in x1 and x2 on gpu
+
+ Parameters
+ ----------
+
+ x1 : np.array (n1,d)
+ matrix with n1 samples of size d
+ x2 : np.array (n2,d), optional
+ matrix with n2 samples of size d (if None then x2=x1)
+ metric : str
+ Metric from 'sqeuclidean', 'euclidean',
+
+
+ Returns
+ -------
+
+ M : np.array (n1,n2)
+ distance matrix computed with given metric
+
+ """
+ if x2 is None:
+ x2 = x1
+ if metric == "sqeuclidean":
+ return euclidean_distances(x1, x2, squared=True, to_numpy=to_numpy)
+ elif metric == "euclidean":
+ return euclidean_distances(x1, x2, squared=False, to_numpy=to_numpy)
+ else:
+ raise NotImplementedError
+
+
+def to_gpu(*args):
+ """ Upload numpy arrays to GPU and return them"""
+ if len(args) > 1:
+ return (cp.asarray(x) for x in args)
+ else:
+ return cp.asarray(args[0])
+
+
+def to_np(*args):
+ """ convert GPU arras to numpy and return them"""
+ if len(args) > 1:
+ return (cp.asnumpy(x) for x in args)
+ else:
+ return cp.asnumpy(args[0])
diff --git a/ot/gromov.py b/ot/gromov.py
new file mode 100644
index 0000000..699ae4c
--- /dev/null
+++ b/ot/gromov.py
@@ -0,0 +1,1179 @@
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein transport method
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from .bregman import sinkhorn
+from .utils import dist, UndefinedParameter
+from .optim import cg
+
+
+def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
+ """Return loss matrices and tensors for Gromov-Wasserstein fast computation
+
+ Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
+ function as the loss function of Gromow-Wasserstein discrepancy.
+
+ The matrices are computed as described in Proposition 1 in [12]
+
+ Where :
+ * C1 : Metric cost matrix in the source space
+ * C2 : Metric cost matrix in the target space
+ * T : A coupling between those two spaces
+
+ The square-loss function L(a,b)=|a-b|^2 is read as :
+ L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
+ * f1(a)=(a^2)
+ * f2(b)=(b^2)
+ * h1(a)=a
+ * h2(b)=2*b
+
+ The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
+ L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
+ * f1(a)=a*log(a)-a
+ * f2(b)=b
+ * h1(a)=a
+ * h2(b)=log(b)
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ T : ndarray, shape (ns, nt)
+ Coupling between source and target spaces
+ p : ndarray, shape (ns,)
+
+ Returns
+ -------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ if loss_fun == 'square_loss':
+ def f1(a):
+ return (a**2)
+
+ def f2(b):
+ return (b**2)
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return 2 * b
+ elif loss_fun == 'kl_loss':
+ def f1(a):
+ return a * np.log(a + 1e-15) - a
+
+ def f2(b):
+ return b
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return np.log(b + 1e-15)
+
+ constC1 = np.dot(np.dot(f1(C1), p.reshape(-1, 1)),
+ np.ones(len(q)).reshape(1, -1))
+ constC2 = np.dot(np.ones(len(p)).reshape(-1, 1),
+ np.dot(q.reshape(1, -1), f2(C2).T))
+ constC = constC1 + constC2
+ hC1 = h1(C1)
+ hC2 = h2(C2)
+
+ return constC, hC1, hC2
+
+
+def tensor_product(constC, hC1, hC2, T):
+ """Return the tensor for Gromov-Wasserstein fast computation
+
+ The tensor is computed as described in Proposition 1 Eq. (6) in [12].
+
+ Parameters
+ ----------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+
+ Returns
+ -------
+ tens : ndarray, shape (ns, nt)
+ \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ A = -np.dot(hC1, T).dot(hC2.T)
+ tens = constC + A
+ # tens -= tens.min()
+ return tens
+
+
+def gwloss(constC, hC1, hC2, T):
+ """Return the Loss for Gromov-Wasserstein
+
+ The loss is computed as described in Proposition 1 Eq. (6) in [12].
+
+ Parameters
+ ----------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+ T : ndarray, shape (ns, nt)
+ Current value of transport matrix T
+
+ Returns
+ -------
+ loss : float
+ Gromov Wasserstein loss
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ tens = tensor_product(constC, hC1, hC2, T)
+
+ return np.sum(tens * T)
+
+
+def gwggrad(constC, hC1, hC2, T):
+ """Return the gradient for Gromov-Wasserstein
+
+ The gradient is computed as described in Proposition 2 in [12].
+
+ Parameters
+ ----------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+ T : ndarray, shape (ns, nt)
+ Current value of transport matrix T
+
+ Returns
+ -------
+ grad : ndarray, shape (ns, nt)
+ Gromov Wasserstein gradient
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ return 2 * tensor_product(constC, hC1, hC2,
+ T) # [12] Prop. 2 misses a 2 factor
+
+
+def update_square_loss(p, lambdas, T, Cs):
+ """
+ Updates C according to the L2 Loss kernel with the S Ts couplings
+ calculated at each iteration
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ Masses in the targeted barycenter.
+ lambdas : list of float
+ List of the S spaces' weights.
+ T : list of S np.ndarray of shape (ns,N)
+ The S Ts couplings calculated at each iteration.
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices.
+
+ Returns
+ ----------
+ C : ndarray, shape (nt, nt)
+ Updated C matrix.
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
+ for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.divide(tmpsum, ppt)
+
+
+def update_kl_loss(p, lambdas, T, Cs):
+ """
+ Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration
+
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ Weights in the targeted barycenter.
+ lambdas : list of the S spaces' weights
+ T : list of S np.ndarray of shape (ns,N)
+ The S Ts couplings calculated at each iteration.
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices.
+
+ Returns
+ ----------
+ C : ndarray, shape (ns,ns)
+ updated C matrix
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
+ for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.exp(np.divide(tmpsum, ppt))
+
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
+ """
+ Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ Where :
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+ - H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ T : ndarray, shape (ns, nt)
+ Doupling between the two spaces that minimizes:
+ \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ log : dict
+ Convergence information and loss.
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ if log:
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ log['gw_dist'] = gwloss(constC, hC1, hC2, res)
+ return res, log
+ else:
+ return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW transport between two graphs see [24]
+
+ .. math::
+ \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
+ L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as discussed in [24]_
+
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str, optional
+ Loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gamma : ndarray, shape (ns, nt)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ if log:
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ log['fgw_dist'] = log['loss'][::-1][0]
+ return res, log
+ else:
+ return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW distance between two graphs see [24]
+
+ .. math::
+ \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
+ L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space.
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space.
+ p : ndarray, shape (ns,)
+ Distribution in the source space.
+ q : ndarray, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str, optional
+ Loss function used for the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research.
+ Else closed form is used. If there is convergence issues use False.
+ **kwargs : dict
+ Parameters can be directly pased to the ot.optim.cg solver.
+
+ Returns
+ -------
+ gamma : ndarray, shape (ns, nt)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ if log:
+ log['fgw_dist'] = log['loss'][::-1][0]
+ log['T'] = res
+ return log['fgw_dist'], log
+ else:
+ return log['fgw_dist']
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
+ """
+ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ Where :
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+ - H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space.
+ q : ndarray, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+ log : dict
+ convergence information and Coupling marix
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ log['gw_dist'] = gwloss(constC, hC1, hC2, res)
+ log['T'] = res
+ if log:
+ return log['gw_dist'], log
+ else:
+ return log['gw_dist']
+
+
+def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ """
+ Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+
+ (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ s.t. T 1 = p
+
+ T^T 1= q
+
+ T\geq 0
+
+ Where :
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+ - H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : string
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+
+ Returns
+ -------
+ T : ndarray, shape (ns, nt)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+
+ T = np.outer(p, q) # Initialization
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ cpt = 0
+ err = 1
+
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < max_iter):
+
+ Tprev = T
+
+ # compute the gradient
+ tens = gwggrad(constC, hC1, hC2, T)
+
+ T = sinkhorn(p, q, tens, epsilon)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(T - Tprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ log['gw_dist'] = gwloss(constC, hC1, hC2, T)
+ return T, log
+ else:
+ return T
+
+
+def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ """
+ Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices
+
+ (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ Where :
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+ - H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ gw, logv = entropic_gromov_wasserstein(
+ C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
+
+ logv['T'] = gw
+
+ if log:
+ return logv['gw_dist'], logv
+ else:
+ return logv['gw_dist']
+
+
+def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
+ """
+ Returns the gromov-wasserstein barycenters of S measured similarity matrices
+
+ (Cs)_{s=1}^{s=S}
+
+ The function solves the following optimization problem:
+
+ .. math::
+ C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s)
+
+
+ Where :
+
+ - :math:`C_s` : metric cost matrix
+ - :math:`p_s` : distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray of shape (ns,ns)
+ Metric cost matrices
+ ps : list of S np.ndarray of shape (ns,)
+ Sample weights in the S spaces
+ p : ndarray, shape(N,)
+ Weights in the targeted barycenter
+ lambdas : list of float
+ List of the S spaces' weights.
+ loss_fun : callable
+ Tensor-matrix multiplication function based on specific loss function.
+ update : callable
+ function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
+ epsilon : float
+ Regularization term >0
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | ndarray, shape (N, N)
+ Random initial value for the C matrix provided by user.
+
+ Returns
+ -------
+ C : ndarray, shape (N, N)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+ """
+
+ S = len(Cs)
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ # XXX use random state
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ else:
+ C = init_C
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while (err > tol) and (cpt < max_iter):
+ Cprev = C
+
+ T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
+ max_iter, 1e-5, verbose, log) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(C - Cprev)
+ error.append(err)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ return C
+
+
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
+ """
+ Returns the gromov-wasserstein barycenters of S measured similarity matrices
+
+ (Cs)_{s=1}^{s=S}
+
+ The function solves the following optimization problem with block
+ coordinate descent:
+
+ .. math::
+ C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+
+ Where :
+
+ - Cs : metric cost matrix
+ - ps : distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray of shape (ns, ns)
+ Metric cost matrices
+ ps : list of S np.ndarray of shape (ns,)
+ Sample weights in the S spaces
+ p : ndarray, shape (N,)
+ Weights in the targeted barycenter
+ lambdas : list of float
+ List of the S spaces' weights
+ loss_fun : tensor-matrix multiplication function based on specific loss function
+ update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0).
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | ndarray, shape(N,N)
+ Random initial value for the C matrix provided by user.
+
+ Returns
+ -------
+ C : ndarray, shape (N, N)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ S = len(Cs)
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ # XXX : should use a random state and not use the global seed
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ else:
+ C = init_C
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while(err > tol and cpt < max_iter):
+ Cprev = C
+
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(C - Cprev)
+ error.append(err)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ return C
+
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
+ verbose=False, log=False, init_C=None, init_X=None):
+ """Compute the fgw barycenter as presented eq (5) in [24].
+
+ Parameters
+ ----------
+ N : integer
+ Desired number of samples of the target barycenter
+ Ys: list of ndarray, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of ndarray, each element has shape (ns,ns)
+ Structure matrices of all samples
+ ps : list of ndarray, each element has shape (ns,)
+ Masses of all samples.
+ lambdas : list of float
+ List of the S spaces' weights
+ alpha : float
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Whether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Whether to fix the feature of the barycenter during the updates
+ init_C : ndarray, shape (N,N), optional
+ Initialization for the barycenters' structure matrix. If not set
+ a random init is used.
+ init_X : ndarray, shape (N,d), optional
+ Initialization for the barycenters' features. If not set a
+ random init is used.
+
+ Returns
+ -------
+ X : ndarray, shape (N, d)
+ Barycenters' features
+ C : ndarray, shape (N, N)
+ Barycenters' structure matrix
+ log_: dict
+ Only returned when log=True. It contains the keys:
+ T : list of (N,ns) transport matrices
+ Ms : all distance matrices between the feature of the barycenter and the
+ other features dist(X,Ys) shape (N,ns)
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ S = len(Cs)
+ d = Ys[0].shape[1] # dimension on the node features
+ if p is None:
+ p = np.ones(N) / N
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
+
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ if fixed_structure:
+ if init_C is None:
+ raise UndefinedParameter('If C is fixed it must be initialized')
+ else:
+ C = init_C
+ else:
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ else:
+ C = init_C
+
+ if fixed_features:
+ if init_X is None:
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
+ else:
+ if init_X is None:
+ X = np.zeros((N, d))
+ else:
+ X = init_X
+
+ T = [np.outer(p, q) for q in ps]
+
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
+
+ cpt = 0
+ err_feature = 1
+ err_structure = 1
+
+ if log:
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
+
+ while((err_feature > tol or err_structure > tol) and cpt < max_iter):
+ Cprev = C
+ Xprev = X
+
+ if not fixed_features:
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
+
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+
+ if not fixed_structure:
+ if loss_fun == 'square_loss':
+ T_temp = [t.T for t in T]
+ C = update_sructure_matrix(p, lambdas, T_temp, Cs)
+
+ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+
+ # T is N,ns
+ err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
+ err_structure = np.linalg.norm(C - Cprev)
+
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms
+
+ if log:
+ return X, C, log_
+ else:
+ return X, C
+
+
+def update_sructure_matrix(p, lambdas, T, Cs):
+ """Updates C according to the L2 Loss kernel with the S Ts couplings.
+
+ It is calculated at each iteration
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ Masses in the targeted barycenter.
+ lambdas : list of float
+ List of the S spaces' weights.
+ T : list of S ndarray of shape (ns, N)
+ The S Ts couplings calculated at each iteration.
+ Cs : list of S ndarray, shape (ns, ns)
+ Metric cost matrices.
+
+ Returns
+ -------
+ C : ndarray, shape (nt, nt)
+ Updated C matrix.
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.divide(tmpsum, ppt)
+
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ """Updates the feature with respect to the S Ts couplings.
+
+
+ See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
+ in [24] calculated at each iteration
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ List of the S spaces' weights
+ Ts : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Ys : list of S ndarray, shape(d,ns)
+ The features.
+
+ Returns
+ -------
+ X : ndarray, shape (d, N)
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ p = np.array(1. / p).reshape(-1,)
+
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))])
+
+ return tmpsum
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
new file mode 100644
index 0000000..f42e222
--- /dev/null
+++ b/ot/lp/EMD.h
@@ -0,0 +1,35 @@
+/* This file is a c++ wrapper function for computing the transportation cost
+ * between two vectors given a cost matrix.
+ *
+ * It was written by Antoine Rolet (2014) and mainly consists of a wrapper
+ * of the code written by Nicolas Bonneel available on this page
+ * http://people.seas.harvard.edu/~nbonneel/FastTransport/
+ *
+ * It was then modified to make it more amenable to python inline calling
+ *
+ * Please give relevant credit to the original author (Nicolas Bonneel) if
+ * you use this code for a publication.
+ *
+ */
+
+
+#ifndef EMD_H
+#define EMD_H
+
+#include <iostream>
+#include <vector>
+#include "network_simplex_simple.h"
+
+using namespace lemon;
+typedef unsigned int node_id_type;
+
+enum ProblemType {
+ INFEASIBLE,
+ OPTIMAL,
+ UNBOUNDED,
+ MAX_ITER_REACHED
+};
+
+int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
+
+#endif
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
new file mode 100644
index 0000000..fc7ca63
--- /dev/null
+++ b/ot/lp/EMD_wrapper.cpp
@@ -0,0 +1,107 @@
+/* This file is a c++ wrapper function for computing the transportation cost
+ * between two vectors given a cost matrix.
+ *
+ * It was written by Antoine Rolet (2014) and mainly consists of a wrapper
+ * of the code written by Nicolas Bonneel available on this page
+ * http://people.seas.harvard.edu/~nbonneel/FastTransport/
+ *
+ * It was then modified to make it more amenable to python inline calling
+ *
+ * Please give relevant credit to the original author (Nicolas Bonneel) if
+ * you use this code for a publication.
+ *
+ */
+
+#include "EMD.h"
+
+
+int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
+ double* alpha, double* beta, double *cost, int maxIter) {
+// beware M and C anre strored in row major C style!!!
+ int n, m, i, cur;
+
+ typedef FullBipartiteDigraph Digraph;
+ DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
+
+ // Get the number of non zero coordinates for r and c
+ n=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ n++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+ m=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ m++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+
+ // Define the graph
+
+ std::vector<int> indI(n), indJ(m);
+ std::vector<double> weights1(n), weights2(m);
+ Digraph di(n, m);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
+
+ // Set supply and demand, don't account for 0 values (faster)
+
+ cur=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ weights1[ cur ] = val;
+ indI[cur++]=i;
+ }
+ }
+
+ // Demand is actually negative supply...
+
+ cur=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ weights2[ cur ] = -val;
+ indJ[cur++]=i;
+ }
+ }
+
+
+ net.supplyMap(&weights1[0], n, &weights2[0], m);
+
+ // Set the cost of each edge
+ for (int i=0; i<n; i++) {
+ for (int j=0; j<m; j++) {
+ double val=*(D+indI[i]*n2+indJ[j]);
+ net.setCost(di.arcFromId(i*m+j), val);
+ }
+ }
+
+
+ // Solve the problem with the network simplex algorithm
+
+ int ret=net.run();
+ if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
+ *cost = 0;
+ Arc a; di.first(a);
+ for (; a != INVALID; di.next(a)) {
+ int i = di.source(a);
+ int j = di.target(a);
+ double flow = net.flow(a);
+ *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
+ *(G+indI[i]*n2+indJ[j-n]) = flow;
+ *(alpha + indI[i]) = -net.potential(i);
+ *(beta + indJ[j-n]) = net.potential(j);
+ }
+
+ }
+
+
+ return ret;
+}
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
new file mode 100644
index 0000000..0c92810
--- /dev/null
+++ b/ot/lp/__init__.py
@@ -0,0 +1,618 @@
+# -*- coding: utf-8 -*-
+"""
+Solvers for the original linear program OT problem
+
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import multiprocessing
+import sys
+import numpy as np
+from scipy.sparse import coo_matrix
+
+from .import cvx
+
+# import compiled emd
+from .emd_wrap import emd_c, check_result, emd_1d_sorted
+from ..utils import parmap
+from .cvx import barycenter
+from ..utils import dist
+
+__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d']
+
+
+def emd(a, b, M, numItermax=100000, log=False):
+ r"""Solves the Earth Movers distance problem and returns the OT matrix
+
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
+ s.t. \gamma 1 = a
+ \gamma^T 1= b
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - a and b are the sample weights
+
+ .. warning::
+ Note that the M matrix needs to be a C-order numpy.array in float64
+ format.
+
+ Uses the algorithm proposed in [1]_
+
+ Parameters
+ ----------
+ a : (ns,) numpy.ndarray, float64
+ Source histogram (uniform weight if empty list)
+ b : (nt,) numpy.ndarray, float64
+ Target histogram (uniform weight if empty list)
+ M : (ns,nt) numpy.ndarray, float64
+ Loss matrix (c-order array with type float64)
+ numItermax : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
+ log: bool, optional (default=False)
+ If True, returns a dictionary containing the cost and dual
+ variables. Otherwise returns only the optimal transportation matrix.
+
+ Returns
+ -------
+ gamma: (ns x nt) numpy.ndarray
+ Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is true, a dictionary containing the cost and dual
+ variables and exit status
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd accepts lists and
+ perform automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.emd(a,b,M)
+ array([[0.5, 0. ],
+ [0. , 0.5]])
+
+ References
+ ----------
+
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
+ (2011, December). Displacement interpolation using Lagrangian mass
+ transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
+ 158). ACM.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT"""
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ # if empty array given then use uniform distributions
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ result_code_string = check_result(result_code)
+ if log:
+ log = {}
+ log['cost'] = cost
+ log['u'] = u
+ log['v'] = v
+ log['warning'] = result_code_string
+ log['result_code'] = result_code
+ return G, log
+ return G
+
+
+def emd2(a, b, M, processes=multiprocessing.cpu_count(),
+ numItermax=100000, log=False, return_matrix=False):
+ r"""Solves the Earth Movers distance problem and returns the loss
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
+ s.t. \gamma 1 = a
+ \gamma^T 1= b
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - a and b are the sample weights
+
+ .. warning::
+ Note that the M matrix needs to be a C-order numpy.array in float64
+ format.
+
+ Uses the algorithm proposed in [1]_
+
+ Parameters
+ ----------
+ a : (ns,) numpy.ndarray, float64
+ Source histogram (uniform weight if empty list)
+ b : (nt,) numpy.ndarray, float64
+ Target histogram (uniform weight if empty list)
+ M : (ns,nt) numpy.ndarray, float64
+ Loss matrix (c-order array with type float64)
+ processes : int, optional (default=nb cpu)
+ Nb of processes used for multiple emd computation (not used on windows)
+ numItermax : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost and dual
+ variables. Otherwise returns only the optimal transportation cost.
+ return_matrix: boolean, optional (default=False)
+ If True, returns the optimal transportation matrix in the log.
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log: dictnp
+ If input log is true, a dictionary containing the cost and dual
+ variables and exit status
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd accepts lists and
+ perform automatic conversion to numpy arrays
+
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.emd2(a,b,M)
+ 0.0
+
+ References
+ ----------
+
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
+ (2011, December). Displacement interpolation using Lagrangian mass
+ transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
+ 158). ACM.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT"""
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ # problem with pikling Forks
+ if sys.platform.endswith('win32'):
+ processes=1
+
+ # if empty array given then use uniform distributions
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ if log or return_matrix:
+ def f(b):
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ result_code_string = check_result(resultCode)
+ log = {}
+ if return_matrix:
+ log['G'] = G
+ log['u'] = u
+ log['v'] = v
+ log['warning'] = result_code_string
+ log['result_code'] = resultCode
+ return [cost, log]
+ else:
+ def f(b):
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ check_result(result_code)
+ return cost
+
+ if len(b.shape) == 1:
+ return f(b)
+ nb = b.shape[1]
+
+ if processes>1:
+ res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ else:
+ res = list(map(f, [b[:, i].copy() for i in range(nb)]))
+
+ return res
+
+
+
+def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
+ """
+ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
+
+ The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
+ This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
+
+ Parameters
+ ----------
+ measures_locations : list of (k_i,d) numpy.ndarray
+ The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
+ measures_weights : list of (k_i,) numpy.ndarray
+ Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
+
+ X_init : (k,d) np.ndarray
+ Initialization of the support locations (on k atoms) of the barycenter
+ b : (k,) np.ndarray
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ weights : (k,) np.ndarray
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) np.ndarray
+ Support locations (on k atoms) of the barycenter
+
+ References
+ ----------
+
+ .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+
+ iter_count = 0
+
+ N = len(measures_locations)
+ k = X_init.shape[0]
+ d = X_init.shape[1]
+ if b is None:
+ b = np.ones((k,))/k
+ if weights is None:
+ weights = np.ones((N,)) / N
+
+ X = X_init
+
+ log_dict = {}
+ displacement_square_norms = []
+
+ displacement_square_norm = stopThr + 1.
+
+ while ( displacement_square_norm > stopThr and iter_count < numItermax ):
+
+ T_sum = np.zeros((k, d))
+
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
+
+ M_i = dist(X, measure_locations_i)
+ T_i = emd(b, measure_weights_i, M_i)
+ T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
+
+ displacement_square_norm = np.sum(np.square(T_sum-X))
+ if log:
+ displacement_square_norms.append(displacement_square_norm)
+
+ X = T_sum
+
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
+
+ iter_count += 1
+
+ if log:
+ log_dict['displacement_square_norms'] = displacement_square_norms
+ return X, log_dict
+ else:
+ return X
+
+
+def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the OT matrix
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost.
+ Otherwise returns only the optimal transportation matrix.
+
+ Returns
+ -------
+ gamma: (ns, nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is True, a dictionary containing the cost
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd_1d(x_a, x_b, a, b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+ >>> ot.emd_1d(x_a, x_b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd : EMD for multidimensional distributions
+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
+ transportation matrix)
+ """
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ x_a = np.asarray(x_a, dtype=np.float64)
+ x_b = np.asarray(x_b, dtype=np.float64)
+
+ assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+ assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+
+ # if empty array given then use uniform distributions
+ if a.ndim == 0 or len(a) == 0:
+ a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
+ if b.ndim == 0 or len(b) == 0:
+ b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
+
+ x_a_1d = x_a.reshape((-1, ))
+ x_b_1d = x_b.reshape((-1, ))
+ perm_a = np.argsort(x_a_1d)
+ perm_b = np.argsort(x_b_1d)
+
+ G_sorted, indices, cost = emd_1d_sorted(a, b,
+ x_a_1d[perm_a], x_b_1d[perm_b],
+ metric=metric, p=p)
+ G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
+ shape=(a.shape[0], b.shape[0]))
+ if dense:
+ G = G.toarray()
+ if log:
+ log = {'cost': cost}
+ return G, log
+ return G
+
+
+def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the loss
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Only used if log is set to True. Due to implementation details,
+ this function runs faster when dense is set to False.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the transportation matrix.
+ Otherwise returns only the loss.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict
+ If input log is True, a dictionary containing the Optimal transportation
+ matrix for the given parameters
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd2_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd2_1d(x_a, x_b, a, b)
+ 0.5
+ >>> ot.emd2_1d(x_a, x_b)
+ 0.5
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd2 : EMD for multidimensional distributions
+ ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
+ instead of the cost)
+ """
+ # If we do not return G (log==False), then we should not to cast it to dense
+ # (useless overhead)
+ G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
+ dense=dense and log, log=True)
+ cost = log_emd['cost']
+ if log:
+ log_emd = {'G': G}
+ return cost, log_emd
+ return cost
+
+
+def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.):
+ r"""Solves the p-Wasserstein distance problem between 1d measures and returns
+ the distance
+
+ .. math::
+ \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p}
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+
+ where :
+
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ p: float, optional (default=1.0)
+ The order of the p-Wasserstein distance to be computed
+
+ Returns
+ -------
+ dist: float
+ p-Wasserstein distance
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function wasserstein_1d accepts
+ lists and performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.wasserstein_1d(x_a, x_b, a, b)
+ 0.5
+ >>> ot.wasserstein_1d(x_a, x_b)
+ 0.5
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd_1d : EMD for 1d distributions
+ """
+ cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
+ dense=False, log=False)
+ return np.power(cost_emd, 1. / p)
diff --git a/ot/lp/core.h b/ot/lp/core.h
new file mode 100644
index 0000000..04dddf7
--- /dev/null
+++ b/ot/lp/core.h
@@ -0,0 +1,103 @@
+/* -*- mode: C++; indent-tabs-mode: nil; -*-
+ *
+ * This file has been adapted by Nicolas Bonneel (2013),
+ * from full_graph.h from LEMON, a generic C++ optimization library,
+ * to make the other files independant from the rest of
+ * the original library.
+ *
+ *
+ **** Original file Copyright Notice :
+ * Copyright (C) 2003-2010
+ * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
+ * (Egervary Research Group on Combinatorial Optimization, EGRES).
+ *
+ * Permission to use, modify and distribute this software is granted
+ * provided that this copyright notice appears in all copies. For
+ * precise terms see the accompanying LICENSE file.
+ *
+ * This software is provided "AS IS" with no warranty of any kind,
+ * express or implied, and with no claim as to its suitability for any
+ * purpose.
+ *
+ */
+
+#ifndef LEMON_CORE_H
+#define LEMON_CORE_H
+
+#include <vector>
+#include <algorithm>
+
+
+// Disable the following warnings when compiling with MSVC:
+// C4250: 'class1' : inherits 'class2::member' via dominance
+// C4355: 'this' : used in base member initializer list
+// C4503: 'function' : decorated name length exceeded, name was truncated
+// C4800: 'type' : forcing value to bool 'true' or 'false' (performance warning)
+// C4996: 'function': was declared deprecated
+#ifdef _MSC_VER
+#pragma warning( disable : 4250 4355 4503 4800 4996 )
+#endif
+
+///\file
+///\brief LEMON core utilities.
+///
+///This header file contains core utilities for LEMON.
+///It is automatically included by all graph types, therefore it usually
+///do not have to be included directly.
+
+namespace lemon {
+
+ /// \brief Dummy type to make it easier to create invalid iterators.
+ ///
+ /// Dummy type to make it easier to create invalid iterators.
+ /// See \ref INVALID for the usage.
+ struct Invalid {
+ public:
+ bool operator==(Invalid) { return true; }
+ bool operator!=(Invalid) { return false; }
+ bool operator< (Invalid) { return false; }
+ };
+
+ /// \brief Invalid iterators.
+ ///
+ /// \ref Invalid is a global type that converts to each iterator
+ /// in such a way that the value of the target iterator will be invalid.
+#ifdef LEMON_ONLY_TEMPLATES
+ const Invalid INVALID = Invalid();
+#else
+ extern const Invalid INVALID;
+#endif
+
+ /// \addtogroup gutils
+ /// @{
+
+ ///Create convenience typedefs for the digraph types and iterators
+
+ ///This \c \#define creates convenient type definitions for the following
+ ///types of \c Digraph: \c Node, \c NodeIt, \c Arc, \c ArcIt, \c InArcIt,
+ ///\c OutArcIt, \c BoolNodeMap, \c IntNodeMap, \c DoubleNodeMap,
+ ///\c BoolArcMap, \c IntArcMap, \c DoubleArcMap.
+ ///
+ ///\note If the graph type is a dependent type, ie. the graph type depend
+ ///on a template parameter, then use \c TEMPLATE_DIGRAPH_TYPEDEFS()
+ ///macro.
+#define DIGRAPH_TYPEDEFS(Digraph) \
+ typedef Digraph::Node Node; \
+ typedef Digraph::Arc Arc; \
+
+
+ ///Create convenience typedefs for the digraph types and iterators
+
+ ///\see DIGRAPH_TYPEDEFS
+ ///
+ ///\note Use this macro, if the graph type is a dependent type,
+ ///ie. the graph type depend on a template parameter.
+#define TEMPLATE_DIGRAPH_TYPEDEFS(Digraph) \
+ typedef typename Digraph::Node Node; \
+ typedef typename Digraph::Arc Arc; \
+
+
+
+} //namespace lemon
+
+#endif
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
new file mode 100644
index 0000000..8e763be
--- /dev/null
+++ b/ot/lp/cvx.py
@@ -0,0 +1,147 @@
+# -*- coding: utf-8 -*-
+"""
+LP solvers for optimal transport using cvxopt
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import scipy as sp
+import scipy.sparse as sps
+
+
+try:
+ import cvxopt
+ from cvxopt import solvers, matrix, spmatrix
+except ImportError:
+ cvxopt = False
+
+
+def scipy_sparse_to_spmatrix(A):
+ """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
+ coo = A.tocoo()
+ SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
+ return SP
+
+
+def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
+ """Compute the Wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem [16]:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+
+ The linear program is solved using the interior point solver from scipy.optimize.
+ If cvxopt solver if installed it can use cvxopt
+
+ Note that this problem do not scale well (both in memory and computational time).
+
+ Parameters
+ ----------
+ A : np.ndarray (d,n)
+ n training distributions a_i of size d
+ M : np.ndarray (d,d)
+ loss matrix for OT
+ reg : float
+ Regularization term >0
+ weights : np.ndarray (n,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ solver : string, optional
+ the solver used, default 'interior-point' use the lp solver from
+ scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
+
+ Returns
+ -------
+ a : (d,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
+
+
+
+ """
+
+ if weights is None:
+ weights = np.ones(A.shape[1]) / A.shape[1]
+ else:
+ assert(len(weights) == A.shape[1])
+
+ n_distributions = A.shape[1]
+ n = A.shape[0]
+
+ n2 = n * n
+ c = np.zeros((0))
+ b_eq1 = np.zeros((0))
+ for i in range(n_distributions):
+ c = np.concatenate((c, M.ravel() * weights[i]))
+ b_eq1 = np.concatenate((b_eq1, A[:, i]))
+ c = np.concatenate((c, np.zeros(n)))
+
+ lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)]
+ # row constraints
+ A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))))
+
+ # columns constraints
+ lst_idiag2 = []
+ lst_eye = []
+ for i in range(n_distributions):
+ if i == 0:
+ lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n)))
+ lst_eye.append(-sps.eye(n))
+ else:
+ lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n)))
+ lst_eye.append(-sps.eye(n - 1, n))
+
+ A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye)))
+ b_eq2 = np.zeros((A_eq2.shape[0]))
+
+ # full problem
+ A_eq = sps.vstack((A_eq1, A_eq2))
+ b_eq = np.concatenate((b_eq1, b_eq2))
+
+ if not cvxopt or solver in ['interior-point']:
+ # cvxopt not installed or interior point
+
+ if solver is None:
+ solver = 'interior-point'
+
+ options = {'sparse': True, 'disp': verbose}
+ sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver,
+ options=options)
+ x = sol.x
+ b = x[-n:]
+
+ else:
+
+ h = np.zeros((n_distributions * n2 + n))
+ G = -sps.eye(n_distributions * n2 + n)
+
+ sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h),
+ A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq),
+ solver=solver)
+
+ x = np.array(sol['x'])
+ b = x[-n:].ravel()
+
+ if log:
+ return b, sol
+ else:
+ return b
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
new file mode 100644
index 0000000..2b6c495
--- /dev/null
+++ b/ot/lp/emd_wrap.pyx
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+"""
+Cython linker with C solver
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+cimport numpy as np
+
+from ..utils import dist
+
+cimport cython
+cimport libc.math as math
+
+import warnings
+
+
+cdef extern from "EMD.h":
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter)
+ cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
+
+
+def check_result(result_code):
+ if result_code == OPTIMAL:
+ return None
+
+ if result_code == INFEASIBLE:
+ message = "Problem infeasible. Check that a and b are in the simplex"
+ elif result_code == UNBOUNDED:
+ message = "Problem unbounded"
+ elif result_code == MAX_ITER_REACHED:
+ message = "numItermax reached before optimality. Try to increase numItermax."
+ warnings.warn(message)
+ return message
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
+ """
+ Solves the Earth Movers distance problem and returns the optimal transport matrix
+
+ gamm=emd(a,b,M)
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - a and b are the sample weights
+
+ .. warning::
+ Note that the M matrix needs to be a C-order :py.cls:`numpy.array`
+
+ Parameters
+ ----------
+ a : (ns,) numpy.ndarray, float64
+ source histogram
+ b : (nt,) numpy.ndarray, float64
+ target histogram
+ M : (ns,nt) numpy.ndarray, float64
+ loss matrix
+ max_iter : int
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
+
+
+ Returns
+ -------
+ gamma: (ns x nt) numpy.ndarray
+ Optimal transportation matrix for the given parameters
+
+ """
+ cdef int n1= M.shape[0]
+ cdef int n2= M.shape[1]
+
+ cdef double cost=0
+ cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+ cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
+ cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2)
+
+
+ if not len(a):
+ a=np.ones((n1,))/n1
+
+ if not len(b):
+ b=np.ones((n2,))/n2
+
+ # calling the function
+ cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+
+ return G, cost, alpha, beta, result_code
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
+ np.ndarray[double, ndim=1, mode="c"] v_weights,
+ np.ndarray[double, ndim=1, mode="c"] u,
+ np.ndarray[double, ndim=1, mode="c"] v,
+ str metric='sqeuclidean',
+ double p=1.):
+ r"""
+ Solves the Earth Movers distance problem between sorted 1d measures and
+ returns the OT matrix and the associated cost
+
+ Parameters
+ ----------
+ u_weights : (ns,) ndarray, float64
+ Source histogram
+ v_weights : (nt,) ndarray, float64
+ Target histogram
+ u : (ns,) ndarray, float64
+ Source dirac locations (on the real line)
+ v : (nt,) ndarray, float64
+ Target dirac locations (on the real line)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+
+ Returns
+ -------
+ gamma: (n, ) ndarray, float64
+ Values in the Optimal transportation matrix
+ indices: (n, 2) ndarray, int64
+ Indices of the values stored in gamma for the Optimal transportation
+ matrix
+ cost
+ cost associated to the optimal transportation
+ """
+ cdef double cost = 0.
+ cdef int n = u_weights.shape[0]
+ cdef int m = v_weights.shape[0]
+
+ cdef int i = 0
+ cdef double w_i = u_weights[0]
+ cdef int j = 0
+ cdef double w_j = v_weights[0]
+
+ cdef double m_ij = 0.
+
+ cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ),
+ dtype=np.float64)
+ cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
+ dtype=np.int)
+ cdef int cur_idx = 0
+ while i < n and j < m:
+ if metric == 'sqeuclidean':
+ m_ij = (u[i] - v[j]) * (u[i] - v[j])
+ elif metric == 'cityblock' or metric == 'euclidean':
+ m_ij = math.fabs(u[i] - v[j])
+ elif metric == 'minkowski':
+ m_ij = math.pow(math.fabs(u[i] - v[j]), p)
+ else:
+ m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
+ metric=metric)[0, 0]
+ if w_i < w_j or j == m - 1:
+ cost += m_ij * w_i
+ G[cur_idx] = w_i
+ indices[cur_idx, 0] = i
+ indices[cur_idx, 1] = j
+ i += 1
+ w_j -= w_i
+ w_i = u_weights[i]
+ else:
+ cost += m_ij * w_j
+ G[cur_idx] = w_j
+ indices[cur_idx, 0] = i
+ indices[cur_idx, 1] = j
+ j += 1
+ w_i -= w_j
+ w_j = v_weights[j]
+ cur_idx += 1
+ return G[:cur_idx], indices[:cur_idx], cost
diff --git a/ot/lp/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h
new file mode 100644
index 0000000..87a1bec
--- /dev/null
+++ b/ot/lp/full_bipartitegraph.h
@@ -0,0 +1,215 @@
+/* -*- mode: C++; indent-tabs-mode: nil; -*-
+ *
+ * This file has been adapted by Nicolas Bonneel (2013),
+ * from full_graph.h from LEMON, a generic C++ optimization library,
+ * to implement a lightweight fully connected bipartite graph. A previous
+ * version of this file is used as part of the Displacement Interpolation
+ * project,
+ * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
+ *
+ *
+ **** Original file Copyright Notice :
+ * Copyright (C) 2003-2010
+ * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
+ * (Egervary Research Group on Combinatorial Optimization, EGRES).
+ *
+ * Permission to use, modify and distribute this software is granted
+ * provided that this copyright notice appears in all copies. For
+ * precise terms see the accompanying LICENSE file.
+ *
+ * This software is provided "AS IS" with no warranty of any kind,
+ * express or implied, and with no claim as to its suitability for any
+ * purpose.
+ *
+ */
+
+#ifndef LEMON_FULL_BIPARTITE_GRAPH_H
+#define LEMON_FULL_BIPARTITE_GRAPH_H
+
+#include "core.h"
+
+///\ingroup graphs
+///\file
+///\brief FullBipartiteDigraph and FullBipartiteGraph classes.
+
+
+namespace lemon {
+
+
+ class FullBipartiteDigraphBase {
+ public:
+
+ typedef FullBipartiteDigraphBase Digraph;
+
+ //class Node;
+ typedef int Node;
+ //class Arc;
+ typedef long long Arc;
+
+ protected:
+
+ int _node_num;
+ long long _arc_num;
+
+ FullBipartiteDigraphBase() {}
+
+ void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = n1 * n2; _n1=n1; _n2=n2;}
+
+ public:
+
+ int _n1, _n2;
+
+
+ Node operator()(int ix) const { return Node(ix); }
+ static int index(const Node& node) { return node; }
+
+ Arc arc(const Node& s, const Node& t) const {
+ if (s<_n1 && t>=_n1)
+ return Arc(s * _n2 + (t-_n1) );
+ else
+ return Arc(-1);
+ }
+
+ int nodeNum() const { return _node_num; }
+ long long arcNum() const { return _arc_num; }
+
+ int maxNodeId() const { return _node_num - 1; }
+ long long maxArcId() const { return _arc_num - 1; }
+
+ Node source(Arc arc) const { return arc / _n2; }
+ Node target(Arc arc) const { return (arc % _n2) + _n1; }
+
+ static int id(Node node) { return node; }
+ static long long id(Arc arc) { return arc; }
+
+ static Node nodeFromId(int id) { return Node(id);}
+ static Arc arcFromId(int id) { return Arc(id);}
+
+
+ Arc findArc(Node s, Node t, Arc prev = -1) const {
+ return prev == -1 ? arc(s, t) : -1;
+ }
+
+ void first(Node& node) const {
+ node = _node_num - 1;
+ }
+
+ static void next(Node& node) {
+ --node;
+ }
+
+ void first(Arc& arc) const {
+ arc = _arc_num - 1;
+ }
+
+ static void next(Arc& arc) {
+ --arc;
+ }
+
+ void firstOut(Arc& arc, const Node& node) const {
+ if (node>=_n1)
+ arc = -1;
+ else
+ arc = (node + 1) * _n2 - 1;
+ }
+
+ void nextOut(Arc& arc) const {
+ if (arc % _n2 == 0) arc = 0;
+ --arc;
+ }
+
+ void firstIn(Arc& arc, const Node& node) const {
+ if (node<_n1)
+ arc = -1;
+ else
+ arc = _arc_num + node - _node_num;
+ }
+
+ void nextIn(Arc& arc) const {
+ arc -= _n2;
+ if (arc < 0) arc = -1;
+ }
+
+ };
+
+ /// \ingroup graphs
+ ///
+ /// \brief A directed full graph class.
+ ///
+ /// FullBipartiteDigraph is a simple and fast implmenetation of directed full
+ /// (complete) graphs. It contains an arc from each node to each node
+ /// (including a loop for each node), therefore the number of arcs
+ /// is the square of the number of nodes.
+ /// This class is completely static and it needs constant memory space.
+ /// Thus you can neither add nor delete nodes or arcs, however
+ /// the structure can be resized using resize().
+ ///
+ /// This type fully conforms to the \ref concepts::Digraph "Digraph concept".
+ /// Most of its member functions and nested classes are documented
+ /// only in the concept class.
+ ///
+ /// This class provides constant time counting for nodes and arcs.
+ ///
+ /// \note FullBipartiteDigraph and FullBipartiteGraph classes are very similar,
+ /// but there are two differences. While this class conforms only
+ /// to the \ref concepts::Digraph "Digraph" concept, FullBipartiteGraph
+ /// conforms to the \ref concepts::Graph "Graph" concept,
+ /// moreover FullBipartiteGraph does not contain a loop for each
+ /// node as this class does.
+ ///
+ /// \sa FullBipartiteGraph
+ class FullBipartiteDigraph : public FullBipartiteDigraphBase {
+ typedef FullBipartiteDigraphBase Parent;
+
+ public:
+
+ /// \brief Default constructor.
+ ///
+ /// Default constructor. The number of nodes and arcs will be zero.
+ FullBipartiteDigraph() { construct(0,0); }
+
+ /// \brief Constructor
+ ///
+ /// Constructor.
+ /// \param n The number of the nodes.
+ FullBipartiteDigraph(int n1, int n2) { construct(n1, n2); }
+
+
+ /// \brief Returns the node with the given index.
+ ///
+ /// Returns the node with the given index. Since this structure is
+ /// completely static, the nodes can be indexed with integers from
+ /// the range <tt>[0..nodeNum()-1]</tt>.
+ /// The index of a node is the same as its ID.
+ /// \sa index()
+ Node operator()(int ix) const { return Parent::operator()(ix); }
+
+ /// \brief Returns the index of the given node.
+ ///
+ /// Returns the index of the given node. Since this structure is
+ /// completely static, the nodes can be indexed with integers from
+ /// the range <tt>[0..nodeNum()-1]</tt>.
+ /// The index of a node is the same as its ID.
+ /// \sa operator()()
+ static int index(const Node& node) { return Parent::index(node); }
+
+ /// \brief Returns the arc connecting the given nodes.
+ ///
+ /// Returns the arc connecting the given nodes.
+ /*Arc arc(Node u, Node v) const {
+ return Parent::arc(u, v);
+ }*/
+
+ /// \brief Number of nodes.
+ int nodeNum() const { return Parent::nodeNum(); }
+ /// \brief Number of arcs.
+ long long arcNum() const { return Parent::arcNum(); }
+ };
+
+
+
+
+} //namespace lemon
+
+
+#endif //LEMON_FULL_GRAPH_H
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
new file mode 100644
index 0000000..7c6a4ce
--- /dev/null
+++ b/ot/lp/network_simplex_simple.h
@@ -0,0 +1,1553 @@
+/* -*- mode: C++; indent-tabs-mode: nil; -*-
+ *
+ *
+ * This file has been adapted by Nicolas Bonneel (2013),
+ * from network_simplex.h from LEMON, a generic C++ optimization library,
+ * to implement a lightweight network simplex for mass transport, more
+ * memory efficient that the original file. A previous version of this file
+ * is used as part of the Displacement Interpolation project,
+ * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
+ *
+ *
+ **** Original file Copyright Notice :
+ *
+ * Copyright (C) 2003-2010
+ * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
+ * (Egervary Research Group on Combinatorial Optimization, EGRES).
+ *
+ * Permission to use, modify and distribute this software is granted
+ * provided that this copyright notice appears in all copies. For
+ * precise terms see the accompanying LICENSE file.
+ *
+ * This software is provided "AS IS" with no warranty of any kind,
+ * express or implied, and with no claim as to its suitability for any
+ * purpose.
+ *
+ */
+
+#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H
+#define LEMON_NETWORK_SIMPLEX_SIMPLE_H
+#define DEBUG_LVL 0
+
+#if DEBUG_LVL>0
+#include <iomanip>
+#endif
+
+
+#define EPSILON 2.2204460492503131e-15
+#define _EPSILON 1e-8
+#define MAX_DEBUG_ITER 100000
+
+
+/// \ingroup min_cost_flow_algs
+///
+/// \file
+/// \brief Network Simplex algorithm for finding a minimum cost flow.
+
+// if your compiler has troubles with stdext or hashmaps, just comment the following line to use a slower std::map instead
+//#define HASHMAP
+
+#include <vector>
+#include <limits>
+#include <algorithm>
+#include <cstdio>
+#ifdef HASHMAP
+#include <hash_map>
+#else
+#include <map>
+#endif
+#include <cmath>
+//#include "core.h"
+//#include "lmath.h"
+
+//#include "sparse_array_n.h"
+#include "full_bipartitegraph.h"
+
+#define INVALIDNODE -1
+#define INVALID (-1)
+
+namespace lemon {
+
+
+ template <typename T>
+ class ProxyObject;
+
+ template<typename T>
+ class SparseValueVector
+ {
+ public:
+ SparseValueVector(int n=0)
+ {
+ }
+ void resize(int n=0){};
+ T operator[](const int id) const
+ {
+#ifdef HASHMAP
+ typename stdext::hash_map<int,T>::const_iterator it = data.find(id);
+#else
+ typename std::map<int,T>::const_iterator it = data.find(id);
+#endif
+ if (it==data.end())
+ return 0;
+ else
+ return it->second;
+ }
+
+ ProxyObject<T> operator[](const int id)
+ {
+ return ProxyObject<T>( this, id );
+ }
+
+ //private:
+#ifdef HASHMAP
+ stdext::hash_map<int,T> data;
+#else
+ std::map<int,T> data;
+#endif
+
+ };
+
+ template <typename T>
+ class ProxyObject {
+ public:
+ ProxyObject( SparseValueVector<T> *v, int idx ){_v=v; _idx=idx;};
+ ProxyObject<T> & operator=( const T &v ) {
+ // If we get here, we know that operator[] was called to perform a write access,
+ // so we can insert an item in the vector if needed
+ if (v!=0)
+ _v->data[_idx]=v;
+ return *this;
+ }
+
+ operator T() {
+ // If we get here, we know that operator[] was called to perform a read access,
+ // so we can simply return the existing object
+#ifdef HASHMAP
+ typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx);
+#else
+ typename std::map<int,T>::iterator it = _v->data.find(_idx);
+#endif
+ if (it==_v->data.end())
+ return 0;
+ else
+ return it->second;
+ }
+
+ void operator+=(T val)
+ {
+ if (val==0) return;
+#ifdef HASHMAP
+ typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx);
+#else
+ typename std::map<int,T>::iterator it = _v->data.find(_idx);
+#endif
+ if (it==_v->data.end())
+ _v->data[_idx] = val;
+ else
+ {
+ T sum = it->second + val;
+ if (sum==0)
+ _v->data.erase(it);
+ else
+ it->second = sum;
+ }
+ }
+ void operator-=(T val)
+ {
+ if (val==0) return;
+#ifdef HASHMAP
+ typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx);
+#else
+ typename std::map<int,T>::iterator it = _v->data.find(_idx);
+#endif
+ if (it==_v->data.end())
+ _v->data[_idx] = -val;
+ else
+ {
+ T sum = it->second - val;
+ if (sum==0)
+ _v->data.erase(it);
+ else
+ it->second = sum;
+ }
+ }
+
+ SparseValueVector<T> *_v;
+ int _idx;
+ };
+
+
+
+ /// \addtogroup min_cost_flow_algs
+ /// @{
+
+ /// \brief Implementation of the primal Network Simplex algorithm
+ /// for finding a \ref min_cost_flow "minimum cost flow".
+ ///
+ /// \ref NetworkSimplexSimple implements the primal Network Simplex algorithm
+ /// for finding a \ref min_cost_flow "minimum cost flow"
+ /// \ref amo93networkflows, \ref dantzig63linearprog,
+ /// \ref kellyoneill91netsimplex.
+ /// This algorithm is a highly efficient specialized version of the
+ /// linear programming simplex method directly for the minimum cost
+ /// flow problem.
+ ///
+ /// In general, %NetworkSimplexSimple is the fastest implementation available
+ /// in LEMON for this problem.
+ /// Moreover, it supports both directions of the supply/demand inequality
+ /// constraints. For more information, see \ref SupplyType.
+ ///
+ /// Most of the parameters of the problem (except for the digraph)
+ /// can be given using separate functions, and the algorithm can be
+ /// executed using the \ref run() function. If some parameters are not
+ /// specified, then default values will be used.
+ ///
+ /// \tparam GR The digraph type the algorithm runs on.
+ /// \tparam V The number type used for flow amounts, capacity bounds
+ /// and supply values in the algorithm. By default, it is \c int.
+ /// \tparam C The number type used for costs and potentials in the
+ /// algorithm. By default, it is the same as \c V.
+ ///
+ /// \warning Both number types must be signed and all input data must
+ /// be integer.
+ ///
+ /// \note %NetworkSimplexSimple provides five different pivot rule
+ /// implementations, from which the most efficient one is used
+ /// by default. For more information, see \ref PivotRule.
+ template <typename GR, typename V = int, typename C = V, typename NodesType = unsigned short int>
+ class NetworkSimplexSimple
+ {
+ public:
+
+ /// \brief Constructor.
+ ///
+ /// The constructor of the class.
+ ///
+ /// \param graph The digraph the algorithm runs on.
+ /// \param arc_mixing Indicate if the arcs have to be stored in a
+ /// mixed order in the internal data structure.
+ /// In special cases, it could lead to better overall performance,
+ /// but it is usually slower. Therefore it is disabled by default.
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) :
+ _graph(graph), //_arc_id(graph),
+ _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
+ MAX(std::numeric_limits<Value>::max()),
+ INF(std::numeric_limits<Value>::has_infinity ?
+ std::numeric_limits<Value>::infinity() : MAX)
+ {
+ // Reset data structures
+ reset();
+ max_iter=maxiters;
+ }
+
+ /// The type of the flow amounts, capacity bounds and supply values
+ typedef V Value;
+ /// The type of the arc costs
+ typedef C Cost;
+
+ public:
+
+ /// \brief Problem type constants for the \c run() function.
+ ///
+ /// Enum type containing the problem type constants that can be
+ /// returned by the \ref run() function of the algorithm.
+ enum ProblemType {
+ /// The problem has no feasible solution (flow).
+ INFEASIBLE,
+ /// The problem has optimal solution (i.e. it is feasible and
+ /// bounded), and the algorithm has found optimal flow and node
+ /// potentials (primal and dual solutions).
+ OPTIMAL,
+ /// The objective function of the problem is unbounded, i.e.
+ /// there is a directed cycle having negative total cost and
+ /// infinite upper bound.
+ UNBOUNDED,
+ /// The maximum number of iteration has been reached
+ MAX_ITER_REACHED
+ };
+
+ /// \brief Constants for selecting the type of the supply constraints.
+ ///
+ /// Enum type containing constants for selecting the supply type,
+ /// i.e. the direction of the inequalities in the supply/demand
+ /// constraints of the \ref min_cost_flow "minimum cost flow problem".
+ ///
+ /// The default supply type is \c GEQ, the \c LEQ type can be
+ /// selected using \ref supplyType().
+ /// The equality form is a special case of both supply types.
+ enum SupplyType {
+ /// This option means that there are <em>"greater or equal"</em>
+ /// supply/demand constraints in the definition of the problem.
+ GEQ,
+ /// This option means that there are <em>"less or equal"</em>
+ /// supply/demand constraints in the definition of the problem.
+ LEQ
+ };
+
+
+
+ private:
+
+ int max_iter;
+ TEMPLATE_DIGRAPH_TYPEDEFS(GR);
+
+ typedef std::vector<int> IntVector;
+ typedef std::vector<NodesType> UHalfIntVector;
+ typedef std::vector<Value> ValueVector;
+ typedef std::vector<Cost> CostVector;
+ // typedef SparseValueVector<Cost> CostVector;
+ typedef std::vector<char> BoolVector;
+ // Note: vector<char> is used instead of vector<bool> for efficiency reasons
+
+ // State constants for arcs
+ enum ArcState {
+ STATE_UPPER = -1,
+ STATE_TREE = 0,
+ STATE_LOWER = 1
+ };
+
+ typedef std::vector<signed char> StateVector;
+ // Note: vector<signed char> is used instead of vector<ArcState> for
+ // efficiency reasons
+
+ private:
+
+ // Data related to the underlying digraph
+ const GR &_graph;
+ int _node_num;
+ int _arc_num;
+ int _all_arc_num;
+ int _search_arc_num;
+
+ // Parameters of the problem
+ SupplyType _stype;
+ Value _sum_supply;
+
+ inline int _node_id(int n) const {return _node_num-n-1;} ;
+
+ //IntArcMap _arc_id;
+ UHalfIntVector _source;
+ UHalfIntVector _target;
+ bool _arc_mixing;
+ public:
+ // Node and arc data
+ CostVector _cost;
+ ValueVector _supply;
+ ValueVector _flow;
+ //SparseValueVector<Value> _flow;
+ CostVector _pi;
+
+
+ private:
+ // Data for storing the spanning tree structure
+ IntVector _parent;
+ IntVector _pred;
+ IntVector _thread;
+ IntVector _rev_thread;
+ IntVector _succ_num;
+ IntVector _last_succ;
+ IntVector _dirty_revs;
+ BoolVector _forward;
+ StateVector _state;
+ int _root;
+
+ // Temporary data used in the current pivot iteration
+ int in_arc, join, u_in, v_in, u_out, v_out;
+ int first, second, right, last;
+ int stem, par_stem, new_stem;
+ Value delta;
+
+ const Value MAX;
+
+ int mixingCoeff;
+
+ public:
+
+ /// \brief Constant for infinite upper bounds (capacities).
+ ///
+ /// Constant for infinite upper bounds (capacities).
+ /// It is \c std::numeric_limits<Value>::infinity() if available,
+ /// \c std::numeric_limits<Value>::max() otherwise.
+ const Value INF;
+
+ private:
+
+ // thank you to DVK and MizardX from StackOverflow for this function!
+ inline int sequence(int k) const {
+ int smallv = (k > num_total_big_subsequence_numbers) & 1;
+
+ k -= num_total_big_subsequence_numbers * smallv;
+ int subsequence_length2 = subsequence_length- smallv;
+ int subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv;
+ int subsequence_offset = (k % subsequence_length2) * mixingCoeff;
+
+ return subsequence_offset + subsequence_num;
+ }
+ int subsequence_length;
+ int num_big_subseqiences;
+ int num_total_big_subsequence_numbers;
+
+ inline int getArcID(const Arc &arc) const
+ {
+ //int n = _arc_num-arc._id-1;
+ int n = _arc_num-GR::id(arc)-1;
+
+ //int a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
+ //int b = _arc_id[arc];
+ if (_arc_mixing)
+ return sequence(n);
+ else
+ return n;
+ }
+
+ // finally unused because too slow
+ inline int getSource(const int arc) const
+ {
+ //int a = _source[arc];
+ //return a;
+
+ int n = _arc_num-arc-1;
+ if (_arc_mixing)
+ n = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
+
+ int b;
+ if (n>=0)
+ b = _node_id(_graph.source(GR::arcFromId( n ) ));
+ else
+ {
+ n = arc+1-_arc_num;
+ if ( n<=_node_num)
+ b = _node_num;
+ else
+ if ( n>=_graph._n1)
+ b = _graph._n1;
+ else
+ b = _graph._n1-n;
+ }
+
+ return b;
+ }
+
+
+
+ // Implementation of the Block Search pivot rule
+ class BlockSearchPivotRule
+ {
+ private:
+
+ // References to the NetworkSimplexSimple class
+ const UHalfIntVector &_source;
+ const UHalfIntVector &_target;
+ const CostVector &_cost;
+ const StateVector &_state;
+ const CostVector &_pi;
+ int &_in_arc;
+ int _search_arc_num;
+
+ // Pivot rule data
+ int _block_size;
+ int _next_arc;
+ NetworkSimplexSimple &_ns;
+
+ public:
+
+ // Constructor
+ BlockSearchPivotRule(NetworkSimplexSimple &ns) :
+ _source(ns._source), _target(ns._target),
+ _cost(ns._cost), _state(ns._state), _pi(ns._pi),
+ _in_arc(ns.in_arc), _search_arc_num(ns._search_arc_num),
+ _next_arc(0),_ns(ns)
+ {
+ // The main parameters of the pivot rule
+ const double BLOCK_SIZE_FACTOR = 1.0;
+ const int MIN_BLOCK_SIZE = 10;
+
+ _block_size = std::max( int(BLOCK_SIZE_FACTOR *
+ std::sqrt(double(_search_arc_num))),
+ MIN_BLOCK_SIZE );
+ }
+ // Find next entering arc
+ bool findEnteringArc() {
+ Cost c, min = 0;
+ int e;
+ int cnt = _block_size;
+ double a;
+ for (e = _next_arc; e != _search_arc_num; ++e) {
+ c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < min) {
+ min = c;
+ _in_arc = e;
+ }
+ if (--cnt == 0) {
+ a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]);
+ a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]);
+ if (min < -EPSILON*a) goto search_end;
+ cnt = _block_size;
+ }
+ }
+ for (e = 0; e != _next_arc; ++e) {
+ c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < min) {
+ min = c;
+ _in_arc = e;
+ }
+ if (--cnt == 0) {
+ a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]);
+ a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]);
+ if (min < -EPSILON*a) goto search_end;
+ cnt = _block_size;
+ }
+ }
+ a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]);
+ a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]);
+ if (min >= -EPSILON*a) return false;
+
+ search_end:
+ _next_arc = e;
+ return true;
+ }
+
+ }; //class BlockSearchPivotRule
+
+
+
+ public:
+
+
+
+ int _init_nb_nodes;
+ long long _init_nb_arcs;
+
+ /// \name Parameters
+ /// The parameters of the algorithm can be specified using these
+ /// functions.
+
+ /// @{
+
+
+ /// \brief Set the costs of the arcs.
+ ///
+ /// This function sets the costs of the arcs.
+ /// If it is not used before calling \ref run(), the costs
+ /// will be set to \c 1 on all arcs.
+ ///
+ /// \param map An arc map storing the costs.
+ /// Its \c Value type must be convertible to the \c Cost type
+ /// of the algorithm.
+ ///
+ /// \return <tt>(*this)</tt>
+ template<typename CostMap>
+ NetworkSimplexSimple& costMap(const CostMap& map) {
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a)) {
+ _cost[getArcID(a)] = map[a];
+ }
+ return *this;
+ }
+
+
+ /// \brief Set the costs of one arc.
+ ///
+ /// This function sets the costs of one arcs.
+ /// Done for memory reasons
+ ///
+ /// \param arc An arc.
+ /// \param arc A cost
+ ///
+ /// \return <tt>(*this)</tt>
+ template<typename Value>
+ NetworkSimplexSimple& setCost(const Arc& arc, const Value cost) {
+ _cost[getArcID(arc)] = cost;
+ return *this;
+ }
+
+
+ /// \brief Set the supply values of the nodes.
+ ///
+ /// This function sets the supply values of the nodes.
+ /// If neither this function nor \ref stSupply() is used before
+ /// calling \ref run(), the supply of each node will be set to zero.
+ ///
+ /// \param map A node map storing the supply values.
+ /// Its \c Value type must be convertible to the \c Value type
+ /// of the algorithm.
+ ///
+ /// \return <tt>(*this)</tt>
+ template<typename SupplyMap>
+ NetworkSimplexSimple& supplyMap(const SupplyMap& map) {
+ Node n; _graph.first(n);
+ for (; n != INVALIDNODE; _graph.next(n)) {
+ _supply[_node_id(n)] = map[n];
+ }
+ return *this;
+ }
+ template<typename SupplyMap>
+ NetworkSimplexSimple& supplyMap(const SupplyMap* map1, int n1, const SupplyMap* map2, int n2) {
+ Node n; _graph.first(n);
+ for (; n != INVALIDNODE; _graph.next(n)) {
+ if (n<n1)
+ _supply[_node_id(n)] = map1[n];
+ else
+ _supply[_node_id(n)] = map2[n-n1];
+ }
+ return *this;
+ }
+ template<typename SupplyMap>
+ NetworkSimplexSimple& supplyMapAll(SupplyMap val1, int n1, SupplyMap val2, int n2) {
+ Node n; _graph.first(n);
+ for (; n != INVALIDNODE; _graph.next(n)) {
+ if (n<n1)
+ _supply[_node_id(n)] = val1;
+ else
+ _supply[_node_id(n)] = val2;
+ }
+ return *this;
+ }
+
+ /// \brief Set single source and target nodes and a supply value.
+ ///
+ /// This function sets a single source node and a single target node
+ /// and the required flow value.
+ /// If neither this function nor \ref supplyMap() is used before
+ /// calling \ref run(), the supply of each node will be set to zero.
+ ///
+ /// Using this function has the same effect as using \ref supplyMap()
+ /// with such a map in which \c k is assigned to \c s, \c -k is
+ /// assigned to \c t and all other nodes have zero supply value.
+ ///
+ /// \param s The source node.
+ /// \param t The target node.
+ /// \param k The required amount of flow from node \c s to node \c t
+ /// (i.e. the supply of \c s and the demand of \c t).
+ ///
+ /// \return <tt>(*this)</tt>
+ NetworkSimplexSimple& stSupply(const Node& s, const Node& t, Value k) {
+ for (int i = 0; i != _node_num; ++i) {
+ _supply[i] = 0;
+ }
+ _supply[_node_id(s)] = k;
+ _supply[_node_id(t)] = -k;
+ return *this;
+ }
+
+ /// \brief Set the type of the supply constraints.
+ ///
+ /// This function sets the type of the supply/demand constraints.
+ /// If it is not used before calling \ref run(), the \ref GEQ supply
+ /// type will be used.
+ ///
+ /// For more information, see \ref SupplyType.
+ ///
+ /// \return <tt>(*this)</tt>
+ NetworkSimplexSimple& supplyType(SupplyType supply_type) {
+ _stype = supply_type;
+ return *this;
+ }
+
+ /// @}
+
+ /// \name Execution Control
+ /// The algorithm can be executed using \ref run().
+
+ /// @{
+
+ /// \brief Run the algorithm.
+ ///
+ /// This function runs the algorithm.
+ /// The paramters can be specified using functions \ref lowerMap(),
+ /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(),
+ /// \ref supplyType().
+ /// For example,
+ /// \code
+ /// NetworkSimplexSimple<ListDigraph> ns(graph);
+ /// ns.lowerMap(lower).upperMap(upper).costMap(cost)
+ /// .supplyMap(sup).run();
+ /// \endcode
+ ///
+ /// This function can be called more than once. All the given parameters
+ /// are kept for the next call, unless \ref resetParams() or \ref reset()
+ /// is used, thus only the modified parameters have to be set again.
+ /// If the underlying digraph was also modified after the construction
+ /// of the class (or the last \ref reset() call), then the \ref reset()
+ /// function must be called.
+ ///
+ /// \param pivot_rule The pivot rule that will be used during the
+ /// algorithm. For more information, see \ref PivotRule.
+ ///
+ /// \return \c INFEASIBLE if no feasible flow exists,
+ /// \n \c OPTIMAL if the problem has optimal solution
+ /// (i.e. it is feasible and bounded), and the algorithm has found
+ /// optimal flow and node potentials (primal and dual solutions),
+ /// \n \c UNBOUNDED if the objective function of the problem is
+ /// unbounded, i.e. there is a directed cycle having negative total
+ /// cost and infinite upper bound.
+ ///
+ /// \see ProblemType, PivotRule
+ /// \see resetParams(), reset()
+ ProblemType run() {
+#if DEBUG_LVL>0
+ std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n";
+#endif
+
+ if (!init()) return INFEASIBLE;
+#if DEBUG_LVL>0
+ std::cout << "Init done, starting iterations\n";
+#endif
+ return start();
+ }
+
+ /// \brief Reset all the parameters that have been given before.
+ ///
+ /// This function resets all the paramaters that have been given
+ /// before using functions \ref lowerMap(), \ref upperMap(),
+ /// \ref costMap(), \ref supplyMap(), \ref stSupply(), \ref supplyType().
+ ///
+ /// It is useful for multiple \ref run() calls. Basically, all the given
+ /// parameters are kept for the next \ref run() call, unless
+ /// \ref resetParams() or \ref reset() is used.
+ /// If the underlying digraph was also modified after the construction
+ /// of the class or the last \ref reset() call, then the \ref reset()
+ /// function must be used, otherwise \ref resetParams() is sufficient.
+ ///
+ /// For example,
+ /// \code
+ /// NetworkSimplexSimple<ListDigraph> ns(graph);
+ ///
+ /// // First run
+ /// ns.lowerMap(lower).upperMap(upper).costMap(cost)
+ /// .supplyMap(sup).run();
+ ///
+ /// // Run again with modified cost map (resetParams() is not called,
+ /// // so only the cost map have to be set again)
+ /// cost[e] += 100;
+ /// ns.costMap(cost).run();
+ ///
+ /// // Run again from scratch using resetParams()
+ /// // (the lower bounds will be set to zero on all arcs)
+ /// ns.resetParams();
+ /// ns.upperMap(capacity).costMap(cost)
+ /// .supplyMap(sup).run();
+ /// \endcode
+ ///
+ /// \return <tt>(*this)</tt>
+ ///
+ /// \see reset(), run()
+ NetworkSimplexSimple& resetParams() {
+ for (int i = 0; i != _node_num; ++i) {
+ _supply[i] = 0;
+ }
+ for (int i = 0; i != _arc_num; ++i) {
+ _cost[i] = 1;
+ }
+ _stype = GEQ;
+ return *this;
+ }
+
+
+
+ int divid (int x, int y)
+ {
+ return (x-x%y)/y;
+ }
+
+ /// \brief Reset the internal data structures and all the parameters
+ /// that have been given before.
+ ///
+ /// This function resets the internal data structures and all the
+ /// paramaters that have been given before using functions \ref lowerMap(),
+ /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(),
+ /// \ref supplyType().
+ ///
+ /// It is useful for multiple \ref run() calls. Basically, all the given
+ /// parameters are kept for the next \ref run() call, unless
+ /// \ref resetParams() or \ref reset() is used.
+ /// If the underlying digraph was also modified after the construction
+ /// of the class or the last \ref reset() call, then the \ref reset()
+ /// function must be used, otherwise \ref resetParams() is sufficient.
+ ///
+ /// See \ref resetParams() for examples.
+ ///
+ /// \return <tt>(*this)</tt>
+ ///
+ /// \see resetParams(), run()
+ NetworkSimplexSimple& reset() {
+ // Resize vectors
+ _node_num = _init_nb_nodes;
+ _arc_num = _init_nb_arcs;
+ int all_node_num = _node_num + 1;
+ int max_arc_num = _arc_num + 2 * _node_num;
+
+ _source.resize(max_arc_num);
+ _target.resize(max_arc_num);
+
+ _cost.resize(max_arc_num);
+ _supply.resize(all_node_num);
+ _flow.resize(max_arc_num);
+ _pi.resize(all_node_num);
+
+ _parent.resize(all_node_num);
+ _pred.resize(all_node_num);
+ _forward.resize(all_node_num);
+ _thread.resize(all_node_num);
+ _rev_thread.resize(all_node_num);
+ _succ_num.resize(all_node_num);
+ _last_succ.resize(all_node_num);
+ _state.resize(max_arc_num);
+
+
+ //_arc_mixing=false;
+ if (_arc_mixing) {
+ // Store the arcs in a mixed order
+ int k = std::max(int(std::sqrt(double(_arc_num))), 10);
+ mixingCoeff = k;
+ subsequence_length = _arc_num / mixingCoeff + 1;
+ num_big_subseqiences = _arc_num % mixingCoeff;
+ num_total_big_subsequence_numbers = subsequence_length * num_big_subseqiences;
+
+ int i = 0, j = 0;
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a)) {
+ _source[i] = _node_id(_graph.source(a));
+ _target[i] = _node_id(_graph.target(a));
+ //_arc_id[a] = i;
+ if ((i += k) >= _arc_num) i = ++j;
+ }
+ } else {
+ // Store the arcs in the original order
+ int i = 0;
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a), ++i) {
+ _source[i] = _node_id(_graph.source(a));
+ _target[i] = _node_id(_graph.target(a));
+ //_arc_id[a] = i;
+ }
+ }
+
+ // Reset parameters
+ resetParams();
+ return *this;
+ }
+
+ /// @}
+
+ /// \name Query Functions
+ /// The results of the algorithm can be obtained using these
+ /// functions.\n
+ /// The \ref run() function must be called before using them.
+
+ /// @{
+
+ /// \brief Return the total cost of the found flow.
+ ///
+ /// This function returns the total cost of the found flow.
+ /// Its complexity is O(e).
+ ///
+ /// \note The return type of the function can be specified as a
+ /// template parameter. For example,
+ /// \code
+ /// ns.totalCost<double>();
+ /// \endcode
+ /// It is useful if the total cost cannot be stored in the \c Cost
+ /// type of the algorithm, which is the default return type of the
+ /// function.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ /*template <typename Number>
+ Number totalCost() const {
+ Number c = 0;
+ for (ArcIt a(_graph); a != INVALID; ++a) {
+ int i = getArcID(a);
+ c += Number(_flow[i]) * Number(_cost[i]);
+ }
+ return c;
+ }*/
+
+ template <typename Number>
+ Number totalCost() const {
+ Number c = 0;
+
+ /*#ifdef HASHMAP
+ typename stdext::hash_map<int, Value>::const_iterator it;
+ #else
+ typename std::map<int, Value>::const_iterator it;
+ #endif
+ for (it = _flow.data.begin(); it!=_flow.data.end(); ++it)
+ c += Number(it->second) * Number(_cost[it->first]);
+ return c;*/
+
+ for (int i=0; i<_flow.size(); i++)
+ c += _flow[i] * Number(_cost[i]);
+ return c;
+
+ }
+
+#ifndef DOXYGEN
+ Cost totalCost() const {
+ return totalCost<Cost>();
+ }
+#endif
+
+ /// \brief Return the flow on the given arc.
+ ///
+ /// This function returns the flow on the given arc.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ Value flow(const Arc& a) const {
+ return _flow[getArcID(a)];
+ }
+
+ /// \brief Return the flow map (the primal solution).
+ ///
+ /// This function copies the flow value on each arc into the given
+ /// map. The \c Value type of the algorithm must be convertible to
+ /// the \c Value type of the map.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ template <typename FlowMap>
+ void flowMap(FlowMap &map) const {
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a)) {
+ map.set(a, _flow[getArcID(a)]);
+ }
+ }
+
+ /// \brief Return the potential (dual value) of the given node.
+ ///
+ /// This function returns the potential (dual value) of the
+ /// given node.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ Cost potential(const Node& n) const {
+ return _pi[_node_id(n)];
+ }
+
+ /// \brief Return the potential map (the dual solution).
+ ///
+ /// This function copies the potential (dual value) of each node
+ /// into the given map.
+ /// The \c Cost type of the algorithm must be convertible to the
+ /// \c Value type of the map.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ template <typename PotentialMap>
+ void potentialMap(PotentialMap &map) const {
+ Node n; _graph.first(n);
+ for (; n != INVALID; _graph.next(n)) {
+ map.set(n, _pi[_node_id(n)]);
+ }
+ }
+
+ /// @}
+
+ private:
+
+ // Initialize internal data structures
+ bool init() {
+ if (_node_num == 0) return false;
+
+ // Check the sum of supply values
+ _sum_supply = 0;
+ for (int i = 0; i != _node_num; ++i) {
+ _sum_supply += _supply[i];
+ }
+ if ( fabs(_sum_supply) > _EPSILON ) return false;
+
+ _sum_supply = 0;
+
+ // Initialize artifical cost
+ Cost ART_COST;
+ if (std::numeric_limits<Cost>::is_exact) {
+ ART_COST = std::numeric_limits<Cost>::max() / 2 + 1;
+ } else {
+ ART_COST = 0;
+ for (int i = 0; i != _arc_num; ++i) {
+ if (_cost[i] > ART_COST) ART_COST = _cost[i];
+ }
+ ART_COST = (ART_COST + 1) * _node_num;
+ }
+
+ // Initialize arc maps
+ for (int i = 0; i != _arc_num; ++i) {
+ //_flow[i] = 0; //by default, the sparse matrix is empty
+ _state[i] = STATE_LOWER;
+ }
+
+ // Set data for the artificial root node
+ _root = _node_num;
+ _parent[_root] = -1;
+ _pred[_root] = -1;
+ _thread[_root] = 0;
+ _rev_thread[0] = _root;
+ _succ_num[_root] = _node_num + 1;
+ _last_succ[_root] = _root - 1;
+ _supply[_root] = -_sum_supply;
+ _pi[_root] = 0;
+
+ // Add artificial arcs and initialize the spanning tree data structure
+ if (_sum_supply == 0) {
+ // EQ supply constraints
+ _search_arc_num = _arc_num;
+ _all_arc_num = _arc_num + _node_num;
+ for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ _parent[u] = _root;
+ _pred[u] = e;
+ _thread[u] = u + 1;
+ _rev_thread[u + 1] = u;
+ _succ_num[u] = 1;
+ _last_succ[u] = u;
+ _state[e] = STATE_TREE;
+ if (_supply[u] >= 0) {
+ _forward[u] = true;
+ _pi[u] = 0;
+ _source[e] = u;
+ _target[e] = _root;
+ _flow[e] = _supply[u];
+ _cost[e] = 0;
+ } else {
+ _forward[u] = false;
+ _pi[u] = ART_COST;
+ _source[e] = _root;
+ _target[e] = u;
+ _flow[e] = -_supply[u];
+ _cost[e] = ART_COST;
+ }
+ }
+ }
+ else if (_sum_supply > 0) {
+ // LEQ supply constraints
+ _search_arc_num = _arc_num + _node_num;
+ int f = _arc_num + _node_num;
+ for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ _parent[u] = _root;
+ _thread[u] = u + 1;
+ _rev_thread[u + 1] = u;
+ _succ_num[u] = 1;
+ _last_succ[u] = u;
+ if (_supply[u] >= 0) {
+ _forward[u] = true;
+ _pi[u] = 0;
+ _pred[u] = e;
+ _source[e] = u;
+ _target[e] = _root;
+ _flow[e] = _supply[u];
+ _cost[e] = 0;
+ _state[e] = STATE_TREE;
+ } else {
+ _forward[u] = false;
+ _pi[u] = ART_COST;
+ _pred[u] = f;
+ _source[f] = _root;
+ _target[f] = u;
+ _flow[f] = -_supply[u];
+ _cost[f] = ART_COST;
+ _state[f] = STATE_TREE;
+ _source[e] = u;
+ _target[e] = _root;
+ //_flow[e] = 0; //by default, the sparse matrix is empty
+ _cost[e] = 0;
+ _state[e] = STATE_LOWER;
+ ++f;
+ }
+ }
+ _all_arc_num = f;
+ }
+ else {
+ // GEQ supply constraints
+ _search_arc_num = _arc_num + _node_num;
+ int f = _arc_num + _node_num;
+ for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ _parent[u] = _root;
+ _thread[u] = u + 1;
+ _rev_thread[u + 1] = u;
+ _succ_num[u] = 1;
+ _last_succ[u] = u;
+ if (_supply[u] <= 0) {
+ _forward[u] = false;
+ _pi[u] = 0;
+ _pred[u] = e;
+ _source[e] = _root;
+ _target[e] = u;
+ _flow[e] = -_supply[u];
+ _cost[e] = 0;
+ _state[e] = STATE_TREE;
+ } else {
+ _forward[u] = true;
+ _pi[u] = -ART_COST;
+ _pred[u] = f;
+ _source[f] = u;
+ _target[f] = _root;
+ _flow[f] = _supply[u];
+ _state[f] = STATE_TREE;
+ _cost[f] = ART_COST;
+ _source[e] = _root;
+ _target[e] = u;
+ //_flow[e] = 0; //by default, the sparse matrix is empty
+ _cost[e] = 0;
+ _state[e] = STATE_LOWER;
+ ++f;
+ }
+ }
+ _all_arc_num = f;
+ }
+
+ return true;
+ }
+
+ // Find the join node
+ void findJoinNode() {
+ int u = _source[in_arc];
+ int v = _target[in_arc];
+ while (u != v) {
+ if (_succ_num[u] < _succ_num[v]) {
+ u = _parent[u];
+ } else {
+ v = _parent[v];
+ }
+ }
+ join = u;
+ }
+
+ // Find the leaving arc of the cycle and returns true if the
+ // leaving arc is not the same as the entering arc
+ bool findLeavingArc() {
+ // Initialize first and second nodes according to the direction
+ // of the cycle
+ if (_state[in_arc] == STATE_LOWER) {
+ first = _source[in_arc];
+ second = _target[in_arc];
+ } else {
+ first = _target[in_arc];
+ second = _source[in_arc];
+ }
+ delta = INF;
+ int result = 0;
+ Value d;
+ int e;
+
+ // Search the cycle along the path form the first node to the root
+ for (int u = first; u != join; u = _parent[u]) {
+ e = _pred[u];
+ d = _forward[u] ? _flow[e] : INF ;
+ if (d < delta) {
+ delta = d;
+ u_out = u;
+ result = 1;
+ }
+ }
+ // Search the cycle along the path form the second node to the root
+ for (int u = second; u != join; u = _parent[u]) {
+ e = _pred[u];
+ d = _forward[u] ? INF : _flow[e];
+ if (d <= delta) {
+ delta = d;
+ u_out = u;
+ result = 2;
+ }
+ }
+
+ if (result == 1) {
+ u_in = first;
+ v_in = second;
+ } else {
+ u_in = second;
+ v_in = first;
+ }
+ return result != 0;
+ }
+
+ // Change _flow and _state vectors
+ void changeFlow(bool change) {
+ // Augment along the cycle
+ if (delta > 0) {
+ Value val = _state[in_arc] * delta;
+ _flow[in_arc] += val;
+ for (int u = _source[in_arc]; u != join; u = _parent[u]) {
+ _flow[_pred[u]] += _forward[u] ? -val : val;
+ }
+ for (int u = _target[in_arc]; u != join; u = _parent[u]) {
+ _flow[_pred[u]] += _forward[u] ? val : -val;
+ }
+ }
+ // Update the state of the entering and leaving arcs
+ if (change) {
+ _state[in_arc] = STATE_TREE;
+ _state[_pred[u_out]] =
+ (_flow[_pred[u_out]] == 0) ? STATE_LOWER : STATE_UPPER;
+ } else {
+ _state[in_arc] = -_state[in_arc];
+ }
+ }
+
+ // Update the tree structure
+ void updateTreeStructure() {
+ int u, w;
+ int old_rev_thread = _rev_thread[u_out];
+ int old_succ_num = _succ_num[u_out];
+ int old_last_succ = _last_succ[u_out];
+ v_out = _parent[u_out];
+
+ u = _last_succ[u_in]; // the last successor of u_in
+ right = _thread[u]; // the node after it
+
+ // Handle the case when old_rev_thread equals to v_in
+ // (it also means that join and v_out coincide)
+ if (old_rev_thread == v_in) {
+ last = _thread[_last_succ[u_out]];
+ } else {
+ last = _thread[v_in];
+ }
+
+ // Update _thread and _parent along the stem nodes (i.e. the nodes
+ // between u_in and u_out, whose parent have to be changed)
+ _thread[v_in] = stem = u_in;
+ _dirty_revs.clear();
+ _dirty_revs.push_back(v_in);
+ par_stem = v_in;
+ while (stem != u_out) {
+ // Insert the next stem node into the thread list
+ new_stem = _parent[stem];
+ _thread[u] = new_stem;
+ _dirty_revs.push_back(u);
+
+ // Remove the subtree of stem from the thread list
+ w = _rev_thread[stem];
+ _thread[w] = right;
+ _rev_thread[right] = w;
+
+ // Change the parent node and shift stem nodes
+ _parent[stem] = par_stem;
+ par_stem = stem;
+ stem = new_stem;
+
+ // Update u and right
+ u = _last_succ[stem] == _last_succ[par_stem] ?
+ _rev_thread[par_stem] : _last_succ[stem];
+ right = _thread[u];
+ }
+ _parent[u_out] = par_stem;
+ _thread[u] = last;
+ _rev_thread[last] = u;
+ _last_succ[u_out] = u;
+
+ // Remove the subtree of u_out from the thread list except for
+ // the case when old_rev_thread equals to v_in
+ // (it also means that join and v_out coincide)
+ if (old_rev_thread != v_in) {
+ _thread[old_rev_thread] = right;
+ _rev_thread[right] = old_rev_thread;
+ }
+
+ // Update _rev_thread using the new _thread values
+ for (int i = 0; i != int(_dirty_revs.size()); ++i) {
+ u = _dirty_revs[i];
+ _rev_thread[_thread[u]] = u;
+ }
+
+ // Update _pred, _forward, _last_succ and _succ_num for the
+ // stem nodes from u_out to u_in
+ int tmp_sc = 0, tmp_ls = _last_succ[u_out];
+ u = u_out;
+ while (u != u_in) {
+ w = _parent[u];
+ _pred[u] = _pred[w];
+ _forward[u] = !_forward[w];
+ tmp_sc += _succ_num[u] - _succ_num[w];
+ _succ_num[u] = tmp_sc;
+ _last_succ[w] = tmp_ls;
+ u = w;
+ }
+ _pred[u_in] = in_arc;
+ _forward[u_in] = (u_in == _source[in_arc]);
+ _succ_num[u_in] = old_succ_num;
+
+ // Set limits for updating _last_succ form v_in and v_out
+ // towards the root
+ int up_limit_in = -1;
+ int up_limit_out = -1;
+ if (_last_succ[join] == v_in) {
+ up_limit_out = join;
+ } else {
+ up_limit_in = join;
+ }
+
+ // Update _last_succ from v_in towards the root
+ for (u = v_in; u != up_limit_in && _last_succ[u] == v_in;
+ u = _parent[u]) {
+ _last_succ[u] = _last_succ[u_out];
+ }
+ // Update _last_succ from v_out towards the root
+ if (join != old_rev_thread && v_in != old_rev_thread) {
+ for (u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ;
+ u = _parent[u]) {
+ _last_succ[u] = old_rev_thread;
+ }
+ } else {
+ for (u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ;
+ u = _parent[u]) {
+ _last_succ[u] = _last_succ[u_out];
+ }
+ }
+
+ // Update _succ_num from v_in to join
+ for (u = v_in; u != join; u = _parent[u]) {
+ _succ_num[u] += old_succ_num;
+ }
+ // Update _succ_num from v_out to join
+ for (u = v_out; u != join; u = _parent[u]) {
+ _succ_num[u] -= old_succ_num;
+ }
+ }
+
+ // Update potentials
+ void updatePotential() {
+ Cost sigma = _forward[u_in] ?
+ _pi[v_in] - _pi[u_in] - _cost[_pred[u_in]] :
+ _pi[v_in] - _pi[u_in] + _cost[_pred[u_in]];
+ // Update potentials in the subtree, which has been moved
+ int end = _thread[_last_succ[u_in]];
+ for (int u = u_in; u != end; u = _thread[u]) {
+ _pi[u] += sigma;
+ }
+ }
+
+ // Heuristic initial pivots
+ bool initialPivots() {
+ Value curr, total = 0;
+ std::vector<Node> supply_nodes, demand_nodes;
+ Node u; _graph.first(u);
+ for (; u != INVALIDNODE; _graph.next(u)) {
+ curr = _supply[_node_id(u)];
+ if (curr > 0) {
+ total += curr;
+ supply_nodes.push_back(u);
+ }
+ else if (curr < 0) {
+ demand_nodes.push_back(u);
+ }
+ }
+ if (_sum_supply > 0) total -= _sum_supply;
+ if (total <= 0) return true;
+
+ IntVector arc_vector;
+ if (_sum_supply >= 0) {
+ if (supply_nodes.size() == 1 && demand_nodes.size() == 1) {
+ // Perform a reverse graph search from the sink to the source
+ //typename GR::template NodeMap<bool> reached(_graph, false);
+ BoolVector reached(_node_num, false);
+ Node s = supply_nodes[0], t = demand_nodes[0];
+ std::vector<Node> stack;
+ reached[t] = true;
+ stack.push_back(t);
+ while (!stack.empty()) {
+ Node u, v = stack.back();
+ stack.pop_back();
+ if (v == s) break;
+ Arc a; _graph.firstIn(a, v);
+ for (; a != INVALID; _graph.nextIn(a)) {
+ if (reached[u = _graph.source(a)]) continue;
+ int j = getArcID(a);
+ if (INF >= total) {
+ arc_vector.push_back(j);
+ reached[u] = true;
+ stack.push_back(u);
+ }
+ }
+ }
+ } else {
+ // Find the min. cost incomming arc for each demand node
+ for (int i = 0; i != int(demand_nodes.size()); ++i) {
+ Node v = demand_nodes[i];
+ Cost c, min_cost = std::numeric_limits<Cost>::max();
+ Arc min_arc = INVALID;
+ Arc a; _graph.firstIn(a, v);
+ for (; a != INVALID; _graph.nextIn(a)) {
+ c = _cost[getArcID(a)];
+ if (c < min_cost) {
+ min_cost = c;
+ min_arc = a;
+ }
+ }
+ if (min_arc != INVALID) {
+ arc_vector.push_back(getArcID(min_arc));
+ }
+ }
+ }
+ } else {
+ // Find the min. cost outgoing arc for each supply node
+ for (int i = 0; i != int(supply_nodes.size()); ++i) {
+ Node u = supply_nodes[i];
+ Cost c, min_cost = std::numeric_limits<Cost>::max();
+ Arc min_arc = INVALID;
+ Arc a; _graph.firstOut(a, u);
+ for (; a != INVALID; _graph.nextOut(a)) {
+ c = _cost[getArcID(a)];
+ if (c < min_cost) {
+ min_cost = c;
+ min_arc = a;
+ }
+ }
+ if (min_arc != INVALID) {
+ arc_vector.push_back(getArcID(min_arc));
+ }
+ }
+ }
+
+ // Perform heuristic initial pivots
+ for (int i = 0; i != int(arc_vector.size()); ++i) {
+ in_arc = arc_vector[i];
+ // l'erreur est probablement ici...
+ if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -
+ _pi[_target[in_arc]]) >= 0) continue;
+ findJoinNode();
+ bool change = findLeavingArc();
+ if (delta >= MAX) return false;
+ changeFlow(change);
+ if (change) {
+ updateTreeStructure();
+ updatePotential();
+ }
+ }
+ return true;
+ }
+
+ // Execute the algorithm
+ ProblemType start() {
+ return start<BlockSearchPivotRule>();
+ }
+
+ template <typename PivotRuleImpl>
+ ProblemType start() {
+ PivotRuleImpl pivot(*this);
+ double prevCost=-1;
+ ProblemType retVal = OPTIMAL;
+
+ // Perform heuristic initial pivots
+ if (!initialPivots()) return UNBOUNDED;
+
+ int iter_number=0;
+ //pivot.setDantzig(true);
+ // Execute the Network Simplex algorithm
+ while (pivot.findEnteringArc()) {
+ if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){
+ char errMess[1000];
+ sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
+ std::cerr << errMess;
+ retVal = MAX_ITER_REACHED;
+ break;
+ }
+#if DEBUG_LVL>0
+ if(iter_number>MAX_DEBUG_ITER)
+ break;
+ if(iter_number%1000==0||iter_number%1000==1){
+ double curCost=totalCost();
+ double sumFlow=0;
+ double a;
+ a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
+ a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
+ for (int i=0; i<_flow.size(); i++) {
+ sumFlow+=_state[i]*_flow[i];
+ }
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
+ std::cout << _cost[in_arc] << "\n";
+ std::cout << _pi[_source[in_arc]] << "\n";
+ std::cout << _pi[_target[in_arc]] << "\n";
+ std::cout << a << "\n";
+ }
+#endif
+
+ findJoinNode();
+ bool change = findLeavingArc();
+ if (delta >= MAX) return UNBOUNDED;
+ changeFlow(change);
+ if (change) {
+ updateTreeStructure();
+ updatePotential();
+ }
+#if DEBUG_LVL>0
+ else{
+ std::cout << "No change\n";
+ }
+#endif
+#if DEBUG_LVL>1
+ std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n";
+#endif
+
+ }
+
+
+#if DEBUG_LVL>0
+ double curCost=totalCost();
+ double sumFlow=0;
+ double a;
+ a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
+ a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
+ for (int i=0; i<_flow.size(); i++) {
+ sumFlow+=_state[i]*_flow[i];
+ }
+
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
+
+#endif
+
+#if DEBUG_LVL>1
+ sumFlow=0;
+ for (int i=0; i<_flow.size(); i++) {
+ sumFlow+=_state[i]*_flow[i];
+ if (_state[i]==STATE_TREE) {
+ std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n";
+ }
+ }
+ std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n";
+#endif
+ // Check feasibility
+ if( retVal == OPTIMAL){
+ for (int e = _search_arc_num; e != _all_arc_num; ++e) {
+ if (_flow[e] != 0){
+ if (abs(_flow[e]) > EPSILON)
+ return INFEASIBLE;
+ else
+ _flow[e]=0;
+
+ }
+ }
+ }
+
+ // Shift potentials to meet the requirements of the GEQ/LEQ type
+ // optimality conditions
+ if (_sum_supply == 0) {
+ if (_stype == GEQ) {
+ Cost max_pot = -std::numeric_limits<Cost>::max();
+ for (int i = 0; i != _node_num; ++i) {
+ if (_pi[i] > max_pot) max_pot = _pi[i];
+ }
+ if (max_pot > 0) {
+ for (int i = 0; i != _node_num; ++i)
+ _pi[i] -= max_pot;
+ }
+ } else {
+ Cost min_pot = std::numeric_limits<Cost>::max();
+ for (int i = 0; i != _node_num; ++i) {
+ if (_pi[i] < min_pot) min_pot = _pi[i];
+ }
+ if (min_pot < 0) {
+ for (int i = 0; i != _node_num; ++i)
+ _pi[i] -= min_pot;
+ }
+ }
+ }
+
+ return retVal;
+ }
+
+ }; //class NetworkSimplexSimple
+
+ ///@}
+
+} //namespace lemon
+
+#endif //LEMON_NETWORK_SIMPLEX_H
diff --git a/ot/optim.py b/ot/optim.py
new file mode 100644
index 0000000..0abd9e9
--- /dev/null
+++ b/ot/optim.py
@@ -0,0 +1,440 @@
+# -*- coding: utf-8 -*-
+"""
+Optimization algorithms for OT
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+from scipy.optimize.linesearch import scalar_search_armijo
+from .lp import emd
+from .bregman import sinkhorn
+
+# The corresponding scipy function does not work for matrices
+
+
+def line_search_armijo(f, xk, pk, gfk, old_fval,
+ args=(), c1=1e-4, alpha0=0.99):
+ """
+ Armijo linesearch function that works with matrices
+
+ find an approximate minimum of f(xk+alpha*pk) that satifies the
+ armijo conditions.
+
+ Parameters
+ ----------
+ f : callable
+ loss function
+ xk : ndarray
+ initial position
+ pk : ndarray
+ descent direction
+ gfk : ndarray
+ gradient of f at xk
+ old_fval : float
+ loss value at xk
+ args : tuple, optional
+ arguments given to f
+ c1 : float, optional
+ c1 const in armijo rule (>0)
+ alpha0 : float, optional
+ initial step (>0)
+
+ Returns
+ -------
+ alpha : float
+ step that satisfy armijo conditions
+ fc : int
+ nb of function call
+ fa : float
+ loss value at step alpha
+
+ """
+ xk = np.atleast_1d(xk)
+ fc = [0]
+
+ def phi(alpha1):
+ fc[0] += 1
+ return f(xk + alpha1 * pk, *args)
+
+ if old_fval is None:
+ phi0 = phi(0.)
+ else:
+ phi0 = old_fval
+
+ derphi0 = np.sum(pk * gfk) # Quickfix for matrices
+ alpha, phi1 = scalar_search_armijo(
+ phi, phi0, derphi0, c1=c1, alpha0=alpha0)
+
+ return alpha, fc[0], phi1
+
+
+def solve_linesearch(cost, G, deltaG, Mi, f_val,
+ armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
+ """
+ Solve the linesearch in the FW iterations
+ Parameters
+ ----------
+ cost : method
+ Cost in the FW for the linesearch
+ G : ndarray, shape(ns,nt)
+ The transport map at a given iteration of the FW
+ deltaG : ndarray (ns,nt)
+ Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
+ Mi : ndarray (ns,nt)
+ Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
+ f_val : float
+ Value of the cost at G
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ C1 : ndarray (ns,ns), optional
+ Structure matrix in the source domain. Only used and necessary when armijo=False
+ C2 : ndarray (nt,nt), optional
+ Structure matrix in the target domain. Only used and necessary when armijo=False
+ reg : float, optional
+ Regularization parameter. Only used and necessary when armijo=False
+ Gc : ndarray (ns,nt)
+ Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
+ constC : ndarray (ns,nt)
+ Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
+ M : ndarray (ns,nt), optional
+ Cost matrix between the features. Only used and necessary when armijo=False
+ Returns
+ -------
+ alpha : float
+ The optimal step size of the FW
+ fc : int
+ nb of function call. Useless here
+ f_val : float
+ The value of the cost for the next iteration
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ if armijo:
+ alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
+ else: # requires symetric matrices
+ dot1 = np.dot(C1, deltaG)
+ dot12 = dot1.dot(C2)
+ a = -2 * reg * np.sum(dot12 * deltaG)
+ b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
+ c = cost(G)
+
+ alpha = solve_1d_linesearch_quad(a, b, c)
+ fc = None
+ f_val = cost(G + alpha * deltaG)
+
+ return alpha, fc, f_val
+
+
+def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
+ stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
+ """
+ Solve the general regularized OT problem with conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (ns,)
+ samples weights in the source domain
+ b : ndarray, shape (nt,)
+ samples in the target domain
+ M : ndarray, shape (ns, nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ G0 : ndarray, shape (ns,nt), optional
+ initial guess (default is indep joint density)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshol on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ **kwargs : dict
+ Parameters for linesearch
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized optimal ransport
+ ot.bregman.sinkhorn : Entropic regularized optimal transport
+
+ """
+
+ loop = 1
+
+ if log:
+ log = {'loss': []}
+
+ if G0 is None:
+ G = np.outer(a, b)
+ else:
+ G = G0
+
+ def cost(G):
+ return np.sum(M * G) + reg * f(G)
+
+ f_val = cost(G)
+ if log:
+ log['loss'].append(f_val)
+
+ it = 0
+
+ if verbose:
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
+
+ while loop:
+
+ it += 1
+ old_fval = f_val
+
+ # problem linearization
+ Mi = M + reg * df(G)
+ # set M positive
+ Mi += Mi.min()
+
+ # solve linear program
+ Gc = emd(a, b, Mi)
+
+ deltaG = Gc - G
+
+ # line search
+ alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
+
+ G = G + alpha * deltaG
+
+ # test convergence
+ if it >= numItermax:
+ loop = 0
+
+ abs_delta_fval = abs(f_val - old_fval)
+ relative_delta_fval = abs_delta_fval / abs(f_val)
+ if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
+ loop = 0
+
+ if log:
+ log['loss'].append(f_val)
+
+ if verbose:
+ if it % 20 == 0:
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
+
+ if log:
+ return G, log
+ else:
+ return G
+
+
+def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
+ numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
+ """
+ Solve the general regularized OT problem with the generalized conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (ns,)
+ samples weights in the source domain
+ b : ndarrayv (nt,)
+ samples in the target domain
+ M : ndarray, shape (ns, nt)
+ loss matrix
+ reg1 : float
+ Entropic Regularization term >0
+ reg2 : float
+ Second Regularization term >0
+ G0 : ndarray, shape (ns, nt), optional
+ initial guess (default is indep joint density)
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations of Sinkhorn
+ stopThr : float, optional
+ Stop threshol on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshol on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gamma : ndarray, shape (ns, nt)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ References
+ ----------
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+
+ See Also
+ --------
+ ot.optim.cg : conditional gradient
+
+ """
+
+ loop = 1
+
+ if log:
+ log = {'loss': []}
+
+ if G0 is None:
+ G = np.outer(a, b)
+ else:
+ G = G0
+
+ def cost(G):
+ return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G)
+
+ f_val = cost(G)
+ if log:
+ log['loss'].append(f_val)
+
+ it = 0
+
+ if verbose:
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
+
+ while loop:
+
+ it += 1
+ old_fval = f_val
+
+ # problem linearization
+ Mi = M + reg2 * df(G)
+
+ # solve linear program with Sinkhorn
+ # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
+ Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax)
+
+ deltaG = Gc - G
+
+ # line search
+ dcost = Mi + reg1 * (1 + np.log(G)) # ??
+ alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
+
+ G = G + alpha * deltaG
+
+ # test convergence
+ if it >= numItermax:
+ loop = 0
+
+ abs_delta_fval = abs(f_val - old_fval)
+ relative_delta_fval = abs_delta_fval / abs(f_val)
+
+ if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
+ loop = 0
+
+ if log:
+ log['loss'].append(f_val)
+
+ if verbose:
+ if it % 20 == 0:
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
+
+ if log:
+ return G, log
+ else:
+ return G
+
+
+def solve_1d_linesearch_quad(a, b, c):
+ """
+ For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem:
+ .. math::
+ \argmin f(x)=a*x^{2}+b*x+c
+
+ Parameters
+ ----------
+ a,b,c : float
+ The coefficients of the quadratic function
+
+ Returns
+ -------
+ x : float
+ The optimal value which leads to the minimal cost
+ """
+ f0 = c
+ df0 = b
+ f1 = a + f0 + df0
+
+ if a > 0: # convex
+ minimum = min(1, max(0, np.divide(-b, 2.0 * a)))
+ return minimum
+ else: # non convex
+ if f0 > f1:
+ return 1
+ else:
+ return 0
diff --git a/ot/plot.py b/ot/plot.py
new file mode 100644
index 0000000..f403e98
--- /dev/null
+++ b/ot/plot.py
@@ -0,0 +1,91 @@
+"""
+Functions for plotting OT matrices
+
+.. warning::
+ Note that by default the module is not import in :mod:`ot`. In order to
+ use it you need to explicitely import :mod:`ot.plot`
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+from matplotlib import gridspec
+
+
+def plot1D_mat(a, b, M, title=''):
+ """ Plot matrix M with the source and target 1D distribution
+
+ Creates a subplot with the source distribution a on the left and
+ target distribution b on the tot. The matrix M is shown in between.
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (na,)
+ Source distribution
+ b : ndarray, shape (nb,)
+ Target distribution
+ M : ndarray, shape (na, nb)
+ Matrix to plot
+ """
+ na, nb = M.shape
+
+ gs = gridspec.GridSpec(3, 3)
+
+ xa = np.arange(na)
+ xb = np.arange(nb)
+
+ ax1 = pl.subplot(gs[0, 1:])
+ pl.plot(xb, b, 'r', label='Target distribution')
+ pl.yticks(())
+ pl.title(title)
+
+ ax2 = pl.subplot(gs[1:, 0])
+ pl.plot(a, xa, 'b', label='Source distribution')
+ pl.gca().invert_xaxis()
+ pl.gca().invert_yaxis()
+ pl.xticks(())
+
+ pl.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2)
+ pl.imshow(M, interpolation='nearest')
+ pl.axis('off')
+
+ pl.xlim((0, nb))
+ pl.tight_layout()
+ pl.subplots_adjust(wspace=0., hspace=0.2)
+
+
+def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
+ """ Plot matrix M in 2D with lines using alpha values
+
+ Plot lines between source and target 2D samples with a color
+ proportional to the value of the matrix G between samples.
+
+
+ Parameters
+ ----------
+ xs : ndarray, shape (ns,2)
+ Source samples positions
+ b : ndarray, shape (nt,2)
+ Target samples positions
+ G : ndarray, shape (na,nb)
+ OT matrix
+ thr : float, optional
+ threshold above which the line is drawn
+ **kwargs : dict
+ paameters given to the plot functions (default color is black if
+ nothing given)
+ """
+ if ('color' not in kwargs) and ('c' not in kwargs):
+ kwargs['color'] = 'k'
+ mx = G.max()
+ for i in range(xs.shape[0]):
+ for j in range(xt.shape[0]):
+ if G[i, j] / mx > thr:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
+ alpha=G[i, j] / mx, **kwargs)
diff --git a/ot/smooth.py b/ot/smooth.py
new file mode 100644
index 0000000..5a8e4b5
--- /dev/null
+++ b/ot/smooth.py
@@ -0,0 +1,600 @@
+#Copyright (c) 2018, Mathieu Blondel
+#All rights reserved.
+#
+#Redistribution and use in source and binary forms, with or without
+#modification, are permitted provided that the following conditions are met:
+#
+#1. Redistributions of source code must retain the above copyright notice, this
+#list of conditions and the following disclaimer.
+#
+#2. Redistributions in binary form must reproduce the above copyright notice,
+#this list of conditions and the following disclaimer in the documentation and/or
+#other materials provided with the distribution.
+#
+#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+#ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+#WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+#IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
+#INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
+#NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
+#OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+#LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+#OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+#THE POSSIBILITY OF SUCH DAMAGE.
+
+# Author: Mathieu Blondel
+# Remi Flamary <remi.flamary@unice.fr>
+
+"""
+Implementation of
+Smooth and Sparse Optimal Transport.
+Mathieu Blondel, Vivien Seguy, Antoine Rolet.
+In Proc. of AISTATS 2018.
+https://arxiv.org/abs/1710.06276
+
+[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal
+Transport. Proceedings of the Twenty-First International Conference on
+Artificial Intelligence and Statistics (AISTATS).
+
+Original code from https://github.com/mblondel/smooth-ot/
+
+"""
+
+import numpy as np
+from scipy.optimize import minimize
+
+
+def projection_simplex(V, z=1, axis=None):
+ """ Projection of x onto the simplex, scaled by z
+
+ P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
+ z: float or array
+ If array, len(z) must be compatible with V
+ axis: None or int
+ - axis=None: project V by P(V.ravel(); z)
+ - axis=1: project each V[i] by P(V[i]; z[i])
+ - axis=0: project each V[:, j] by P(V[:, j]; z[j])
+ """
+ if axis == 1:
+ n_features = V.shape[1]
+ U = np.sort(V, axis=1)[:, ::-1]
+ z = np.ones(len(V)) * z
+ cssv = np.cumsum(U, axis=1) - z[:, np.newaxis]
+ ind = np.arange(n_features) + 1
+ cond = U - cssv / ind > 0
+ rho = np.count_nonzero(cond, axis=1)
+ theta = cssv[np.arange(len(V)), rho - 1] / rho
+ return np.maximum(V - theta[:, np.newaxis], 0)
+
+ elif axis == 0:
+ return projection_simplex(V.T, z, axis=1).T
+
+ else:
+ V = V.ravel().reshape(1, -1)
+ return projection_simplex(V, z, axis=1).ravel()
+
+
+class Regularization(object):
+ """Base class for Regularization objects
+
+ Notes
+ -----
+ This class is not intended for direct use but as aparent for true
+ regularizatiojn implementation.
+ """
+
+ def __init__(self, gamma=1.0):
+ """
+
+ Parameters
+ ----------
+ gamma: float
+ Regularization parameter.
+ We recover unregularized OT when gamma -> 0.
+
+ """
+ self.gamma = gamma
+
+ def delta_Omega(X):
+ """
+ Compute delta_Omega(X[:, j]) for each X[:, j].
+ delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).
+
+ Parameters
+ ----------
+ X: array, shape = len(a) x len(b)
+ Input array.
+
+ Returns
+ -------
+ v: array, len(b)
+ Values: v[j] = delta_Omega(X[:, j])
+ G: array, len(a) x len(b)
+ Gradients: G[:, j] = nabla delta_Omega(X[:, j])
+ """
+ raise NotImplementedError
+
+ def max_Omega(X, b):
+ """
+ Compute max_Omega_j(X[:, j]) for each X[:, j].
+ max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j].
+
+ Parameters
+ ----------
+ X: array, shape = len(a) x len(b)
+ Input array.
+
+ Returns
+ -------
+ v: array, len(b)
+ Values: v[j] = max_Omega_j(X[:, j])
+ G: array, len(a) x len(b)
+ Gradients: G[:, j] = nabla max_Omega_j(X[:, j])
+ """
+ raise NotImplementedError
+
+ def Omega(T):
+ """
+ Compute regularization term.
+
+ Parameters
+ ----------
+ T: array, shape = len(a) x len(b)
+ Input array.
+
+ Returns
+ -------
+ value: float
+ Regularization term.
+ """
+ raise NotImplementedError
+
+
+class NegEntropy(Regularization):
+ """ NegEntropy regularization """
+
+ def delta_Omega(self, X):
+ G = np.exp(X / self.gamma - 1)
+ val = self.gamma * np.sum(G, axis=0)
+ return val, G
+
+ def max_Omega(self, X, b):
+ max_X = np.max(X, axis=0) / self.gamma
+ exp_X = np.exp(X / self.gamma - max_X)
+ val = self.gamma * (np.log(np.sum(exp_X, axis=0)) + max_X)
+ val -= self.gamma * np.log(b)
+ G = exp_X / np.sum(exp_X, axis=0)
+ return val, G
+
+ def Omega(self, T):
+ return self.gamma * np.sum(T * np.log(T))
+
+
+class SquaredL2(Regularization):
+ """ Squared L2 regularization """
+
+ def delta_Omega(self, X):
+ max_X = np.maximum(X, 0)
+ val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma)
+ G = max_X / self.gamma
+ return val, G
+
+ def max_Omega(self, X, b):
+ G = projection_simplex(X / (b * self.gamma), axis=0)
+ val = np.sum(X * G, axis=0)
+ val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
+ return val, G
+
+ def Omega(self, T):
+ return 0.5 * self.gamma * np.sum(T ** 2)
+
+
+def dual_obj_grad(alpha, beta, a, b, C, regul):
+ """
+ Compute objective value and gradients of dual objective.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ beta: array, shape = len(b)
+ Current iterate of dual potentials.
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+
+ Returns
+ -------
+ obj: float
+ Objective value (higher is better).
+ grad_alpha: array, shape = len(a)
+ Gradient w.r.t. alpha.
+ grad_beta: array, shape = len(b)
+ Gradient w.r.t. beta.
+ """
+ obj = np.dot(alpha, a) + np.dot(beta, b)
+ grad_alpha = a.copy()
+ grad_beta = b.copy()
+
+ # X[:, j] = alpha + beta[j] - C[:, j]
+ X = alpha[:, np.newaxis] + beta - C
+
+ # val.shape = len(b)
+ # G.shape = len(a) x len(b)
+ val, G = regul.delta_Omega(X)
+
+ obj -= np.sum(val)
+ grad_alpha -= G.sum(axis=1)
+ grad_beta -= G.sum(axis=0)
+
+ return obj, grad_alpha, grad_beta
+
+
+def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
+ verbose=False):
+ """
+ Solve the "smoothed" dual objective.
+
+ Parameters
+ ----------
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+ method: str
+ Solver to be used (passed to `scipy.optimize.minimize`).
+ tol: float
+ Tolerance parameter.
+ max_iter: int
+ Maximum number of iterations.
+
+ Returns
+ -------
+ alpha: array, shape = len(a)
+ beta: array, shape = len(b)
+ Dual potentials.
+ """
+
+ def _func(params):
+ # Unpack alpha and beta.
+ alpha = params[:len(a)]
+ beta = params[len(a):]
+
+ obj, grad_alpha, grad_beta = dual_obj_grad(alpha, beta, a, b, C, regul)
+
+ # Pack grad_alpha and grad_beta.
+ grad = np.concatenate((grad_alpha, grad_beta))
+
+ # We need to maximize the dual.
+ return -obj, -grad
+
+ # Unfortunately, `minimize` only supports functions whose argument is a
+ # vector. So, we need to concatenate alpha and beta.
+ alpha_init = np.zeros(len(a))
+ beta_init = np.zeros(len(b))
+ params_init = np.concatenate((alpha_init, beta_init))
+
+ res = minimize(_func, params_init, method=method, jac=True,
+ tol=tol, options=dict(maxiter=max_iter, disp=verbose))
+
+ alpha = res.x[:len(a)]
+ beta = res.x[len(a):]
+
+ return alpha, beta, res
+
+
+def semi_dual_obj_grad(alpha, a, b, C, regul):
+ """
+ Compute objective value and gradient of semi-dual objective.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ Current iterate of semi-dual potentials.
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a max_Omega(X) method.
+
+ Returns
+ -------
+ obj: float
+ Objective value (higher is better).
+ grad: array, shape = len(a)
+ Gradient w.r.t. alpha.
+ """
+ obj = np.dot(alpha, a)
+ grad = a.copy()
+
+ # X[:, j] = alpha - C[:, j]
+ X = alpha[:, np.newaxis] - C
+
+ # val.shape = len(b)
+ # G.shape = len(a) x len(b)
+ val, G = regul.max_Omega(X, b)
+
+ obj -= np.dot(b, val)
+ grad -= np.dot(G, b)
+
+ return obj, grad
+
+
+def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
+ verbose=False):
+ """
+ Solve the "smoothed" semi-dual objective.
+
+ Parameters
+ ----------
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a max_Omega(X) method.
+ method: str
+ Solver to be used (passed to `scipy.optimize.minimize`).
+ tol: float
+ Tolerance parameter.
+ max_iter: int
+ Maximum number of iterations.
+
+ Returns
+ -------
+ alpha: array, shape = len(a)
+ Semi-dual potentials.
+ """
+
+ def _func(alpha):
+ obj, grad = semi_dual_obj_grad(alpha, a, b, C, regul)
+ # We need to maximize the semi-dual.
+ return -obj, -grad
+
+ alpha_init = np.zeros(len(a))
+
+ res = minimize(_func, alpha_init, method=method, jac=True,
+ tol=tol, options=dict(maxiter=max_iter, disp=verbose))
+
+ return res.x, res
+
+
+def get_plan_from_dual(alpha, beta, C, regul):
+ """
+ Retrieve optimal transportation plan from optimal dual potentials.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ beta: array, shape = len(b)
+ Optimal dual potentials.
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+
+ Returns
+ -------
+ T: array, shape = len(a) x len(b)
+ Optimal transportation plan.
+ """
+ X = alpha[:, np.newaxis] + beta - C
+ return regul.delta_Omega(X)[1]
+
+
+def get_plan_from_semi_dual(alpha, b, C, regul):
+ """
+ Retrieve optimal transportation plan from optimal semi-dual potentials.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ Optimal semi-dual potentials.
+ b: array, shape = len(b)
+ Second input histogram (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+
+ Returns
+ -------
+ T: array, shape = len(a) x len(b)
+ Optimal transportation plan.
+ """
+ X = alpha[:, np.newaxis] - C
+ return regul.max_Omega(X, b)[1] * b
+
+
+def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, verbose=False, log=False):
+ r"""
+ Solve the regularized OT problem in the dual and return the OT matrix
+
+ The function solves the smooth relaxed dual formulation (7) in [17]_ :
+
+ .. math::
+ \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j)
+
+ where :
+
+ - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega`
+ - a and b are source and target weights (sum to 1)
+
+ The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega`
+ (See [17]_ Proposition 1).
+ The optimization algorithm is using gradient decent (L-BFGS by default).
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ reg_type : str
+ Regularization type, can be the following (default ='l2'):
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
+ - 'l2' : Squared Euclidean regularization
+ method : str
+ Solver to use for scipy.optimize.minimize
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.sinhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ if reg_type.lower() in ['l2', 'squaredl2']:
+ regul = SquaredL2(gamma=reg)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+
+ # solve dual
+ alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
+ tol=stopThr, verbose=verbose)
+
+ # reconstruct transport matrix
+ G = get_plan_from_dual(alpha, beta, M, regul)
+
+ if log:
+ log = {'alpha': alpha, 'beta': beta, 'res': res}
+ return G, log
+ else:
+ return G
+
+
+def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, verbose=False, log=False):
+ r"""
+ Solve the regularized OT problem in the semi-dual and return the OT matrix
+
+ The function solves the smooth relaxed dual formulation (10) in [17]_ :
+
+ .. math::
+ \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b)
+
+ where :
+
+ .. math::
+ OT_\Omega^*(\alpha,b)=\sum_j b_j
+
+ - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17]
+ - a and b are source and target weights (sum to 1)
+
+ The OT matrix can is reconstructed using [17]_ Proposition 2.
+ The optimization algorithm is using gradient decent (L-BFGS by default).
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ reg_type : str
+ Regularization type, can be the following (default ='l2'):
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
+ - 'l2' : Squared Euclidean regularization
+ method : str
+ Solver to use for scipy.optimize.minimize
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.sinhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+ if reg_type.lower() in ['l2', 'squaredl2']:
+ regul = SquaredL2(gamma=reg)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+
+ # solve dual
+ alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax,
+ tol=stopThr, verbose=verbose)
+
+ # reconstruct transport matrix
+ G = get_plan_from_semi_dual(alpha, b, M, regul)
+
+ if log:
+ log = {'alpha': alpha, 'res': res}
+ return G, log
+ else:
+ return G
diff --git a/ot/stochastic.py b/ot/stochastic.py
new file mode 100644
index 0000000..13ed9cc
--- /dev/null
+++ b/ot/stochastic.py
@@ -0,0 +1,755 @@
+"""
+Stochastic solvers for regularized OT.
+
+
+"""
+
+# Author: Kilian Fatras <kilian.fatras@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+
+
+##############################################################################
+# Optimization toolbox for SEMI - DUAL problems
+##############################################################################
+
+
+def coordinate_grad_semi_dual(b, M, reg, beta, i):
+ r'''
+ Compute the coordinate gradient update for regularized discrete distributions for (i, :)
+
+ The function computes the gradient of the semi dual problem:
+
+ .. math::
+ \max_v \sum_i (\sum_j v_j * b_j - reg * log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i
+
+ Where :
+
+ - M is the (ns,nt) metric cost matrix
+ - v is a dual variable in R^J
+ - reg is the regularization term
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the ASGD & SAG algorithms
+ as proposed in [18]_ [alg.1 & alg.2]
+
+
+ Parameters
+ ----------
+ b : ndarray, shape (nt,)
+ Target measure.
+ M : ndarray, shape (ns, nt)
+ Cost matrix.
+ reg : float
+ Regularization term > 0.
+ v : ndarray, shape (nt,)
+ Dual variable.
+ i : int
+ Picked number i.
+
+ Returns
+ -------
+ coordinate gradient : ndarray, shape (nt,)
+
+ Examples
+ --------
+ >>> import ot
+ >>> np.random.seed(0)
+ >>> n_source = 7
+ >>> n_target = 4
+ >>> a = ot.utils.unif(n_source)
+ >>> b = ot.utils.unif(n_target)
+ >>> X_source = np.random.randn(n_source, 2)
+ >>> Y_target = np.random.randn(n_target, 2)
+ >>> M = ot.dist(X_source, Y_target)
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+
+ References
+ ----------
+ [Genevay et al., 2016] :
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
+ '''
+ r = M[i, :] - beta
+ exp_beta = np.exp(-r / reg) * b
+ khi = exp_beta / (np.sum(exp_beta))
+ return b - khi
+
+
+def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
+ r'''
+ Compute the SAG algorithm to solve the regularized discrete measures
+ optimal transport max problem
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1 = b
+
+ \gamma \geq 0
+
+ Where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the SAG algorithm
+ as proposed in [18]_ [alg.1]
+
+
+ Parameters
+ ----------
+
+ a : ndarray, shape (ns,),
+ Source measure.
+ b : ndarray, shape (nt,),
+ Target measure.
+ M : ndarray, shape (ns, nt),
+ Cost matrix.
+ reg : float
+ Regularization term > 0
+ numItermax : int
+ Number of iteration.
+ lr : float
+ Learning rate.
+
+ Returns
+ -------
+ v : ndarray, shape (nt,)
+ Dual variable.
+
+ Examples
+ --------
+ >>> import ot
+ >>> np.random.seed(0)
+ >>> n_source = 7
+ >>> n_target = 4
+ >>> a = ot.utils.unif(n_source)
+ >>> b = ot.utils.unif(n_target)
+ >>> X_source = np.random.randn(n_source, 2)
+ >>> Y_target = np.random.randn(n_target, 2)
+ >>> M = ot.dist(X_source, Y_target)
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ References
+ ----------
+
+ [Genevay et al., 2016] :
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
+ '''
+
+ if lr is None:
+ lr = 1. / max(a / reg)
+ n_source = np.shape(M)[0]
+ n_target = np.shape(M)[1]
+ cur_beta = np.zeros(n_target)
+ stored_gradient = np.zeros((n_source, n_target))
+ sum_stored_gradient = np.zeros(n_target)
+ for _ in range(numItermax):
+ i = np.random.randint(n_source)
+ cur_coord_grad = a[i] * coordinate_grad_semi_dual(b, M, reg,
+ cur_beta, i)
+ sum_stored_gradient += (cur_coord_grad - stored_gradient[i])
+ stored_gradient[i] = cur_coord_grad
+ cur_beta += lr * (1. / n_source) * sum_stored_gradient
+ return cur_beta
+
+
+def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
+ r'''
+ Compute the ASGD algorithm to solve the regularized semi continous measures optimal transport max problem
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma \geq 0
+
+ Where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the ASGD algorithm
+ as proposed in [18]_ [alg.2]
+
+
+ Parameters
+ ----------
+ b : ndarray, shape (nt,)
+ target measure
+ M : ndarray, shape (ns, nt)
+ cost matrix
+ reg : float
+ Regularization term > 0
+ numItermax : int
+ Number of iteration.
+ lr : float
+ Learning rate.
+
+ Returns
+ -------
+ ave_v : ndarray, shape (nt,)
+ dual variable
+
+ Examples
+ --------
+ >>> import ot
+ >>> np.random.seed(0)
+ >>> n_source = 7
+ >>> n_target = 4
+ >>> a = ot.utils.unif(n_source)
+ >>> b = ot.utils.unif(n_target)
+ >>> X_source = np.random.randn(n_source, 2)
+ >>> Y_target = np.random.randn(n_target, 2)
+ >>> M = ot.dist(X_source, Y_target)
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ References
+ ----------
+
+ [Genevay et al., 2016] :
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
+ '''
+
+ if lr is None:
+ lr = 1. / max(a / reg)
+ n_source = np.shape(M)[0]
+ n_target = np.shape(M)[1]
+ cur_beta = np.zeros(n_target)
+ ave_beta = np.zeros(n_target)
+ for cur_iter in range(numItermax):
+ k = cur_iter + 1
+ i = np.random.randint(n_source)
+ cur_coord_grad = coordinate_grad_semi_dual(b, M, reg, cur_beta, i)
+ cur_beta += (lr / np.sqrt(k)) * cur_coord_grad
+ ave_beta = (1. / k) * cur_beta + (1 - 1. / k) * ave_beta
+ return ave_beta
+
+
+def c_transform_entropic(b, M, reg, beta):
+ r'''
+ The goal is to recover u from the c-transform.
+
+ The function computes the c_transform of a dual variable from the other
+ dual variable:
+
+ .. math::
+ u = v^{c,reg} = -reg \sum_j exp((v - M)/reg) b_j
+
+ Where :
+
+ - M is the (ns,nt) metric cost matrix
+ - u, v are dual variables in R^IxR^J
+ - reg is the regularization term
+
+ It is used to recover an optimal u from optimal v solving the semi dual
+ problem, see Proposition 2.1 of [18]_
+
+
+ Parameters
+ ----------
+ b : ndarray, shape (nt,)
+ Target measure
+ M : ndarray, shape (ns, nt)
+ Cost matrix
+ reg : float
+ Regularization term > 0
+ v : ndarray, shape (nt,)
+ Dual variable.
+
+ Returns
+ -------
+ u : ndarray, shape (ns,)
+ Dual variable.
+
+ Examples
+ --------
+ >>> import ot
+ >>> np.random.seed(0)
+ >>> n_source = 7
+ >>> n_target = 4
+ >>> a = ot.utils.unif(n_source)
+ >>> b = ot.utils.unif(n_target)
+ >>> X_source = np.random.randn(n_source, 2)
+ >>> Y_target = np.random.randn(n_target, 2)
+ >>> M = ot.dist(X_source, Y_target)
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ References
+ ----------
+
+ [Genevay et al., 2016] :
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
+ '''
+
+ n_source = np.shape(M)[0]
+ alpha = np.zeros(n_source)
+ for i in range(n_source):
+ r = M[i, :] - beta
+ min_r = np.min(r)
+ exp_beta = np.exp(-(r - min_r) / reg) * b
+ alpha[i] = min_r - reg * np.log(np.sum(exp_beta))
+ return alpha
+
+
+def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
+ log=False):
+ r'''
+ Compute the transportation matrix to solve the regularized discrete
+ measures optimal transport max problem
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma \geq 0
+
+ Where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+ The algorithm used for solving the problem is the SAG or ASGD algorithms
+ as proposed in [18]_
+
+
+ Parameters
+ ----------
+
+ a : ndarray, shape (ns,)
+ source measure
+ b : ndarray, shape (nt,)
+ target measure
+ M : ndarray, shape (ns, nt)
+ cost matrix
+ reg : float
+ Regularization term > 0
+ methode : str
+ used method (SAG or ASGD)
+ numItermax : int
+ number of iteration
+ lr : float
+ learning rate
+ n_source : int
+ size of the source measure
+ n_target : int
+ size of the target measure
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ pi : ndarray, shape (ns, nt)
+ transportation matrix
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+ >>> import ot
+ >>> np.random.seed(0)
+ >>> n_source = 7
+ >>> n_target = 4
+ >>> a = ot.utils.unif(n_source)
+ >>> b = ot.utils.unif(n_target)
+ >>> X_source = np.random.randn(n_source, 2)
+ >>> Y_target = np.random.randn(n_target, 2)
+ >>> M = ot.dist(X_source, Y_target)
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ References
+ ----------
+
+ [Genevay et al., 2016] :
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
+ '''
+
+ if method.lower() == "sag":
+ opt_beta = sag_entropic_transport(a, b, M, reg, numItermax, lr)
+ elif method.lower() == "asgd":
+ opt_beta = averaged_sgd_entropic_transport(a, b, M, reg, numItermax, lr)
+ else:
+ print("Please, select your method between SAG and ASGD")
+ return None
+
+ opt_alpha = c_transform_entropic(b, M, reg, opt_beta)
+ pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
+ a[:, None] * b[None, :])
+
+ if log:
+ log = {}
+ log['alpha'] = opt_alpha
+ log['beta'] = opt_beta
+ return pi, log
+ else:
+ return pi
+
+
+##############################################################################
+# Optimization toolbox for DUAL problems
+##############################################################################
+
+
+def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
+ batch_beta):
+ r'''
+ Computes the partial gradient of the dual optimal transport problem.
+
+ For each (i,j) in a batch of coordinates, the partial gradients are :
+
+ .. math::
+ \partial_{u_i} F = u_i * b_s/l_{v} - \sum_{j \in B_v} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j
+
+ \partial_{v_j} F = v_j * b_s/l_{u} - \sum_{i \in B_u} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j
+
+ Where :
+
+ - M is the (ns,nt) metric cost matrix
+ - u, v are dual variables in R^ixR^J
+ - reg is the regularization term
+ - :math:`B_u` and :math:`B_v` are lists of index
+ - :math:`b_s` is the size of the batchs :math:`B_u` and :math:`B_v`
+ - :math:`l_u` and :math:`l_v` are the lenghts of :math:`B_u` and :math:`B_v`
+ - a and b are source and target weights (sum to 1)
+
+
+ The algorithm used for solving the dual problem is the SGD algorithm
+ as proposed in [19]_ [alg.1]
+
+
+ Parameters
+ ----------
+ a : ndarray, shape (ns,)
+ source measure
+ b : ndarray, shape (nt,)
+ target measure
+ M : ndarray, shape (ns, nt)
+ cost matrix
+ reg : float
+ Regularization term > 0
+ alpha : ndarray, shape (ns,)
+ dual variable
+ beta : ndarray, shape (nt,)
+ dual variable
+ batch_size : int
+ size of the batch
+ batch_alpha : ndarray, shape (bs,)
+ batch of index of alpha
+ batch_beta : ndarray, shape (bs,)
+ batch of index of beta
+
+ Returns
+ -------
+ grad : ndarray, shape (ns,)
+ partial grad F
+
+ Examples
+ --------
+ >>> import ot
+ >>> np.random.seed(0)
+ >>> n_source = 7
+ >>> n_target = 4
+ >>> a = ot.utils.unif(n_source)
+ >>> b = ot.utils.unif(n_target)
+ >>> X_source = np.random.randn(n_source, 2)
+ >>> Y_target = np.random.randn(n_target, 2)
+ >>> M = ot.dist(X_source, Y_target)
+ >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg=1, batch_size=3, numItermax=30000, lr=0.1, log=True)
+ >>> log['alpha']
+ array([0.71759102, 1.57057384, 0.85576566, 0.1208211 , 0.59190466,
+ 1.197148 , 0.17805133])
+ >>> log['beta']
+ array([0.49741367, 0.57478564, 1.40075528, 2.75890102])
+ >>> sgd_dual_pi
+ array([[2.09730063e-02, 8.38169324e-02, 7.50365455e-03, 8.72731415e-09],
+ [5.58432437e-03, 5.89881299e-04, 3.09558411e-05, 8.35469849e-07],
+ [3.26489515e-03, 7.15536035e-02, 2.99778211e-02, 3.02601593e-10],
+ [4.05390622e-02, 5.31085068e-02, 6.65191787e-02, 1.55812785e-06],
+ [7.82299812e-02, 6.12099102e-03, 4.44989098e-02, 2.37719187e-03],
+ [5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04],
+ [6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]])
+
+ References
+ ----------
+
+ [Seguy et al., 2018] :
+ International Conference on Learning Representation (2018),
+ arXiv preprint arxiv:1711.02283.
+ '''
+ G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] -
+ M[batch_alpha, :][:, batch_beta]) / reg) *
+ a[batch_alpha, None] * b[None, batch_beta])
+ grad_beta = np.zeros(np.shape(M)[1])
+ grad_alpha = np.zeros(np.shape(M)[0])
+ grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0]
+ + G.sum(0))
+ grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta)
+ / np.shape(M)[1] + G.sum(1))
+
+ return grad_alpha, grad_beta
+
+
+def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
+ r'''
+ Compute the sgd algorithm to solve the regularized discrete measures
+ optimal transport dual problem
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma \geq 0
+
+ Where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ Parameters
+ ----------
+ a : ndarray, shape (ns,)
+ source measure
+ b : ndarray, shape (nt,)
+ target measure
+ M : ndarray, shape (ns, nt)
+ cost matrix
+ reg : float
+ Regularization term > 0
+ batch_size : int
+ size of the batch
+ numItermax : int
+ number of iteration
+ lr : float
+ learning rate
+
+ Returns
+ -------
+ alpha : ndarray, shape (ns,)
+ dual variable
+ beta : ndarray, shape (nt,)
+ dual variable
+
+ Examples
+ --------
+ >>> import ot
+ >>> n_source = 7
+ >>> n_target = 4
+ >>> reg = 1
+ >>> numItermax = 20000
+ >>> lr = 0.1
+ >>> batch_size = 3
+ >>> log = True
+ >>> a = ot.utils.unif(n_source)
+ >>> b = ot.utils.unif(n_target)
+ >>> rng = np.random.RandomState(0)
+ >>> X_source = rng.randn(n_source, 2)
+ >>> Y_target = rng.randn(n_target, 2)
+ >>> M = ot.dist(X_source, Y_target)
+ >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
+ >>> log['alpha']
+ array([0.64171798, 1.27932201, 0.78132257, 0.15638935, 0.54888354,
+ 1.03663469, 0.20595781])
+ >>> log['beta']
+ array([0.51207194, 0.58033189, 1.28922676, 2.26859736])
+ >>> sgd_dual_pi
+ array([[1.97276541e-02, 7.81248547e-02, 6.22136048e-03, 4.95442423e-09],
+ [4.23494310e-03, 4.43286263e-04, 2.06927079e-05, 3.82389139e-07],
+ [3.07542414e-03, 6.67897769e-02, 2.48904999e-02, 1.72030247e-10],
+ [4.26271990e-02, 5.53375455e-02, 6.16535024e-02, 9.88812650e-07],
+ [7.60423265e-02, 5.89585256e-03, 3.81267087e-02, 1.39458256e-03],
+ [4.37557504e-02, 1.85189176e-03, 1.72335760e-03, 3.55491279e-04],
+ [6.33096109e-02, 4.11683954e-02, 5.02962051e-02, 5.43097516e-06]])
+
+ References
+ ----------
+ [Seguy et al., 2018] :
+ International Conference on Learning Representation (2018),
+ arXiv preprint arxiv:1711.02283.
+ '''
+
+ n_source = np.shape(M)[0]
+ n_target = np.shape(M)[1]
+ cur_alpha = np.zeros(n_source)
+ cur_beta = np.zeros(n_target)
+ for cur_iter in range(numItermax):
+ k = np.sqrt(cur_iter + 1)
+ batch_alpha = np.random.choice(n_source, batch_size, replace=False)
+ batch_beta = np.random.choice(n_target, batch_size, replace=False)
+ update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
+ cur_beta, batch_size,
+ batch_alpha, batch_beta)
+ cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha]
+ cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta]
+
+ return cur_alpha, cur_beta
+
+
+def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
+ log=False):
+ r'''
+ Compute the transportation matrix to solve the regularized discrete measures
+ optimal transport dual problem
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma \geq 0
+
+ Where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ Parameters
+ ----------
+ a : ndarray, shape (ns,)
+ source measure
+ b : ndarray, shape (nt,)
+ target measure
+ M : ndarray, shape (ns, nt)
+ cost matrix
+ reg : float
+ Regularization term > 0
+ batch_size : int
+ size of the batch
+ numItermax : int
+ number of iteration
+ lr : float
+ learning rate
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ pi : ndarray, shape (ns, nt)
+ transportation matrix
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+ >>> import ot
+ >>> n_source = 7
+ >>> n_target = 4
+ >>> reg = 1
+ >>> numItermax = 20000
+ >>> lr = 0.1
+ >>> batch_size = 3
+ >>> log = True
+ >>> a = ot.utils.unif(n_source)
+ >>> b = ot.utils.unif(n_target)
+ >>> rng = np.random.RandomState(0)
+ >>> X_source = rng.randn(n_source, 2)
+ >>> Y_target = rng.randn(n_target, 2)
+ >>> M = ot.dist(X_source, Y_target)
+ >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
+ >>> log['alpha']
+ array([0.64057733, 1.2683513 , 0.75610161, 0.16024284, 0.54926534,
+ 1.0514201 , 0.19958936])
+ >>> log['beta']
+ array([0.51372571, 0.58843489, 1.27993921, 2.24344807])
+ >>> sgd_dual_pi
+ array([[1.97377795e-02, 7.86706853e-02, 6.15682001e-03, 4.82586997e-09],
+ [4.19566963e-03, 4.42016865e-04, 2.02777272e-05, 3.68823708e-07],
+ [3.00379244e-03, 6.56562018e-02, 2.40462171e-02, 1.63579656e-10],
+ [4.28626062e-02, 5.60031599e-02, 6.13193826e-02, 9.67977735e-07],
+ [7.61972739e-02, 5.94609051e-03, 3.77886693e-02, 1.36046648e-03],
+ [4.44810042e-02, 1.89476742e-03, 1.73285847e-03, 3.51826036e-04],
+ [6.30118293e-02, 4.12398660e-02, 4.95148998e-02, 5.26247246e-06]])
+
+ References
+ ----------
+
+ [Seguy et al., 2018] :
+ International Conference on Learning Representation (2018),
+ arXiv preprint arxiv:1711.02283.
+ '''
+
+ opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
+ numItermax, lr)
+ pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
+ a[:, None] * b[None, :])
+ if log:
+ log = {}
+ log['alpha'] = opt_alpha
+ log['beta'] = opt_beta
+ return pi, log
+ else:
+ return pi
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
new file mode 100644
index 0000000..d516dfc
--- /dev/null
+++ b/ot/unbalanced.py
@@ -0,0 +1,1022 @@
+# -*- coding: utf-8 -*-
+"""
+Regularized Unbalanced OT
+"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+# License: MIT License
+
+from __future__ import division
+import warnings
+import numpy as np
+from scipy.special import logsumexp
+
+# from .utils import unif, dist
+
+
+def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the unbalanced entropic regularization optimal transport problem
+ and return the OT plan
+
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_reg_scaling', see those function for specific parameters
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
+
+
+ See Also
+ --------
+ ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
+ ot.unbalanced.sinkhorn_stabilized_unbalanced:
+ Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
+ Unbalanced Sinkhorn with epslilon scaling [9][10]
+
+ """
+
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
+ numItermax=1000, stopThr=1e-6, verbose=False,
+ log=False, **kwargs):
+ r"""
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the loss
+
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_reg_scaling', see those function for specific parameters
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .10]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.],[1., 0.]]
+ >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.)
+ array([0.31912866])
+
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
+
+ See Also
+ --------
+ ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
+ ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+
+ """
+ b = np.asarray(b, dtype=np.float64)
+ if len(b.shape) < 2:
+ b = b[:, None]
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method %s.' % method)
+
+
+def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the entropic regularization unbalanced optimal transport problem and return the loss
+
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.],[1., 0.]]
+ >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
+
+ References
+ ----------
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
+ if len(b) == 0:
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if n_hists:
+ u = np.ones((dim_a, 1)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
+ else:
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
+
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ u = (a / Kv) ** fi
+ Ktu = K.T.dot(u)
+ v = (b / Ktu) ** fi
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ u = uprev
+ v = vprev
+ break
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt += 1
+
+ if log:
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+
+ if n_hists: # return only loss
+ res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+
+ if log:
+ return u[:, None] * K * v[None, :], log
+ else:
+ return u[:, None] * K * v[None, :]
+
+
+def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False,
+ **kwargs):
+ r"""
+ Solve the entropic regularization unbalanced optimal transport
+ problem and return the loss
+
+ The function solves the following optimization problem using log-domain
+ stabilization as proposed in [10]:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.],[1., 0.]]
+ >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
+
+ References
+ ----------
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
+ if len(b) == 0:
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
+ else:
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim_a)
+ beta = np.zeros(dim_b)
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+
+ if n_hists:
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ if n_hists:
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ else:
+ alpha = alpha + reg * np.log(np.max(u))
+ beta = beta + reg * np.log(np.max(v))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ u = uprev
+ v = vprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
+ 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if n_hists:
+ logu = alpha[:, None] / reg + np.log(u)
+ logv = beta[:, None] / reg + np.log(v)
+ else:
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ if log:
+ log['logu'] = logu
+ log['logv'] = logv
+ if n_hists: # return only loss
+ res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
+ logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
+ res = np.exp(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+ ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ if log:
+ return ot_matrix, log
+ else:
+ return ot_matrix
+
+
+def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of
+ matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m : float
+ Marginal relaxation term > 0
+ tau : float
+ Stabilization threshold for log domain absorption.
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
+ G. (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+
+ """
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ fi = reg_m / (reg_m + reg)
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ q = (Ktu ** (1 - fi)) * f_beta
+ q = q.dot(weights) ** (1 / (1 - fi))
+ Q = q[:, None]
+ v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(q - qprev).max() / max(abs(q).max(),
+ abs(qprev).max(), 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+
+ """
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ K = np.exp(- M / reg)
+
+ fi = reg_m / (reg_m + reg)
+
+ v = np.ones((dim, n_hists)) / dim
+ u = np.ones((dim, 1)) / dim
+
+ cpt = 0
+ err = 1.
+
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ u = (A / Kv) ** fi
+ Ktu = K.T.dot(u)
+ q = ((Ktu ** (1 - fi)).dot(weights))
+ q = q ** (1 / (1 - fi))
+ Q = q[:, None]
+ v = (Q / Ktu) ** fi
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ u = uprev
+ v = vprev
+ break
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False, **kwargs):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+ """
+
+ if method.lower() == 'sinkhorn':
+ return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return barycenter_unbalanced(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
diff --git a/ot/utils.py b/ot/utils.py
new file mode 100644
index 0000000..b71458b
--- /dev/null
+++ b/ot/utils.py
@@ -0,0 +1,498 @@
+# -*- coding: utf-8 -*-
+"""
+Various useful functions
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import multiprocessing
+from functools import reduce
+import time
+
+import numpy as np
+from scipy.spatial.distance import cdist
+import sys
+import warnings
+try:
+ from inspect import signature
+except ImportError:
+ from .externals.funcsigs import signature
+
+__time_tic_toc = time.time()
+
+
+def tic():
+ """ Python implementation of Matlab tic() function """
+ global __time_tic_toc
+ __time_tic_toc = time.time()
+
+
+def toc(message='Elapsed time : {} s'):
+ """ Python implementation of Matlab toc() function """
+ t = time.time()
+ print(message.format(t - __time_tic_toc))
+ return t - __time_tic_toc
+
+
+def toq():
+ """ Python implementation of Julia toc() function """
+ t = time.time()
+ return t - __time_tic_toc
+
+
+def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
+ """Compute kernel matrix"""
+ if method.lower() in ['gaussian', 'gauss', 'rbf']:
+ K = np.exp(-dist(x1, x2) / (2 * sigma**2))
+ return K
+
+
+def unif(n):
+ """ return a uniform histogram of length n (simplex)
+
+ Parameters
+ ----------
+
+ n : int
+ number of bins in the histogram
+
+ Returns
+ -------
+ h : np.array (n,)
+ histogram of length n such that h_i=1/n for all i
+
+
+ """
+ return np.ones((n,)) / n
+
+
+def clean_zeros(a, b, M):
+ """ Remove all components with zeros weights in a and b
+ """
+ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
+ a2 = a[a > 0]
+ b2 = b[b > 0]
+ return a2, b2, M2
+
+
+def euclidean_distances(X, Y, squared=False):
+ """
+ Considering the rows of X (and Y=X) as vectors, compute the
+ distance matrix between each pair of vectors.
+ Parameters
+ ----------
+ X : {array-like}, shape (n_samples_1, n_features)
+ Y : {array-like}, shape (n_samples_2, n_features)
+ squared : boolean, optional
+ Return squared Euclidean distances.
+ Returns
+ -------
+ distances : {array}, shape (n_samples_1, n_samples_2)
+ """
+ XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
+ YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
+ distances = np.dot(X, Y.T)
+ distances *= -2
+ distances += XX
+ distances += YY
+ np.maximum(distances, 0, out=distances)
+ if X is Y:
+ # Ensure that distances between vectors and themselves are set to 0.0.
+ # This may not be the case due to floating point rounding errors.
+ distances.flat[::distances.shape[0] + 1] = 0.0
+ return distances if squared else np.sqrt(distances, out=distances)
+
+
+def dist(x1, x2=None, metric='sqeuclidean'):
+ """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+
+ Parameters
+ ----------
+
+ x1 : ndarray, shape (n1,d)
+ matrix with n1 samples of size d
+ x2 : array, shape (n2,d), optional
+ matrix with n2 samples of size d (if None then x2=x1)
+ metric : str | callable, optional
+ Name of the metric to be computed (full list in the doc of scipy), If a string,
+ the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
+ 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
+ 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
+ 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
+
+
+ Returns
+ -------
+
+ M : np.array (n1,n2)
+ distance matrix computed with given metric
+
+ """
+ if x2 is None:
+ x2 = x1
+ if metric == "sqeuclidean":
+ return euclidean_distances(x1, x2, squared=True)
+ return cdist(x1, x2, metric=metric)
+
+
+def dist0(n, method='lin_square'):
+ """Compute standard cost matrices of size (n, n) for OT problems
+
+ Parameters
+ ----------
+ n : int
+ Size of the cost matrix.
+ method : str, optional
+ Type of loss matrix chosen from:
+
+ * 'lin_square' : linear sampling between 0 and n-1, quadratic loss
+
+ Returns
+ -------
+ M : ndarray, shape (n1,n2)
+ Distance matrix computed with given metric.
+ """
+ res = 0
+ if method == 'lin_square':
+ x = np.arange(n, dtype=np.float64).reshape((n, 1))
+ res = dist(x, x)
+ return res
+
+
+def cost_normalization(C, norm=None):
+ """ Apply normalization to the loss matrix
+
+ Parameters
+ ----------
+ C : ndarray, shape (n1, n2)
+ The cost matrix to normalize.
+ norm : str
+ Type of normalization from 'median', 'max', 'log', 'loglog'. Any
+ other value do not normalize.
+
+ Returns
+ -------
+ C : ndarray, shape (n1, n2)
+ The input cost matrix normalized according to given norm.
+ """
+
+ if norm is None:
+ pass
+ elif norm == "median":
+ C /= float(np.median(C))
+ elif norm == "max":
+ C /= float(np.max(C))
+ elif norm == "log":
+ C = np.log(1 + C)
+ elif norm == "loglog":
+ C = np.log1p(np.log1p(C))
+ else:
+ raise ValueError('Norm %s is not a valid option.\n'
+ 'Valid options are:\n'
+ 'median, max, log, loglog' % norm)
+ return C
+
+
+def dots(*args):
+ """ dots function for multiple matrix multiply """
+ return reduce(np.dot, args)
+
+
+def fun(f, q_in, q_out):
+ """ Utility function for parmap with no serializing problems """
+ while True:
+ i, x = q_in.get()
+ if i is None:
+ break
+ q_out.put((i, f(x)))
+
+
+def parmap(f, X, nprocs=multiprocessing.cpu_count()):
+ """ paralell map for multiprocessing (only map on windows)"""
+
+ if not sys.platform.endswith('win32'):
+
+ q_in = multiprocessing.Queue(1)
+ q_out = multiprocessing.Queue()
+
+ proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
+ for _ in range(nprocs)]
+ for p in proc:
+ p.daemon = True
+ p.start()
+
+ sent = [q_in.put((i, x)) for i, x in enumerate(X)]
+ [q_in.put((None, None)) for _ in range(nprocs)]
+ res = [q_out.get() for _ in range(len(sent))]
+
+ [p.join() for p in proc]
+
+ return [x for i, x in sorted(res)]
+ else:
+ return list(map(f, X))
+
+
+def check_params(**kwargs):
+ """check_params: check whether some parameters are missing
+ """
+
+ missing_params = []
+ check = True
+
+ for param in kwargs:
+ if kwargs[param] is None:
+ missing_params.append(param)
+
+ if len(missing_params) > 0:
+ print("POT - Warning: following necessary parameters are missing")
+ for p in missing_params:
+ print("\n", p)
+
+ check = False
+
+ return check
+
+
+def check_random_state(seed):
+ """Turn seed into a np.random.RandomState instance
+
+ Parameters
+ ----------
+ seed : None | int | instance of RandomState
+ If seed is None, return the RandomState singleton used by np.random.
+ If seed is an int, return a new RandomState instance seeded with seed.
+ If seed is already a RandomState instance, return it.
+ Otherwise raise ValueError.
+ """
+ if seed is None or seed is np.random:
+ return np.random.mtrand._rand
+ if isinstance(seed, (int, np.integer)):
+ return np.random.RandomState(seed)
+ if isinstance(seed, np.random.RandomState):
+ return seed
+ raise ValueError('{} cannot be used to seed a numpy.random.RandomState'
+ ' instance'.format(seed))
+
+
+class deprecated(object):
+ """Decorator to mark a function or class as deprecated.
+
+ deprecated class from scikit-learn package
+ https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
+ Issue a warning when the function is called/the class is instantiated and
+ adds a warning to the docstring.
+ The optional extra argument will be appended to the deprecation message
+ and the docstring. Note: to use this with the default value for extra, put
+ in an empty of parentheses:
+ >>> from ot.deprecation import deprecated # doctest: +SKIP
+ >>> @deprecated() # doctest: +SKIP
+ ... def some_function(): pass # doctest: +SKIP
+
+ Parameters
+ ----------
+ extra : str
+ To be added to the deprecation messages.
+ """
+
+ # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
+ # but with many changes.
+
+ def __init__(self, extra=''):
+ self.extra = extra
+
+ def __call__(self, obj):
+ """Call method
+ Parameters
+ ----------
+ obj : object
+ """
+ if isinstance(obj, type):
+ return self._decorate_class(obj)
+ else:
+ return self._decorate_fun(obj)
+
+ def _decorate_class(self, cls):
+ msg = "Class %s is deprecated" % cls.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ # FIXME: we should probably reset __new__ for full generality
+ init = cls.__init__
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return init(*args, **kwargs)
+
+ cls.__init__ = wrapped
+
+ wrapped.__name__ = '__init__'
+ wrapped.__doc__ = self._update_doc(init.__doc__)
+ wrapped.deprecated_original = init
+
+ return cls
+
+ def _decorate_fun(self, fun):
+ """Decorate function fun"""
+
+ msg = "Function %s is deprecated" % fun.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return fun(*args, **kwargs)
+
+ wrapped.__name__ = fun.__name__
+ wrapped.__dict__ = fun.__dict__
+ wrapped.__doc__ = self._update_doc(fun.__doc__)
+
+ return wrapped
+
+ def _update_doc(self, olddoc):
+ newdoc = "DEPRECATED"
+ if self.extra:
+ newdoc = "%s: %s" % (newdoc, self.extra)
+ if olddoc:
+ newdoc = "%s\n\n%s" % (newdoc, olddoc)
+ return newdoc
+
+
+def _is_deprecated(func):
+ """Helper to check if func is wraped by our deprecated decorator"""
+ if sys.version_info < (3, 5):
+ raise NotImplementedError("This is only available for python3.5 "
+ "or above")
+ closures = getattr(func, '__closure__', [])
+ if closures is None:
+ closures = []
+ is_deprecated = ('deprecated' in ''.join([c.cell_contents
+ for c in closures
+ if isinstance(c.cell_contents, str)]))
+ return is_deprecated
+
+
+class BaseEstimator(object):
+ """Base class for most objects in POT
+
+ Code adapted from sklearn BaseEstimator class
+
+ Notes
+ -----
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
+ """
+
+ @classmethod
+ def _get_param_names(cls):
+ """Get parameter names for the estimator"""
+
+ # fetch the constructor or the original constructor before
+ # deprecation wrapping if any
+ init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
+ if init is object.__init__:
+ # No explicit constructor to introspect
+ return []
+
+ # introspect the constructor arguments to find the model parameters
+ # to represent
+ init_signature = signature(init)
+ # Consider the constructor parameters excluding 'self'
+ parameters = [p for p in init_signature.parameters.values()
+ if p.name != 'self' and p.kind != p.VAR_KEYWORD]
+ for p in parameters:
+ if p.kind == p.VAR_POSITIONAL:
+ raise RuntimeError("POT estimators should always "
+ "specify their parameters in the signature"
+ " of their __init__ (no varargs)."
+ " %s with constructor %s doesn't "
+ " follow this convention."
+ % (cls, init_signature))
+ # Extract and sort argument names excluding 'self'
+ return sorted([p.name for p in parameters])
+
+ def get_params(self, deep=True):
+ """Get parameters for this estimator.
+
+ Parameters
+ ----------
+ deep : bool, optional
+ If True, will return the parameters for this estimator and
+ contained subobjects that are estimators.
+
+ Returns
+ -------
+ params : mapping of string to any
+ Parameter names mapped to their values.
+ """
+ out = dict()
+ for key in self._get_param_names():
+ # We need deprecation warnings to always be on in order to
+ # catch deprecated param values.
+ # This is set in utils/__init__.py but it gets overwritten
+ # when running under python3 somehow.
+ warnings.simplefilter("always", DeprecationWarning)
+ try:
+ with warnings.catch_warnings(record=True) as w:
+ value = getattr(self, key, None)
+ if len(w) and w[0].category == DeprecationWarning:
+ # if the parameter is deprecated, don't show it
+ continue
+ finally:
+ warnings.filters.pop(0)
+
+ # XXX: should we rather test if instance of estimator?
+ if deep and hasattr(value, 'get_params'):
+ deep_items = value.get_params().items()
+ out.update((key + '__' + k, val) for k, val in deep_items)
+ out[key] = value
+ return out
+
+ def set_params(self, **params):
+ """Set the parameters of this estimator.
+
+ The method works on simple estimators as well as on nested objects
+ (such as pipelines). The latter have parameters of the form
+ ``<component>__<parameter>`` so that it's possible to update each
+ component of a nested object.
+
+ Returns
+ -------
+ self
+ """
+ if not params:
+ # Simple optimisation to gain speed (inspect is slow)
+ return self
+ valid_params = self.get_params(deep=True)
+ # for key, value in iteritems(params):
+ for key, value in params.items():
+ split = key.split('__', 1)
+ if len(split) > 1:
+ # nested objects case
+ name, sub_name = split
+ if name not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (name, self))
+ sub_object = valid_params[name]
+ sub_object.set_params(**{sub_name: value})
+ else:
+ # simple objects case
+ if key not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (key, self.__class__.__name__))
+ setattr(self, key, value)
+ return self
+
+
+class UndefinedParameter(Exception):
+ """
+ Aim at raising an Exception when a undefined parameter is called
+
+ """
+ pass