diff options
author | Gard Spreemann <gspr@nonempty.org> | 2020-01-20 14:07:53 +0100 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2020-01-20 14:07:53 +0100 |
commit | bdfb24ff37ea777d6e266b145047cd4e281ebac3 (patch) | |
tree | 00cbac5f3dc25a4ee76164828abd72c1cbab37cc /ot | |
parent | abc441b00f0fe2fa4ef0efc4e1aa67b27cca9a13 (diff) | |
parent | 5e70a77fbb2feec513f21c9ef65dcc535329ace6 (diff) |
Merge tag '0.6.0' into debian/sid
Diffstat (limited to 'ot')
-rw-r--r-- | ot/__init__.py | 81 | ||||
-rw-r--r-- | ot/bregman.py | 1787 | ||||
-rw-r--r-- | ot/da.py | 1916 | ||||
-rw-r--r-- | ot/datasets.py | 164 | ||||
-rw-r--r-- | ot/dr.py | 200 | ||||
-rw-r--r-- | ot/externals/__init__.py | 0 | ||||
-rw-r--r-- | ot/externals/funcsigs.py | 817 | ||||
-rw-r--r-- | ot/gpu/__init__.py | 41 | ||||
-rw-r--r-- | ot/gpu/bregman.py | 194 | ||||
-rw-r--r-- | ot/gpu/da.py | 144 | ||||
-rw-r--r-- | ot/gpu/utils.py | 101 | ||||
-rw-r--r-- | ot/gromov.py | 1179 | ||||
-rw-r--r-- | ot/lp/EMD.h | 35 | ||||
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 107 | ||||
-rw-r--r-- | ot/lp/__init__.py | 618 | ||||
-rw-r--r-- | ot/lp/core.h | 103 | ||||
-rw-r--r-- | ot/lp/cvx.py | 147 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 187 | ||||
-rw-r--r-- | ot/lp/full_bipartitegraph.h | 215 | ||||
-rw-r--r-- | ot/lp/network_simplex_simple.h | 1553 | ||||
-rw-r--r-- | ot/optim.py | 440 | ||||
-rw-r--r-- | ot/plot.py | 91 | ||||
-rw-r--r-- | ot/smooth.py | 600 | ||||
-rw-r--r-- | ot/stochastic.py | 755 | ||||
-rw-r--r-- | ot/unbalanced.py | 1022 | ||||
-rw-r--r-- | ot/utils.py | 498 |
26 files changed, 12995 insertions, 0 deletions
diff --git a/ot/__init__.py b/ot/__init__.py new file mode 100644 index 0000000..89c7936 --- /dev/null +++ b/ot/__init__.py @@ -0,0 +1,81 @@ +""" + +This is the main module of the POT toolbox. It provides easy access to +a number of sub-modules and functions described below. + +.. note:: + + + Here is a list of the submodules and short description of what they contain. + + - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems. + - :any:`ot.bregman` contains OT solvers for the entropic OT problems using + Bregman projections. + - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems. + - :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT + problems. + - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov + Wasserstein problems. + - :any:`ot.optim` contains generic solvers OT based optimization problems + - :any:`ot.da` contains classes and function related to Monge mapping + estimation and Domain Adaptation (DA). + - :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers + - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein + Discriminant Analysis. + - :any:`ot.utils` contains utility functions such as distance computation and + timing. + - :any:`ot.datasets` contains toy dataset generation functions. + - :any:`ot.plot` contains visualization functions + - :any:`ot.stochastic` contains stochastic solvers for regularized OT. + - :any:`ot.unbalanced` contains solvers for regularized unbalanced OT. + +.. warning:: + The list of automatically imported sub-modules is as follows: + :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` + :py:mod:`ot.utils`, :py:mod:`ot.datasets`, + :py:mod:`ot.gromov`, :py:mod:`ot.smooth` + :py:mod:`ot.stochastic` + + The following sub-modules are not imported due to additional dependencies: + + - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. + - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. + - :any:`ot.plot` : depends on :code:`matplotlib` + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + + +# All submodules and packages +from . import lp +from . import bregman +from . import optim +from . import utils +from . import datasets +from . import da +from . import gromov +from . import smooth +from . import stochastic +from . import unbalanced + +# OT functions +from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d +from .bregman import sinkhorn, sinkhorn2, barycenter +from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 +from .da import sinkhorn_lpl1_mm + +# utils functions +from .utils import dist, unif, tic, toc, toq + +__version__ = "0.6.0" + +__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', + 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', + 'emd_1d', 'emd2_1d', 'wasserstein_1d', + 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', + 'sinkhorn_unbalanced', 'barycenter_unbalanced', + 'sinkhorn_unbalanced2'] diff --git a/ot/bregman.py b/ot/bregman.py new file mode 100644 index 0000000..2cd832b --- /dev/null +++ b/ot/bregman.py @@ -0,0 +1,1787 @@ +# -*- coding: utf-8 -*- +""" +Bregman projections for regularized OT +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Nicolas Courty <ncourty@irisa.fr> +# Kilian Fatras <kilian.fatras@irisa.fr> +# Titouan Vayer <titouan.vayer@irisa.fr> +# Hicham Janati <hicham.janati@inria.fr> +# +# License: MIT License + +import numpy as np +import warnings +from .utils import unif, dist + + +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the entropic regularization optimal transport problem and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + + + Parameters + ---------- + a : ndarray, shape (dim_a,) + samples weights in the source domain + b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : ndarray, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + method : str + method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_epsilon_scaling', see those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : ndarray, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + + """ + + if method.lower() == 'sinkhorn': + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'greenkhorn': + return greenkhorn(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log) + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() == 'sinkhorn_epsilon_scaling': + return sinkhorn_epsilon_scaling(a, b, M, reg, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the entropic regularization optimal transport problem and return the loss + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + + + Parameters + ---------- + a : ndarray, shape (dim_a,) + samples weights in the source domain + b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : ndarray, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_epsilon_scaling', see those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + W : (n_hists) ndarray or float + Optimal transportation loss for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn2(a, b, M, 1) + array([0.26894142]) + + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] + ot.bregman.greenkhorn : Greenkhorn [21] + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + + """ + b = np.asarray(b, dtype=np.float64) + if len(b.shape) < 2: + b = b[:, None] + if method.lower() == 'sinkhorn': + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_epsilon_scaling': + return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the entropic regularization optimal transport problem and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + + + Parameters + ---------- + a : ndarray, shape (dim_a,) + samples weights in the source domain + b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : ndarray, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : ndarray, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + # init data + dim_a = len(a) + dim_b = len(b) + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = np.ones((dim_a, n_hists)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + else: + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b + + # print(reg) + + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + # print(np.min(K)) + tmp2 = np.empty(b.shape, dtype=M.dtype) + + Kp = (1 / a).reshape(-1, 1) * K + cpt = 0 + err = 1 + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + KtransposeU = np.dot(K.T, u) + v = np.divide(b, KtransposeU) + u = 1. / np.dot(Kp, v) + + if (np.any(KtransposeU == 0) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errors at iteration', cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + if n_hists: + np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) + else: + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + np.einsum('i,ij,j->j', u, K, v, out=tmp2) + err = np.linalg.norm(tmp2 - b) # violation of marginal + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + if log: + log['u'] = u + log['v'] = v + + if n_hists: # return only loss + res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + if log: + return res, log + else: + return res + + else: # return OT matrix + + if log: + return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log + else: + return u.reshape((-1, 1)) * K * v.reshape((1, -1)) + + +def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, + log=False): + r""" + Solve the entropic regularization optimal transport problem and return the OT matrix + + The algorithm used is based on the paper + + Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration + by Jason Altschuler, Jonathan Weed, Philippe Rigollet + appeared at NIPS 2017 + + which is a stochastic version of the Sinkhorn-Knopp algorithm [2]. + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (histograms, both sum to 1) + + + + Parameters + ---------- + a : ndarray, shape (dim_a,) + samples weights in the source domain + b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : ndarray, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + log : bool, optional + record log if True + + Returns + ------- + gamma : ndarray, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.bregman.greenkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + dim_a = a.shape[0] + dim_b = b.shape[0] + + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty_like(M) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + u = np.full(dim_a, 1. / dim_a) + v = np.full(dim_b, 1. / dim_b) + G = u[:, np.newaxis] * K * v[np.newaxis, :] + + viol = G.sum(1) - a + viol_2 = G.sum(0) - b + stopThr_val = 1 + + if log: + log = dict() + log['u'] = u + log['v'] = v + + for i in range(numItermax): + i_1 = np.argmax(np.abs(viol)) + i_2 = np.argmax(np.abs(viol_2)) + m_viol_1 = np.abs(viol[i_1]) + m_viol_2 = np.abs(viol_2[i_2]) + stopThr_val = np.maximum(m_viol_1, m_viol_2) + + if m_viol_1 > m_viol_2: + old_u = u[i_1] + u[i_1] = a[i_1] / (K[i_1, :].dot(v)) + G[i_1, :] = u[i_1] * K[i_1, :] * v + + viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1] + viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v) + + else: + old_v = v[i_2] + v[i_2] = b[i_2] / (K[:, i_2].T.dot(u)) + G[:, i_2] = u * K[:, i_2] * v[i_2] + #aviol = (G@one_m - a) + #aviol_2 = (G.T@one_n - b) + viol += (-old_v + v[i_2]) * K[:, i_2] * u + viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] + + #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) + + if stopThr_val <= stopThr: + break + else: + print('Warning: Algorithm did not converge') + + if log: + log['u'] = u + log['v'] = v + + if log: + return G, log + else: + return G + + +def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, + warmstart=None, verbose=False, print_period=20, + log=False, **kwargs): + r""" + Solve the entropic regularization OT problem with log stabilization + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (histograms, both sum to 1) + + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [2]_ but with the log stabilization + proposed in [10]_ an defined in [9]_ (Algo 3.1) . + + + Parameters + ---------- + a : ndarray, shape (dim_a,) + samples weights in the source domain + b : ndarray, shape (dim_b,) + samples in the target domain + M : ndarray, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + tau : float + thershold for max value in u or v for log scaling + warmstart : tible of vectors + if given then sarting values for alpha an beta log scalings + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : ndarray, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.bregman.sinkhorn_stabilized(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + # test if multiple target + if len(b.shape) > 1: + n_hists = b.shape[1] + a = a[:, np.newaxis] + else: + n_hists = 0 + + # init data + dim_a = len(a) + dim_b = len(b) + + cpt = 0 + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if warmstart is None: + alpha, beta = np.zeros(dim_a), np.zeros(dim_b) + else: + alpha, beta = warmstart + + if n_hists: + u = np.ones((dim_a, n_hists)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + else: + u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b + + def get_K(alpha, beta): + """log space computation""" + return np.exp(-(M - alpha.reshape((dim_a, 1)) + - beta.reshape((1, dim_b))) / reg) + + def get_Gamma(alpha, beta, u, v): + """log space gamma computation""" + return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) + / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b)))) + + # print(np.min(K)) + + K = get_K(alpha, beta) + transp = K + loop = 1 + cpt = 0 + err = 1 + while loop: + + uprev = u + vprev = v + + # sinkhorn update + v = b / (np.dot(K.T, u) + 1e-16) + u = a / (np.dot(K, v) + 1e-16) + + # remove numerical problems and store them in K + if np.abs(u).max() > tau or np.abs(v).max() > tau: + if n_hists: + alpha, beta = alpha + reg * \ + np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + else: + alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) + if n_hists: + u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b + else: + u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b + K = get_K(alpha, beta) + + if cpt % print_period == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + if n_hists: + err_u = abs(u - uprev).max() + err_u /= max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() + err_v /= max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) + else: + transp = get_Gamma(alpha, beta, u, v) + err = np.linalg.norm((np.sum(transp, axis=0) - b)) + if log: + log['err'].append(err) + + if verbose: + if cpt % (print_period * 20) == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + if err <= stopThr: + loop = False + + if cpt >= numItermax: + loop = False + + if np.any(np.isnan(u)) or np.any(np.isnan(v)): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errors at iteration', cpt) + u = uprev + v = vprev + break + + cpt = cpt + 1 + + if log: + if n_hists: + alpha = alpha[:, None] + beta = beta[:, None] + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + log['logu'] = logu + log['logv'] = logv + log['alpha'] = alpha + reg * np.log(u) + log['beta'] = beta + reg * np.log(v) + log['warmstart'] = (log['alpha'], log['beta']) + if n_hists: + res = np.zeros((n_hists)) + for i in range(n_hists): + res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + return res, log + + else: + return get_Gamma(alpha, beta, u, v), log + else: + if n_hists: + res = np.zeros((n_hists)) + for i in range(n_hists): + res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + return res + else: + return get_Gamma(alpha, beta, u, v) + + +def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, + numInnerItermax=100, tau=1e3, stopThr=1e-9, + warmstart=None, verbose=False, print_period=10, + log=False, **kwargs): + r""" + Solve the entropic regularization optimal transport problem with log + stabilization and epsilon scaling. + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (histograms, both sum to 1) + + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [2]_ but with the log stabilization + proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2 + + + Parameters + ---------- + a : ndarray, shape (dim_a,) + samples weights in the source domain + b : ndarray, shape (dim_b,) + samples in the target domain + M : ndarray, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + tau : float + thershold for max value in u or v for log scaling + warmstart : tuple of vectors + if given then sarting values for alpha an beta log scalings + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterationsin the inner slog stabilized sinkhorn + epsilon0 : int, optional + first epsilon regularization value (then exponential decrease to reg) + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : ndarray, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + # init data + dim_a = len(a) + dim_b = len(b) + + # nrelative umerical precision with 64 bits + numItermin = 35 + numItermax = max(numItermin, numItermax) # ensure that last velue is exact + + cpt = 0 + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if warmstart is None: + alpha, beta = np.zeros(dim_a), np.zeros(dim_b) + else: + alpha, beta = warmstart + + def get_K(alpha, beta): + """log space computation""" + return np.exp(-(M - alpha.reshape((dim_a, 1)) + - beta.reshape((1, dim_b))) / reg) + + # print(np.min(K)) + def get_reg(n): # exponential decreasing + return (epsilon0 - reg) * np.exp(-n) + reg + + loop = 1 + cpt = 0 + err = 1 + while loop: + + regi = get_reg(cpt) + + G, logi = sinkhorn_stabilized(a, b, M, regi, + numItermax=numInnerItermax, stopThr=1e-9, + warmstart=(alpha, beta), verbose=False, + print_period=20, tau=tau, log=True) + + alpha = logi['alpha'] + beta = logi['beta'] + + if cpt >= numItermax: + loop = False + + if cpt % (print_period) == 0: # spsion nearly converged + # we can speed up the process by checking for the error only all + # the 10th iterations + transp = G + err = np.linalg.norm( + (np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2 + if log: + log['err'].append(err) + + if verbose: + if cpt % (print_period * 10) == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + if err <= stopThr and cpt > numItermin: + loop = False + + cpt = cpt + 1 + # print('err=',err,' cpt=',cpt) + if log: + log['alpha'] = alpha + log['beta'] = beta + log['warmstart'] = (log['alpha'], log['beta']) + return G, log + else: + return G + + +def geometricBar(weights, alldistribT): + """return the weighted geometric mean of distributions""" + assert(len(weights) == alldistribT.shape[1]) + return np.exp(np.dot(np.log(alldistribT), weights.T)) + + +def geometricMean(alldistribT): + """return the geometric mean of distributions""" + return np.exp(np.mean(np.log(alldistribT), axis=1)) + + +def projR(gamma, p): + """return the KL projection on the row constrints """ + return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T + + +def projC(gamma, q): + """return the KL projection on the column constrints """ + return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) + + +def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, **kwargs): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + + Parameters + ---------- + A : ndarray, shape (dim, n_hists) + n_hists training distributions a_i of size dim + M : ndarray, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + method : str (optional) + method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' + weights : ndarray, shape (n_hists,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + """ + + if method.lower() == 'sinkhorn': + return barycenter_sinkhorn(A, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return barycenter_stabilized(A, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + + Parameters + ---------- + A : ndarray, shape (dim, n_hists) + n_hists training distributions a_i of size dim + M : ndarray, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + weights : ndarray, shape (n_hists,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + """ + + if weights is None: + weights = np.ones(A.shape[1]) / A.shape[1] + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + # M = M/np.median(M) # suggested by G. Peyre + K = np.exp(-M / reg) + + cpt = 0 + err = 1 + + UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T) + u = (geometricMean(UKv) / UKv.T).T + + while (err > stopThr and cpt < numItermax): + cpt = cpt + 1 + UKv = u * np.dot(K, np.divide(A, np.dot(K, u))) + u = (u.T * geometricBar(weights, UKv)).T / UKv + + if cpt % 10 == 1: + err = np.sum(np.std(UKv, axis=1)) + + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + if log: + log['niter'] = cpt + return geometricBar(weights, UKv), log + else: + return geometricBar(weights, UKv) + + +def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + with stabilization. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + + Parameters + ---------- + A : ndarray, shape (dim, n_hists) + n_hists training distributions a_i of size dim + M : ndarray, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + tau : float + thershold for max value in u or v for log scaling + weights : ndarray, shape (n_hists,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + """ + + dim, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + u = np.ones((dim, n_hists)) / dim + v = np.ones((dim, n_hists)) / dim + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + cpt = 0 + err = 1. + alpha = np.zeros(dim) + beta = np.zeros(dim) + q = np.ones(dim) / dim + while (err > stopThr and cpt < numItermax): + qprev = q + Kv = K.dot(v) + u = A / (Kv + 1e-16) + Ktu = K.T.dot(u) + q = geometricBar(weights, Ktu) + Q = q[:, None] + v = Q / (Ktu + 1e-16) + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + q = qprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(u * Kv - A).max() + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg`" + + "Or a larger absorption threshold `tau`.") + if log: + log['niter'] = cpt + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) + return q, log + else: + return q + + +def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, + stopThr=1e-9, stabThr=1e-30, verbose=False, + log=False): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - reg is the regularization strength scalar value + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_ + + Parameters + ---------- + A : ndarray, shape (n_hists, width, height) + n distributions (2D images) of size width x height + reg : float + Regularization term >0 + weights : ndarray, shape (n_hists,) + Weights of each image on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + a : ndarray, shape (width, height) + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). + Convolutional wasserstein distances: Efficient optimal transportation on geometric domains + ACM Transactions on Graphics (TOG), 34(4), 66 + + + """ + + if weights is None: + weights = np.ones(A.shape[0]) / A.shape[0] + else: + assert(len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + b = np.zeros_like(A[0, :, :]) + U = np.ones_like(A) + KV = np.ones_like(A) + + cpt = 0 + err = 1 + + # build the convolution operator + t = np.linspace(0, 1, A.shape[1]) + [Y, X] = np.meshgrid(t, t) + xi1 = np.exp(-(X - Y)**2 / reg) + + def K(x): + return np.dot(np.dot(xi1, x), xi1) + + while (err > stopThr and cpt < numItermax): + + bold = b + cpt = cpt + 1 + + b = np.zeros_like(A[0, :, :]) + for r in range(A.shape[0]): + KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :]))) + b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :])) + b = np.exp(b) + for r in range(A.shape[0]): + U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :]) + + if cpt % 10 == 1: + err = np.sum(np.abs(bold - b)) + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + if log: + log['niter'] = cpt + log['U'] = U + return b, log + else: + return b + + +def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, + stopThr=1e-3, verbose=False, log=False): + r""" + Compute the unmixing of an observation with a given dictionary using Wasserstein distance + + The function solve the following optimization problem: + + .. math:: + \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h}) + + + where : + + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn) + - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` + - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` + - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` + - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting + - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization + - :math:`\\alpha`weight data fitting and regularization + + The optimization problem is solved suing the algorithm described in [4] + + + Parameters + ---------- + a : ndarray, shape (dim_a) + observed distribution (histogram, sums to 1) + D : ndarray, shape (dim_a, n_atoms) + dictionary matrix + M : ndarray, shape (dim_a, dim_a) + loss matrix + M0 : ndarray, shape (n_atoms, dim_prior) + loss matrix + h0 : ndarray, shape (n_atoms,) + prior on the estimated unmixing h + reg : float + Regularization term >0 (Wasserstein data fitting) + reg0 : float + Regularization term >0 (Wasserstein reg with h0) + alpha : float + How much should we trust the prior ([0,1]) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + h : ndarray, shape (n_atoms,) + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + + .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. + + """ + + # M = M/np.median(M) + K = np.exp(-M / reg) + + # M0 = M0/np.median(M0) + K0 = np.exp(-M0 / reg0) + old = h0 + + err = 1 + cpt = 0 + # log = {'niter':0, 'all_err':[]} + if log: + log = {'err': []} + + while (err > stopThr and cpt < numItermax): + K = projC(K, a) + K0 = projC(K0, h0) + new = np.sum(K0, axis=1) + # we recombine the current selection from dictionnary + inv_new = np.dot(D, new) + other = np.sum(K, axis=1) + # geometric interpolation + delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new)) + K = projR(K, delta) + K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0) + + err = np.linalg.norm(np.sum(K0, axis=1) - old) + old = new + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt = cpt + 1 + + if log: + log['niter'] = cpt + return np.sum(K0, axis=1), log + else: + return np.sum(K0, axis=1) + + +def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, verbose=False, + log=False, **kwargs): + r''' + Solve the entropic regularization optimal transport problem and return the + OT matrix from empirical data + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`a` and :math:`b` are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : ndarray, shape (n_samples_a,) + samples weights in the source domain + b : ndarray, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : ndarray, shape (n_samples_a, n_samples_b) + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 2 + >>> n_samples_b = 2 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE + array([[4.99977301e-01, 2.26989344e-05], + [2.26989344e-05, 4.99977301e-01]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + ''' + + if a is None: + a = unif(np.shape(X_s)[0]) + if b is None: + b = unif(np.shape(X_t)[0]) + + M = dist(X_s, X_t, metric=metric) + + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi + + +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): + r''' + Solve the entropic regularization optimal transport problem from empirical + data and return the OT loss + + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`a` and :math:`b` are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : ndarray, shape (n_samples_a,) + samples weights in the source domain + b : ndarray, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : ndarray, shape (n_samples_a, n_samples_b) + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 2 + >>> n_samples_b = 2 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False) + array([4.53978687e-05]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + ''' + + if a is None: + a = unif(np.shape(X_s)[0]) + if b is None: + b = unif(np.shape(X_t)[0]) + + M = dist(X_s, X_t, metric=metric) + + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_loss + + +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): + r''' + Compute the sinkhorn divergence loss from empirical data + + The function solves the following optimization problems and return the + sinkhorn divergence :math:`S`: + + .. math:: + + W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) + + W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) + + S &= W - 1/2 * (W_a + W_b) + + .. math:: + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + + \gamma_a 1 = a + + \gamma_a^T 1= a + + \gamma_a\geq 0 + + \gamma_b 1 = b + + \gamma_b^T 1= b + + \gamma_b\geq 0 + where : + + - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b)) + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`a` and :math:`b` are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : ndarray, shape (n_samples_a,) + samples weights in the source domain + b : ndarray, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : ndarray, shape (n_samples_a, n_samples_b) + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + >>> n_samples_a = 2 + >>> n_samples_b = 4 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS + array([1.499...]) + + + References + ---------- + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + ''' + if log: + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + + log = {} + log['sinkhorn_loss_ab'] = sinkhorn_loss_ab + log['sinkhorn_loss_a'] = sinkhorn_loss_a + log['sinkhorn_loss_b'] = sinkhorn_loss_b + log['log_sinkhorn_ab'] = log_ab + log['log_sinkhorn_a'] = log_a + log['log_sinkhorn_b'] = log_b + + return max(0, sinkhorn_div), log + + else: + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + return max(0, sinkhorn_div) diff --git a/ot/da.py b/ot/da.py new file mode 100644 index 0000000..108a38d --- /dev/null +++ b/ot/da.py @@ -0,0 +1,1916 @@ +# -*- coding: utf-8 -*- +""" +Domain adaptation with optimal transport +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Nicolas Courty <ncourty@irisa.fr> +# Michael Perrot <michael.perrot@univ-st-etienne.fr> +# Nathalie Gayraud <nat.gayraud@gmail.com> +# +# License: MIT License + +import numpy as np +import scipy.linalg as linalg + +from .bregman import sinkhorn +from .lp import emd +from .utils import unif, dist, kernel, cost_normalization +from .utils import check_params, BaseEstimator +from .unbalanced import sinkhorn_unbalanced +from .optim import cg +from .optim import gcg + + +def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, + numInnerItermax=200, stopInnerThr=1e-9, verbose=False, + log=False): + """ + Solve the entropic regularization optimal transport problem with nonconvex + group lasso regularization + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma) + + \eta \Omega_g(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e + (\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega_g` is the group lasso regularization term + :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` + where :math:`\mathcal{I}_c` are the index of samples from class c + in the source domain. + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the generalized conditional + gradient as proposed in [5]_ [7]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + labels_a : np.ndarray (ns,) + labels of samples in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term for entropic regularization >0 + eta : float, optional + Regularization term for group lasso regularization >0 + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations (inner sinkhorn solver) + stopInnerThr : float, optional + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE + Transactions on Pattern Analysis and Machine Intelligence , + vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). + Generalized conditional gradient: analysis of convergence + and applications. arXiv preprint arXiv:1510.06567. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + + """ + p = 0.5 + epsilon = 1e-3 + + indices_labels = [] + classes = np.unique(labels_a) + for c in classes: + idxc, = np.where(labels_a == c) + indices_labels.append(idxc) + + W = np.zeros(M.shape) + + for cpt in range(numItermax): + Mreg = M + eta * W + transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr) + # the transport has been computed. Check if classes are really + # separated + W = np.ones(M.shape) + for (i, c) in enumerate(classes): + majs = np.sum(transp[indices_labels[i]], axis=0) + majs = p * ((majs + epsilon)**(p - 1)) + W[indices_labels[i]] = majs + + return transp + + +def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, + numInnerItermax=200, stopInnerThr=1e-9, verbose=False, + log=False): + """ + Solve the entropic regularization optimal transport problem with group + lasso regularization + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ + \eta \Omega_g(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega_e` is the entropic regularization term + :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega_g` is the group lasso regulaization term + :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` + where :math:`\mathcal{I}_c` are the index of samples from class + c in the source domain. + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the generalised conditional + gradient as proposed in [5]_ [7]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + labels_a : np.ndarray (ns,) + labels of samples in the source domain + b : np.ndarray (nt,) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term for entropic regularization >0 + eta : float, optional + Regularization term for group lasso regularization >0 + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations (inner sinkhorn solver) + stopInnerThr : float, optional + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE Transactions + on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). + Generalized conditional gradient: analysis of convergence and + applications. arXiv preprint arXiv:1510.06567. + + See Also + -------- + ot.optim.gcg : Generalized conditional gradient for OT problems + + """ + lstlab = np.unique(labels_a) + + def f(G): + res = 0 + for i in range(G.shape[1]): + for lab in lstlab: + temp = G[labels_a == lab, i] + res += np.linalg.norm(temp) + return res + + def df(G): + W = np.zeros(G.shape) + for i in range(G.shape[1]): + for lab in lstlab: + temp = G[labels_a == lab, i] + n = np.linalg.norm(temp) + if n: + W[labels_a == lab, i] = temp / n + return W + + return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax, + numInnerItermax=numInnerItermax, stopThr=stopInnerThr, + verbose=verbose, log=log) + + +def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, + verbose2=False, numItermax=100, numInnerItermax=10, + stopInnerThr=1e-6, stopThr=1e-5, log=False, + **kwargs): + """Joint OT and linear mapping estimation as proposed in [8] + + The function solves the following optimization problem: + + .. math:: + \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + + \mu<\gamma,M>_F + \eta \|L -I\|^2_F + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) squared euclidean cost matrix between samples in + Xs and Xt (scaled by ns) + - :math:`L` is a dxd linear operator that approximates the barycentric + mapping + - :math:`I` is the identity matrix (neutral linear mapping) + - a and b are uniform source and target weights + + The problem consist in solving jointly an optimal transport matrix + :math:`\gamma` and a linear mapping that fits the barycentric mapping + :math:`n_s\gamma X_t`. + + One can also estimate a mapping with constant bias (see supplementary + material of [8]) using the bias optional argument. + + The algorithm used for solving the problem is the block coordinate + descent that alternates between updates of G (using conditionnal gradient) + and the update of L using a classical least square solver. + + + Parameters + ---------- + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + mu : float,optional + Weight for the linear OT loss (>0) + eta : float, optional + Regularization term for the linear mapping L (>0) + bias : bool,optional + Estimate linear mapping with constant bias + numItermax : int, optional + Max number of BCD iterations + stopThr : float, optional + Stop threshold on relative loss decrease (>0) + numInnerItermax : int, optional + Max number of iterations (inner CG solver) + stopInnerThr : float, optional + Stop threshold on error (inner CG solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + L : (d x d) ndarray + Linear mapping matrix (d+1 x d if bias) + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, + "Mapping estimation for discrete optimal transport", + Neural Information Processing Systems (NIPS), 2016. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1] + + if bias: + xs1 = np.hstack((xs, np.ones((ns, 1)))) + xstxs = xs1.T.dot(xs1) + Id = np.eye(d + 1) + Id[-1] = 0 + I0 = Id[:, :-1] + + def sel(x): + return x[:-1, :] + else: + xs1 = xs + xstxs = xs1.T.dot(xs1) + Id = np.eye(d) + I0 = Id + + def sel(x): + return x + + if log: + log = {'err': []} + + a, b = unif(ns), unif(nt) + M = dist(xs, xt) * ns + G = emd(a, b, M) + + vloss = [] + + def loss(L, G): + """Compute full loss""" + return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \ + np.sum(G * M) + eta * np.sum(sel(L - I0)**2) + + def solve_L(G): + """ solve L problem with fixed G (least square)""" + xst = ns * G.dot(xt) + return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0) + + def solve_G(L, G0): + """Update G with CG algorithm""" + xsi = xs1.dot(L) + + def f(G): + return np.sum((xsi - ns * G.dot(xt))**2) + + def df(G): + return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, + numItermax=numInnerItermax, stopThr=stopInnerThr) + return G + + L = solve_L(G) + + vloss.append(loss(L, G)) + + if verbose: + print('{:5s}|{:12s}|{:8s}'.format( + 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0)) + + # init loop + if numItermax > 0: + loop = 1 + else: + loop = 0 + it = 0 + + while loop: + + it += 1 + + # update G + G = solve_G(L, G) + + # update L + L = solve_L(G) + + vloss.append(loss(L, G)) + + if it >= numItermax: + loop = 0 + + if abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) < stopThr: + loop = 0 + + if verbose: + if it % 20 == 0: + print('{:5s}|{:12s}|{:8s}'.format( + 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}'.format( + it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]))) + if log: + log['loss'] = vloss + return G, L, log + else: + return G, L + + +def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', + sigma=1, bias=False, verbose=False, verbose2=False, + numItermax=100, numInnerItermax=10, + stopInnerThr=1e-6, stopThr=1e-5, log=False, + **kwargs): + """Joint OT and nonlinear mapping estimation with kernels as proposed in [8] + + The function solves the following optimization problem: + + .. math:: + \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) - + n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H} + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) squared euclidean cost matrix between samples in + Xs and Xt (scaled by ns) + - :math:`L` is a ns x d linear operator on a kernel matrix that + approximates the barycentric mapping + - a and b are uniform source and target weights + + The problem consist in solving jointly an optimal transport matrix + :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping + :math:`n_s\gamma X_t`. + + One can also estimate a mapping with constant bias (see supplementary + material of [8]) using the bias optional argument. + + The algorithm used for solving the problem is the block coordinate + descent that alternates between updates of G (using conditionnal gradient) + and the update of L using a classical kernel least square solver. + + + Parameters + ---------- + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + mu : float,optional + Weight for the linear OT loss (>0) + eta : float, optional + Regularization term for the linear mapping L (>0) + kerneltype : str,optional + kernel used by calling function ot.utils.kernel (gaussian by default) + sigma : float, optional + Gaussian kernel bandwidth. + bias : bool,optional + Estimate linear mapping with constant bias + verbose : bool, optional + Print information along iterations + verbose2 : bool, optional + Print information along iterations + numItermax : int, optional + Max number of BCD iterations + numInnerItermax : int, optional + Max number of iterations (inner CG solver) + stopInnerThr : float, optional + Stop threshold on error (inner CG solver) (>0) + stopThr : float, optional + Stop threshold on relative loss decrease (>0) + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + L : (ns x d) ndarray + Nonlinear mapping matrix (ns+1 x d if bias) + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, + "Mapping estimation for discrete optimal transport", + Neural Information Processing Systems (NIPS), 2016. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + ns, nt = xs.shape[0], xt.shape[0] + + K = kernel(xs, xs, method=kerneltype, sigma=sigma) + if bias: + K1 = np.hstack((K, np.ones((ns, 1)))) + Id = np.eye(ns + 1) + Id[-1] = 0 + Kp = np.eye(ns + 1) + Kp[:ns, :ns] = K + + # ls regu + # K0 = K1.T.dot(K1)+eta*I + # Kreg=I + + # RKHS regul + K0 = K1.T.dot(K1) + eta * Kp + Kreg = Kp + + else: + K1 = K + Id = np.eye(ns) + + # ls regul + # K0 = K1.T.dot(K1)+eta*I + # Kreg=I + + # proper kernel ridge + K0 = K + eta * Id + Kreg = K + + if log: + log = {'err': []} + + a, b = unif(ns), unif(nt) + M = dist(xs, xt) * ns + G = emd(a, b, M) + + vloss = [] + + def loss(L, G): + """Compute full loss""" + return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \ + np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) + + def solve_L_nobias(G): + """ solve L problem with fixed G (least square)""" + xst = ns * G.dot(xt) + return np.linalg.solve(K0, xst) + + def solve_L_bias(G): + """ solve L problem with fixed G (least square)""" + xst = ns * G.dot(xt) + return np.linalg.solve(K0, K1.T.dot(xst)) + + def solve_G(L, G0): + """Update G with CG algorithm""" + xsi = K1.dot(L) + + def f(G): + return np.sum((xsi - ns * G.dot(xt))**2) + + def df(G): + return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, + numItermax=numInnerItermax, stopThr=stopInnerThr) + return G + + if bias: + solve_L = solve_L_bias + else: + solve_L = solve_L_nobias + + L = solve_L(G) + + vloss.append(loss(L, G)) + + if verbose: + print('{:5s}|{:12s}|{:8s}'.format( + 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0)) + + # init loop + if numItermax > 0: + loop = 1 + else: + loop = 0 + it = 0 + + while loop: + + it += 1 + + # update G + G = solve_G(L, G) + + # update L + L = solve_L(G) + + vloss.append(loss(L, G)) + + if it >= numItermax: + loop = 0 + + if abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) < stopThr: + loop = 0 + + if verbose: + if it % 20 == 0: + print('{:5s}|{:12s}|{:8s}'.format( + 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}'.format( + it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]))) + if log: + log['loss'] = vloss + return G, L, log + else: + return G, L + + +def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, + wt=None, bias=True, log=False): + """ return OT linear operator between samples + + The function estimates the optimal linear operator that aligns the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark + 2.29 in [15]. + + The linear operator from source to target :math:`M` + + .. math:: + M(x)=Ax+b + + where : + + .. math:: + A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \Sigma_s^{-1/2} + .. math:: + b=\mu_t-A\mu_s + + Parameters + ---------- + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + reg : float,optional + regularization added to the diagonals of convariances (>0) + ws : np.ndarray (ns,1), optional + weights for the source samples + wt : np.ndarray (ns,1), optional + weights for the target samples + bias: boolean, optional + estimate bias b else b=0 (default:True) + log : bool, optional + record log if True + + + Returns + ------- + A : (d x d) ndarray + Linear operator + b : (1 x d) ndarray + bias + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + + """ + + d = xs.shape[1] + + if bias: + mxs = xs.mean(0, keepdims=True) + mxt = xt.mean(0, keepdims=True) + + xs = xs - mxs + xt = xt - mxt + else: + mxs = np.zeros((1, d)) + mxt = np.zeros((1, d)) + + if ws is None: + ws = np.ones((xs.shape[0], 1)) / xs.shape[0] + + if wt is None: + wt = np.ones((xt.shape[0], 1)) / xt.shape[0] + + Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d) + Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d) + + Cs12 = linalg.sqrtm(Cs) + Cs_12 = linalg.inv(Cs12) + + M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + + A = Cs_12.dot(M0.dot(Cs_12)) + + b = mxt - mxs.dot(A) + + if log: + log = {} + log['Cs'] = Cs + log['Ct'] = Ct + log['Cs12'] = Cs12 + log['Cs_12'] = Cs_12 + return A, b, log + else: + return A, b + + +def distribution_estimation_uniform(X): + """estimates a uniform distribution from an array of samples X + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The array of samples + + Returns + ------- + mu : array-like, shape (n_samples,) + The uniform distribution estimated from X + """ + + return unif(X.shape[0]) + + +class BaseTransport(BaseEstimator): + + """Base class for OTDA objects + + Notes + ----- + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). + + fit method should: + - estimate a cost matrix and store it in a `cost_` attribute + - estimate a coupling matrix and store it in a `coupling_` + attribute + - estimate distributions from source and target data and store them in + mu_s and mu_t attributes + - store Xs and Xt in attributes to be used later on in transform and + inverse_transform methods + + transform method should always get as input a Xs parameter + inverse_transform method should always get as input a Xt parameter + """ + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt): + + # pairwise distance + self.cost_ = dist(Xs, Xt, metric=self.metric) + self.cost_ = cost_normalization(self.cost_, self.norm) + + if (ys is not None) and (yt is not None): + + if self.limit_max != np.infty: + self.limit_max = self.limit_max * np.max(self.cost_) + + # assumes labeled source samples occupy the first rows + # and labeled target samples occupy the first columns + classes = [c for c in np.unique(ys) if c != -1] + for c in classes: + idx_s = np.where((ys != c) & (ys != -1)) + idx_t = np.where(yt == c) + + # all the coefficients corresponding to a source sample + # and a target sample : + # with different labels get a infinite + for j in idx_t[0]: + self.cost_[idx_s[0], j] = self.limit_max + + # distribution estimation + self.mu_s = self.distribution_estimation(Xs) + self.mu_t = self.distribution_estimation(Xt) + + # store arrays of samples + self.xs_ = Xs + self.xt_ = Xt + + return self + + def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) and transports source samples Xs onto target + ones Xt + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + transp_Xs : array-like, shape (n_source_samples, n_features) + The source samples samples. + """ + + return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt) + + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xs : array-like, shape (n_source_samples, n_features) + The transport source samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + if np.array_equal(self.xs_, Xs): + + # perform standard barycentric mapping + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute transported samples + transp_Xs = np.dot(transp, self.xt_) + else: + # perform out of sample mapping + indices = np.arange(Xs.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] + + transp_Xs = [] + for bi in batch_ind: + + # get the nearest neighbor in the source domain + D0 = dist(Xs[bi], self.xs_) + idx = np.argmin(D0, axis=1) + + # transport the source samples + transp = self.coupling_ / np.sum( + self.coupling_, 1)[:, None] + transp[~ np.isfinite(transp)] = 0 + transp_Xs_ = np.dot(transp, self.xt_) + + # define the transported points + transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :] + + transp_Xs.append(transp_Xs_) + + transp_Xs = np.concatenate(transp_Xs, axis=0) + + return transp_Xs + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, + batch_size=128): + """Transports target samples Xt onto target samples Xs + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xt : array-like, shape (n_source_samples, n_features) + The transported target samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xt=Xt): + + if np.array_equal(self.xt_, Xt): + + # perform standard barycentric mapping + transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None] + + # set nans to 0 + transp_[~ np.isfinite(transp_)] = 0 + + # compute transported samples + transp_Xt = np.dot(transp_, self.xs_) + else: + # perform out of sample mapping + indices = np.arange(Xt.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] + + transp_Xt = [] + for bi in batch_ind: + + D0 = dist(Xt[bi], self.xt_) + idx = np.argmin(D0, axis=1) + + # transport the target samples + transp_ = self.coupling_.T / np.sum( + self.coupling_, 0)[:, None] + transp_[~ np.isfinite(transp_)] = 0 + transp_Xt_ = np.dot(transp_, self.xs_) + + # define the transported points + transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :] + + transp_Xt.append(transp_Xt_) + + transp_Xt = np.concatenate(transp_Xt, axis=0) + + return transp_Xt + + +class LinearTransport(BaseTransport): + """ OT linear operator between empirical distributions + + The function estimates the optimal linear operator that aligns the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in + remark 2.29 in [15]. + + The linear operator from source to target :math:`M` + + .. math:: + M(x)=Ax+b + + where : + + .. math:: + A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \Sigma_s^{-1/2} + .. math:: + b=\mu_t-A\mu_s + + Parameters + ---------- + reg : float,optional + regularization added to the daigonals of convariances (>0) + bias: boolean, optional + estimate bias b else b=0 (default:True) + log : bool, optional + record log if True + + References + ---------- + + .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + """ + + def __init__(self, reg=1e-8, bias=True, log=False, + distribution_estimation=distribution_estimation_uniform): + + self.bias = bias + self.log = log + self.reg = reg + self.distribution_estimation = distribution_estimation + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + self.mu_s = self.distribution_estimation(Xs) + self.mu_t = self.distribution_estimation(Xt) + + # coupling estimation + returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, + ws=self.mu_s.reshape((-1, 1)), + wt=self.mu_t.reshape((-1, 1)), + bias=self.bias, log=self.log) + + # deal with the value of log + if self.log: + self.A_, self.B_, self.log_ = returned_ + else: + self.A_, self.B_, = returned_ + self.log_ = dict() + + # re compute inverse mapping + self.A1_ = linalg.inv(self.A_) + self.B1_ = -self.B_.dot(self.A1_) + + return self + + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xs : array-like, shape (n_source_samples, n_features) + The transport source samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + transp_Xs = Xs.dot(self.A_) + self.B_ + + return transp_Xs + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, + batch_size=128): + """Transports target samples Xt onto target samples Xs + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xt : array-like, shape (n_source_samples, n_features) + The transported target samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xt=Xt): + + transp_Xt = Xt.dot(self.A1_) + self.B1_ + + return transp_Xt + + +class SinkhornTransport(BaseTransport): + + """Domain Adapatation OT method based on Sinkhorn Algorithm + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + max_iter : int, float, optional (default=1000) + The minimum number of iteration before stopping the optimization + algorithm if no it has not converged + tol : float, optional (default=10e-9) + The precision required to stop the optimization algorithm. + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + log : int, optional (default=False) + Controls the logs of the optimization algorithm + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + limit_max: float, optional (defaul=np.infty) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an cost defined + by this variable + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True + + References + ---------- + .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE Transactions + on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) + 26, 2013 + """ + + def __init__(self, reg_e=1., max_iter=1000, + tol=10e-9, verbose=False, log=False, + metric="sqeuclidean", norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans', limit_max=np.infty): + + self.reg_e = reg_e + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.log = log + self.metric = metric + self.norm = norm + self.limit_max = limit_max + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + super(SinkhornTransport, self).fit(Xs, ys, Xt, yt) + + # coupling estimation + returned_ = sinkhorn( + a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e, + numItermax=self.max_iter, stopThr=self.tol, + verbose=self.verbose, log=self.log) + + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + + return self + + +class EMDTransport(BaseTransport): + + """Domain Adapatation OT method based on Earth Mover's Distance + + Parameters + ---------- + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + log : int, optional (default=False) + Controls the logs of the optimization algorithm + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + limit_max: float, optional (default=10) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an infinite cost + (10 times the maximum value of the cost matrix) + max_iter : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + + References + ---------- + .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE Transactions + on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + """ + + def __init__(self, metric="sqeuclidean", norm=None, log=False, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans', limit_max=10, + max_iter=100000): + + self.metric = metric + self.norm = norm + self.log = log + self.limit_max = limit_max + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + self.max_iter = max_iter + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + super(EMDTransport, self).fit(Xs, ys, Xt, yt) + + returned_ = emd( + a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter, + log=self.log) + + # coupling estimation + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + return self + + +class SinkhornLpl1Transport(BaseTransport): + + """Domain Adapatation OT method based on sinkhorn algorithm + + LpL1 class regularization. + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + reg_cl : float, optional (default=0.1) + Class regularization parameter + max_iter : int, float, optional (default=10) + The minimum number of iteration before stopping the optimization + algorithm if no it has not converged + max_inner_iter : int, float, optional (default=200) + The number of iteration in the inner loop + log : bool, optional (default=False) + Controls the logs of the optimization algorithm + tol : float, optional (default=10e-9) + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + limit_max: float, optional (defaul=np.infty) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit a cost defined by + limit_max. + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + + References + ---------- + + .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE + Transactions on Pattern Analysis and Machine Intelligence , + vol.PP, no.99, pp.1-1 + .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). + Generalized conditional gradient: analysis of convergence + and applications. arXiv preprint arXiv:1510.06567. + + """ + + def __init__(self, reg_e=1., reg_cl=0.1, + max_iter=10, max_inner_iter=200, log=False, + tol=10e-9, verbose=False, + metric="sqeuclidean", norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans', limit_max=np.infty): + + self.reg_e = reg_e + self.reg_cl = reg_cl + self.max_iter = max_iter + self.max_inner_iter = max_inner_iter + self.tol = tol + self.log = log + self.verbose = verbose + self.metric = metric + self.norm = norm + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + self.limit_max = limit_max + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt, ys=ys): + + super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt) + + returned_ = sinkhorn_lpl1_mm( + a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_, + reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter, + numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, + verbose=self.verbose, log=self.log) + + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + return self + + +class SinkhornL1l2Transport(BaseTransport): + + """Domain Adapatation OT method based on sinkhorn algorithm + + l1l2 class regularization. + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + reg_cl : float, optional (default=0.1) + Class regularization parameter + max_iter : int, float, optional (default=10) + The minimum number of iteration before stopping the optimization + algorithm if no it has not converged + max_inner_iter : int, float, optional (default=200) + The number of iteration in the inner loop + tol : float, optional (default=10e-9) + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + log : bool, optional (default=False) + Controls the logs of the optimization algorithm + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + limit_max: float, optional (default=10) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an infinite cost + (10 times the maximum value of the cost matrix) + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True + + References + ---------- + + .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE + Transactions on Pattern Analysis and Machine Intelligence , + vol.PP, no.99, pp.1-1 + .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). + Generalized conditional gradient: analysis of convergence + and applications. arXiv preprint arXiv:1510.06567. + + """ + + def __init__(self, reg_e=1., reg_cl=0.1, + max_iter=10, max_inner_iter=200, + tol=10e-9, verbose=False, log=False, + metric="sqeuclidean", norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans', limit_max=10): + + self.reg_e = reg_e + self.reg_cl = reg_cl + self.max_iter = max_iter + self.max_inner_iter = max_inner_iter + self.tol = tol + self.verbose = verbose + self.log = log + self.metric = metric + self.norm = norm + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + self.limit_max = limit_max + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt, ys=ys): + + super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt) + + returned_ = sinkhorn_l1l2_gl( + a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_, + reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter, + numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, + verbose=self.verbose, log=self.log) + + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + + return self + + +class MappingTransport(BaseEstimator): + + """MappingTransport: DA methods that aims at jointly estimating a optimal + transport coupling and the associated mapping + + Parameters + ---------- + mu : float, optional (default=1) + Weight for the linear OT loss (>0) + eta : float, optional (default=0.001) + Regularization term for the linear mapping L (>0) + bias : bool, optional (default=False) + Estimate linear mapping with constant bias + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + kernel : string, optional (default="linear") + The kernel to use either linear or gaussian + sigma : float, optional (default=1) + The gaussian kernel parameter + max_iter : int, optional (default=100) + Max number of BCD iterations + tol : float, optional (default=1e-5) + Stop threshold on relative loss decrease (>0) + max_inner_iter : int, optional (default=10) + Max number of iterations (inner CG solver) + inner_tol : float, optional (default=1e-6) + Stop threshold on error (inner CG solver) (>0) + log : bool, optional (default=False) + record log if True + verbose : bool, optional (default=False) + Print information along iterations + verbose2 : bool, optional (default=False) + Print information along iterations + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + mapping_ : array-like, shape (n_features (+ 1), n_features) + (if bias) for kernel == linear + The associated mapping + array-like, shape (n_source_samples (+ 1), n_features) + (if bias) for kernel == gaussian + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True + + References + ---------- + + .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, + "Mapping estimation for discrete optimal transport", + Neural Information Processing Systems (NIPS), 2016. + + """ + + def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean", + norm=None, kernel="linear", sigma=1, max_iter=100, tol=1e-5, + max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False, + verbose2=False): + + self.metric = metric + self.norm = norm + self.mu = mu + self.eta = eta + self.bias = bias + self.kernel = kernel + self.sigma = sigma + self.max_iter = max_iter + self.tol = tol + self.max_inner_iter = max_inner_iter + self.inner_tol = inner_tol + self.log = log + self.verbose = verbose + self.verbose2 = verbose2 + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + """Builds an optimal coupling and estimates the associated mapping + from source and target sets of samples (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt): + + self.xs_ = Xs + self.xt_ = Xt + + if self.kernel == "linear": + returned_ = joint_OT_mapping_linear( + Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, + verbose=self.verbose, verbose2=self.verbose2, + numItermax=self.max_iter, + numInnerItermax=self.max_inner_iter, stopThr=self.tol, + stopInnerThr=self.inner_tol, log=self.log) + + elif self.kernel == "gaussian": + returned_ = joint_OT_mapping_kernel( + Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, + sigma=self.sigma, verbose=self.verbose, + verbose2=self.verbose, numItermax=self.max_iter, + numInnerItermax=self.max_inner_iter, + stopInnerThr=self.inner_tol, stopThr=self.tol, + log=self.log) + + # deal with the value of log + if self.log: + self.coupling_, self.mapping_, self.log_ = returned_ + else: + self.coupling_, self.mapping_ = returned_ + self.log_ = dict() + + return self + + def transform(self, Xs): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + + Returns + ------- + transp_Xs : array-like, shape (n_source_samples, n_features) + The transport source samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + if np.array_equal(self.xs_, Xs): + # perform standard barycentric mapping + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute transported samples + transp_Xs = np.dot(transp, self.xt_) + else: + if self.kernel == "gaussian": + K = kernel(Xs, self.xs_, method=self.kernel, + sigma=self.sigma) + elif self.kernel == "linear": + K = Xs + if self.bias: + K = np.hstack((K, np.ones((Xs.shape[0], 1)))) + transp_Xs = K.dot(self.mapping_) + + return transp_Xs + + +class UnbalancedSinkhornTransport(BaseTransport): + + """Domain Adapatation unbalanced OT method based on sinkhorn algorithm + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + reg_m : float, optional (default=0.1) + Mass regularization parameter + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_epsilon_scaling', see those function for specific parameters + max_iter : int, float, optional (default=10) + The minimum number of iteration before stopping the optimization + algorithm if no it has not converged + tol : float, optional (default=10e-9) + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + log : bool, optional (default=False) + Controls the logs of the optimization algorithm + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + limit_max: float, optional (default=10) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an infinite cost + (10 times the maximum value of the cost matrix) + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True + + References + ---------- + + .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + + """ + + def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', + max_iter=10, tol=1e-9, verbose=False, log=False, + metric="sqeuclidean", norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans', limit_max=10): + + self.reg_e = reg_e + self.reg_m = reg_m + self.method = method + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.log = log + self.metric = metric + self.norm = norm + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + self.limit_max = limit_max + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt): + + super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt) + + returned_ = sinkhorn_unbalanced( + a=self.mu_s, b=self.mu_t, M=self.cost_, + reg=self.reg_e, reg_m=self.reg_m, method=self.method, + numItermax=self.max_iter, stopThr=self.tol, + verbose=self.verbose, log=self.log) + + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + + return self diff --git a/ot/datasets.py b/ot/datasets.py new file mode 100644 index 0000000..ba0cfd9 --- /dev/null +++ b/ot/datasets.py @@ -0,0 +1,164 @@ +""" +Simple example datasets for OT +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + + +import numpy as np +import scipy as sp +from .utils import check_random_state, deprecated + + +def make_1D_gauss(n, m, s): + """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) + + Parameters + ---------- + n : int + number of bins in the histogram + m : float + mean value of the gaussian distribution + s : float + standard deviaton of the gaussian distribution + + Returns + ------- + h : ndarray (n,) + 1D histogram for a gaussian distribution + """ + x = np.arange(n, dtype=np.float64) + h = np.exp(-(x - m)**2 / (2 * s**2)) + return h / h.sum() + + +@deprecated() +def get_1D_gauss(n, m, sigma): + """ Deprecated see make_1D_gauss """ + return make_1D_gauss(n, m, sigma) + + +def make_2D_samples_gauss(n, m, sigma, random_state=None): + """Return n samples drawn from 2D gaussian N(m,sigma) + + Parameters + ---------- + n : int + number of samples to make + m : ndarray, shape (2,) + mean value of the gaussian distribution + sigma : ndarray, shape (2, 2) + covariance matrix of the gaussian distribution + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + Returns + ------- + X : ndarray, shape (n, 2) + n samples drawn from N(m, sigma). + """ + + generator = check_random_state(random_state) + if np.isscalar(sigma): + sigma = np.array([sigma, ]) + if len(sigma) > 1: + P = sp.linalg.sqrtm(sigma) + res = generator.randn(n, 2).dot(P) + m + else: + res = generator.randn(n, 2) * np.sqrt(sigma) + m + return res + + +@deprecated() +def get_2D_samples_gauss(n, m, sigma, random_state=None): + """ Deprecated see make_2D_samples_gauss """ + return make_2D_samples_gauss(n, m, sigma, random_state=None) + + +def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): + """Dataset generation for classification problems + + Parameters + ---------- + dataset : str + type of classification problem (see code) + n : int + number of training samples + nz : float + noise level (>0) + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + Returns + ------- + X : ndarray, shape (n, d) + n observation of size d + y : ndarray, shape (n,) + labels of the samples. + """ + generator = check_random_state(random_state) + + if dataset.lower() == '3gauss': + y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 + x = np.zeros((n, 2)) + # class 1 + x[y == 1, 0] = -1. + x[y == 1, 1] = -1. + x[y == 2, 0] = -1. + x[y == 2, 1] = 1. + x[y == 3, 0] = 1. + x[y == 3, 1] = 0 + + x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) + + elif dataset.lower() == '3gauss2': + y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 + x = np.zeros((n, 2)) + y[y == 4] = 3 + # class 1 + x[y == 1, 0] = -2. + x[y == 1, 1] = -2. + x[y == 2, 0] = -2. + x[y == 2, 1] = 2. + x[y == 3, 0] = 2. + x[y == 3, 1] = 0 + + x[y != 3, :] += nz * generator.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) + + elif dataset.lower() == 'gaussrot': + rot = np.array( + [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]]) + m1 = np.array([-1, 1]) + m2 = np.array([1, -1]) + y = np.floor((np.arange(n) * 1.0 / n * 2)) + 1 + n1 = np.sum(y == 1) + n2 = np.sum(y == 2) + x = np.zeros((n, 2)) + + x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator) + x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator) + + x = x.dot(rot) + + else: + x = np.array(0) + y = np.array(0) + print("unknown dataset") + + return x, y.astype(int) + + +@deprecated() +def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): + """ Deprecated see make_data_classif """ + return make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs) diff --git a/ot/dr.py b/ot/dr.py new file mode 100644 index 0000000..680dabf --- /dev/null +++ b/ot/dr.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- +""" +Dimension reduction with optimal transport + + +.. warning:: + Note that by default the module is not import in :mod:`ot`. In order to + use it you need to explicitely import :mod:`ot.dr` + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +from scipy import linalg +import autograd.numpy as np +from pymanopt.manifolds import Stiefel +from pymanopt import Problem +from pymanopt.solvers import SteepestDescent, TrustRegions + + +def dist(x1, x2): + """ Compute squared euclidean distance between samples (autograd) + """ + x1p2 = np.sum(np.square(x1), 1) + x2p2 = np.sum(np.square(x2), 1) + return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T) + + +def sinkhorn(w1, w2, M, reg, k): + """Sinkhorn algorithm with fixed number of iteration (autograd) + """ + K = np.exp(-M / reg) + ui = np.ones((M.shape[0],)) + vi = np.ones((M.shape[1],)) + for i in range(k): + vi = w2 / (np.dot(K.T, ui)) + ui = w1 / (np.dot(K, vi)) + G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1])) + return G + + +def split_classes(X, y): + """split samples in X by classes in y + """ + lstsclass = np.unique(y) + return [X[y == i, :].astype(np.float32) for i in lstsclass] + + +def fda(X, y, p=2, reg=1e-16): + """Fisher Discriminant Analysis + + Parameters + ---------- + X : ndarray, shape (n, d) + Training samples. + y : ndarray, shape (n,) + Labels for training samples. + p : int, optional + Size of dimensionnality reduction. + reg : float, optional + Regularization term >0 (ridge regularization) + + Returns + ------- + P : ndarray, shape (d, p) + Optimal transportation matrix for the given parameters + proj : callable + projection function including mean centering + """ + + mx = np.mean(X) + X -= mx.reshape((1, -1)) + + # data split between classes + d = X.shape[1] + xc = split_classes(X, y) + nc = len(xc) + + p = min(nc - 1, p) + + Cw = 0 + for x in xc: + Cw += np.cov(x, rowvar=False) + Cw /= nc + + mxc = np.zeros((d, nc)) + + for i in range(nc): + mxc[:, i] = np.mean(xc[i]) + + mx0 = np.mean(mxc, 1) + Cb = 0 + for i in range(nc): + Cb += (mxc[:, i] - mx0).reshape((-1, 1)) * \ + (mxc[:, i] - mx0).reshape((1, -1)) + + w, V = linalg.eig(Cb, Cw + reg * np.eye(d)) + + idx = np.argsort(w.real) + + Popt = V[:, idx[-p:]] + + def proj(X): + return (X - mx.reshape((1, -1))).dot(Popt) + + return Popt, proj + + +def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): + """ + Wasserstein Discriminant Analysis [11]_ + + The function solves the following optimization problem: + + .. math:: + P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)} + + where : + + - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold + - :math:`W` is entropic regularized Wasserstein distances + - :math:`X^i` are samples in the dataset corresponding to class i + + Parameters + ---------- + X : ndarray, shape (n, d) + Training samples. + y : ndarray, shape (n,) + Labels for training samples. + p : int, optional + Size of dimensionnality reduction. + reg : float, optional + Regularization term >0 (entropic regularization) + solver : None | str, optional + None for steepest descent or 'TrustRegions' for trust regions algorithm + else should be a pymanopt.solvers + P0 : ndarray, shape (d, p) + Initial starting point for projection. + verbose : int, optional + Print information along iterations. + + Returns + ------- + P : ndarray, shape (d, p) + Optimal transportation matrix for the given parameters + proj : callable + Projection function including mean centering. + + References + ---------- + .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). + Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. + """ # noqa + + mx = np.mean(X) + X -= mx.reshape((1, -1)) + + # data split between classes + d = X.shape[1] + xc = split_classes(X, y) + # compute uniform weighs + wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc] + + def cost(P): + # wda loss + loss_b = 0 + loss_w = 0 + + for i, xi in enumerate(xc): + xi = np.dot(xi, P) + for j, xj in enumerate(xc[i:]): + xj = np.dot(xj, P) + M = dist(xi, xj) + G = sinkhorn(wc[i], wc[j + i], M, reg, k) + if j == 0: + loss_w += np.sum(G * M) + else: + loss_b += np.sum(G * M) + + # loss inversed because minimization + return loss_w / loss_b + + # declare manifold and problem + manifold = Stiefel(d, p) + problem = Problem(manifold=manifold, cost=cost) + + # declare solver and solve + if solver is None: + solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose) + elif solver in ['tr', 'TrustRegions']: + solver = TrustRegions(maxiter=maxiter, logverbosity=verbose) + + Popt = solver.solve(problem, x=P0) + + def proj(X): + return (X - mx.reshape((1, -1))).dot(Popt) + + return Popt, proj diff --git a/ot/externals/__init__.py b/ot/externals/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/ot/externals/__init__.py diff --git a/ot/externals/funcsigs.py b/ot/externals/funcsigs.py new file mode 100644 index 0000000..106bde7 --- /dev/null +++ b/ot/externals/funcsigs.py @@ -0,0 +1,817 @@ +# Copyright 2001-2013 Python Software Foundation; All Rights Reserved +"""Function signature objects for callables + +Back port of Python 3.3's function signature tools from the inspect module, +modified to be compatible with Python 2.7 and 3.2+. +""" +from __future__ import absolute_import, division, print_function +import itertools +import functools +import re +import types + +from collections import OrderedDict + +__version__ = "0.4" + +__all__ = ['BoundArguments', 'Parameter', 'Signature', 'signature'] + + +_WrapperDescriptor = type(type.__call__) +_MethodWrapper = type(all.__call__) + +_NonUserDefinedCallables = (_WrapperDescriptor, + _MethodWrapper, + types.BuiltinFunctionType) + + +def formatannotation(annotation, base_module=None): + if isinstance(annotation, type): + if annotation.__module__ in ('builtins', '__builtin__', base_module): + return annotation.__name__ + return annotation.__module__ + '.' + annotation.__name__ + return repr(annotation) + + +def _get_user_defined_method(cls, method_name, *nested): + try: + if cls is type: + return + meth = getattr(cls, method_name) + for name in nested: + meth = getattr(meth, name, meth) + except AttributeError: + return + else: + if not isinstance(meth, _NonUserDefinedCallables): + # Once '__signature__' will be added to 'C'-level + # callables, this check won't be necessary + return meth + + +def signature(obj): + '''Get a signature object for the passed callable.''' + + if not callable(obj): + raise TypeError('{0!r} is not a callable object'.format(obj)) + + if isinstance(obj, types.MethodType): + sig = signature(obj.__func__) + if obj.__self__ is None: + # Unbound method: the first parameter becomes positional-only + if sig.parameters: + first = sig.parameters.values()[0].replace( + kind=_POSITIONAL_ONLY) + return sig.replace( + parameters=(first,) + tuple(sig.parameters.values())[1:]) + else: + return sig + else: + # In this case we skip the first parameter of the underlying + # function (usually `self` or `cls`). + return sig.replace(parameters=tuple(sig.parameters.values())[1:]) + + try: + sig = obj.__signature__ + except AttributeError: + pass + else: + if sig is not None: + return sig + + try: + # Was this function wrapped by a decorator? + wrapped = obj.__wrapped__ + except AttributeError: + pass + else: + return signature(wrapped) + + if isinstance(obj, types.FunctionType): + return Signature.from_function(obj) + + if isinstance(obj, functools.partial): + sig = signature(obj.func) + + new_params = OrderedDict(sig.parameters.items()) + + partial_args = obj.args or () + partial_keywords = obj.keywords or {} + try: + ba = sig.bind_partial(*partial_args, **partial_keywords) + except TypeError: + msg = 'partial object {0!r} has incorrect arguments'.format(obj) + raise ValueError(msg) + + for arg_name, arg_value in ba.arguments.items(): + param = new_params[arg_name] + if arg_name in partial_keywords: + # We set a new default value, because the following code + # is correct: + # + # >>> def foo(a): print(a) + # >>> print(partial(partial(foo, a=10), a=20)()) + # 20 + # >>> print(partial(partial(foo, a=10), a=20)(a=30)) + # 30 + # + # So, with 'partial' objects, passing a keyword argument is + # like setting a new default value for the corresponding + # parameter + # + # We also mark this parameter with '_partial_kwarg' + # flag. Later, in '_bind', the 'default' value of this + # parameter will be added to 'kwargs', to simulate + # the 'functools.partial' real call. + new_params[arg_name] = param.replace(default=arg_value, + _partial_kwarg=True) + + elif (param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL) + and not param._partial_kwarg): + new_params.pop(arg_name) + + return sig.replace(parameters=new_params.values()) + + sig = None + if isinstance(obj, type): + # obj is a class or a metaclass + + # First, let's see if it has an overloaded __call__ defined + # in its metaclass + call = _get_user_defined_method(type(obj), '__call__') + if call is not None: + sig = signature(call) + else: + # Now we check if the 'obj' class has a '__new__' method + new = _get_user_defined_method(obj, '__new__') + if new is not None: + sig = signature(new) + else: + # Finally, we should have at least __init__ implemented + init = _get_user_defined_method(obj, '__init__') + if init is not None: + sig = signature(init) + elif not isinstance(obj, _NonUserDefinedCallables): + # An object with __call__ + # We also check that the 'obj' is not an instance of + # _WrapperDescriptor or _MethodWrapper to avoid + # infinite recursion (and even potential segfault) + call = _get_user_defined_method(type(obj), '__call__', 'im_func') + if call is not None: + sig = signature(call) + + if sig is not None: + # For classes and objects we skip the first parameter of their + # __call__, __new__, or __init__ methods + return sig.replace(parameters=tuple(sig.parameters.values())[1:]) + + if isinstance(obj, types.BuiltinFunctionType): + # Raise a nicer error message for builtins + msg = 'no signature found for builtin function {0!r}'.format(obj) + raise ValueError(msg) + + raise ValueError( + 'callable {0!r} is not supported by signature'.format(obj)) + + +class _void(object): + '''A private marker - used in Parameter & Signature''' + + +class _empty(object): + pass + + +class _ParameterKind(int): + def __new__(self, *args, **kwargs): + obj = int.__new__(self, *args) + obj._name = kwargs['name'] + return obj + + def __str__(self): + return self._name + + def __repr__(self): + return '<_ParameterKind: {0!r}>'.format(self._name) + + +_POSITIONAL_ONLY = _ParameterKind(0, name='POSITIONAL_ONLY') +_POSITIONAL_OR_KEYWORD = _ParameterKind(1, name='POSITIONAL_OR_KEYWORD') +_VAR_POSITIONAL = _ParameterKind(2, name='VAR_POSITIONAL') +_KEYWORD_ONLY = _ParameterKind(3, name='KEYWORD_ONLY') +_VAR_KEYWORD = _ParameterKind(4, name='VAR_KEYWORD') + + +class Parameter(object): + '''Represents a parameter in a function signature. + + Has the following public attributes: + + * name : str + The name of the parameter as a string. + * default : object + The default value for the parameter if specified. If the + parameter has no default value, this attribute is not set. + * annotation + The annotation for the parameter if specified. If the + parameter has no annotation, this attribute is not set. + * kind : str + Describes how argument values are bound to the parameter. + Possible values: `Parameter.POSITIONAL_ONLY`, + `Parameter.POSITIONAL_OR_KEYWORD`, `Parameter.VAR_POSITIONAL`, + `Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`. + ''' + + __slots__ = ('_name', '_kind', '_default', '_annotation', '_partial_kwarg') + + POSITIONAL_ONLY = _POSITIONAL_ONLY + POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD + VAR_POSITIONAL = _VAR_POSITIONAL + KEYWORD_ONLY = _KEYWORD_ONLY + VAR_KEYWORD = _VAR_KEYWORD + + empty = _empty + + def __init__(self, name, kind, default=_empty, annotation=_empty, + _partial_kwarg=False): + + if kind not in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD, + _VAR_POSITIONAL, _KEYWORD_ONLY, _VAR_KEYWORD): + raise ValueError("invalid value for 'Parameter.kind' attribute") + self._kind = kind + + if default is not _empty: + if kind in (_VAR_POSITIONAL, _VAR_KEYWORD): + msg = '{0} parameters cannot have default values'.format(kind) + raise ValueError(msg) + self._default = default + self._annotation = annotation + + if name is None: + if kind != _POSITIONAL_ONLY: + raise ValueError("None is not a valid name for a " + "non-positional-only parameter") + self._name = name + else: + name = str(name) + if kind != _POSITIONAL_ONLY and not re.match( + r'[a-z_]\w*$', name, re.I): + msg = '{0!r} is not a valid parameter name'.format(name) + raise ValueError(msg) + self._name = name + + self._partial_kwarg = _partial_kwarg + + @property + def name(self): + return self._name + + @property + def default(self): + return self._default + + @property + def annotation(self): + return self._annotation + + @property + def kind(self): + return self._kind + + def replace(self, name=_void, kind=_void, annotation=_void, + default=_void, _partial_kwarg=_void): + '''Creates a customized copy of the Parameter.''' + + if name is _void: + name = self._name + + if kind is _void: + kind = self._kind + + if annotation is _void: + annotation = self._annotation + + if default is _void: + default = self._default + + if _partial_kwarg is _void: + _partial_kwarg = self._partial_kwarg + + return type(self)(name, kind, default=default, annotation=annotation, + _partial_kwarg=_partial_kwarg) + + def __str__(self): + kind = self.kind + + formatted = self._name + if kind == _POSITIONAL_ONLY: + if formatted is None: + formatted = '' + formatted = '<{0}>'.format(formatted) + + # Add annotation and default value + if self._annotation is not _empty: + formatted = '{0}:{1}'.format(formatted, + formatannotation(self._annotation)) + + if self._default is not _empty: + formatted = '{0}={1}'.format(formatted, repr(self._default)) + + if kind == _VAR_POSITIONAL: + formatted = '*' + formatted + elif kind == _VAR_KEYWORD: + formatted = '**' + formatted + + return formatted + + def __repr__(self): + return '<{0} at {1:#x} {2!r}>'.format(self.__class__.__name__, + id(self), self.name) + + def __hash__(self): + msg = "unhashable type: '{0}'".format(self.__class__.__name__) + raise TypeError(msg) + + def __eq__(self, other): + return (issubclass(other.__class__, Parameter) + and self._name == other._name + and self._kind == other._kind + and self._default == other._default + and self._annotation == other._annotation) + + def __ne__(self, other): + return not self.__eq__(other) + + +class BoundArguments(object): + '''Result of `Signature.bind` call. Holds the mapping of arguments + to the function's parameters. + + Has the following public attributes: + + * arguments : OrderedDict + An ordered mutable mapping of parameters' names to arguments' values. + Does not contain arguments' default values. + * signature : Signature + The Signature object that created this instance. + * args : tuple + Tuple of positional arguments values. + * kwargs : dict + Dict of keyword arguments values. + ''' + + def __init__(self, signature, arguments): + self.arguments = arguments + self._signature = signature + + @property + def signature(self): + return self._signature + + @property + def args(self): + args = [] + for param_name, param in self._signature.parameters.items(): + if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) + or param._partial_kwarg): + # Keyword arguments mapped by 'functools.partial' + # (Parameter._partial_kwarg is True) are mapped + # in 'BoundArguments.kwargs', along with VAR_KEYWORD & + # KEYWORD_ONLY + break + + try: + arg = self.arguments[param_name] + except KeyError: + # We're done here. Other arguments + # will be mapped in 'BoundArguments.kwargs' + break + else: + if param.kind == _VAR_POSITIONAL: + # *args + args.extend(arg) + else: + # plain argument + args.append(arg) + + return tuple(args) + + @property + def kwargs(self): + kwargs = {} + kwargs_started = False + for param_name, param in self._signature.parameters.items(): + if not kwargs_started: + if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) + or param._partial_kwarg): + kwargs_started = True + else: + if param_name not in self.arguments: + kwargs_started = True + continue + + if not kwargs_started: + continue + + try: + arg = self.arguments[param_name] + except KeyError: + pass + else: + if param.kind == _VAR_KEYWORD: + # **kwargs + kwargs.update(arg) + else: + # plain keyword argument + kwargs[param_name] = arg + + return kwargs + + def __hash__(self): + msg = "unhashable type: '{0}'".format(self.__class__.__name__) + raise TypeError(msg) + + def __eq__(self, other): + return (issubclass(other.__class__, BoundArguments) + and self.signature == other.signature + and self.arguments == other.arguments) + + def __ne__(self, other): + return not self.__eq__(other) + + +class Signature(object): + '''A Signature object represents the overall signature of a function. + It stores a Parameter object for each parameter accepted by the + function, as well as information specific to the function itself. + + A Signature object has the following public attributes and methods: + + * parameters : OrderedDict + An ordered mapping of parameters' names to the corresponding + Parameter objects (keyword-only arguments are in the same order + as listed in `code.co_varnames`). + * return_annotation : object + The annotation for the return type of the function if specified. + If the function has no annotation for its return type, this + attribute is not set. + * bind(*args, **kwargs) -> BoundArguments + Creates a mapping from positional and keyword arguments to + parameters. + * bind_partial(*args, **kwargs) -> BoundArguments + Creates a partial mapping from positional and keyword arguments + to parameters (simulating 'functools.partial' behavior.) + ''' + + __slots__ = ('_return_annotation', '_parameters') + + _parameter_cls = Parameter + _bound_arguments_cls = BoundArguments + + empty = _empty + + def __init__(self, parameters=None, return_annotation=_empty, + __validate_parameters__=True): + '''Constructs Signature from the given list of Parameter + objects and 'return_annotation'. All arguments are optional. + ''' + + if parameters is None: + params = OrderedDict() + else: + if __validate_parameters__: + params = OrderedDict() + top_kind = _POSITIONAL_ONLY + + for idx, param in enumerate(parameters): + kind = param.kind + if kind < top_kind: + msg = 'wrong parameter order: {0} before {1}' + msg = msg.format(top_kind, param.kind) + raise ValueError(msg) + else: + top_kind = kind + + name = param.name + if name is None: + name = str(idx) + param = param.replace(name=name) + + if name in params: + msg = 'duplicate parameter name: {0!r}'.format(name) + raise ValueError(msg) + params[name] = param + else: + params = OrderedDict(((param.name, param) + for param in parameters)) + + self._parameters = params + self._return_annotation = return_annotation + + @classmethod + def from_function(cls, func): + '''Constructs Signature for the given python function''' + + if not isinstance(func, types.FunctionType): + raise TypeError('{0!r} is not a Python function'.format(func)) + + Parameter = cls._parameter_cls + + # Parameter information. + func_code = func.__code__ + pos_count = func_code.co_argcount + arg_names = func_code.co_varnames + positional = tuple(arg_names[:pos_count]) + keyword_only_count = getattr(func_code, 'co_kwonlyargcount', 0) + keyword_only = arg_names[pos_count:(pos_count + keyword_only_count)] + annotations = getattr(func, '__annotations__', {}) + defaults = func.__defaults__ + kwdefaults = getattr(func, '__kwdefaults__', None) + + if defaults: + pos_default_count = len(defaults) + else: + pos_default_count = 0 + + parameters = [] + + # Non-keyword-only parameters w/o defaults. + non_default_count = pos_count - pos_default_count + for name in positional[:non_default_count]: + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_POSITIONAL_OR_KEYWORD)) + + # ... w/ defaults. + for offset, name in enumerate(positional[non_default_count:]): + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_POSITIONAL_OR_KEYWORD, + default=defaults[offset])) + + # *args + if func_code.co_flags & 0x04: + name = arg_names[pos_count + keyword_only_count] + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_VAR_POSITIONAL)) + + # Keyword-only parameters. + for name in keyword_only: + default = _empty + if kwdefaults is not None: + default = kwdefaults.get(name, _empty) + + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_KEYWORD_ONLY, + default=default)) + # **kwargs + if func_code.co_flags & 0x08: + index = pos_count + keyword_only_count + if func_code.co_flags & 0x04: + index += 1 + + name = arg_names[index] + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_VAR_KEYWORD)) + + return cls(parameters, + return_annotation=annotations.get('return', _empty), + __validate_parameters__=False) + + @property + def parameters(self): + try: + return types.MappingProxyType(self._parameters) + except AttributeError: + return OrderedDict(self._parameters.items()) + + @property + def return_annotation(self): + return self._return_annotation + + def replace(self, parameters=_void, return_annotation=_void): + '''Creates a customized copy of the Signature. + Pass 'parameters' and/or 'return_annotation' arguments + to override them in the new copy. + ''' + + if parameters is _void: + parameters = self.parameters.values() + + if return_annotation is _void: + return_annotation = self._return_annotation + + return type(self)(parameters, + return_annotation=return_annotation) + + def __hash__(self): + msg = "unhashable type: '{0}'".format(self.__class__.__name__) + raise TypeError(msg) + + def __eq__(self, other): + if (not issubclass(type(other), Signature) + or self.return_annotation != other.return_annotation + or len(self.parameters) != len(other.parameters)): + return False + + other_positions = dict((param, idx) + for idx, param in enumerate(other.parameters.keys())) + + for idx, (param_name, param) in enumerate(self.parameters.items()): + if param.kind == _KEYWORD_ONLY: + try: + other_param = other.parameters[param_name] + except KeyError: + return False + else: + if param != other_param: + return False + else: + try: + other_idx = other_positions[param_name] + except KeyError: + return False + else: + if (idx != other_idx + or param != other.parameters[param_name]): + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def _bind(self, args, kwargs, partial=False): + '''Private method. Don't use directly.''' + + arguments = OrderedDict() + + parameters = iter(self.parameters.values()) + parameters_ex = () + arg_vals = iter(args) + + if partial: + # Support for binding arguments to 'functools.partial' objects. + # See 'functools.partial' case in 'signature()' implementation + # for details. + for param_name, param in self.parameters.items(): + if (param._partial_kwarg and param_name not in kwargs): + # Simulating 'functools.partial' behavior + kwargs[param_name] = param.default + + while True: + # Let's iterate through the positional arguments and corresponding + # parameters + try: + arg_val = next(arg_vals) + except StopIteration: + # No more positional arguments + try: + param = next(parameters) + except StopIteration: + # No more parameters. That's it. Just need to check that + # we have no `kwargs` after this while loop + break + else: + if param.kind == _VAR_POSITIONAL: + # That's OK, just empty *args. Let's start parsing + # kwargs + break + elif param.name in kwargs: + if param.kind == _POSITIONAL_ONLY: + msg = '{arg!r} parameter is positional only, ' \ + 'but was passed as a keyword' + msg = msg.format(arg=param.name) + raise TypeError(msg) + parameters_ex = (param,) + break + elif (param.kind == _VAR_KEYWORD + or param.default is not _empty): + # That's fine too - we have a default value for this + # parameter. So, lets start parsing `kwargs`, starting + # with the current parameter + parameters_ex = (param,) + break + else: + if partial: + parameters_ex = (param,) + break + else: + msg = '{arg!r} parameter lacking default value' + msg = msg.format(arg=param.name) + raise TypeError(msg) + else: + # We have a positional argument to process + try: + param = next(parameters) + except StopIteration: + raise TypeError('too many positional arguments') + else: + if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY): + # Looks like we have no parameter for this positional + # argument + raise TypeError('too many positional arguments') + + if param.kind == _VAR_POSITIONAL: + # We have an '*args'-like argument, let's fill it with + # all positional arguments we have left and move on to + # the next phase + values = [arg_val] + values.extend(arg_vals) + arguments[param.name] = tuple(values) + break + + if param.name in kwargs: + raise TypeError('multiple values for argument ' + '{arg!r}'.format(arg=param.name)) + + arguments[param.name] = arg_val + + # Now, we iterate through the remaining parameters to process + # keyword arguments + kwargs_param = None + for param in itertools.chain(parameters_ex, parameters): + if param.kind == _POSITIONAL_ONLY: + # This should never happen in case of a properly built + # Signature object (but let's have this check here + # to ensure correct behaviour just in case) + raise TypeError('{arg!r} parameter is positional only, ' + 'but was passed as a keyword'. + format(arg=param.name)) + + if param.kind == _VAR_KEYWORD: + # Memorize that we have a '**kwargs'-like parameter + kwargs_param = param + continue + + param_name = param.name + try: + arg_val = kwargs.pop(param_name) + except KeyError: + # We have no value for this parameter. It's fine though, + # if it has a default value, or it is an '*args'-like + # parameter, left alone by the processing of positional + # arguments. + if (not partial and param.kind != _VAR_POSITIONAL + and param.default is _empty): + raise TypeError('{arg!r} parameter lacking default value'. + format(arg=param_name)) + + else: + arguments[param_name] = arg_val + + if kwargs: + if kwargs_param is not None: + # Process our '**kwargs'-like parameter + arguments[kwargs_param.name] = kwargs + else: + raise TypeError('too many keyword arguments') + + return self._bound_arguments_cls(self, arguments) + + def bind(self, *args, **kwargs): + '''Get a BoundArguments object, that maps the passed `args` + and `kwargs` to the function's signature. Raises `TypeError` + if the passed arguments can not be bound. + ''' + return self._bind(args, kwargs) + + def bind_partial(self, *args, **kwargs): + '''Get a BoundArguments object, that partially maps the + passed `args` and `kwargs` to the function's signature. + Raises `TypeError` if the passed arguments can not be bound. + ''' + return self._bind(args, kwargs, partial=True) + + def __str__(self): + result = [] + render_kw_only_separator = True + for idx, param in enumerate(self.parameters.values()): + formatted = str(param) + + kind = param.kind + if kind == _VAR_POSITIONAL: + # OK, we have an '*args'-like parameter, so we won't need + # a '*' to separate keyword-only arguments + render_kw_only_separator = False + elif kind == _KEYWORD_ONLY and render_kw_only_separator: + # We have a keyword-only parameter to render and we haven't + # rendered an '*args'-like parameter before, so add a '*' + # separator to the parameters list ("foo(arg1, *, arg2)" case) + result.append('*') + # This condition should be only triggered once, so + # reset the flag + render_kw_only_separator = False + + result.append(formatted) + + rendered = '({0})'.format(', '.join(result)) + + if self.return_annotation is not _empty: + anno = formatannotation(self.return_annotation) + rendered += ' -> {0}'.format(anno) + + return rendered diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py new file mode 100644 index 0000000..1ab95bb --- /dev/null +++ b/ot/gpu/__init__.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +""" + +This module provides GPU implementation for several OT solvers and utility +functions. The GPU backend in handled by `cupy +<https://cupy.chainer.org/>`_. + +.. warning:: + Note that by default the module is not import in :mod:`ot`. In order to + use it you need to explicitely import :mod:`ot.gpu` . + +By default, the functions in this module accept and return numpy arrays +in order to proide drop-in replacement for the other POT function but +the transfer between CPU en GPU comes with a significant overhead. + +In order to get the best performances, we recommend to give only cupy +arrays to the functions and desactivate the conversion to numpy of the +result of the function with parameter ``to_numpy=False``. + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Leo Gautheron <https://github.com/aje> +# +# License: MIT License + +from . import bregman +from . import da +from .bregman import sinkhorn +from .da import sinkhorn_lpl1_mm + +from . import utils +from .utils import dist, to_gpu, to_np + + + + + +__all__ = ["utils", "dist", "sinkhorn", + "sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np'] + diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py new file mode 100644 index 0000000..2e2df83 --- /dev/null +++ b/ot/gpu/bregman.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +""" +Bregman projections for regularized OT with GPU +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Leo Gautheron <https://github.com/aje> +# +# License: MIT License + +import cupy as np # np used for matrix computation +import cupy as cp # cp used for cupy specific operations +from . import utils + + +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, + verbose=False, log=False, to_numpy=True, **kwargs): + """ + Solve the entropic regularization optimal transport on GPU + + If the input matrix are in numpy format, they will be uploaded to the + GPU first which can incur significant time overhead. + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + to_numpy : boolean, optional (default True) + If true convert back the GPU array result to numpy format. + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = cp.asarray(a) + b = cp.asarray(b) + M = cp.asarray(M) + + if len(a) == 0: + a = np.ones((M.shape[0],)) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],)) / M.shape[1] + + # init data + Nini = len(a) + Nfin = len(b) + + if len(b.shape) > 1: + nbb = b.shape[1] + else: + nbb = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if nbb: + u = np.ones((Nini, nbb)) / Nini + v = np.ones((Nfin, nbb)) / Nfin + else: + u = np.ones(Nini) / Nini + v = np.ones(Nfin) / Nfin + + # print(reg) + + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + # print(np.min(K)) + tmp2 = np.empty(b.shape, dtype=M.dtype) + + Kp = (1 / a).reshape(-1, 1) * K + cpt = 0 + err = 1 + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + KtransposeU = np.dot(K.T, u) + v = np.divide(b, KtransposeU) + u = 1. / np.dot(Kp, v) + + if (np.any(KtransposeU == 0) or + np.any(np.isnan(u)) or np.any(np.isnan(v)) or + np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errors at iteration', cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + if nbb: + err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ + np.sum((v - vprev)**2) / np.sum((v)**2) + else: + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + tmp2 = np.sum(u[:, None] * K * v[None, :], 0) + #tmp2=np.einsum('i,ij,j->j', u, K, v) + err = np.linalg.norm(tmp2 - b)**2 # violation of marginal + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + if log: + log['u'] = u + log['v'] = v + + if nbb: # return only loss + #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory) + res = np.empty(nbb) + for i in range(nbb): + res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i]) + if to_numpy: + res = utils.to_np(res) + if log: + return res, log + else: + return res + + else: # return OT matrix + res = u.reshape((-1, 1)) * K * v.reshape((1, -1)) + if to_numpy: + res = utils.to_np(res) + if log: + return res, log + else: + return res + + +# define sinkhorn as sinkhorn_knopp +sinkhorn = sinkhorn_knopp diff --git a/ot/gpu/da.py b/ot/gpu/da.py new file mode 100644 index 0000000..4a98038 --- /dev/null +++ b/ot/gpu/da.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +""" +Domain adaptation with optimal transport with GPU implementation +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Nicolas Courty <ncourty@irisa.fr> +# Michael Perrot <michael.perrot@univ-st-etienne.fr> +# Leo Gautheron <https://github.com/aje> +# +# License: MIT License + + +import cupy as np # np used for matrix computation +import cupy as cp # cp used for cupy specific operations +import numpy as npp +from . import utils + +from .bregman import sinkhorn + + +def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, + numInnerItermax=200, stopInnerThr=1e-9, verbose=False, + log=False, to_numpy=True): + """ + Solve the entropic regularization optimal transport problem with nonconvex + group lasso regularization on GPU + + If the input matrix are in numpy format, they will be uploaded to the + GPU first which can incur significant time overhead. + + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma) + + \eta \Omega_g(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega_e` is the entropic regularization term + :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega_g` is the group lasso regulaization term + :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` + where :math:`\mathcal{I}_c` are the index of samples from class c + in the source domain. + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the generalised conditional + gradient as proposed in [5]_ [7]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + labels_a : np.ndarray (ns,) + labels of samples in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term for entropic regularization >0 + eta : float, optional + Regularization term for group lasso regularization >0 + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations (inner sinkhorn solver) + stopInnerThr : float, optional + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + to_numpy : boolean, optional (default True) + If true convert back the GPU array result to numpy format. + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE + Transactions on Pattern Analysis and Machine Intelligence , + vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). + Generalized conditional gradient: analysis of convergence + and applications. arXiv preprint arXiv:1510.06567. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + + """ + + a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M) + + p = 0.5 + epsilon = 1e-3 + + indices_labels = [] + labels_a2 = cp.asnumpy(labels_a) + classes = npp.unique(labels_a2) + for c in classes: + idxc, = utils.to_gpu(npp.where(labels_a2 == c)) + indices_labels.append(idxc) + + W = np.zeros(M.shape) + + for cpt in range(numItermax): + Mreg = M + eta * W + transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr, to_numpy=False) + # the transport has been computed. Check if classes are really + # separated + W = np.ones(M.shape) + for (i, c) in enumerate(classes): + + majs = np.sum(transp[indices_labels[i]], axis=0) + majs = p * ((majs + epsilon)**(p - 1)) + W[indices_labels[i]] = majs + + if to_numpy: + return utils.to_np(transp) + else: + return transp diff --git a/ot/gpu/utils.py b/ot/gpu/utils.py new file mode 100644 index 0000000..41e168a --- /dev/null +++ b/ot/gpu/utils.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +""" +Utility functions for GPU +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Nicolas Courty <ncourty@irisa.fr> +# Leo Gautheron <https://github.com/aje> +# +# License: MIT License + +import cupy as np # np used for matrix computation +import cupy as cp # cp used for cupy specific operations + + +def euclidean_distances(a, b, squared=False, to_numpy=True): + """ + Compute the pairwise euclidean distance between matrices a and b. + + If the input matrix are in numpy format, they will be uploaded to the + GPU first which can incur significant time overhead. + + Parameters + ---------- + a : np.ndarray (n, f) + first matrix + b : np.ndarray (m, f) + second matrix + to_numpy : boolean, optional (default True) + If true convert back the GPU array result to numpy format. + squared : boolean, optional (default False) + if True, return squared euclidean distance matrix + + Returns + ------- + c : (n x m) np.ndarray or cupy.ndarray + pairwise euclidean distance distance matrix + """ + + a, b = to_gpu(a, b) + + a2 = np.sum(np.square(a), 1) + b2 = np.sum(np.square(b), 1) + + c = -2 * np.dot(a, b.T) + c += a2[:, None] + c += b2[None, :] + + if not squared: + np.sqrt(c, out=c) + if to_numpy: + return to_np(c) + else: + return c + + +def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True): + """Compute distance between samples in x1 and x2 on gpu + + Parameters + ---------- + + x1 : np.array (n1,d) + matrix with n1 samples of size d + x2 : np.array (n2,d), optional + matrix with n2 samples of size d (if None then x2=x1) + metric : str + Metric from 'sqeuclidean', 'euclidean', + + + Returns + ------- + + M : np.array (n1,n2) + distance matrix computed with given metric + + """ + if x2 is None: + x2 = x1 + if metric == "sqeuclidean": + return euclidean_distances(x1, x2, squared=True, to_numpy=to_numpy) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False, to_numpy=to_numpy) + else: + raise NotImplementedError + + +def to_gpu(*args): + """ Upload numpy arrays to GPU and return them""" + if len(args) > 1: + return (cp.asarray(x) for x in args) + else: + return cp.asarray(args[0]) + + +def to_np(*args): + """ convert GPU arras to numpy and return them""" + if len(args) > 1: + return (cp.asnumpy(x) for x in args) + else: + return cp.asnumpy(args[0]) diff --git a/ot/gromov.py b/ot/gromov.py new file mode 100644 index 0000000..699ae4c --- /dev/null +++ b/ot/gromov.py @@ -0,0 +1,1179 @@ +# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein transport method
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from .bregman import sinkhorn
+from .utils import dist, UndefinedParameter
+from .optim import cg
+
+
+def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
+ """Return loss matrices and tensors for Gromov-Wasserstein fast computation
+
+ Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
+ function as the loss function of Gromow-Wasserstein discrepancy.
+
+ The matrices are computed as described in Proposition 1 in [12]
+
+ Where :
+ * C1 : Metric cost matrix in the source space
+ * C2 : Metric cost matrix in the target space
+ * T : A coupling between those two spaces
+
+ The square-loss function L(a,b)=|a-b|^2 is read as :
+ L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
+ * f1(a)=(a^2)
+ * f2(b)=(b^2)
+ * h1(a)=a
+ * h2(b)=2*b
+
+ The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
+ L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
+ * f1(a)=a*log(a)-a
+ * f2(b)=b
+ * h1(a)=a
+ * h2(b)=log(b)
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ T : ndarray, shape (ns, nt)
+ Coupling between source and target spaces
+ p : ndarray, shape (ns,)
+
+ Returns
+ -------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ if loss_fun == 'square_loss':
+ def f1(a):
+ return (a**2)
+
+ def f2(b):
+ return (b**2)
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return 2 * b
+ elif loss_fun == 'kl_loss':
+ def f1(a):
+ return a * np.log(a + 1e-15) - a
+
+ def f2(b):
+ return b
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return np.log(b + 1e-15)
+
+ constC1 = np.dot(np.dot(f1(C1), p.reshape(-1, 1)),
+ np.ones(len(q)).reshape(1, -1))
+ constC2 = np.dot(np.ones(len(p)).reshape(-1, 1),
+ np.dot(q.reshape(1, -1), f2(C2).T))
+ constC = constC1 + constC2
+ hC1 = h1(C1)
+ hC2 = h2(C2)
+
+ return constC, hC1, hC2
+
+
+def tensor_product(constC, hC1, hC2, T):
+ """Return the tensor for Gromov-Wasserstein fast computation
+
+ The tensor is computed as described in Proposition 1 Eq. (6) in [12].
+
+ Parameters
+ ----------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+
+ Returns
+ -------
+ tens : ndarray, shape (ns, nt)
+ \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ A = -np.dot(hC1, T).dot(hC2.T)
+ tens = constC + A
+ # tens -= tens.min()
+ return tens
+
+
+def gwloss(constC, hC1, hC2, T):
+ """Return the Loss for Gromov-Wasserstein
+
+ The loss is computed as described in Proposition 1 Eq. (6) in [12].
+
+ Parameters
+ ----------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+ T : ndarray, shape (ns, nt)
+ Current value of transport matrix T
+
+ Returns
+ -------
+ loss : float
+ Gromov Wasserstein loss
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ tens = tensor_product(constC, hC1, hC2, T)
+
+ return np.sum(tens * T)
+
+
+def gwggrad(constC, hC1, hC2, T):
+ """Return the gradient for Gromov-Wasserstein
+
+ The gradient is computed as described in Proposition 2 in [12].
+
+ Parameters
+ ----------
+ constC : ndarray, shape (ns, nt)
+ Constant C matrix in Eq. (6)
+ hC1 : ndarray, shape (ns, ns)
+ h1(C1) matrix in Eq. (6)
+ hC2 : ndarray, shape (nt, nt)
+ h2(C) matrix in Eq. (6)
+ T : ndarray, shape (ns, nt)
+ Current value of transport matrix T
+
+ Returns
+ -------
+ grad : ndarray, shape (ns, nt)
+ Gromov Wasserstein gradient
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ return 2 * tensor_product(constC, hC1, hC2,
+ T) # [12] Prop. 2 misses a 2 factor
+
+
+def update_square_loss(p, lambdas, T, Cs):
+ """
+ Updates C according to the L2 Loss kernel with the S Ts couplings
+ calculated at each iteration
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ Masses in the targeted barycenter.
+ lambdas : list of float
+ List of the S spaces' weights.
+ T : list of S np.ndarray of shape (ns,N)
+ The S Ts couplings calculated at each iteration.
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices.
+
+ Returns
+ ----------
+ C : ndarray, shape (nt, nt)
+ Updated C matrix.
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
+ for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.divide(tmpsum, ppt)
+
+
+def update_kl_loss(p, lambdas, T, Cs):
+ """
+ Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration
+
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ Weights in the targeted barycenter.
+ lambdas : list of the S spaces' weights
+ T : list of S np.ndarray of shape (ns,N)
+ The S Ts couplings calculated at each iteration.
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices.
+
+ Returns
+ ----------
+ C : ndarray, shape (ns,ns)
+ updated C matrix
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
+ for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.exp(np.divide(tmpsum, ppt))
+
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
+ """
+ Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ Where :
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+ - H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ T : ndarray, shape (ns, nt)
+ Doupling between the two spaces that minimizes:
+ \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ log : dict
+ Convergence information and loss.
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ if log:
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ log['gw_dist'] = gwloss(constC, hC1, hC2, res)
+ return res, log
+ else:
+ return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW transport between two graphs see [24]
+
+ .. math::
+ \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
+ L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as discussed in [24]_
+
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str, optional
+ Loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gamma : ndarray, shape (ns, nt)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ if log:
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ log['fgw_dist'] = log['loss'][::-1][0]
+ return res, log
+ else:
+ return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW distance between two graphs see [24]
+
+ .. math::
+ \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
+ L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space.
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space.
+ p : ndarray, shape (ns,)
+ Distribution in the source space.
+ q : ndarray, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str, optional
+ Loss function used for the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research.
+ Else closed form is used. If there is convergence issues use False.
+ **kwargs : dict
+ Parameters can be directly pased to the ot.optim.cg solver.
+
+ Returns
+ -------
+ gamma : ndarray, shape (ns, nt)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ if log:
+ log['fgw_dist'] = log['loss'][::-1][0]
+ log['T'] = res
+ return log['fgw_dist'], log
+ else:
+ return log['fgw_dist']
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
+ """
+ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ Where :
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+ - H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space.
+ q : ndarray, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+ log : dict
+ convergence information and Coupling marix
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ log['gw_dist'] = gwloss(constC, hC1, hC2, res)
+ log['T'] = res
+ if log:
+ return log['gw_dist'], log
+ else:
+ return log['gw_dist']
+
+
+def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ """
+ Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+
+ (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ s.t. T 1 = p
+
+ T^T 1= q
+
+ T\geq 0
+
+ Where :
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+ - H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : string
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+
+ Returns
+ -------
+ T : ndarray, shape (ns, nt)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+
+ T = np.outer(p, q) # Initialization
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ cpt = 0
+ err = 1
+
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < max_iter):
+
+ Tprev = T
+
+ # compute the gradient
+ tens = gwggrad(constC, hC1, hC2, T)
+
+ T = sinkhorn(p, q, tens, epsilon)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(T - Tprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ log['gw_dist'] = gwloss(constC, hC1, hC2, T)
+ return T, log
+ else:
+ return T
+
+
+def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ """
+ Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices
+
+ (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ Where :
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+ - H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ gw, logv = entropic_gromov_wasserstein(
+ C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
+
+ logv['T'] = gw
+
+ if log:
+ return logv['gw_dist'], logv
+ else:
+ return logv['gw_dist']
+
+
+def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
+ """
+ Returns the gromov-wasserstein barycenters of S measured similarity matrices
+
+ (Cs)_{s=1}^{s=S}
+
+ The function solves the following optimization problem:
+
+ .. math::
+ C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s)
+
+
+ Where :
+
+ - :math:`C_s` : metric cost matrix
+ - :math:`p_s` : distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray of shape (ns,ns)
+ Metric cost matrices
+ ps : list of S np.ndarray of shape (ns,)
+ Sample weights in the S spaces
+ p : ndarray, shape(N,)
+ Weights in the targeted barycenter
+ lambdas : list of float
+ List of the S spaces' weights.
+ loss_fun : callable
+ Tensor-matrix multiplication function based on specific loss function.
+ update : callable
+ function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
+ epsilon : float
+ Regularization term >0
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | ndarray, shape (N, N)
+ Random initial value for the C matrix provided by user.
+
+ Returns
+ -------
+ C : ndarray, shape (N, N)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+ """
+
+ S = len(Cs)
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ # XXX use random state
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ else:
+ C = init_C
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while (err > tol) and (cpt < max_iter):
+ Cprev = C
+
+ T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
+ max_iter, 1e-5, verbose, log) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(C - Cprev)
+ error.append(err)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ return C
+
+
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
+ """
+ Returns the gromov-wasserstein barycenters of S measured similarity matrices
+
+ (Cs)_{s=1}^{s=S}
+
+ The function solves the following optimization problem with block
+ coordinate descent:
+
+ .. math::
+ C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+
+ Where :
+
+ - Cs : metric cost matrix
+ - ps : distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray of shape (ns, ns)
+ Metric cost matrices
+ ps : list of S np.ndarray of shape (ns,)
+ Sample weights in the S spaces
+ p : ndarray, shape (N,)
+ Weights in the targeted barycenter
+ lambdas : list of float
+ List of the S spaces' weights
+ loss_fun : tensor-matrix multiplication function based on specific loss function
+ update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0).
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | ndarray, shape(N,N)
+ Random initial value for the C matrix provided by user.
+
+ Returns
+ -------
+ C : ndarray, shape (N, N)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ S = len(Cs)
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ # XXX : should use a random state and not use the global seed
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ else:
+ C = init_C
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while(err > tol and cpt < max_iter):
+ Cprev = C
+
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(C - Cprev)
+ error.append(err)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ return C
+
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
+ verbose=False, log=False, init_C=None, init_X=None):
+ """Compute the fgw barycenter as presented eq (5) in [24].
+
+ Parameters
+ ----------
+ N : integer
+ Desired number of samples of the target barycenter
+ Ys: list of ndarray, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of ndarray, each element has shape (ns,ns)
+ Structure matrices of all samples
+ ps : list of ndarray, each element has shape (ns,)
+ Masses of all samples.
+ lambdas : list of float
+ List of the S spaces' weights
+ alpha : float
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Whether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Whether to fix the feature of the barycenter during the updates
+ init_C : ndarray, shape (N,N), optional
+ Initialization for the barycenters' structure matrix. If not set
+ a random init is used.
+ init_X : ndarray, shape (N,d), optional
+ Initialization for the barycenters' features. If not set a
+ random init is used.
+
+ Returns
+ -------
+ X : ndarray, shape (N, d)
+ Barycenters' features
+ C : ndarray, shape (N, N)
+ Barycenters' structure matrix
+ log_: dict
+ Only returned when log=True. It contains the keys:
+ T : list of (N,ns) transport matrices
+ Ms : all distance matrices between the feature of the barycenter and the
+ other features dist(X,Ys) shape (N,ns)
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ S = len(Cs)
+ d = Ys[0].shape[1] # dimension on the node features
+ if p is None:
+ p = np.ones(N) / N
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
+
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ if fixed_structure:
+ if init_C is None:
+ raise UndefinedParameter('If C is fixed it must be initialized')
+ else:
+ C = init_C
+ else:
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ else:
+ C = init_C
+
+ if fixed_features:
+ if init_X is None:
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
+ else:
+ if init_X is None:
+ X = np.zeros((N, d))
+ else:
+ X = init_X
+
+ T = [np.outer(p, q) for q in ps]
+
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
+
+ cpt = 0
+ err_feature = 1
+ err_structure = 1
+
+ if log:
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
+
+ while((err_feature > tol or err_structure > tol) and cpt < max_iter):
+ Cprev = C
+ Xprev = X
+
+ if not fixed_features:
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
+
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+
+ if not fixed_structure:
+ if loss_fun == 'square_loss':
+ T_temp = [t.T for t in T]
+ C = update_sructure_matrix(p, lambdas, T_temp, Cs)
+
+ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+
+ # T is N,ns
+ err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
+ err_structure = np.linalg.norm(C - Cprev)
+
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms
+
+ if log:
+ return X, C, log_
+ else:
+ return X, C
+
+
+def update_sructure_matrix(p, lambdas, T, Cs):
+ """Updates C according to the L2 Loss kernel with the S Ts couplings.
+
+ It is calculated at each iteration
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ Masses in the targeted barycenter.
+ lambdas : list of float
+ List of the S spaces' weights.
+ T : list of S ndarray of shape (ns, N)
+ The S Ts couplings calculated at each iteration.
+ Cs : list of S ndarray, shape (ns, ns)
+ Metric cost matrices.
+
+ Returns
+ -------
+ C : ndarray, shape (nt, nt)
+ Updated C matrix.
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.divide(tmpsum, ppt)
+
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ """Updates the feature with respect to the S Ts couplings.
+
+
+ See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
+ in [24] calculated at each iteration
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ List of the S spaces' weights
+ Ts : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Ys : list of S ndarray, shape(d,ns)
+ The features.
+
+ Returns
+ -------
+ X : ndarray, shape (d, N)
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ p = np.array(1. / p).reshape(-1,)
+
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))])
+
+ return tmpsum
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h new file mode 100644 index 0000000..f42e222 --- /dev/null +++ b/ot/lp/EMD.h @@ -0,0 +1,35 @@ +/* This file is a c++ wrapper function for computing the transportation cost + * between two vectors given a cost matrix. + * + * It was written by Antoine Rolet (2014) and mainly consists of a wrapper + * of the code written by Nicolas Bonneel available on this page + * http://people.seas.harvard.edu/~nbonneel/FastTransport/ + * + * It was then modified to make it more amenable to python inline calling + * + * Please give relevant credit to the original author (Nicolas Bonneel) if + * you use this code for a publication. + * + */ + + +#ifndef EMD_H +#define EMD_H + +#include <iostream> +#include <vector> +#include "network_simplex_simple.h" + +using namespace lemon; +typedef unsigned int node_id_type; + +enum ProblemType { + INFEASIBLE, + OPTIMAL, + UNBOUNDED, + MAX_ITER_REACHED +}; + +int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); + +#endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp new file mode 100644 index 0000000..fc7ca63 --- /dev/null +++ b/ot/lp/EMD_wrapper.cpp @@ -0,0 +1,107 @@ +/* This file is a c++ wrapper function for computing the transportation cost + * between two vectors given a cost matrix. + * + * It was written by Antoine Rolet (2014) and mainly consists of a wrapper + * of the code written by Nicolas Bonneel available on this page + * http://people.seas.harvard.edu/~nbonneel/FastTransport/ + * + * It was then modified to make it more amenable to python inline calling + * + * Please give relevant credit to the original author (Nicolas Bonneel) if + * you use this code for a publication. + * + */ + +#include "EMD.h" + + +int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, int maxIter) { +// beware M and C anre strored in row major C style!!! + int n, m, i, cur; + + typedef FullBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(FullBipartiteDigraph); + + // Get the number of non zero coordinates for r and c + n=0; + for (int i=0; i<n1; i++) { + double val=*(X+i); + if (val>0) { + n++; + }else if(val<0){ + return INFEASIBLE; + } + } + m=0; + for (int i=0; i<n2; i++) { + double val=*(Y+i); + if (val>0) { + m++; + }else if(val<0){ + return INFEASIBLE; + } + } + + // Define the graph + + std::vector<int> indI(n), indJ(m); + std::vector<double> weights1(n), weights2(m); + Digraph di(n, m); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); + + // Set supply and demand, don't account for 0 values (faster) + + cur=0; + for (int i=0; i<n1; i++) { + double val=*(X+i); + if (val>0) { + weights1[ cur ] = val; + indI[cur++]=i; + } + } + + // Demand is actually negative supply... + + cur=0; + for (int i=0; i<n2; i++) { + double val=*(Y+i); + if (val>0) { + weights2[ cur ] = -val; + indJ[cur++]=i; + } + } + + + net.supplyMap(&weights1[0], n, &weights2[0], m); + + // Set the cost of each edge + for (int i=0; i<n; i++) { + for (int j=0; j<m; j++) { + double val=*(D+indI[i]*n2+indJ[j]); + net.setCost(di.arcFromId(i*m+j), val); + } + } + + + // Solve the problem with the network simplex algorithm + + int ret=net.run(); + if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { + *cost = 0; + Arc a; di.first(a); + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a); + double flow = net.flow(a); + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + *(G+indI[i]*n2+indJ[j-n]) = flow; + *(alpha + indI[i]) = -net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); + } + + } + + + return ret; +} diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py new file mode 100644 index 0000000..0c92810 --- /dev/null +++ b/ot/lp/__init__.py @@ -0,0 +1,618 @@ +# -*- coding: utf-8 -*- +""" +Solvers for the original linear program OT problem + + + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import multiprocessing +import sys +import numpy as np +from scipy.sparse import coo_matrix + +from .import cvx + +# import compiled emd +from .emd_wrap import emd_c, check_result, emd_1d_sorted +from ..utils import parmap +from .cvx import barycenter +from ..utils import dist + +__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', + 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + + +def emd(a, b, M, numItermax=100000, log=False): + r"""Solves the Earth Movers distance problem and returns the OT matrix + + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 + where : + + - M is the metric cost matrix + - a and b are the sample weights + + .. warning:: + Note that the M matrix needs to be a C-order numpy.array in float64 + format. + + Uses the algorithm proposed in [1]_ + + Parameters + ---------- + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + M : (ns,nt) numpy.ndarray, float64 + Loss matrix (c-order array with type float64) + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual + variables. Otherwise returns only the optimal transportation matrix. + + Returns + ------- + gamma: (ns x nt) numpy.ndarray + Optimal transportation matrix for the given parameters + log: dict + If input log is true, a dictionary containing the cost and dual + variables and exit status + + + Examples + -------- + + Simple example with obvious solution. The function emd accepts lists and + perform automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.emd(a,b,M) + array([[0.5, 0. ], + [0. , 0.5]]) + + References + ---------- + + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT""" + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + # if empty array given then use uniform distributions + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + result_code_string = check_result(result_code) + if log: + log = {} + log['cost'] = cost + log['u'] = u + log['v'] = v + log['warning'] = result_code_string + log['result_code'] = result_code + return G, log + return G + + +def emd2(a, b, M, processes=multiprocessing.cpu_count(), + numItermax=100000, log=False, return_matrix=False): + r"""Solves the Earth Movers distance problem and returns the loss + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 + where : + + - M is the metric cost matrix + - a and b are the sample weights + + .. warning:: + Note that the M matrix needs to be a C-order numpy.array in float64 + format. + + Uses the algorithm proposed in [1]_ + + Parameters + ---------- + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + M : (ns,nt) numpy.ndarray, float64 + Loss matrix (c-order array with type float64) + processes : int, optional (default=nb cpu) + Nb of processes used for multiple emd computation (not used on windows) + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost and dual + variables. Otherwise returns only the optimal transportation cost. + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log. + + Returns + ------- + gamma: (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log: dictnp + If input log is true, a dictionary containing the cost and dual + variables and exit status + + + Examples + -------- + + Simple example with obvious solution. The function emd accepts lists and + perform automatic conversion to numpy arrays + + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.emd2(a,b,M) + 0.0 + + References + ---------- + + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT""" + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + # problem with pikling Forks + if sys.platform.endswith('win32'): + processes=1 + + # if empty array given then use uniform distributions + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + if log or return_matrix: + def f(b): + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + result_code_string = check_result(resultCode) + log = {} + if return_matrix: + log['G'] = G + log['u'] = u + log['v'] = v + log['warning'] = result_code_string + log['result_code'] = resultCode + return [cost, log] + else: + def f(b): + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + check_result(result_code) + return cost + + if len(b.shape) == 1: + return f(b) + nb = b.shape[1] + + if processes>1: + res = parmap(f, [b[:, i] for i in range(nb)], processes) + else: + res = list(map(f, [b[:, i].copy() for i in range(nb)])) + + return res + + + +def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): + """ + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) + + The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. + This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. + + Parameters + ---------- + measures_locations : list of (k_i,d) numpy.ndarray + The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) + measures_weights : list of (k_i,) numpy.ndarray + Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure + + X_init : (k,d) np.ndarray + Initialization of the support locations (on k atoms) of the barycenter + b : (k,) np.ndarray + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (k,) np.ndarray + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) np.ndarray + Support locations (on k atoms) of the barycenter + + References + ---------- + + .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = np.ones((k,))/k + if weights is None: + weights = np.ones((N,)) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1. + + while ( displacement_square_norm > stopThr and iter_count < numItermax ): + + T_sum = np.zeros((k, d)) + + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): + + M_i = dist(X, measure_locations_i) + T_i = emd(b, measure_weights_i, M_i) + T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) + + displacement_square_norm = np.sum(np.square(T_sum-X)) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + + iter_count += 1 + + if log: + log_dict['displacement_square_norms'] = displacement_square_norms + return X, log_dict + else: + return X + + +def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the OT matrix + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost. + Otherwise returns only the optimal transportation matrix. + + Returns + ------- + gamma: (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log: dict + If input log is True, a dictionary containing the cost + + + Examples + -------- + + Simple example with obvious solution. The function emd_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd_1d(x_a, x_b, a, b) + array([[0. , 0.5], + [0.5, 0. ]]) + >>> ot.emd_1d(x_a, x_b) + array([[0. , 0.5], + [0.5, 0. ]]) + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd : EMD for multidimensional distributions + ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the + transportation matrix) + """ + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + x_a = np.asarray(x_a, dtype=np.float64) + x_b = np.asarray(x_b, dtype=np.float64) + + assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + + # if empty array given then use uniform distributions + if a.ndim == 0 or len(a) == 0: + a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0] + if b.ndim == 0 or len(b) == 0: + b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + + x_a_1d = x_a.reshape((-1, )) + x_b_1d = x_b.reshape((-1, )) + perm_a = np.argsort(x_a_1d) + perm_b = np.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted(a, b, + x_a_1d[perm_a], x_b_1d[perm_b], + metric=metric, p=p) + G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), + shape=(a.shape[0], b.shape[0])) + if dense: + G = G.toarray() + if log: + log = {'cost': cost} + return G, log + return G + + +def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the loss + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Only used if log is set to True. Due to implementation details, + this function runs faster when dense is set to False. + log: boolean, optional (default=False) + If True, returns a dictionary containing the transportation matrix. + Otherwise returns only the loss. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict + If input log is True, a dictionary containing the Optimal transportation + matrix for the given parameters + + + Examples + -------- + + Simple example with obvious solution. The function emd2_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd2_1d(x_a, x_b, a, b) + 0.5 + >>> ot.emd2_1d(x_a, x_b) + 0.5 + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd2 : EMD for multidimensional distributions + ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix + instead of the cost) + """ + # If we do not return G (log==False), then we should not to cast it to dense + # (useless overhead) + G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, + dense=dense and log, log=True) + cost = log_emd['cost'] + if log: + log_emd = {'G': G} + return cost, log_emd + return cost + + +def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.): + r"""Solves the p-Wasserstein distance problem between 1d measures and returns + the distance + + .. math:: + \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p} + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + + where : + + - x_a and x_b are the samples + - a and b are the sample weights + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + p: float, optional (default=1.0) + The order of the p-Wasserstein distance to be computed + + Returns + ------- + dist: float + p-Wasserstein distance + + + Examples + -------- + + Simple example with obvious solution. The function wasserstein_1d accepts + lists and performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.wasserstein_1d(x_a, x_b, a, b) + 0.5 + >>> ot.wasserstein_1d(x_a, x_b) + 0.5 + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd_1d : EMD for 1d distributions + """ + cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, + dense=False, log=False) + return np.power(cost_emd, 1. / p) diff --git a/ot/lp/core.h b/ot/lp/core.h new file mode 100644 index 0000000..04dddf7 --- /dev/null +++ b/ot/lp/core.h @@ -0,0 +1,103 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- + * + * This file has been adapted by Nicolas Bonneel (2013), + * from full_graph.h from LEMON, a generic C++ optimization library, + * to make the other files independant from the rest of + * the original library. + * + * + **** Original file Copyright Notice : + * Copyright (C) 2003-2010 + * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport + * (Egervary Research Group on Combinatorial Optimization, EGRES). + * + * Permission to use, modify and distribute this software is granted + * provided that this copyright notice appears in all copies. For + * precise terms see the accompanying LICENSE file. + * + * This software is provided "AS IS" with no warranty of any kind, + * express or implied, and with no claim as to its suitability for any + * purpose. + * + */ + +#ifndef LEMON_CORE_H +#define LEMON_CORE_H + +#include <vector> +#include <algorithm> + + +// Disable the following warnings when compiling with MSVC: +// C4250: 'class1' : inherits 'class2::member' via dominance +// C4355: 'this' : used in base member initializer list +// C4503: 'function' : decorated name length exceeded, name was truncated +// C4800: 'type' : forcing value to bool 'true' or 'false' (performance warning) +// C4996: 'function': was declared deprecated +#ifdef _MSC_VER +#pragma warning( disable : 4250 4355 4503 4800 4996 ) +#endif + +///\file +///\brief LEMON core utilities. +/// +///This header file contains core utilities for LEMON. +///It is automatically included by all graph types, therefore it usually +///do not have to be included directly. + +namespace lemon { + + /// \brief Dummy type to make it easier to create invalid iterators. + /// + /// Dummy type to make it easier to create invalid iterators. + /// See \ref INVALID for the usage. + struct Invalid { + public: + bool operator==(Invalid) { return true; } + bool operator!=(Invalid) { return false; } + bool operator< (Invalid) { return false; } + }; + + /// \brief Invalid iterators. + /// + /// \ref Invalid is a global type that converts to each iterator + /// in such a way that the value of the target iterator will be invalid. +#ifdef LEMON_ONLY_TEMPLATES + const Invalid INVALID = Invalid(); +#else + extern const Invalid INVALID; +#endif + + /// \addtogroup gutils + /// @{ + + ///Create convenience typedefs for the digraph types and iterators + + ///This \c \#define creates convenient type definitions for the following + ///types of \c Digraph: \c Node, \c NodeIt, \c Arc, \c ArcIt, \c InArcIt, + ///\c OutArcIt, \c BoolNodeMap, \c IntNodeMap, \c DoubleNodeMap, + ///\c BoolArcMap, \c IntArcMap, \c DoubleArcMap. + /// + ///\note If the graph type is a dependent type, ie. the graph type depend + ///on a template parameter, then use \c TEMPLATE_DIGRAPH_TYPEDEFS() + ///macro. +#define DIGRAPH_TYPEDEFS(Digraph) \ + typedef Digraph::Node Node; \ + typedef Digraph::Arc Arc; \ + + + ///Create convenience typedefs for the digraph types and iterators + + ///\see DIGRAPH_TYPEDEFS + /// + ///\note Use this macro, if the graph type is a dependent type, + ///ie. the graph type depend on a template parameter. +#define TEMPLATE_DIGRAPH_TYPEDEFS(Digraph) \ + typedef typename Digraph::Node Node; \ + typedef typename Digraph::Arc Arc; \ + + + +} //namespace lemon + +#endif diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py new file mode 100644 index 0000000..8e763be --- /dev/null +++ b/ot/lp/cvx.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +""" +LP solvers for optimal transport using cvxopt +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import scipy as sp +import scipy.sparse as sps + + +try: + import cvxopt + from cvxopt import solvers, matrix, spmatrix +except ImportError: + cvxopt = False + + +def scipy_sparse_to_spmatrix(A): + """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" + coo = A.tocoo() + SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) + return SP + + +def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): + """Compute the Wasserstein barycenter of distributions A + + The function solves the following optimization problem [16]: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + + The linear program is solved using the interior point solver from scipy.optimize. + If cvxopt solver if installed it can use cvxopt + + Note that this problem do not scale well (both in memory and computational time). + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions a_i of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + solver : string, optional + the solver used, default 'interior-point' use the lp solver from + scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. + + Returns + ------- + a : (d,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. + + + + """ + + if weights is None: + weights = np.ones(A.shape[1]) / A.shape[1] + else: + assert(len(weights) == A.shape[1]) + + n_distributions = A.shape[1] + n = A.shape[0] + + n2 = n * n + c = np.zeros((0)) + b_eq1 = np.zeros((0)) + for i in range(n_distributions): + c = np.concatenate((c, M.ravel() * weights[i])) + b_eq1 = np.concatenate((b_eq1, A[:, i])) + c = np.concatenate((c, np.zeros(n))) + + lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] + # row constraints + A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))) + + # columns constraints + lst_idiag2 = [] + lst_eye = [] + for i in range(n_distributions): + if i == 0: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) + lst_eye.append(-sps.eye(n)) + else: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) + lst_eye.append(-sps.eye(n - 1, n)) + + A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) + b_eq2 = np.zeros((A_eq2.shape[0])) + + # full problem + A_eq = sps.vstack((A_eq1, A_eq2)) + b_eq = np.concatenate((b_eq1, b_eq2)) + + if not cvxopt or solver in ['interior-point']: + # cvxopt not installed or interior point + + if solver is None: + solver = 'interior-point' + + options = {'sparse': True, 'disp': verbose} + sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, + options=options) + x = sol.x + b = x[-n:] + + else: + + h = np.zeros((n_distributions * n2 + n)) + G = -sps.eye(n_distributions * n2 + n) + + sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), + solver=solver) + + x = np.array(sol['x']) + b = x[-n:].ravel() + + if log: + return b, sol + else: + return b diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx new file mode 100644 index 0000000..2b6c495 --- /dev/null +++ b/ot/lp/emd_wrap.pyx @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +""" +Cython linker with C solver +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +cimport numpy as np + +from ..utils import dist + +cimport cython +cimport libc.math as math + +import warnings + + +cdef extern from "EMD.h": + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) + cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED + + +def check_result(result_code): + if result_code == OPTIMAL: + return None + + if result_code == INFEASIBLE: + message = "Problem infeasible. Check that a and b are in the simplex" + elif result_code == UNBOUNDED: + message = "Problem unbounded" + elif result_code == MAX_ITER_REACHED: + message = "numItermax reached before optimality. Try to increase numItermax." + warnings.warn(message) + return message + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): + """ + Solves the Earth Movers distance problem and returns the optimal transport matrix + + gamm=emd(a,b,M) + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the metric cost matrix + - a and b are the sample weights + + .. warning:: + Note that the M matrix needs to be a C-order :py.cls:`numpy.array` + + Parameters + ---------- + a : (ns,) numpy.ndarray, float64 + source histogram + b : (nt,) numpy.ndarray, float64 + target histogram + M : (ns,nt) numpy.ndarray, float64 + loss matrix + max_iter : int + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + + + Returns + ------- + gamma: (ns x nt) numpy.ndarray + Optimal transportation matrix for the given parameters + + """ + cdef int n1= M.shape[0] + cdef int n2= M.shape[1] + + cdef double cost=0 + cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) + + + if not len(a): + a=np.ones((n1,))/n1 + + if not len(b): + b=np.ones((n2,))/n2 + + # calling the function + cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + + return G, cost, alpha, beta, result_code + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, + np.ndarray[double, ndim=1, mode="c"] v_weights, + np.ndarray[double, ndim=1, mode="c"] u, + np.ndarray[double, ndim=1, mode="c"] v, + str metric='sqeuclidean', + double p=1.): + r""" + Solves the Earth Movers distance problem between sorted 1d measures and + returns the OT matrix and the associated cost + + Parameters + ---------- + u_weights : (ns,) ndarray, float64 + Source histogram + v_weights : (nt,) ndarray, float64 + Target histogram + u : (ns,) ndarray, float64 + Source dirac locations (on the real line) + v : (nt,) ndarray, float64 + Target dirac locations (on the real line) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + + Returns + ------- + gamma: (n, ) ndarray, float64 + Values in the Optimal transportation matrix + indices: (n, 2) ndarray, int64 + Indices of the values stored in gamma for the Optimal transportation + matrix + cost + cost associated to the optimal transportation + """ + cdef double cost = 0. + cdef int n = u_weights.shape[0] + cdef int m = v_weights.shape[0] + + cdef int i = 0 + cdef double w_i = u_weights[0] + cdef int j = 0 + cdef double w_j = v_weights[0] + + cdef double m_ij = 0. + + cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ), + dtype=np.float64) + cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), + dtype=np.int) + cdef int cur_idx = 0 + while i < n and j < m: + if metric == 'sqeuclidean': + m_ij = (u[i] - v[j]) * (u[i] - v[j]) + elif metric == 'cityblock' or metric == 'euclidean': + m_ij = math.fabs(u[i] - v[j]) + elif metric == 'minkowski': + m_ij = math.pow(math.fabs(u[i] - v[j]), p) + else: + m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), + metric=metric)[0, 0] + if w_i < w_j or j == m - 1: + cost += m_ij * w_i + G[cur_idx] = w_i + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j + i += 1 + w_j -= w_i + w_i = u_weights[i] + else: + cost += m_ij * w_j + G[cur_idx] = w_j + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j + j += 1 + w_i -= w_j + w_j = v_weights[j] + cur_idx += 1 + return G[:cur_idx], indices[:cur_idx], cost diff --git a/ot/lp/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h new file mode 100644 index 0000000..87a1bec --- /dev/null +++ b/ot/lp/full_bipartitegraph.h @@ -0,0 +1,215 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- + * + * This file has been adapted by Nicolas Bonneel (2013), + * from full_graph.h from LEMON, a generic C++ optimization library, + * to implement a lightweight fully connected bipartite graph. A previous + * version of this file is used as part of the Displacement Interpolation + * project, + * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/ + * + * + **** Original file Copyright Notice : + * Copyright (C) 2003-2010 + * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport + * (Egervary Research Group on Combinatorial Optimization, EGRES). + * + * Permission to use, modify and distribute this software is granted + * provided that this copyright notice appears in all copies. For + * precise terms see the accompanying LICENSE file. + * + * This software is provided "AS IS" with no warranty of any kind, + * express or implied, and with no claim as to its suitability for any + * purpose. + * + */ + +#ifndef LEMON_FULL_BIPARTITE_GRAPH_H +#define LEMON_FULL_BIPARTITE_GRAPH_H + +#include "core.h" + +///\ingroup graphs +///\file +///\brief FullBipartiteDigraph and FullBipartiteGraph classes. + + +namespace lemon { + + + class FullBipartiteDigraphBase { + public: + + typedef FullBipartiteDigraphBase Digraph; + + //class Node; + typedef int Node; + //class Arc; + typedef long long Arc; + + protected: + + int _node_num; + long long _arc_num; + + FullBipartiteDigraphBase() {} + + void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = n1 * n2; _n1=n1; _n2=n2;} + + public: + + int _n1, _n2; + + + Node operator()(int ix) const { return Node(ix); } + static int index(const Node& node) { return node; } + + Arc arc(const Node& s, const Node& t) const { + if (s<_n1 && t>=_n1) + return Arc(s * _n2 + (t-_n1) ); + else + return Arc(-1); + } + + int nodeNum() const { return _node_num; } + long long arcNum() const { return _arc_num; } + + int maxNodeId() const { return _node_num - 1; } + long long maxArcId() const { return _arc_num - 1; } + + Node source(Arc arc) const { return arc / _n2; } + Node target(Arc arc) const { return (arc % _n2) + _n1; } + + static int id(Node node) { return node; } + static long long id(Arc arc) { return arc; } + + static Node nodeFromId(int id) { return Node(id);} + static Arc arcFromId(int id) { return Arc(id);} + + + Arc findArc(Node s, Node t, Arc prev = -1) const { + return prev == -1 ? arc(s, t) : -1; + } + + void first(Node& node) const { + node = _node_num - 1; + } + + static void next(Node& node) { + --node; + } + + void first(Arc& arc) const { + arc = _arc_num - 1; + } + + static void next(Arc& arc) { + --arc; + } + + void firstOut(Arc& arc, const Node& node) const { + if (node>=_n1) + arc = -1; + else + arc = (node + 1) * _n2 - 1; + } + + void nextOut(Arc& arc) const { + if (arc % _n2 == 0) arc = 0; + --arc; + } + + void firstIn(Arc& arc, const Node& node) const { + if (node<_n1) + arc = -1; + else + arc = _arc_num + node - _node_num; + } + + void nextIn(Arc& arc) const { + arc -= _n2; + if (arc < 0) arc = -1; + } + + }; + + /// \ingroup graphs + /// + /// \brief A directed full graph class. + /// + /// FullBipartiteDigraph is a simple and fast implmenetation of directed full + /// (complete) graphs. It contains an arc from each node to each node + /// (including a loop for each node), therefore the number of arcs + /// is the square of the number of nodes. + /// This class is completely static and it needs constant memory space. + /// Thus you can neither add nor delete nodes or arcs, however + /// the structure can be resized using resize(). + /// + /// This type fully conforms to the \ref concepts::Digraph "Digraph concept". + /// Most of its member functions and nested classes are documented + /// only in the concept class. + /// + /// This class provides constant time counting for nodes and arcs. + /// + /// \note FullBipartiteDigraph and FullBipartiteGraph classes are very similar, + /// but there are two differences. While this class conforms only + /// to the \ref concepts::Digraph "Digraph" concept, FullBipartiteGraph + /// conforms to the \ref concepts::Graph "Graph" concept, + /// moreover FullBipartiteGraph does not contain a loop for each + /// node as this class does. + /// + /// \sa FullBipartiteGraph + class FullBipartiteDigraph : public FullBipartiteDigraphBase { + typedef FullBipartiteDigraphBase Parent; + + public: + + /// \brief Default constructor. + /// + /// Default constructor. The number of nodes and arcs will be zero. + FullBipartiteDigraph() { construct(0,0); } + + /// \brief Constructor + /// + /// Constructor. + /// \param n The number of the nodes. + FullBipartiteDigraph(int n1, int n2) { construct(n1, n2); } + + + /// \brief Returns the node with the given index. + /// + /// Returns the node with the given index. Since this structure is + /// completely static, the nodes can be indexed with integers from + /// the range <tt>[0..nodeNum()-1]</tt>. + /// The index of a node is the same as its ID. + /// \sa index() + Node operator()(int ix) const { return Parent::operator()(ix); } + + /// \brief Returns the index of the given node. + /// + /// Returns the index of the given node. Since this structure is + /// completely static, the nodes can be indexed with integers from + /// the range <tt>[0..nodeNum()-1]</tt>. + /// The index of a node is the same as its ID. + /// \sa operator()() + static int index(const Node& node) { return Parent::index(node); } + + /// \brief Returns the arc connecting the given nodes. + /// + /// Returns the arc connecting the given nodes. + /*Arc arc(Node u, Node v) const { + return Parent::arc(u, v); + }*/ + + /// \brief Number of nodes. + int nodeNum() const { return Parent::nodeNum(); } + /// \brief Number of arcs. + long long arcNum() const { return Parent::arcNum(); } + }; + + + + +} //namespace lemon + + +#endif //LEMON_FULL_GRAPH_H diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h new file mode 100644 index 0000000..7c6a4ce --- /dev/null +++ b/ot/lp/network_simplex_simple.h @@ -0,0 +1,1553 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- + * + * + * This file has been adapted by Nicolas Bonneel (2013), + * from network_simplex.h from LEMON, a generic C++ optimization library, + * to implement a lightweight network simplex for mass transport, more + * memory efficient that the original file. A previous version of this file + * is used as part of the Displacement Interpolation project, + * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/ + * + * + **** Original file Copyright Notice : + * + * Copyright (C) 2003-2010 + * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport + * (Egervary Research Group on Combinatorial Optimization, EGRES). + * + * Permission to use, modify and distribute this software is granted + * provided that this copyright notice appears in all copies. For + * precise terms see the accompanying LICENSE file. + * + * This software is provided "AS IS" with no warranty of any kind, + * express or implied, and with no claim as to its suitability for any + * purpose. + * + */ + +#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H +#define LEMON_NETWORK_SIMPLEX_SIMPLE_H +#define DEBUG_LVL 0 + +#if DEBUG_LVL>0 +#include <iomanip> +#endif + + +#define EPSILON 2.2204460492503131e-15 +#define _EPSILON 1e-8 +#define MAX_DEBUG_ITER 100000 + + +/// \ingroup min_cost_flow_algs +/// +/// \file +/// \brief Network Simplex algorithm for finding a minimum cost flow. + +// if your compiler has troubles with stdext or hashmaps, just comment the following line to use a slower std::map instead +//#define HASHMAP + +#include <vector> +#include <limits> +#include <algorithm> +#include <cstdio> +#ifdef HASHMAP +#include <hash_map> +#else +#include <map> +#endif +#include <cmath> +//#include "core.h" +//#include "lmath.h" + +//#include "sparse_array_n.h" +#include "full_bipartitegraph.h" + +#define INVALIDNODE -1 +#define INVALID (-1) + +namespace lemon { + + + template <typename T> + class ProxyObject; + + template<typename T> + class SparseValueVector + { + public: + SparseValueVector(int n=0) + { + } + void resize(int n=0){}; + T operator[](const int id) const + { +#ifdef HASHMAP + typename stdext::hash_map<int,T>::const_iterator it = data.find(id); +#else + typename std::map<int,T>::const_iterator it = data.find(id); +#endif + if (it==data.end()) + return 0; + else + return it->second; + } + + ProxyObject<T> operator[](const int id) + { + return ProxyObject<T>( this, id ); + } + + //private: +#ifdef HASHMAP + stdext::hash_map<int,T> data; +#else + std::map<int,T> data; +#endif + + }; + + template <typename T> + class ProxyObject { + public: + ProxyObject( SparseValueVector<T> *v, int idx ){_v=v; _idx=idx;}; + ProxyObject<T> & operator=( const T &v ) { + // If we get here, we know that operator[] was called to perform a write access, + // so we can insert an item in the vector if needed + if (v!=0) + _v->data[_idx]=v; + return *this; + } + + operator T() { + // If we get here, we know that operator[] was called to perform a read access, + // so we can simply return the existing object +#ifdef HASHMAP + typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx); +#else + typename std::map<int,T>::iterator it = _v->data.find(_idx); +#endif + if (it==_v->data.end()) + return 0; + else + return it->second; + } + + void operator+=(T val) + { + if (val==0) return; +#ifdef HASHMAP + typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx); +#else + typename std::map<int,T>::iterator it = _v->data.find(_idx); +#endif + if (it==_v->data.end()) + _v->data[_idx] = val; + else + { + T sum = it->second + val; + if (sum==0) + _v->data.erase(it); + else + it->second = sum; + } + } + void operator-=(T val) + { + if (val==0) return; +#ifdef HASHMAP + typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx); +#else + typename std::map<int,T>::iterator it = _v->data.find(_idx); +#endif + if (it==_v->data.end()) + _v->data[_idx] = -val; + else + { + T sum = it->second - val; + if (sum==0) + _v->data.erase(it); + else + it->second = sum; + } + } + + SparseValueVector<T> *_v; + int _idx; + }; + + + + /// \addtogroup min_cost_flow_algs + /// @{ + + /// \brief Implementation of the primal Network Simplex algorithm + /// for finding a \ref min_cost_flow "minimum cost flow". + /// + /// \ref NetworkSimplexSimple implements the primal Network Simplex algorithm + /// for finding a \ref min_cost_flow "minimum cost flow" + /// \ref amo93networkflows, \ref dantzig63linearprog, + /// \ref kellyoneill91netsimplex. + /// This algorithm is a highly efficient specialized version of the + /// linear programming simplex method directly for the minimum cost + /// flow problem. + /// + /// In general, %NetworkSimplexSimple is the fastest implementation available + /// in LEMON for this problem. + /// Moreover, it supports both directions of the supply/demand inequality + /// constraints. For more information, see \ref SupplyType. + /// + /// Most of the parameters of the problem (except for the digraph) + /// can be given using separate functions, and the algorithm can be + /// executed using the \ref run() function. If some parameters are not + /// specified, then default values will be used. + /// + /// \tparam GR The digraph type the algorithm runs on. + /// \tparam V The number type used for flow amounts, capacity bounds + /// and supply values in the algorithm. By default, it is \c int. + /// \tparam C The number type used for costs and potentials in the + /// algorithm. By default, it is the same as \c V. + /// + /// \warning Both number types must be signed and all input data must + /// be integer. + /// + /// \note %NetworkSimplexSimple provides five different pivot rule + /// implementations, from which the most efficient one is used + /// by default. For more information, see \ref PivotRule. + template <typename GR, typename V = int, typename C = V, typename NodesType = unsigned short int> + class NetworkSimplexSimple + { + public: + + /// \brief Constructor. + /// + /// The constructor of the class. + /// + /// \param graph The digraph the algorithm runs on. + /// \param arc_mixing Indicate if the arcs have to be stored in a + /// mixed order in the internal data structure. + /// In special cases, it could lead to better overall performance, + /// but it is usually slower. Therefore it is disabled by default. + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) : + _graph(graph), //_arc_id(graph), + _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), + MAX(std::numeric_limits<Value>::max()), + INF(std::numeric_limits<Value>::has_infinity ? + std::numeric_limits<Value>::infinity() : MAX) + { + // Reset data structures + reset(); + max_iter=maxiters; + } + + /// The type of the flow amounts, capacity bounds and supply values + typedef V Value; + /// The type of the arc costs + typedef C Cost; + + public: + + /// \brief Problem type constants for the \c run() function. + /// + /// Enum type containing the problem type constants that can be + /// returned by the \ref run() function of the algorithm. + enum ProblemType { + /// The problem has no feasible solution (flow). + INFEASIBLE, + /// The problem has optimal solution (i.e. it is feasible and + /// bounded), and the algorithm has found optimal flow and node + /// potentials (primal and dual solutions). + OPTIMAL, + /// The objective function of the problem is unbounded, i.e. + /// there is a directed cycle having negative total cost and + /// infinite upper bound. + UNBOUNDED, + /// The maximum number of iteration has been reached + MAX_ITER_REACHED + }; + + /// \brief Constants for selecting the type of the supply constraints. + /// + /// Enum type containing constants for selecting the supply type, + /// i.e. the direction of the inequalities in the supply/demand + /// constraints of the \ref min_cost_flow "minimum cost flow problem". + /// + /// The default supply type is \c GEQ, the \c LEQ type can be + /// selected using \ref supplyType(). + /// The equality form is a special case of both supply types. + enum SupplyType { + /// This option means that there are <em>"greater or equal"</em> + /// supply/demand constraints in the definition of the problem. + GEQ, + /// This option means that there are <em>"less or equal"</em> + /// supply/demand constraints in the definition of the problem. + LEQ + }; + + + + private: + + int max_iter; + TEMPLATE_DIGRAPH_TYPEDEFS(GR); + + typedef std::vector<int> IntVector; + typedef std::vector<NodesType> UHalfIntVector; + typedef std::vector<Value> ValueVector; + typedef std::vector<Cost> CostVector; + // typedef SparseValueVector<Cost> CostVector; + typedef std::vector<char> BoolVector; + // Note: vector<char> is used instead of vector<bool> for efficiency reasons + + // State constants for arcs + enum ArcState { + STATE_UPPER = -1, + STATE_TREE = 0, + STATE_LOWER = 1 + }; + + typedef std::vector<signed char> StateVector; + // Note: vector<signed char> is used instead of vector<ArcState> for + // efficiency reasons + + private: + + // Data related to the underlying digraph + const GR &_graph; + int _node_num; + int _arc_num; + int _all_arc_num; + int _search_arc_num; + + // Parameters of the problem + SupplyType _stype; + Value _sum_supply; + + inline int _node_id(int n) const {return _node_num-n-1;} ; + + //IntArcMap _arc_id; + UHalfIntVector _source; + UHalfIntVector _target; + bool _arc_mixing; + public: + // Node and arc data + CostVector _cost; + ValueVector _supply; + ValueVector _flow; + //SparseValueVector<Value> _flow; + CostVector _pi; + + + private: + // Data for storing the spanning tree structure + IntVector _parent; + IntVector _pred; + IntVector _thread; + IntVector _rev_thread; + IntVector _succ_num; + IntVector _last_succ; + IntVector _dirty_revs; + BoolVector _forward; + StateVector _state; + int _root; + + // Temporary data used in the current pivot iteration + int in_arc, join, u_in, v_in, u_out, v_out; + int first, second, right, last; + int stem, par_stem, new_stem; + Value delta; + + const Value MAX; + + int mixingCoeff; + + public: + + /// \brief Constant for infinite upper bounds (capacities). + /// + /// Constant for infinite upper bounds (capacities). + /// It is \c std::numeric_limits<Value>::infinity() if available, + /// \c std::numeric_limits<Value>::max() otherwise. + const Value INF; + + private: + + // thank you to DVK and MizardX from StackOverflow for this function! + inline int sequence(int k) const { + int smallv = (k > num_total_big_subsequence_numbers) & 1; + + k -= num_total_big_subsequence_numbers * smallv; + int subsequence_length2 = subsequence_length- smallv; + int subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv; + int subsequence_offset = (k % subsequence_length2) * mixingCoeff; + + return subsequence_offset + subsequence_num; + } + int subsequence_length; + int num_big_subseqiences; + int num_total_big_subsequence_numbers; + + inline int getArcID(const Arc &arc) const + { + //int n = _arc_num-arc._id-1; + int n = _arc_num-GR::id(arc)-1; + + //int a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; + //int b = _arc_id[arc]; + if (_arc_mixing) + return sequence(n); + else + return n; + } + + // finally unused because too slow + inline int getSource(const int arc) const + { + //int a = _source[arc]; + //return a; + + int n = _arc_num-arc-1; + if (_arc_mixing) + n = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; + + int b; + if (n>=0) + b = _node_id(_graph.source(GR::arcFromId( n ) )); + else + { + n = arc+1-_arc_num; + if ( n<=_node_num) + b = _node_num; + else + if ( n>=_graph._n1) + b = _graph._n1; + else + b = _graph._n1-n; + } + + return b; + } + + + + // Implementation of the Block Search pivot rule + class BlockSearchPivotRule + { + private: + + // References to the NetworkSimplexSimple class + const UHalfIntVector &_source; + const UHalfIntVector &_target; + const CostVector &_cost; + const StateVector &_state; + const CostVector &_pi; + int &_in_arc; + int _search_arc_num; + + // Pivot rule data + int _block_size; + int _next_arc; + NetworkSimplexSimple &_ns; + + public: + + // Constructor + BlockSearchPivotRule(NetworkSimplexSimple &ns) : + _source(ns._source), _target(ns._target), + _cost(ns._cost), _state(ns._state), _pi(ns._pi), + _in_arc(ns.in_arc), _search_arc_num(ns._search_arc_num), + _next_arc(0),_ns(ns) + { + // The main parameters of the pivot rule + const double BLOCK_SIZE_FACTOR = 1.0; + const int MIN_BLOCK_SIZE = 10; + + _block_size = std::max( int(BLOCK_SIZE_FACTOR * + std::sqrt(double(_search_arc_num))), + MIN_BLOCK_SIZE ); + } + // Find next entering arc + bool findEnteringArc() { + Cost c, min = 0; + int e; + int cnt = _block_size; + double a; + for (e = _next_arc; e != _search_arc_num; ++e) { + c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < min) { + min = c; + _in_arc = e; + } + if (--cnt == 0) { + a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); + a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + if (min < -EPSILON*a) goto search_end; + cnt = _block_size; + } + } + for (e = 0; e != _next_arc; ++e) { + c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < min) { + min = c; + _in_arc = e; + } + if (--cnt == 0) { + a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); + a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + if (min < -EPSILON*a) goto search_end; + cnt = _block_size; + } + } + a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); + a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + if (min >= -EPSILON*a) return false; + + search_end: + _next_arc = e; + return true; + } + + }; //class BlockSearchPivotRule + + + + public: + + + + int _init_nb_nodes; + long long _init_nb_arcs; + + /// \name Parameters + /// The parameters of the algorithm can be specified using these + /// functions. + + /// @{ + + + /// \brief Set the costs of the arcs. + /// + /// This function sets the costs of the arcs. + /// If it is not used before calling \ref run(), the costs + /// will be set to \c 1 on all arcs. + /// + /// \param map An arc map storing the costs. + /// Its \c Value type must be convertible to the \c Cost type + /// of the algorithm. + /// + /// \return <tt>(*this)</tt> + template<typename CostMap> + NetworkSimplexSimple& costMap(const CostMap& map) { + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a)) { + _cost[getArcID(a)] = map[a]; + } + return *this; + } + + + /// \brief Set the costs of one arc. + /// + /// This function sets the costs of one arcs. + /// Done for memory reasons + /// + /// \param arc An arc. + /// \param arc A cost + /// + /// \return <tt>(*this)</tt> + template<typename Value> + NetworkSimplexSimple& setCost(const Arc& arc, const Value cost) { + _cost[getArcID(arc)] = cost; + return *this; + } + + + /// \brief Set the supply values of the nodes. + /// + /// This function sets the supply values of the nodes. + /// If neither this function nor \ref stSupply() is used before + /// calling \ref run(), the supply of each node will be set to zero. + /// + /// \param map A node map storing the supply values. + /// Its \c Value type must be convertible to the \c Value type + /// of the algorithm. + /// + /// \return <tt>(*this)</tt> + template<typename SupplyMap> + NetworkSimplexSimple& supplyMap(const SupplyMap& map) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + _supply[_node_id(n)] = map[n]; + } + return *this; + } + template<typename SupplyMap> + NetworkSimplexSimple& supplyMap(const SupplyMap* map1, int n1, const SupplyMap* map2, int n2) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + if (n<n1) + _supply[_node_id(n)] = map1[n]; + else + _supply[_node_id(n)] = map2[n-n1]; + } + return *this; + } + template<typename SupplyMap> + NetworkSimplexSimple& supplyMapAll(SupplyMap val1, int n1, SupplyMap val2, int n2) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + if (n<n1) + _supply[_node_id(n)] = val1; + else + _supply[_node_id(n)] = val2; + } + return *this; + } + + /// \brief Set single source and target nodes and a supply value. + /// + /// This function sets a single source node and a single target node + /// and the required flow value. + /// If neither this function nor \ref supplyMap() is used before + /// calling \ref run(), the supply of each node will be set to zero. + /// + /// Using this function has the same effect as using \ref supplyMap() + /// with such a map in which \c k is assigned to \c s, \c -k is + /// assigned to \c t and all other nodes have zero supply value. + /// + /// \param s The source node. + /// \param t The target node. + /// \param k The required amount of flow from node \c s to node \c t + /// (i.e. the supply of \c s and the demand of \c t). + /// + /// \return <tt>(*this)</tt> + NetworkSimplexSimple& stSupply(const Node& s, const Node& t, Value k) { + for (int i = 0; i != _node_num; ++i) { + _supply[i] = 0; + } + _supply[_node_id(s)] = k; + _supply[_node_id(t)] = -k; + return *this; + } + + /// \brief Set the type of the supply constraints. + /// + /// This function sets the type of the supply/demand constraints. + /// If it is not used before calling \ref run(), the \ref GEQ supply + /// type will be used. + /// + /// For more information, see \ref SupplyType. + /// + /// \return <tt>(*this)</tt> + NetworkSimplexSimple& supplyType(SupplyType supply_type) { + _stype = supply_type; + return *this; + } + + /// @} + + /// \name Execution Control + /// The algorithm can be executed using \ref run(). + + /// @{ + + /// \brief Run the algorithm. + /// + /// This function runs the algorithm. + /// The paramters can be specified using functions \ref lowerMap(), + /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(), + /// \ref supplyType(). + /// For example, + /// \code + /// NetworkSimplexSimple<ListDigraph> ns(graph); + /// ns.lowerMap(lower).upperMap(upper).costMap(cost) + /// .supplyMap(sup).run(); + /// \endcode + /// + /// This function can be called more than once. All the given parameters + /// are kept for the next call, unless \ref resetParams() or \ref reset() + /// is used, thus only the modified parameters have to be set again. + /// If the underlying digraph was also modified after the construction + /// of the class (or the last \ref reset() call), then the \ref reset() + /// function must be called. + /// + /// \param pivot_rule The pivot rule that will be used during the + /// algorithm. For more information, see \ref PivotRule. + /// + /// \return \c INFEASIBLE if no feasible flow exists, + /// \n \c OPTIMAL if the problem has optimal solution + /// (i.e. it is feasible and bounded), and the algorithm has found + /// optimal flow and node potentials (primal and dual solutions), + /// \n \c UNBOUNDED if the objective function of the problem is + /// unbounded, i.e. there is a directed cycle having negative total + /// cost and infinite upper bound. + /// + /// \see ProblemType, PivotRule + /// \see resetParams(), reset() + ProblemType run() { +#if DEBUG_LVL>0 + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n"; +#endif + + if (!init()) return INFEASIBLE; +#if DEBUG_LVL>0 + std::cout << "Init done, starting iterations\n"; +#endif + return start(); + } + + /// \brief Reset all the parameters that have been given before. + /// + /// This function resets all the paramaters that have been given + /// before using functions \ref lowerMap(), \ref upperMap(), + /// \ref costMap(), \ref supplyMap(), \ref stSupply(), \ref supplyType(). + /// + /// It is useful for multiple \ref run() calls. Basically, all the given + /// parameters are kept for the next \ref run() call, unless + /// \ref resetParams() or \ref reset() is used. + /// If the underlying digraph was also modified after the construction + /// of the class or the last \ref reset() call, then the \ref reset() + /// function must be used, otherwise \ref resetParams() is sufficient. + /// + /// For example, + /// \code + /// NetworkSimplexSimple<ListDigraph> ns(graph); + /// + /// // First run + /// ns.lowerMap(lower).upperMap(upper).costMap(cost) + /// .supplyMap(sup).run(); + /// + /// // Run again with modified cost map (resetParams() is not called, + /// // so only the cost map have to be set again) + /// cost[e] += 100; + /// ns.costMap(cost).run(); + /// + /// // Run again from scratch using resetParams() + /// // (the lower bounds will be set to zero on all arcs) + /// ns.resetParams(); + /// ns.upperMap(capacity).costMap(cost) + /// .supplyMap(sup).run(); + /// \endcode + /// + /// \return <tt>(*this)</tt> + /// + /// \see reset(), run() + NetworkSimplexSimple& resetParams() { + for (int i = 0; i != _node_num; ++i) { + _supply[i] = 0; + } + for (int i = 0; i != _arc_num; ++i) { + _cost[i] = 1; + } + _stype = GEQ; + return *this; + } + + + + int divid (int x, int y) + { + return (x-x%y)/y; + } + + /// \brief Reset the internal data structures and all the parameters + /// that have been given before. + /// + /// This function resets the internal data structures and all the + /// paramaters that have been given before using functions \ref lowerMap(), + /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(), + /// \ref supplyType(). + /// + /// It is useful for multiple \ref run() calls. Basically, all the given + /// parameters are kept for the next \ref run() call, unless + /// \ref resetParams() or \ref reset() is used. + /// If the underlying digraph was also modified after the construction + /// of the class or the last \ref reset() call, then the \ref reset() + /// function must be used, otherwise \ref resetParams() is sufficient. + /// + /// See \ref resetParams() for examples. + /// + /// \return <tt>(*this)</tt> + /// + /// \see resetParams(), run() + NetworkSimplexSimple& reset() { + // Resize vectors + _node_num = _init_nb_nodes; + _arc_num = _init_nb_arcs; + int all_node_num = _node_num + 1; + int max_arc_num = _arc_num + 2 * _node_num; + + _source.resize(max_arc_num); + _target.resize(max_arc_num); + + _cost.resize(max_arc_num); + _supply.resize(all_node_num); + _flow.resize(max_arc_num); + _pi.resize(all_node_num); + + _parent.resize(all_node_num); + _pred.resize(all_node_num); + _forward.resize(all_node_num); + _thread.resize(all_node_num); + _rev_thread.resize(all_node_num); + _succ_num.resize(all_node_num); + _last_succ.resize(all_node_num); + _state.resize(max_arc_num); + + + //_arc_mixing=false; + if (_arc_mixing) { + // Store the arcs in a mixed order + int k = std::max(int(std::sqrt(double(_arc_num))), 10); + mixingCoeff = k; + subsequence_length = _arc_num / mixingCoeff + 1; + num_big_subseqiences = _arc_num % mixingCoeff; + num_total_big_subsequence_numbers = subsequence_length * num_big_subseqiences; + + int i = 0, j = 0; + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a)) { + _source[i] = _node_id(_graph.source(a)); + _target[i] = _node_id(_graph.target(a)); + //_arc_id[a] = i; + if ((i += k) >= _arc_num) i = ++j; + } + } else { + // Store the arcs in the original order + int i = 0; + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a), ++i) { + _source[i] = _node_id(_graph.source(a)); + _target[i] = _node_id(_graph.target(a)); + //_arc_id[a] = i; + } + } + + // Reset parameters + resetParams(); + return *this; + } + + /// @} + + /// \name Query Functions + /// The results of the algorithm can be obtained using these + /// functions.\n + /// The \ref run() function must be called before using them. + + /// @{ + + /// \brief Return the total cost of the found flow. + /// + /// This function returns the total cost of the found flow. + /// Its complexity is O(e). + /// + /// \note The return type of the function can be specified as a + /// template parameter. For example, + /// \code + /// ns.totalCost<double>(); + /// \endcode + /// It is useful if the total cost cannot be stored in the \c Cost + /// type of the algorithm, which is the default return type of the + /// function. + /// + /// \pre \ref run() must be called before using this function. + /*template <typename Number> + Number totalCost() const { + Number c = 0; + for (ArcIt a(_graph); a != INVALID; ++a) { + int i = getArcID(a); + c += Number(_flow[i]) * Number(_cost[i]); + } + return c; + }*/ + + template <typename Number> + Number totalCost() const { + Number c = 0; + + /*#ifdef HASHMAP + typename stdext::hash_map<int, Value>::const_iterator it; + #else + typename std::map<int, Value>::const_iterator it; + #endif + for (it = _flow.data.begin(); it!=_flow.data.end(); ++it) + c += Number(it->second) * Number(_cost[it->first]); + return c;*/ + + for (int i=0; i<_flow.size(); i++) + c += _flow[i] * Number(_cost[i]); + return c; + + } + +#ifndef DOXYGEN + Cost totalCost() const { + return totalCost<Cost>(); + } +#endif + + /// \brief Return the flow on the given arc. + /// + /// This function returns the flow on the given arc. + /// + /// \pre \ref run() must be called before using this function. + Value flow(const Arc& a) const { + return _flow[getArcID(a)]; + } + + /// \brief Return the flow map (the primal solution). + /// + /// This function copies the flow value on each arc into the given + /// map. The \c Value type of the algorithm must be convertible to + /// the \c Value type of the map. + /// + /// \pre \ref run() must be called before using this function. + template <typename FlowMap> + void flowMap(FlowMap &map) const { + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a)) { + map.set(a, _flow[getArcID(a)]); + } + } + + /// \brief Return the potential (dual value) of the given node. + /// + /// This function returns the potential (dual value) of the + /// given node. + /// + /// \pre \ref run() must be called before using this function. + Cost potential(const Node& n) const { + return _pi[_node_id(n)]; + } + + /// \brief Return the potential map (the dual solution). + /// + /// This function copies the potential (dual value) of each node + /// into the given map. + /// The \c Cost type of the algorithm must be convertible to the + /// \c Value type of the map. + /// + /// \pre \ref run() must be called before using this function. + template <typename PotentialMap> + void potentialMap(PotentialMap &map) const { + Node n; _graph.first(n); + for (; n != INVALID; _graph.next(n)) { + map.set(n, _pi[_node_id(n)]); + } + } + + /// @} + + private: + + // Initialize internal data structures + bool init() { + if (_node_num == 0) return false; + + // Check the sum of supply values + _sum_supply = 0; + for (int i = 0; i != _node_num; ++i) { + _sum_supply += _supply[i]; + } + if ( fabs(_sum_supply) > _EPSILON ) return false; + + _sum_supply = 0; + + // Initialize artifical cost + Cost ART_COST; + if (std::numeric_limits<Cost>::is_exact) { + ART_COST = std::numeric_limits<Cost>::max() / 2 + 1; + } else { + ART_COST = 0; + for (int i = 0; i != _arc_num; ++i) { + if (_cost[i] > ART_COST) ART_COST = _cost[i]; + } + ART_COST = (ART_COST + 1) * _node_num; + } + + // Initialize arc maps + for (int i = 0; i != _arc_num; ++i) { + //_flow[i] = 0; //by default, the sparse matrix is empty + _state[i] = STATE_LOWER; + } + + // Set data for the artificial root node + _root = _node_num; + _parent[_root] = -1; + _pred[_root] = -1; + _thread[_root] = 0; + _rev_thread[0] = _root; + _succ_num[_root] = _node_num + 1; + _last_succ[_root] = _root - 1; + _supply[_root] = -_sum_supply; + _pi[_root] = 0; + + // Add artificial arcs and initialize the spanning tree data structure + if (_sum_supply == 0) { + // EQ supply constraints + _search_arc_num = _arc_num; + _all_arc_num = _arc_num + _node_num; + for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _pred[u] = e; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + _state[e] = STATE_TREE; + if (_supply[u] >= 0) { + _forward[u] = true; + _pi[u] = 0; + _source[e] = u; + _target[e] = _root; + _flow[e] = _supply[u]; + _cost[e] = 0; + } else { + _forward[u] = false; + _pi[u] = ART_COST; + _source[e] = _root; + _target[e] = u; + _flow[e] = -_supply[u]; + _cost[e] = ART_COST; + } + } + } + else if (_sum_supply > 0) { + // LEQ supply constraints + _search_arc_num = _arc_num + _node_num; + int f = _arc_num + _node_num; + for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + if (_supply[u] >= 0) { + _forward[u] = true; + _pi[u] = 0; + _pred[u] = e; + _source[e] = u; + _target[e] = _root; + _flow[e] = _supply[u]; + _cost[e] = 0; + _state[e] = STATE_TREE; + } else { + _forward[u] = false; + _pi[u] = ART_COST; + _pred[u] = f; + _source[f] = _root; + _target[f] = u; + _flow[f] = -_supply[u]; + _cost[f] = ART_COST; + _state[f] = STATE_TREE; + _source[e] = u; + _target[e] = _root; + //_flow[e] = 0; //by default, the sparse matrix is empty + _cost[e] = 0; + _state[e] = STATE_LOWER; + ++f; + } + } + _all_arc_num = f; + } + else { + // GEQ supply constraints + _search_arc_num = _arc_num + _node_num; + int f = _arc_num + _node_num; + for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + if (_supply[u] <= 0) { + _forward[u] = false; + _pi[u] = 0; + _pred[u] = e; + _source[e] = _root; + _target[e] = u; + _flow[e] = -_supply[u]; + _cost[e] = 0; + _state[e] = STATE_TREE; + } else { + _forward[u] = true; + _pi[u] = -ART_COST; + _pred[u] = f; + _source[f] = u; + _target[f] = _root; + _flow[f] = _supply[u]; + _state[f] = STATE_TREE; + _cost[f] = ART_COST; + _source[e] = _root; + _target[e] = u; + //_flow[e] = 0; //by default, the sparse matrix is empty + _cost[e] = 0; + _state[e] = STATE_LOWER; + ++f; + } + } + _all_arc_num = f; + } + + return true; + } + + // Find the join node + void findJoinNode() { + int u = _source[in_arc]; + int v = _target[in_arc]; + while (u != v) { + if (_succ_num[u] < _succ_num[v]) { + u = _parent[u]; + } else { + v = _parent[v]; + } + } + join = u; + } + + // Find the leaving arc of the cycle and returns true if the + // leaving arc is not the same as the entering arc + bool findLeavingArc() { + // Initialize first and second nodes according to the direction + // of the cycle + if (_state[in_arc] == STATE_LOWER) { + first = _source[in_arc]; + second = _target[in_arc]; + } else { + first = _target[in_arc]; + second = _source[in_arc]; + } + delta = INF; + int result = 0; + Value d; + int e; + + // Search the cycle along the path form the first node to the root + for (int u = first; u != join; u = _parent[u]) { + e = _pred[u]; + d = _forward[u] ? _flow[e] : INF ; + if (d < delta) { + delta = d; + u_out = u; + result = 1; + } + } + // Search the cycle along the path form the second node to the root + for (int u = second; u != join; u = _parent[u]) { + e = _pred[u]; + d = _forward[u] ? INF : _flow[e]; + if (d <= delta) { + delta = d; + u_out = u; + result = 2; + } + } + + if (result == 1) { + u_in = first; + v_in = second; + } else { + u_in = second; + v_in = first; + } + return result != 0; + } + + // Change _flow and _state vectors + void changeFlow(bool change) { + // Augment along the cycle + if (delta > 0) { + Value val = _state[in_arc] * delta; + _flow[in_arc] += val; + for (int u = _source[in_arc]; u != join; u = _parent[u]) { + _flow[_pred[u]] += _forward[u] ? -val : val; + } + for (int u = _target[in_arc]; u != join; u = _parent[u]) { + _flow[_pred[u]] += _forward[u] ? val : -val; + } + } + // Update the state of the entering and leaving arcs + if (change) { + _state[in_arc] = STATE_TREE; + _state[_pred[u_out]] = + (_flow[_pred[u_out]] == 0) ? STATE_LOWER : STATE_UPPER; + } else { + _state[in_arc] = -_state[in_arc]; + } + } + + // Update the tree structure + void updateTreeStructure() { + int u, w; + int old_rev_thread = _rev_thread[u_out]; + int old_succ_num = _succ_num[u_out]; + int old_last_succ = _last_succ[u_out]; + v_out = _parent[u_out]; + + u = _last_succ[u_in]; // the last successor of u_in + right = _thread[u]; // the node after it + + // Handle the case when old_rev_thread equals to v_in + // (it also means that join and v_out coincide) + if (old_rev_thread == v_in) { + last = _thread[_last_succ[u_out]]; + } else { + last = _thread[v_in]; + } + + // Update _thread and _parent along the stem nodes (i.e. the nodes + // between u_in and u_out, whose parent have to be changed) + _thread[v_in] = stem = u_in; + _dirty_revs.clear(); + _dirty_revs.push_back(v_in); + par_stem = v_in; + while (stem != u_out) { + // Insert the next stem node into the thread list + new_stem = _parent[stem]; + _thread[u] = new_stem; + _dirty_revs.push_back(u); + + // Remove the subtree of stem from the thread list + w = _rev_thread[stem]; + _thread[w] = right; + _rev_thread[right] = w; + + // Change the parent node and shift stem nodes + _parent[stem] = par_stem; + par_stem = stem; + stem = new_stem; + + // Update u and right + u = _last_succ[stem] == _last_succ[par_stem] ? + _rev_thread[par_stem] : _last_succ[stem]; + right = _thread[u]; + } + _parent[u_out] = par_stem; + _thread[u] = last; + _rev_thread[last] = u; + _last_succ[u_out] = u; + + // Remove the subtree of u_out from the thread list except for + // the case when old_rev_thread equals to v_in + // (it also means that join and v_out coincide) + if (old_rev_thread != v_in) { + _thread[old_rev_thread] = right; + _rev_thread[right] = old_rev_thread; + } + + // Update _rev_thread using the new _thread values + for (int i = 0; i != int(_dirty_revs.size()); ++i) { + u = _dirty_revs[i]; + _rev_thread[_thread[u]] = u; + } + + // Update _pred, _forward, _last_succ and _succ_num for the + // stem nodes from u_out to u_in + int tmp_sc = 0, tmp_ls = _last_succ[u_out]; + u = u_out; + while (u != u_in) { + w = _parent[u]; + _pred[u] = _pred[w]; + _forward[u] = !_forward[w]; + tmp_sc += _succ_num[u] - _succ_num[w]; + _succ_num[u] = tmp_sc; + _last_succ[w] = tmp_ls; + u = w; + } + _pred[u_in] = in_arc; + _forward[u_in] = (u_in == _source[in_arc]); + _succ_num[u_in] = old_succ_num; + + // Set limits for updating _last_succ form v_in and v_out + // towards the root + int up_limit_in = -1; + int up_limit_out = -1; + if (_last_succ[join] == v_in) { + up_limit_out = join; + } else { + up_limit_in = join; + } + + // Update _last_succ from v_in towards the root + for (u = v_in; u != up_limit_in && _last_succ[u] == v_in; + u = _parent[u]) { + _last_succ[u] = _last_succ[u_out]; + } + // Update _last_succ from v_out towards the root + if (join != old_rev_thread && v_in != old_rev_thread) { + for (u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ; + u = _parent[u]) { + _last_succ[u] = old_rev_thread; + } + } else { + for (u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ; + u = _parent[u]) { + _last_succ[u] = _last_succ[u_out]; + } + } + + // Update _succ_num from v_in to join + for (u = v_in; u != join; u = _parent[u]) { + _succ_num[u] += old_succ_num; + } + // Update _succ_num from v_out to join + for (u = v_out; u != join; u = _parent[u]) { + _succ_num[u] -= old_succ_num; + } + } + + // Update potentials + void updatePotential() { + Cost sigma = _forward[u_in] ? + _pi[v_in] - _pi[u_in] - _cost[_pred[u_in]] : + _pi[v_in] - _pi[u_in] + _cost[_pred[u_in]]; + // Update potentials in the subtree, which has been moved + int end = _thread[_last_succ[u_in]]; + for (int u = u_in; u != end; u = _thread[u]) { + _pi[u] += sigma; + } + } + + // Heuristic initial pivots + bool initialPivots() { + Value curr, total = 0; + std::vector<Node> supply_nodes, demand_nodes; + Node u; _graph.first(u); + for (; u != INVALIDNODE; _graph.next(u)) { + curr = _supply[_node_id(u)]; + if (curr > 0) { + total += curr; + supply_nodes.push_back(u); + } + else if (curr < 0) { + demand_nodes.push_back(u); + } + } + if (_sum_supply > 0) total -= _sum_supply; + if (total <= 0) return true; + + IntVector arc_vector; + if (_sum_supply >= 0) { + if (supply_nodes.size() == 1 && demand_nodes.size() == 1) { + // Perform a reverse graph search from the sink to the source + //typename GR::template NodeMap<bool> reached(_graph, false); + BoolVector reached(_node_num, false); + Node s = supply_nodes[0], t = demand_nodes[0]; + std::vector<Node> stack; + reached[t] = true; + stack.push_back(t); + while (!stack.empty()) { + Node u, v = stack.back(); + stack.pop_back(); + if (v == s) break; + Arc a; _graph.firstIn(a, v); + for (; a != INVALID; _graph.nextIn(a)) { + if (reached[u = _graph.source(a)]) continue; + int j = getArcID(a); + if (INF >= total) { + arc_vector.push_back(j); + reached[u] = true; + stack.push_back(u); + } + } + } + } else { + // Find the min. cost incomming arc for each demand node + for (int i = 0; i != int(demand_nodes.size()); ++i) { + Node v = demand_nodes[i]; + Cost c, min_cost = std::numeric_limits<Cost>::max(); + Arc min_arc = INVALID; + Arc a; _graph.firstIn(a, v); + for (; a != INVALID; _graph.nextIn(a)) { + c = _cost[getArcID(a)]; + if (c < min_cost) { + min_cost = c; + min_arc = a; + } + } + if (min_arc != INVALID) { + arc_vector.push_back(getArcID(min_arc)); + } + } + } + } else { + // Find the min. cost outgoing arc for each supply node + for (int i = 0; i != int(supply_nodes.size()); ++i) { + Node u = supply_nodes[i]; + Cost c, min_cost = std::numeric_limits<Cost>::max(); + Arc min_arc = INVALID; + Arc a; _graph.firstOut(a, u); + for (; a != INVALID; _graph.nextOut(a)) { + c = _cost[getArcID(a)]; + if (c < min_cost) { + min_cost = c; + min_arc = a; + } + } + if (min_arc != INVALID) { + arc_vector.push_back(getArcID(min_arc)); + } + } + } + + // Perform heuristic initial pivots + for (int i = 0; i != int(arc_vector.size()); ++i) { + in_arc = arc_vector[i]; + // l'erreur est probablement ici... + if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - + _pi[_target[in_arc]]) >= 0) continue; + findJoinNode(); + bool change = findLeavingArc(); + if (delta >= MAX) return false; + changeFlow(change); + if (change) { + updateTreeStructure(); + updatePotential(); + } + } + return true; + } + + // Execute the algorithm + ProblemType start() { + return start<BlockSearchPivotRule>(); + } + + template <typename PivotRuleImpl> + ProblemType start() { + PivotRuleImpl pivot(*this); + double prevCost=-1; + ProblemType retVal = OPTIMAL; + + // Perform heuristic initial pivots + if (!initialPivots()) return UNBOUNDED; + + int iter_number=0; + //pivot.setDantzig(true); + // Execute the Network Simplex algorithm + while (pivot.findEnteringArc()) { + if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){ + char errMess[1000]; + sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); + std::cerr << errMess; + retVal = MAX_ITER_REACHED; + break; + } +#if DEBUG_LVL>0 + if(iter_number>MAX_DEBUG_ITER) + break; + if(iter_number%1000==0||iter_number%1000==1){ + double curCost=totalCost(); + double sumFlow=0; + double a; + a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); + a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + } + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + std::cout << _cost[in_arc] << "\n"; + std::cout << _pi[_source[in_arc]] << "\n"; + std::cout << _pi[_target[in_arc]] << "\n"; + std::cout << a << "\n"; + } +#endif + + findJoinNode(); + bool change = findLeavingArc(); + if (delta >= MAX) return UNBOUNDED; + changeFlow(change); + if (change) { + updateTreeStructure(); + updatePotential(); + } +#if DEBUG_LVL>0 + else{ + std::cout << "No change\n"; + } +#endif +#if DEBUG_LVL>1 + std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; +#endif + + } + + +#if DEBUG_LVL>0 + double curCost=totalCost(); + double sumFlow=0; + double a; + a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); + a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + } + + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + +#endif + +#if DEBUG_LVL>1 + sumFlow=0; + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + if (_state[i]==STATE_TREE) { + std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; + } + } + std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; +#endif + // Check feasibility + if( retVal == OPTIMAL){ + for (int e = _search_arc_num; e != _all_arc_num; ++e) { + if (_flow[e] != 0){ + if (abs(_flow[e]) > EPSILON) + return INFEASIBLE; + else + _flow[e]=0; + + } + } + } + + // Shift potentials to meet the requirements of the GEQ/LEQ type + // optimality conditions + if (_sum_supply == 0) { + if (_stype == GEQ) { + Cost max_pot = -std::numeric_limits<Cost>::max(); + for (int i = 0; i != _node_num; ++i) { + if (_pi[i] > max_pot) max_pot = _pi[i]; + } + if (max_pot > 0) { + for (int i = 0; i != _node_num; ++i) + _pi[i] -= max_pot; + } + } else { + Cost min_pot = std::numeric_limits<Cost>::max(); + for (int i = 0; i != _node_num; ++i) { + if (_pi[i] < min_pot) min_pot = _pi[i]; + } + if (min_pot < 0) { + for (int i = 0; i != _node_num; ++i) + _pi[i] -= min_pot; + } + } + } + + return retVal; + } + + }; //class NetworkSimplexSimple + + ///@} + +} //namespace lemon + +#endif //LEMON_NETWORK_SIMPLEX_H diff --git a/ot/optim.py b/ot/optim.py new file mode 100644 index 0000000..0abd9e9 --- /dev/null +++ b/ot/optim.py @@ -0,0 +1,440 @@ +# -*- coding: utf-8 -*- +""" +Optimization algorithms for OT +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Titouan Vayer <titouan.vayer@irisa.fr> +# +# License: MIT License + +import numpy as np +from scipy.optimize.linesearch import scalar_search_armijo +from .lp import emd +from .bregman import sinkhorn + +# The corresponding scipy function does not work for matrices + + +def line_search_armijo(f, xk, pk, gfk, old_fval, + args=(), c1=1e-4, alpha0=0.99): + """ + Armijo linesearch function that works with matrices + + find an approximate minimum of f(xk+alpha*pk) that satifies the + armijo conditions. + + Parameters + ---------- + f : callable + loss function + xk : ndarray + initial position + pk : ndarray + descent direction + gfk : ndarray + gradient of f at xk + old_fval : float + loss value at xk + args : tuple, optional + arguments given to f + c1 : float, optional + c1 const in armijo rule (>0) + alpha0 : float, optional + initial step (>0) + + Returns + ------- + alpha : float + step that satisfy armijo conditions + fc : int + nb of function call + fa : float + loss value at step alpha + + """ + xk = np.atleast_1d(xk) + fc = [0] + + def phi(alpha1): + fc[0] += 1 + return f(xk + alpha1 * pk, *args) + + if old_fval is None: + phi0 = phi(0.) + else: + phi0 = old_fval + + derphi0 = np.sum(pk * gfk) # Quickfix for matrices + alpha, phi1 = scalar_search_armijo( + phi, phi0, derphi0, c1=c1, alpha0=alpha0) + + return alpha, fc[0], phi1 + + +def solve_linesearch(cost, G, deltaG, Mi, f_val, + armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): + """ + Solve the linesearch in the FW iterations + Parameters + ---------- + cost : method + Cost in the FW for the linesearch + G : ndarray, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : ndarray (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + Mi : ndarray (ns,nt) + Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost + f_val : float + Value of the cost at G + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + C1 : ndarray (ns,ns), optional + Structure matrix in the source domain. Only used and necessary when armijo=False + C2 : ndarray (nt,nt), optional + Structure matrix in the target domain. Only used and necessary when armijo=False + reg : float, optional + Regularization parameter. Only used and necessary when armijo=False + Gc : ndarray (ns,nt) + Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False + constC : ndarray (ns,nt) + Constant for the gromov cost. See [24]. Only used and necessary when armijo=False + M : ndarray (ns,nt), optional + Cost matrix between the features. Only used and necessary when armijo=False + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + if armijo: + alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + else: # requires symetric matrices + dot1 = np.dot(C1, deltaG) + dot12 = dot1.dot(C2) + a = -2 * reg * np.sum(dot12 * deltaG) + b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + c = cost(G) + + alpha = solve_1d_linesearch_quad(a, b, c) + fc = None + f_val = cost(G + alpha * deltaG) + + return alpha, fc, f_val + + +def cg(a, b, M, reg, f, df, G0=None, numItermax=200, + stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): + """ + Solve the general regularized OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + + + Parameters + ---------- + a : ndarray, shape (ns,) + samples weights in the source domain + b : ndarray, shape (nt,) + samples in the target domain + M : ndarray, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : ndarray, shape (ns,nt), optional + initial guess (default is indep joint density) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + + See Also + -------- + ot.lp.emd : Unregularized optimal ransport + ot.bregman.sinkhorn : Entropic regularized optimal transport + + """ + + loop = 1 + + if log: + log = {'loss': []} + + if G0 is None: + G = np.outer(a, b) + else: + G = G0 + + def cost(G): + return np.sum(M * G) + reg * f(G) + + f_val = cost(G) + if log: + log['loss'].append(f_val) + + it = 0 + + if verbose: + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) + + while loop: + + it += 1 + old_fval = f_val + + # problem linearization + Mi = M + reg * df(G) + # set M positive + Mi += Mi.min() + + # solve linear program + Gc = emd(a, b, Mi) + + deltaG = Gc - G + + # line search + alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + + G = G + alpha * deltaG + + # test convergence + if it >= numItermax: + loop = 0 + + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: + loop = 0 + + if log: + log['loss'].append(f_val) + + if verbose: + if it % 20 == 0: + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) + + if log: + return G, log + else: + return G + + +def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, + numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): + """ + Solve the general regularized OT problem with the generalized conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_ + + + Parameters + ---------- + a : ndarray, shape (ns,) + samples weights in the source domain + b : ndarrayv (nt,) + samples in the target domain + M : ndarray, shape (ns, nt) + loss matrix + reg1 : float + Entropic Regularization term >0 + reg2 : float + Second Regularization term >0 + G0 : ndarray, shape (ns, nt), optional + initial guess (default is indep joint density) + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations of Sinkhorn + stopThr : float, optional + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : ndarray, shape (ns, nt) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + + See Also + -------- + ot.optim.cg : conditional gradient + + """ + + loop = 1 + + if log: + log = {'loss': []} + + if G0 is None: + G = np.outer(a, b) + else: + G = G0 + + def cost(G): + return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G) + + f_val = cost(G) + if log: + log['loss'].append(f_val) + + it = 0 + + if verbose: + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) + + while loop: + + it += 1 + old_fval = f_val + + # problem linearization + Mi = M + reg2 * df(G) + + # solve linear program with Sinkhorn + # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) + Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax) + + deltaG = Gc - G + + # line search + dcost = Mi + reg1 * (1 + np.log(G)) # ?? + alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val) + + G = G + alpha * deltaG + + # test convergence + if it >= numItermax: + loop = 0 + + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: + loop = 0 + + if log: + log['loss'].append(f_val) + + if verbose: + if it % 20 == 0: + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) + + if log: + return G, log + else: + return G + + +def solve_1d_linesearch_quad(a, b, c): + """ + For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: + .. math:: + \argmin f(x)=a*x^{2}+b*x+c + + Parameters + ---------- + a,b,c : float + The coefficients of the quadratic function + + Returns + ------- + x : float + The optimal value which leads to the minimal cost + """ + f0 = c + df0 = b + f1 = a + f0 + df0 + + if a > 0: # convex + minimum = min(1, max(0, np.divide(-b, 2.0 * a))) + return minimum + else: # non convex + if f0 > f1: + return 1 + else: + return 0 diff --git a/ot/plot.py b/ot/plot.py new file mode 100644 index 0000000..f403e98 --- /dev/null +++ b/ot/plot.py @@ -0,0 +1,91 @@ +""" +Functions for plotting OT matrices + +.. warning:: + Note that by default the module is not import in :mod:`ot`. In order to + use it you need to explicitely import :mod:`ot.plot` + + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +from matplotlib import gridspec + + +def plot1D_mat(a, b, M, title=''): + """ Plot matrix M with the source and target 1D distribution + + Creates a subplot with the source distribution a on the left and + target distribution b on the tot. The matrix M is shown in between. + + + Parameters + ---------- + a : ndarray, shape (na,) + Source distribution + b : ndarray, shape (nb,) + Target distribution + M : ndarray, shape (na, nb) + Matrix to plot + """ + na, nb = M.shape + + gs = gridspec.GridSpec(3, 3) + + xa = np.arange(na) + xb = np.arange(nb) + + ax1 = pl.subplot(gs[0, 1:]) + pl.plot(xb, b, 'r', label='Target distribution') + pl.yticks(()) + pl.title(title) + + ax2 = pl.subplot(gs[1:, 0]) + pl.plot(a, xa, 'b', label='Source distribution') + pl.gca().invert_xaxis() + pl.gca().invert_yaxis() + pl.xticks(()) + + pl.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2) + pl.imshow(M, interpolation='nearest') + pl.axis('off') + + pl.xlim((0, nb)) + pl.tight_layout() + pl.subplots_adjust(wspace=0., hspace=0.2) + + +def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): + """ Plot matrix M in 2D with lines using alpha values + + Plot lines between source and target 2D samples with a color + proportional to the value of the matrix G between samples. + + + Parameters + ---------- + xs : ndarray, shape (ns,2) + Source samples positions + b : ndarray, shape (nt,2) + Target samples positions + G : ndarray, shape (na,nb) + OT matrix + thr : float, optional + threshold above which the line is drawn + **kwargs : dict + paameters given to the plot functions (default color is black if + nothing given) + """ + if ('color' not in kwargs) and ('c' not in kwargs): + kwargs['color'] = 'k' + mx = G.max() + for i in range(xs.shape[0]): + for j in range(xt.shape[0]): + if G[i, j] / mx > thr: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], + alpha=G[i, j] / mx, **kwargs) diff --git a/ot/smooth.py b/ot/smooth.py new file mode 100644 index 0000000..5a8e4b5 --- /dev/null +++ b/ot/smooth.py @@ -0,0 +1,600 @@ +#Copyright (c) 2018, Mathieu Blondel +#All rights reserved. +# +#Redistribution and use in source and binary forms, with or without +#modification, are permitted provided that the following conditions are met: +# +#1. Redistributions of source code must retain the above copyright notice, this +#list of conditions and the following disclaimer. +# +#2. Redistributions in binary form must reproduce the above copyright notice, +#this list of conditions and the following disclaimer in the documentation and/or +#other materials provided with the distribution. +# +#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +#ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +#WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +#IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +#INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +#NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +#OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +#LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +#OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +#THE POSSIBILITY OF SUCH DAMAGE. + +# Author: Mathieu Blondel +# Remi Flamary <remi.flamary@unice.fr> + +""" +Implementation of +Smooth and Sparse Optimal Transport. +Mathieu Blondel, Vivien Seguy, Antoine Rolet. +In Proc. of AISTATS 2018. +https://arxiv.org/abs/1710.06276 + +[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal +Transport. Proceedings of the Twenty-First International Conference on +Artificial Intelligence and Statistics (AISTATS). + +Original code from https://github.com/mblondel/smooth-ot/ + +""" + +import numpy as np +from scipy.optimize import minimize + + +def projection_simplex(V, z=1, axis=None): + """ Projection of x onto the simplex, scaled by z + + P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2 + z: float or array + If array, len(z) must be compatible with V + axis: None or int + - axis=None: project V by P(V.ravel(); z) + - axis=1: project each V[i] by P(V[i]; z[i]) + - axis=0: project each V[:, j] by P(V[:, j]; z[j]) + """ + if axis == 1: + n_features = V.shape[1] + U = np.sort(V, axis=1)[:, ::-1] + z = np.ones(len(V)) * z + cssv = np.cumsum(U, axis=1) - z[:, np.newaxis] + ind = np.arange(n_features) + 1 + cond = U - cssv / ind > 0 + rho = np.count_nonzero(cond, axis=1) + theta = cssv[np.arange(len(V)), rho - 1] / rho + return np.maximum(V - theta[:, np.newaxis], 0) + + elif axis == 0: + return projection_simplex(V.T, z, axis=1).T + + else: + V = V.ravel().reshape(1, -1) + return projection_simplex(V, z, axis=1).ravel() + + +class Regularization(object): + """Base class for Regularization objects + + Notes + ----- + This class is not intended for direct use but as aparent for true + regularizatiojn implementation. + """ + + def __init__(self, gamma=1.0): + """ + + Parameters + ---------- + gamma: float + Regularization parameter. + We recover unregularized OT when gamma -> 0. + + """ + self.gamma = gamma + + def delta_Omega(X): + """ + Compute delta_Omega(X[:, j]) for each X[:, j]. + delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y). + + Parameters + ---------- + X: array, shape = len(a) x len(b) + Input array. + + Returns + ------- + v: array, len(b) + Values: v[j] = delta_Omega(X[:, j]) + G: array, len(a) x len(b) + Gradients: G[:, j] = nabla delta_Omega(X[:, j]) + """ + raise NotImplementedError + + def max_Omega(X, b): + """ + Compute max_Omega_j(X[:, j]) for each X[:, j]. + max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j]. + + Parameters + ---------- + X: array, shape = len(a) x len(b) + Input array. + + Returns + ------- + v: array, len(b) + Values: v[j] = max_Omega_j(X[:, j]) + G: array, len(a) x len(b) + Gradients: G[:, j] = nabla max_Omega_j(X[:, j]) + """ + raise NotImplementedError + + def Omega(T): + """ + Compute regularization term. + + Parameters + ---------- + T: array, shape = len(a) x len(b) + Input array. + + Returns + ------- + value: float + Regularization term. + """ + raise NotImplementedError + + +class NegEntropy(Regularization): + """ NegEntropy regularization """ + + def delta_Omega(self, X): + G = np.exp(X / self.gamma - 1) + val = self.gamma * np.sum(G, axis=0) + return val, G + + def max_Omega(self, X, b): + max_X = np.max(X, axis=0) / self.gamma + exp_X = np.exp(X / self.gamma - max_X) + val = self.gamma * (np.log(np.sum(exp_X, axis=0)) + max_X) + val -= self.gamma * np.log(b) + G = exp_X / np.sum(exp_X, axis=0) + return val, G + + def Omega(self, T): + return self.gamma * np.sum(T * np.log(T)) + + +class SquaredL2(Regularization): + """ Squared L2 regularization """ + + def delta_Omega(self, X): + max_X = np.maximum(X, 0) + val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma) + G = max_X / self.gamma + return val, G + + def max_Omega(self, X, b): + G = projection_simplex(X / (b * self.gamma), axis=0) + val = np.sum(X * G, axis=0) + val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0) + return val, G + + def Omega(self, T): + return 0.5 * self.gamma * np.sum(T ** 2) + + +def dual_obj_grad(alpha, beta, a, b, C, regul): + """ + Compute objective value and gradients of dual objective. + + Parameters + ---------- + alpha: array, shape = len(a) + beta: array, shape = len(b) + Current iterate of dual potentials. + a: array, shape = len(a) + b: array, shape = len(b) + Input histograms (should be non-negative and sum to 1). + C: array, shape = len(a) x len(b) + Ground cost matrix. + regul: Regularization object + Should implement a delta_Omega(X) method. + + Returns + ------- + obj: float + Objective value (higher is better). + grad_alpha: array, shape = len(a) + Gradient w.r.t. alpha. + grad_beta: array, shape = len(b) + Gradient w.r.t. beta. + """ + obj = np.dot(alpha, a) + np.dot(beta, b) + grad_alpha = a.copy() + grad_beta = b.copy() + + # X[:, j] = alpha + beta[j] - C[:, j] + X = alpha[:, np.newaxis] + beta - C + + # val.shape = len(b) + # G.shape = len(a) x len(b) + val, G = regul.delta_Omega(X) + + obj -= np.sum(val) + grad_alpha -= G.sum(axis=1) + grad_beta -= G.sum(axis=0) + + return obj, grad_alpha, grad_beta + + +def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, + verbose=False): + """ + Solve the "smoothed" dual objective. + + Parameters + ---------- + a: array, shape = len(a) + b: array, shape = len(b) + Input histograms (should be non-negative and sum to 1). + C: array, shape = len(a) x len(b) + Ground cost matrix. + regul: Regularization object + Should implement a delta_Omega(X) method. + method: str + Solver to be used (passed to `scipy.optimize.minimize`). + tol: float + Tolerance parameter. + max_iter: int + Maximum number of iterations. + + Returns + ------- + alpha: array, shape = len(a) + beta: array, shape = len(b) + Dual potentials. + """ + + def _func(params): + # Unpack alpha and beta. + alpha = params[:len(a)] + beta = params[len(a):] + + obj, grad_alpha, grad_beta = dual_obj_grad(alpha, beta, a, b, C, regul) + + # Pack grad_alpha and grad_beta. + grad = np.concatenate((grad_alpha, grad_beta)) + + # We need to maximize the dual. + return -obj, -grad + + # Unfortunately, `minimize` only supports functions whose argument is a + # vector. So, we need to concatenate alpha and beta. + alpha_init = np.zeros(len(a)) + beta_init = np.zeros(len(b)) + params_init = np.concatenate((alpha_init, beta_init)) + + res = minimize(_func, params_init, method=method, jac=True, + tol=tol, options=dict(maxiter=max_iter, disp=verbose)) + + alpha = res.x[:len(a)] + beta = res.x[len(a):] + + return alpha, beta, res + + +def semi_dual_obj_grad(alpha, a, b, C, regul): + """ + Compute objective value and gradient of semi-dual objective. + + Parameters + ---------- + alpha: array, shape = len(a) + Current iterate of semi-dual potentials. + a: array, shape = len(a) + b: array, shape = len(b) + Input histograms (should be non-negative and sum to 1). + C: array, shape = len(a) x len(b) + Ground cost matrix. + regul: Regularization object + Should implement a max_Omega(X) method. + + Returns + ------- + obj: float + Objective value (higher is better). + grad: array, shape = len(a) + Gradient w.r.t. alpha. + """ + obj = np.dot(alpha, a) + grad = a.copy() + + # X[:, j] = alpha - C[:, j] + X = alpha[:, np.newaxis] - C + + # val.shape = len(b) + # G.shape = len(a) x len(b) + val, G = regul.max_Omega(X, b) + + obj -= np.dot(b, val) + grad -= np.dot(G, b) + + return obj, grad + + +def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, + verbose=False): + """ + Solve the "smoothed" semi-dual objective. + + Parameters + ---------- + a: array, shape = len(a) + b: array, shape = len(b) + Input histograms (should be non-negative and sum to 1). + C: array, shape = len(a) x len(b) + Ground cost matrix. + regul: Regularization object + Should implement a max_Omega(X) method. + method: str + Solver to be used (passed to `scipy.optimize.minimize`). + tol: float + Tolerance parameter. + max_iter: int + Maximum number of iterations. + + Returns + ------- + alpha: array, shape = len(a) + Semi-dual potentials. + """ + + def _func(alpha): + obj, grad = semi_dual_obj_grad(alpha, a, b, C, regul) + # We need to maximize the semi-dual. + return -obj, -grad + + alpha_init = np.zeros(len(a)) + + res = minimize(_func, alpha_init, method=method, jac=True, + tol=tol, options=dict(maxiter=max_iter, disp=verbose)) + + return res.x, res + + +def get_plan_from_dual(alpha, beta, C, regul): + """ + Retrieve optimal transportation plan from optimal dual potentials. + + Parameters + ---------- + alpha: array, shape = len(a) + beta: array, shape = len(b) + Optimal dual potentials. + C: array, shape = len(a) x len(b) + Ground cost matrix. + regul: Regularization object + Should implement a delta_Omega(X) method. + + Returns + ------- + T: array, shape = len(a) x len(b) + Optimal transportation plan. + """ + X = alpha[:, np.newaxis] + beta - C + return regul.delta_Omega(X)[1] + + +def get_plan_from_semi_dual(alpha, b, C, regul): + """ + Retrieve optimal transportation plan from optimal semi-dual potentials. + + Parameters + ---------- + alpha: array, shape = len(a) + Optimal semi-dual potentials. + b: array, shape = len(b) + Second input histogram (should be non-negative and sum to 1). + C: array, shape = len(a) x len(b) + Ground cost matrix. + regul: Regularization object + Should implement a delta_Omega(X) method. + + Returns + ------- + T: array, shape = len(a) x len(b) + Optimal transportation plan. + """ + X = alpha[:, np.newaxis] - C + return regul.max_Omega(X, b)[1] * b + + +def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, + numItermax=500, verbose=False, log=False): + r""" + Solve the regularized OT problem in the dual and return the OT matrix + + The function solves the smooth relaxed dual formulation (7) in [17]_ : + + .. math:: + \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j) + + where : + + - :math:`\mathbf{m}_j` is the jth column of the cost matrix + - :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega` + - a and b are source and target weights (sum to 1) + + The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega` + (See [17]_ Proposition 1). + The optimization algorithm is using gradient decent (L-BFGS by default). + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + reg_type : str + Regularization type, can be the following (default ='l2'): + - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_) + - 'l2' : Squared Euclidean regularization + method : str + Solver to use for scipy.optimize.minimize + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.sinhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + + """ + + if reg_type.lower() in ['l2', 'squaredl2']: + regul = SquaredL2(gamma=reg) + elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: + regul = NegEntropy(gamma=reg) + else: + raise NotImplementedError('Unknown regularization') + + # solve dual + alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, + tol=stopThr, verbose=verbose) + + # reconstruct transport matrix + G = get_plan_from_dual(alpha, beta, M, regul) + + if log: + log = {'alpha': alpha, 'beta': beta, 'res': res} + return G, log + else: + return G + + +def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, + numItermax=500, verbose=False, log=False): + r""" + Solve the regularized OT problem in the semi-dual and return the OT matrix + + The function solves the smooth relaxed dual formulation (10) in [17]_ : + + .. math:: + \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b) + + where : + + .. math:: + OT_\Omega^*(\alpha,b)=\sum_j b_j + + - :math:`\mathbf{m}_j` is the jth column of the cost matrix + - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17] + - a and b are source and target weights (sum to 1) + + The OT matrix can is reconstructed using [17]_ Proposition 2. + The optimization algorithm is using gradient decent (L-BFGS by default). + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + reg_type : str + Regularization type, can be the following (default ='l2'): + - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_) + - 'l2' : Squared Euclidean regularization + method : str + Solver to use for scipy.optimize.minimize + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.sinhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + + """ + if reg_type.lower() in ['l2', 'squaredl2']: + regul = SquaredL2(gamma=reg) + elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: + regul = NegEntropy(gamma=reg) + else: + raise NotImplementedError('Unknown regularization') + + # solve dual + alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax, + tol=stopThr, verbose=verbose) + + # reconstruct transport matrix + G = get_plan_from_semi_dual(alpha, b, M, regul) + + if log: + log = {'alpha': alpha, 'res': res} + return G, log + else: + return G diff --git a/ot/stochastic.py b/ot/stochastic.py new file mode 100644 index 0000000..13ed9cc --- /dev/null +++ b/ot/stochastic.py @@ -0,0 +1,755 @@ +""" +Stochastic solvers for regularized OT. + + +""" + +# Author: Kilian Fatras <kilian.fatras@gmail.com> +# +# License: MIT License + +import numpy as np + + +############################################################################## +# Optimization toolbox for SEMI - DUAL problems +############################################################################## + + +def coordinate_grad_semi_dual(b, M, reg, beta, i): + r''' + Compute the coordinate gradient update for regularized discrete distributions for (i, :) + + The function computes the gradient of the semi dual problem: + + .. math:: + \max_v \sum_i (\sum_j v_j * b_j - reg * log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i + + Where : + + - M is the (ns,nt) metric cost matrix + - v is a dual variable in R^J + - reg is the regularization term + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the ASGD & SAG algorithms + as proposed in [18]_ [alg.1 & alg.2] + + + Parameters + ---------- + b : ndarray, shape (nt,) + Target measure. + M : ndarray, shape (ns, nt) + Cost matrix. + reg : float + Regularization term > 0. + v : ndarray, shape (nt,) + Dual variable. + i : int + Picked number i. + + Returns + ------- + coordinate gradient : ndarray, shape (nt,) + + Examples + -------- + >>> import ot + >>> np.random.seed(0) + >>> n_source = 7 + >>> n_target = 4 + >>> a = ot.utils.unif(n_source) + >>> b = ot.utils.unif(n_target) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) + >>> M = ot.dist(X_source, Y_target) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + + References + ---------- + [Genevay et al., 2016] : + Stochastic Optimization for Large-scale Optimal Transport, + Advances in Neural Information Processing Systems (2016), + arXiv preprint arxiv:1605.08527. + ''' + r = M[i, :] - beta + exp_beta = np.exp(-r / reg) * b + khi = exp_beta / (np.sum(exp_beta)) + return b - khi + + +def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): + r''' + Compute the SAG algorithm to solve the regularized discrete measures + optimal transport max problem + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1 = b + + \gamma \geq 0 + + Where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the SAG algorithm + as proposed in [18]_ [alg.1] + + + Parameters + ---------- + + a : ndarray, shape (ns,), + Source measure. + b : ndarray, shape (nt,), + Target measure. + M : ndarray, shape (ns, nt), + Cost matrix. + reg : float + Regularization term > 0 + numItermax : int + Number of iteration. + lr : float + Learning rate. + + Returns + ------- + v : ndarray, shape (nt,) + Dual variable. + + Examples + -------- + >>> import ot + >>> np.random.seed(0) + >>> n_source = 7 + >>> n_target = 4 + >>> a = ot.utils.unif(n_source) + >>> b = ot.utils.unif(n_target) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) + >>> M = ot.dist(X_source, Y_target) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + References + ---------- + + [Genevay et al., 2016] : + Stochastic Optimization for Large-scale Optimal Transport, + Advances in Neural Information Processing Systems (2016), + arXiv preprint arxiv:1605.08527. + ''' + + if lr is None: + lr = 1. / max(a / reg) + n_source = np.shape(M)[0] + n_target = np.shape(M)[1] + cur_beta = np.zeros(n_target) + stored_gradient = np.zeros((n_source, n_target)) + sum_stored_gradient = np.zeros(n_target) + for _ in range(numItermax): + i = np.random.randint(n_source) + cur_coord_grad = a[i] * coordinate_grad_semi_dual(b, M, reg, + cur_beta, i) + sum_stored_gradient += (cur_coord_grad - stored_gradient[i]) + stored_gradient[i] = cur_coord_grad + cur_beta += lr * (1. / n_source) * sum_stored_gradient + return cur_beta + + +def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): + r''' + Compute the ASGD algorithm to solve the regularized semi continous measures optimal transport max problem + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma \geq 0 + + Where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the ASGD algorithm + as proposed in [18]_ [alg.2] + + + Parameters + ---------- + b : ndarray, shape (nt,) + target measure + M : ndarray, shape (ns, nt) + cost matrix + reg : float + Regularization term > 0 + numItermax : int + Number of iteration. + lr : float + Learning rate. + + Returns + ------- + ave_v : ndarray, shape (nt,) + dual variable + + Examples + -------- + >>> import ot + >>> np.random.seed(0) + >>> n_source = 7 + >>> n_target = 4 + >>> a = ot.utils.unif(n_source) + >>> b = ot.utils.unif(n_target) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) + >>> M = ot.dist(X_source, Y_target) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + References + ---------- + + [Genevay et al., 2016] : + Stochastic Optimization for Large-scale Optimal Transport, + Advances in Neural Information Processing Systems (2016), + arXiv preprint arxiv:1605.08527. + ''' + + if lr is None: + lr = 1. / max(a / reg) + n_source = np.shape(M)[0] + n_target = np.shape(M)[1] + cur_beta = np.zeros(n_target) + ave_beta = np.zeros(n_target) + for cur_iter in range(numItermax): + k = cur_iter + 1 + i = np.random.randint(n_source) + cur_coord_grad = coordinate_grad_semi_dual(b, M, reg, cur_beta, i) + cur_beta += (lr / np.sqrt(k)) * cur_coord_grad + ave_beta = (1. / k) * cur_beta + (1 - 1. / k) * ave_beta + return ave_beta + + +def c_transform_entropic(b, M, reg, beta): + r''' + The goal is to recover u from the c-transform. + + The function computes the c_transform of a dual variable from the other + dual variable: + + .. math:: + u = v^{c,reg} = -reg \sum_j exp((v - M)/reg) b_j + + Where : + + - M is the (ns,nt) metric cost matrix + - u, v are dual variables in R^IxR^J + - reg is the regularization term + + It is used to recover an optimal u from optimal v solving the semi dual + problem, see Proposition 2.1 of [18]_ + + + Parameters + ---------- + b : ndarray, shape (nt,) + Target measure + M : ndarray, shape (ns, nt) + Cost matrix + reg : float + Regularization term > 0 + v : ndarray, shape (nt,) + Dual variable. + + Returns + ------- + u : ndarray, shape (ns,) + Dual variable. + + Examples + -------- + >>> import ot + >>> np.random.seed(0) + >>> n_source = 7 + >>> n_target = 4 + >>> a = ot.utils.unif(n_source) + >>> b = ot.utils.unif(n_target) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) + >>> M = ot.dist(X_source, Y_target) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + References + ---------- + + [Genevay et al., 2016] : + Stochastic Optimization for Large-scale Optimal Transport, + Advances in Neural Information Processing Systems (2016), + arXiv preprint arxiv:1605.08527. + ''' + + n_source = np.shape(M)[0] + alpha = np.zeros(n_source) + for i in range(n_source): + r = M[i, :] - beta + min_r = np.min(r) + exp_beta = np.exp(-(r - min_r) / reg) * b + alpha[i] = min_r - reg * np.log(np.sum(exp_beta)) + return alpha + + +def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, + log=False): + r''' + Compute the transportation matrix to solve the regularized discrete + measures optimal transport max problem + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma \geq 0 + + Where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + The algorithm used for solving the problem is the SAG or ASGD algorithms + as proposed in [18]_ + + + Parameters + ---------- + + a : ndarray, shape (ns,) + source measure + b : ndarray, shape (nt,) + target measure + M : ndarray, shape (ns, nt) + cost matrix + reg : float + Regularization term > 0 + methode : str + used method (SAG or ASGD) + numItermax : int + number of iteration + lr : float + learning rate + n_source : int + size of the source measure + n_target : int + size of the target measure + log : bool, optional + record log if True + + Returns + ------- + pi : ndarray, shape (ns, nt) + transportation matrix + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + >>> import ot + >>> np.random.seed(0) + >>> n_source = 7 + >>> n_target = 4 + >>> a = ot.utils.unif(n_source) + >>> b = ot.utils.unif(n_target) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) + >>> M = ot.dist(X_source, Y_target) + >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000) + array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06], + [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03], + [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07], + [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04], + [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01], + [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], + [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + References + ---------- + + [Genevay et al., 2016] : + Stochastic Optimization for Large-scale Optimal Transport, + Advances in Neural Information Processing Systems (2016), + arXiv preprint arxiv:1605.08527. + ''' + + if method.lower() == "sag": + opt_beta = sag_entropic_transport(a, b, M, reg, numItermax, lr) + elif method.lower() == "asgd": + opt_beta = averaged_sgd_entropic_transport(a, b, M, reg, numItermax, lr) + else: + print("Please, select your method between SAG and ASGD") + return None + + opt_alpha = c_transform_entropic(b, M, reg, opt_beta) + pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) * + a[:, None] * b[None, :]) + + if log: + log = {} + log['alpha'] = opt_alpha + log['beta'] = opt_beta + return pi, log + else: + return pi + + +############################################################################## +# Optimization toolbox for DUAL problems +############################################################################## + + +def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, + batch_beta): + r''' + Computes the partial gradient of the dual optimal transport problem. + + For each (i,j) in a batch of coordinates, the partial gradients are : + + .. math:: + \partial_{u_i} F = u_i * b_s/l_{v} - \sum_{j \in B_v} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j + + \partial_{v_j} F = v_j * b_s/l_{u} - \sum_{i \in B_u} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j + + Where : + + - M is the (ns,nt) metric cost matrix + - u, v are dual variables in R^ixR^J + - reg is the regularization term + - :math:`B_u` and :math:`B_v` are lists of index + - :math:`b_s` is the size of the batchs :math:`B_u` and :math:`B_v` + - :math:`l_u` and :math:`l_v` are the lenghts of :math:`B_u` and :math:`B_v` + - a and b are source and target weights (sum to 1) + + + The algorithm used for solving the dual problem is the SGD algorithm + as proposed in [19]_ [alg.1] + + + Parameters + ---------- + a : ndarray, shape (ns,) + source measure + b : ndarray, shape (nt,) + target measure + M : ndarray, shape (ns, nt) + cost matrix + reg : float + Regularization term > 0 + alpha : ndarray, shape (ns,) + dual variable + beta : ndarray, shape (nt,) + dual variable + batch_size : int + size of the batch + batch_alpha : ndarray, shape (bs,) + batch of index of alpha + batch_beta : ndarray, shape (bs,) + batch of index of beta + + Returns + ------- + grad : ndarray, shape (ns,) + partial grad F + + Examples + -------- + >>> import ot + >>> np.random.seed(0) + >>> n_source = 7 + >>> n_target = 4 + >>> a = ot.utils.unif(n_source) + >>> b = ot.utils.unif(n_target) + >>> X_source = np.random.randn(n_source, 2) + >>> Y_target = np.random.randn(n_target, 2) + >>> M = ot.dist(X_source, Y_target) + >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg=1, batch_size=3, numItermax=30000, lr=0.1, log=True) + >>> log['alpha'] + array([0.71759102, 1.57057384, 0.85576566, 0.1208211 , 0.59190466, + 1.197148 , 0.17805133]) + >>> log['beta'] + array([0.49741367, 0.57478564, 1.40075528, 2.75890102]) + >>> sgd_dual_pi + array([[2.09730063e-02, 8.38169324e-02, 7.50365455e-03, 8.72731415e-09], + [5.58432437e-03, 5.89881299e-04, 3.09558411e-05, 8.35469849e-07], + [3.26489515e-03, 7.15536035e-02, 2.99778211e-02, 3.02601593e-10], + [4.05390622e-02, 5.31085068e-02, 6.65191787e-02, 1.55812785e-06], + [7.82299812e-02, 6.12099102e-03, 4.44989098e-02, 2.37719187e-03], + [5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04], + [6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]]) + + References + ---------- + + [Seguy et al., 2018] : + International Conference on Learning Representation (2018), + arXiv preprint arxiv:1711.02283. + ''' + G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] - + M[batch_alpha, :][:, batch_beta]) / reg) * + a[batch_alpha, None] * b[None, batch_beta]) + grad_beta = np.zeros(np.shape(M)[1]) + grad_alpha = np.zeros(np.shape(M)[0]) + grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0] + + G.sum(0)) + grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta) + / np.shape(M)[1] + G.sum(1)) + + return grad_alpha, grad_beta + + +def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): + r''' + Compute the sgd algorithm to solve the regularized discrete measures + optimal transport dual problem + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma \geq 0 + + Where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + Parameters + ---------- + a : ndarray, shape (ns,) + source measure + b : ndarray, shape (nt,) + target measure + M : ndarray, shape (ns, nt) + cost matrix + reg : float + Regularization term > 0 + batch_size : int + size of the batch + numItermax : int + number of iteration + lr : float + learning rate + + Returns + ------- + alpha : ndarray, shape (ns,) + dual variable + beta : ndarray, shape (nt,) + dual variable + + Examples + -------- + >>> import ot + >>> n_source = 7 + >>> n_target = 4 + >>> reg = 1 + >>> numItermax = 20000 + >>> lr = 0.1 + >>> batch_size = 3 + >>> log = True + >>> a = ot.utils.unif(n_source) + >>> b = ot.utils.unif(n_target) + >>> rng = np.random.RandomState(0) + >>> X_source = rng.randn(n_source, 2) + >>> Y_target = rng.randn(n_target, 2) + >>> M = ot.dist(X_source, Y_target) + >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log) + >>> log['alpha'] + array([0.64171798, 1.27932201, 0.78132257, 0.15638935, 0.54888354, + 1.03663469, 0.20595781]) + >>> log['beta'] + array([0.51207194, 0.58033189, 1.28922676, 2.26859736]) + >>> sgd_dual_pi + array([[1.97276541e-02, 7.81248547e-02, 6.22136048e-03, 4.95442423e-09], + [4.23494310e-03, 4.43286263e-04, 2.06927079e-05, 3.82389139e-07], + [3.07542414e-03, 6.67897769e-02, 2.48904999e-02, 1.72030247e-10], + [4.26271990e-02, 5.53375455e-02, 6.16535024e-02, 9.88812650e-07], + [7.60423265e-02, 5.89585256e-03, 3.81267087e-02, 1.39458256e-03], + [4.37557504e-02, 1.85189176e-03, 1.72335760e-03, 3.55491279e-04], + [6.33096109e-02, 4.11683954e-02, 5.02962051e-02, 5.43097516e-06]]) + + References + ---------- + [Seguy et al., 2018] : + International Conference on Learning Representation (2018), + arXiv preprint arxiv:1711.02283. + ''' + + n_source = np.shape(M)[0] + n_target = np.shape(M)[1] + cur_alpha = np.zeros(n_source) + cur_beta = np.zeros(n_target) + for cur_iter in range(numItermax): + k = np.sqrt(cur_iter + 1) + batch_alpha = np.random.choice(n_source, batch_size, replace=False) + batch_beta = np.random.choice(n_target, batch_size, replace=False) + update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha, + cur_beta, batch_size, + batch_alpha, batch_beta) + cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha] + cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta] + + return cur_alpha, cur_beta + + +def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, + log=False): + r''' + Compute the transportation matrix to solve the regularized discrete measures + optimal transport dual problem + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma \geq 0 + + Where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + Parameters + ---------- + a : ndarray, shape (ns,) + source measure + b : ndarray, shape (nt,) + target measure + M : ndarray, shape (ns, nt) + cost matrix + reg : float + Regularization term > 0 + batch_size : int + size of the batch + numItermax : int + number of iteration + lr : float + learning rate + log : bool, optional + record log if True + + Returns + ------- + pi : ndarray, shape (ns, nt) + transportation matrix + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + >>> import ot + >>> n_source = 7 + >>> n_target = 4 + >>> reg = 1 + >>> numItermax = 20000 + >>> lr = 0.1 + >>> batch_size = 3 + >>> log = True + >>> a = ot.utils.unif(n_source) + >>> b = ot.utils.unif(n_target) + >>> rng = np.random.RandomState(0) + >>> X_source = rng.randn(n_source, 2) + >>> Y_target = rng.randn(n_target, 2) + >>> M = ot.dist(X_source, Y_target) + >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log) + >>> log['alpha'] + array([0.64057733, 1.2683513 , 0.75610161, 0.16024284, 0.54926534, + 1.0514201 , 0.19958936]) + >>> log['beta'] + array([0.51372571, 0.58843489, 1.27993921, 2.24344807]) + >>> sgd_dual_pi + array([[1.97377795e-02, 7.86706853e-02, 6.15682001e-03, 4.82586997e-09], + [4.19566963e-03, 4.42016865e-04, 2.02777272e-05, 3.68823708e-07], + [3.00379244e-03, 6.56562018e-02, 2.40462171e-02, 1.63579656e-10], + [4.28626062e-02, 5.60031599e-02, 6.13193826e-02, 9.67977735e-07], + [7.61972739e-02, 5.94609051e-03, 3.77886693e-02, 1.36046648e-03], + [4.44810042e-02, 1.89476742e-03, 1.73285847e-03, 3.51826036e-04], + [6.30118293e-02, 4.12398660e-02, 4.95148998e-02, 5.26247246e-06]]) + + References + ---------- + + [Seguy et al., 2018] : + International Conference on Learning Representation (2018), + arXiv preprint arxiv:1711.02283. + ''' + + opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size, + numItermax, lr) + pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) * + a[:, None] * b[None, :]) + if log: + log = {} + log['alpha'] = opt_alpha + log['beta'] = opt_beta + return pi, log + else: + return pi diff --git a/ot/unbalanced.py b/ot/unbalanced.py new file mode 100644 index 0000000..d516dfc --- /dev/null +++ b/ot/unbalanced.py @@ -0,0 +1,1022 @@ +# -*- coding: utf-8 -*- +""" +Regularized Unbalanced OT +""" + +# Author: Hicham Janati <hicham.janati@inria.fr> +# License: MIT License + +from __future__ import division +import warnings +import numpy as np +from scipy.special import logsumexp + +# from .utils import unif, dist + + +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): + r""" + Solve the unbalanced entropic regularization optimal transport problem + and return the OT plan + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization + term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + + + Parameters + ---------- + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) + loss matrix + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_reg_scaling', see those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems + (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 + + + See Also + -------- + ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized_unbalanced: + Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_reg_scaling_unbalanced: + Unbalanced Sinkhorn with epslilon scaling [9][10] + + """ + + if method.lower() == 'sinkhorn': + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', + numItermax=1000, stopThr=1e-6, verbose=False, + log=False, **kwargs): + r""" + Solve the entropic regularization unbalanced optimal transport problem and + return the loss + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + + + Parameters + ---------- + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) + loss matrix + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_reg_scaling', see those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + + >>> import ot + >>> a=[.5, .10] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) + array([0.31912866]) + + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems + (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 + + See Also + -------- + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + + """ + b = np.asarray(b, dtype=np.float64) + if len(b.shape) < 2: + b = b[:, None] + if method.lower() == 'sinkhorn': + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) + + +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): + r""" + Solve the entropic regularization unbalanced optimal transport problem and return the loss + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + + + Parameters + ---------- + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) + loss matrix + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) + + References + ---------- + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = np.ones(dim_a, dtype=np.float64) / dim_a + if len(b) == 0: + b = np.ones(dim_b, dtype=np.float64) / dim_b + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = np.ones((dim_a, 1)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + a = a.reshape(dim_a, 1) + else: + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b + + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = reg_m / (reg_m + reg) + + cpt = 0 + err = 1. + + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + u = (a / Kv) ** fi + Ktu = K.T.dot(u) + v = (b / Ktu) ** fi + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) + if log: + log['err'].append(err) + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt += 1 + + if log: + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) + + if n_hists: # return only loss + res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + if log: + return res, log + else: + return res + + else: # return OT matrix + + if log: + return u[:, None] * K * v[None, :], log + else: + return u[:, None] * K * v[None, :] + + +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, + **kwargs): + r""" + Solve the entropic regularization unbalanced optimal transport + problem and return the loss + + The function solves the following optimization problem using log-domain + stabilization as proposed in [10]: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization + term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + + + Parameters + ---------- + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) + loss matrix + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + tau : float + thershold for max value in u or v for log scaling + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) + + References + ---------- + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = np.ones(dim_a, dtype=np.float64) / dim_a + if len(b) == 0: + b = np.ones(dim_b, dtype=np.float64) / dim_b + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = np.ones((dim_a, n_hists)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + a = a.reshape(dim_a, 1) + else: + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = reg_m / (reg_m + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(dim_a) + beta = np.zeros(dim_b) + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + reg_m)) + f_beta = np.exp(- beta / (reg + reg_m)) + + if n_hists: + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((a / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + v = ((b / (Ktu + 1e-16)) ** fi) * f_beta + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + if n_hists: + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + else: + alpha = alpha + reg * np.log(np.max(u)) + beta = beta + reg * np.log(np.max(v)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + u = uprev + v = vprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), + 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`.") + if n_hists: + logu = alpha[:, None] / reg + np.log(u) + logv = beta[:, None] / reg + np.log(v) + else: + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + if log: + log['logu'] = logu + log['logv'] = logv + if n_hists: # return only loss + res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + + logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) + res = np.exp(res) + if log: + return res, log + else: + return res + + else: # return OT matrix + ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + if log: + return ot_matrix, log + else: + return ot_matrix + + +def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): + r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of + matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. + reg : float + Entropy regularization term > 0 + reg_m : float + Marginal relaxation term > 0 + tau : float + Stabilization threshold for log domain absorption. + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, + G. (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + + + """ + dim, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + fi = reg_m / (reg_m + reg) + + u = np.ones((dim, n_hists)) / dim + v = np.ones((dim, n_hists)) / dim + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = reg_m / (reg_m + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(dim) + beta = np.zeros(dim) + q = np.ones(dim) / dim + while (err > stopThr and cpt < numItermax): + qprev = q + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + reg_m)) + f_beta = np.exp(- beta / (reg + reg_m)) + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((A / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + q = (Ktu ** (1 - fi)) * f_beta + q = q.dot(weights) ** (1 / (1 - fi)) + Q = q[:, None] + v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + q = qprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(q - qprev).max() / max(abs(q).max(), + abs(qprev).max(), 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`.") + if log: + log['niter'] = cpt + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) + return q, log + else: + return q + + +def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): + r"""Compute the entropic unbalanced wasserstein barycenter of A. + + The function solves the following optimization problem with a + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprin + arXiv:1607.05816. + + + """ + dim, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = np.exp(- M / reg) + + fi = reg_m / (reg_m + reg) + + v = np.ones((dim, n_hists)) / dim + u = np.ones((dim, 1)) / dim + + cpt = 0 + err = 1. + + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + u = (A / Kv) ** fi + Ktu = K.T.dot(u) + q = ((Ktu ** (1 - fi)).dot(weights)) + q = q ** (1 / (1 - fi)) + Q = q[:, None] + v = (Q / Ktu) ** fi + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err_u = abs(u - uprev).max() + err_u /= max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() + err_v /= max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if log: + log['niter'] = cpt + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) + return q, log + else: + return q + + +def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): + r"""Compute the entropic unbalanced wasserstein barycenter of A. + + The function solves the following optimization problem with a + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprin + arXiv:1607.05816. + + """ + + if method.lower() == 'sinkhorn': + return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return barycenter_unbalanced_stabilized(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') + return barycenter_unbalanced(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) diff --git a/ot/utils.py b/ot/utils.py new file mode 100644 index 0000000..b71458b --- /dev/null +++ b/ot/utils.py @@ -0,0 +1,498 @@ +# -*- coding: utf-8 -*- +""" +Various useful functions +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import multiprocessing +from functools import reduce +import time + +import numpy as np +from scipy.spatial.distance import cdist +import sys +import warnings +try: + from inspect import signature +except ImportError: + from .externals.funcsigs import signature + +__time_tic_toc = time.time() + + +def tic(): + """ Python implementation of Matlab tic() function """ + global __time_tic_toc + __time_tic_toc = time.time() + + +def toc(message='Elapsed time : {} s'): + """ Python implementation of Matlab toc() function """ + t = time.time() + print(message.format(t - __time_tic_toc)) + return t - __time_tic_toc + + +def toq(): + """ Python implementation of Julia toc() function """ + t = time.time() + return t - __time_tic_toc + + +def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): + """Compute kernel matrix""" + if method.lower() in ['gaussian', 'gauss', 'rbf']: + K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + return K + + +def unif(n): + """ return a uniform histogram of length n (simplex) + + Parameters + ---------- + + n : int + number of bins in the histogram + + Returns + ------- + h : np.array (n,) + histogram of length n such that h_i=1/n for all i + + + """ + return np.ones((n,)) / n + + +def clean_zeros(a, b, M): + """ Remove all components with zeros weights in a and b + """ + M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) + a2 = a[a > 0] + b2 = b[b > 0] + return a2, b2, M2 + + +def euclidean_distances(X, Y, squared=False): + """ + Considering the rows of X (and Y=X) as vectors, compute the + distance matrix between each pair of vectors. + Parameters + ---------- + X : {array-like}, shape (n_samples_1, n_features) + Y : {array-like}, shape (n_samples_2, n_features) + squared : boolean, optional + Return squared Euclidean distances. + Returns + ------- + distances : {array}, shape (n_samples_1, n_samples_2) + """ + XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] + YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] + distances = np.dot(X, Y.T) + distances *= -2 + distances += XX + distances += YY + np.maximum(distances, 0, out=distances) + if X is Y: + # Ensure that distances between vectors and themselves are set to 0.0. + # This may not be the case due to floating point rounding errors. + distances.flat[::distances.shape[0] + 1] = 0.0 + return distances if squared else np.sqrt(distances, out=distances) + + +def dist(x1, x2=None, metric='sqeuclidean'): + """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + + Parameters + ---------- + + x1 : ndarray, shape (n1,d) + matrix with n1 samples of size d + x2 : array, shape (n2,d), optional + matrix with n2 samples of size d (if None then x2=x1) + metric : str | callable, optional + Name of the metric to be computed (full list in the doc of scipy), If a string, + the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', + 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', + 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. + + + Returns + ------- + + M : np.array (n1,n2) + distance matrix computed with given metric + + """ + if x2 is None: + x2 = x1 + if metric == "sqeuclidean": + return euclidean_distances(x1, x2, squared=True) + return cdist(x1, x2, metric=metric) + + +def dist0(n, method='lin_square'): + """Compute standard cost matrices of size (n, n) for OT problems + + Parameters + ---------- + n : int + Size of the cost matrix. + method : str, optional + Type of loss matrix chosen from: + + * 'lin_square' : linear sampling between 0 and n-1, quadratic loss + + Returns + ------- + M : ndarray, shape (n1,n2) + Distance matrix computed with given metric. + """ + res = 0 + if method == 'lin_square': + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + res = dist(x, x) + return res + + +def cost_normalization(C, norm=None): + """ Apply normalization to the loss matrix + + Parameters + ---------- + C : ndarray, shape (n1, n2) + The cost matrix to normalize. + norm : str + Type of normalization from 'median', 'max', 'log', 'loglog'. Any + other value do not normalize. + + Returns + ------- + C : ndarray, shape (n1, n2) + The input cost matrix normalized according to given norm. + """ + + if norm is None: + pass + elif norm == "median": + C /= float(np.median(C)) + elif norm == "max": + C /= float(np.max(C)) + elif norm == "log": + C = np.log(1 + C) + elif norm == "loglog": + C = np.log1p(np.log1p(C)) + else: + raise ValueError('Norm %s is not a valid option.\n' + 'Valid options are:\n' + 'median, max, log, loglog' % norm) + return C + + +def dots(*args): + """ dots function for multiple matrix multiply """ + return reduce(np.dot, args) + + +def fun(f, q_in, q_out): + """ Utility function for parmap with no serializing problems """ + while True: + i, x = q_in.get() + if i is None: + break + q_out.put((i, f(x))) + + +def parmap(f, X, nprocs=multiprocessing.cpu_count()): + """ paralell map for multiprocessing (only map on windows)""" + + if not sys.platform.endswith('win32'): + + q_in = multiprocessing.Queue(1) + q_out = multiprocessing.Queue() + + proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) + for _ in range(nprocs)] + for p in proc: + p.daemon = True + p.start() + + sent = [q_in.put((i, x)) for i, x in enumerate(X)] + [q_in.put((None, None)) for _ in range(nprocs)] + res = [q_out.get() for _ in range(len(sent))] + + [p.join() for p in proc] + + return [x for i, x in sorted(res)] + else: + return list(map(f, X)) + + +def check_params(**kwargs): + """check_params: check whether some parameters are missing + """ + + missing_params = [] + check = True + + for param in kwargs: + if kwargs[param] is None: + missing_params.append(param) + + if len(missing_params) > 0: + print("POT - Warning: following necessary parameters are missing") + for p in missing_params: + print("\n", p) + + check = False + + return check + + +def check_random_state(seed): + """Turn seed into a np.random.RandomState instance + + Parameters + ---------- + seed : None | int | instance of RandomState + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + """ + if seed is None or seed is np.random: + return np.random.mtrand._rand + if isinstance(seed, (int, np.integer)): + return np.random.RandomState(seed) + if isinstance(seed, np.random.RandomState): + return seed + raise ValueError('{} cannot be used to seed a numpy.random.RandomState' + ' instance'.format(seed)) + + +class deprecated(object): + """Decorator to mark a function or class as deprecated. + + deprecated class from scikit-learn package + https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py + Issue a warning when the function is called/the class is instantiated and + adds a warning to the docstring. + The optional extra argument will be appended to the deprecation message + and the docstring. Note: to use this with the default value for extra, put + in an empty of parentheses: + >>> from ot.deprecation import deprecated # doctest: +SKIP + >>> @deprecated() # doctest: +SKIP + ... def some_function(): pass # doctest: +SKIP + + Parameters + ---------- + extra : str + To be added to the deprecation messages. + """ + + # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary, + # but with many changes. + + def __init__(self, extra=''): + self.extra = extra + + def __call__(self, obj): + """Call method + Parameters + ---------- + obj : object + """ + if isinstance(obj, type): + return self._decorate_class(obj) + else: + return self._decorate_fun(obj) + + def _decorate_class(self, cls): + msg = "Class %s is deprecated" % cls.__name__ + if self.extra: + msg += "; %s" % self.extra + + # FIXME: we should probably reset __new__ for full generality + init = cls.__init__ + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return init(*args, **kwargs) + + cls.__init__ = wrapped + + wrapped.__name__ = '__init__' + wrapped.__doc__ = self._update_doc(init.__doc__) + wrapped.deprecated_original = init + + return cls + + def _decorate_fun(self, fun): + """Decorate function fun""" + + msg = "Function %s is deprecated" % fun.__name__ + if self.extra: + msg += "; %s" % self.extra + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return fun(*args, **kwargs) + + wrapped.__name__ = fun.__name__ + wrapped.__dict__ = fun.__dict__ + wrapped.__doc__ = self._update_doc(fun.__doc__) + + return wrapped + + def _update_doc(self, olddoc): + newdoc = "DEPRECATED" + if self.extra: + newdoc = "%s: %s" % (newdoc, self.extra) + if olddoc: + newdoc = "%s\n\n%s" % (newdoc, olddoc) + return newdoc + + +def _is_deprecated(func): + """Helper to check if func is wraped by our deprecated decorator""" + if sys.version_info < (3, 5): + raise NotImplementedError("This is only available for python3.5 " + "or above") + closures = getattr(func, '__closure__', []) + if closures is None: + closures = [] + is_deprecated = ('deprecated' in ''.join([c.cell_contents + for c in closures + if isinstance(c.cell_contents, str)])) + return is_deprecated + + +class BaseEstimator(object): + """Base class for most objects in POT + + Code adapted from sklearn BaseEstimator class + + Notes + ----- + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). + """ + + @classmethod + def _get_param_names(cls): + """Get parameter names for the estimator""" + + # fetch the constructor or the original constructor before + # deprecation wrapping if any + init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + if init is object.__init__: + # No explicit constructor to introspect + return [] + + # introspect the constructor arguments to find the model parameters + # to represent + init_signature = signature(init) + # Consider the constructor parameters excluding 'self' + parameters = [p for p in init_signature.parameters.values() + if p.name != 'self' and p.kind != p.VAR_KEYWORD] + for p in parameters: + if p.kind == p.VAR_POSITIONAL: + raise RuntimeError("POT estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + " %s with constructor %s doesn't " + " follow this convention." + % (cls, init_signature)) + # Extract and sort argument names excluding 'self' + return sorted([p.name for p in parameters]) + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Parameters + ---------- + deep : bool, optional + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : mapping of string to any + Parameter names mapped to their values. + """ + out = dict() + for key in self._get_param_names(): + # We need deprecation warnings to always be on in order to + # catch deprecated param values. + # This is set in utils/__init__.py but it gets overwritten + # when running under python3 somehow. + warnings.simplefilter("always", DeprecationWarning) + try: + with warnings.catch_warnings(record=True) as w: + value = getattr(self, key, None) + if len(w) and w[0].category == DeprecationWarning: + # if the parameter is deprecated, don't show it + continue + finally: + warnings.filters.pop(0) + + # XXX: should we rather test if instance of estimator? + if deep and hasattr(value, 'get_params'): + deep_items = value.get_params().items() + out.update((key + '__' + k, val) for k, val in deep_items) + out[key] = value + return out + + def set_params(self, **params): + """Set the parameters of this estimator. + + The method works on simple estimators as well as on nested objects + (such as pipelines). The latter have parameters of the form + ``<component>__<parameter>`` so that it's possible to update each + component of a nested object. + + Returns + ------- + self + """ + if not params: + # Simple optimisation to gain speed (inspect is slow) + return self + valid_params = self.get_params(deep=True) + # for key, value in iteritems(params): + for key, value in params.items(): + split = key.split('__', 1) + if len(split) > 1: + # nested objects case + name, sub_name = split + if name not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (name, self)) + sub_object = valid_params[name] + sub_object.set_params(**{sub_name: value}) + else: + # simple objects case + if key not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (key, self.__class__.__name__)) + setattr(self, key, value) + return self + + +class UndefinedParameter(Exception): + """ + Aim at raising an Exception when a undefined parameter is called + + """ + pass |