summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-06-11 13:13:33 +0200
committerGitHub <noreply@github.com>2018-06-11 13:13:33 +0200
commit327b0c6e0ccb0c9453179eb316021c34bcdffec4 (patch)
tree7d6f87bbbe776b7194a8fd1469094d1ce9506f97
parent530dc93a60e9b81fb8d1b44680deea77dacf660b (diff)
parentecb093b95ddbeaeb800b443d3a5c9d5c24c5897c (diff)
Merge pull request #50 from rflamary/smooth_ot
Smooth and Sparse OT
-rw-r--r--README.md14
-rw-r--r--docs/source/all.rst5
-rw-r--r--docs/source/readme.rst46
-rw-r--r--examples/plot_OT_1D_smooth.py110
-rw-r--r--ot/__init__.py2
-rw-r--r--ot/lp/__init__.py2
-rw-r--r--ot/smooth.py600
-rw-r--r--test/test_smooth.py79
8 files changed, 837 insertions, 21 deletions
diff --git a/README.md b/README.md
index c5096a8..a22306d 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,8 @@ It provides the following solvers:
* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
-* Non regularized Wasserstein barycenters [16] with LP solver.
+* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
+* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
@@ -29,10 +30,11 @@ Some demonstrations (both in Python and Jupyter Notebook format) are available i
If you use this toolbox in your research and find it useful, please cite POT using the following bibtex reference:
```
-@article{flamary2017pot,
- title={POT Python Optimal Transport library},
- author={Flamary, R{\'e}mi and Courty, Nicolas},
- year={2017}
+@misc{flamary2017pot,
+title={POT Python Optimal Transport library},
+author={Flamary, R{'e}mi and Courty, Nicolas},
+url={https://github.com/rflamary/POT},
+year={2017}
}
```
@@ -215,3 +217,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) .
[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924.
+
+[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
diff --git a/docs/source/all.rst b/docs/source/all.rst
index c84d968..bbb9833 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -19,6 +19,11 @@ ot.bregman
.. automodule:: ot.bregman
:members:
+
+ot.smooth
+-----
+.. automodule:: ot.smooth
+ :members:
ot.gromov
----------
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index 725c207..5d37f64 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -14,7 +14,11 @@ It provides the following solvers:
[1].
- Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2]
and stabilized version [9][10] with optional GPU implementation
- (required cudamat).
+ (requires cudamat).
+- Smooth optimal transport solvers (dual and semi-dual) for KL and
+ squared L2 regularizations [17].
+- Non regularized Wasserstein barycenters [16] with LP solver (only
+ small scale).
- Bregman projections for Wasserstein barycenter [3] and unmixing [4].
- Optimal transport for domain adaptation with group lasso
regularization [5]
@@ -29,6 +33,21 @@ It provides the following solvers:
Some demonstrations (both in Python and Jupyter Notebook format) are
available in the examples folder.
+Using and citing the toolbox
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+If you use this toolbox in your research and find it useful, please cite
+POT using the following bibtex reference:
+
+::
+
+ @misc{flamary2017pot,
+ title={POT Python Optimal Transport library},
+ author={Flamary, R{'e}mi and Courty, Nicolas},
+ url={https://github.com/rflamary/POT},
+ year={2017}
+ }
+
Installation
------------
@@ -37,7 +56,7 @@ C++ compiler for using the EMD solver and relies on the following Python
modules:
- Numpy (>=1.11)
-- Scipy (>=0.17)
+- Scipy (>=1.0)
- Cython (>=0.23)
- Matplotlib (>=1.5)
@@ -212,20 +231,6 @@ languages):
- `Marco Cuturi <http://marcocuturi.net/>`__ (Sinkhorn Knopp in
Matlab/Cuda)
-Using and citing the toolbox
-----------------------------
-
-If you use this toolbox in your research and find it useful, please cite
-POT using the following bibtex reference:
-
-::
-
- @article{flamary2017pot,
- title={POT Python Optimal Transport library},
- author={Flamary, R{\'e}mi and Courty, Nicolas},
- year={2017}
- }
-
Contributions and code of conduct
---------------------------------
@@ -320,6 +325,15 @@ Journal of Optimization Theory and Applications Vol 43.
[15] Peyré, G., & Cuturi, M. (2018). `Computational Optimal
Transport <https://arxiv.org/pdf/1803.00567.pdf>`__ .
+[16] Agueh, M., & Carlier, G. (2011). `Barycenters in the Wasserstein
+space <https://hal.archives-ouvertes.fr/hal-00637399/document>`__. SIAM
+Journal on Mathematical Analysis, 43(2), 904-924.
+
+[17] Blondel, M., Seguy, V., & Rolet, A. (2018). `Smooth and Sparse
+Optimal Transport <https://arxiv.org/abs/1710.06276>`__. Proceedings of
+the Twenty-First International Conference on Artificial Intelligence and
+Statistics (AISTATS).
+
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
new file mode 100644
index 0000000..b690751
--- /dev/null
+++ b/examples/plot_OT_1D_smooth.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+"""
+===========================
+1D smooth optimal transport
+===========================
+
+This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
+and their visualization.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+
+##############################################################################
+# Generate data
+# -------------
+
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+#%% plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve EMD
+# ---------
+
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+##############################################################################
+# Solve Sinkhorn
+# --------------
+
+
+#%% Sinkhorn
+
+lambd = 2e-3
+Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
+
+pl.show()
+
+##############################################################################
+# Solve Smooth OT
+# --------------
+
+
+#%% Smooth OT with KL regularization
+
+lambd = 2e-3
+Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')
+
+pl.figure(5, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')
+
+pl.show()
+
+
+#%% Smooth OT with KL regularization
+
+lambd = 1e-1
+Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
+
+pl.figure(6, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')
+
+pl.show()
diff --git a/ot/__init__.py b/ot/__init__.py
index 1500e59..995ce33 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -18,6 +18,8 @@ from . import utils
from . import datasets
from . import da
from . import gromov
+from . import smooth
+
# OT functions
from .lp import emd, emd2
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 5dda82a..4c0d170 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -18,6 +18,8 @@ from .emd_wrap import emd_c, check_result
from ..utils import parmap
from .cvx import barycenter
+__all__=['emd', 'emd2', 'barycenter', 'cvx']
+
def emd(a, b, M, numItermax=100000, log=False):
"""Solves the Earth Movers distance problem and returns the OT matrix
diff --git a/ot/smooth.py b/ot/smooth.py
new file mode 100644
index 0000000..5a8e4b5
--- /dev/null
+++ b/ot/smooth.py
@@ -0,0 +1,600 @@
+#Copyright (c) 2018, Mathieu Blondel
+#All rights reserved.
+#
+#Redistribution and use in source and binary forms, with or without
+#modification, are permitted provided that the following conditions are met:
+#
+#1. Redistributions of source code must retain the above copyright notice, this
+#list of conditions and the following disclaimer.
+#
+#2. Redistributions in binary form must reproduce the above copyright notice,
+#this list of conditions and the following disclaimer in the documentation and/or
+#other materials provided with the distribution.
+#
+#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+#ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+#WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+#IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
+#INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
+#NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
+#OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+#LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+#OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+#THE POSSIBILITY OF SUCH DAMAGE.
+
+# Author: Mathieu Blondel
+# Remi Flamary <remi.flamary@unice.fr>
+
+"""
+Implementation of
+Smooth and Sparse Optimal Transport.
+Mathieu Blondel, Vivien Seguy, Antoine Rolet.
+In Proc. of AISTATS 2018.
+https://arxiv.org/abs/1710.06276
+
+[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal
+Transport. Proceedings of the Twenty-First International Conference on
+Artificial Intelligence and Statistics (AISTATS).
+
+Original code from https://github.com/mblondel/smooth-ot/
+
+"""
+
+import numpy as np
+from scipy.optimize import minimize
+
+
+def projection_simplex(V, z=1, axis=None):
+ """ Projection of x onto the simplex, scaled by z
+
+ P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
+ z: float or array
+ If array, len(z) must be compatible with V
+ axis: None or int
+ - axis=None: project V by P(V.ravel(); z)
+ - axis=1: project each V[i] by P(V[i]; z[i])
+ - axis=0: project each V[:, j] by P(V[:, j]; z[j])
+ """
+ if axis == 1:
+ n_features = V.shape[1]
+ U = np.sort(V, axis=1)[:, ::-1]
+ z = np.ones(len(V)) * z
+ cssv = np.cumsum(U, axis=1) - z[:, np.newaxis]
+ ind = np.arange(n_features) + 1
+ cond = U - cssv / ind > 0
+ rho = np.count_nonzero(cond, axis=1)
+ theta = cssv[np.arange(len(V)), rho - 1] / rho
+ return np.maximum(V - theta[:, np.newaxis], 0)
+
+ elif axis == 0:
+ return projection_simplex(V.T, z, axis=1).T
+
+ else:
+ V = V.ravel().reshape(1, -1)
+ return projection_simplex(V, z, axis=1).ravel()
+
+
+class Regularization(object):
+ """Base class for Regularization objects
+
+ Notes
+ -----
+ This class is not intended for direct use but as aparent for true
+ regularizatiojn implementation.
+ """
+
+ def __init__(self, gamma=1.0):
+ """
+
+ Parameters
+ ----------
+ gamma: float
+ Regularization parameter.
+ We recover unregularized OT when gamma -> 0.
+
+ """
+ self.gamma = gamma
+
+ def delta_Omega(X):
+ """
+ Compute delta_Omega(X[:, j]) for each X[:, j].
+ delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).
+
+ Parameters
+ ----------
+ X: array, shape = len(a) x len(b)
+ Input array.
+
+ Returns
+ -------
+ v: array, len(b)
+ Values: v[j] = delta_Omega(X[:, j])
+ G: array, len(a) x len(b)
+ Gradients: G[:, j] = nabla delta_Omega(X[:, j])
+ """
+ raise NotImplementedError
+
+ def max_Omega(X, b):
+ """
+ Compute max_Omega_j(X[:, j]) for each X[:, j].
+ max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j].
+
+ Parameters
+ ----------
+ X: array, shape = len(a) x len(b)
+ Input array.
+
+ Returns
+ -------
+ v: array, len(b)
+ Values: v[j] = max_Omega_j(X[:, j])
+ G: array, len(a) x len(b)
+ Gradients: G[:, j] = nabla max_Omega_j(X[:, j])
+ """
+ raise NotImplementedError
+
+ def Omega(T):
+ """
+ Compute regularization term.
+
+ Parameters
+ ----------
+ T: array, shape = len(a) x len(b)
+ Input array.
+
+ Returns
+ -------
+ value: float
+ Regularization term.
+ """
+ raise NotImplementedError
+
+
+class NegEntropy(Regularization):
+ """ NegEntropy regularization """
+
+ def delta_Omega(self, X):
+ G = np.exp(X / self.gamma - 1)
+ val = self.gamma * np.sum(G, axis=0)
+ return val, G
+
+ def max_Omega(self, X, b):
+ max_X = np.max(X, axis=0) / self.gamma
+ exp_X = np.exp(X / self.gamma - max_X)
+ val = self.gamma * (np.log(np.sum(exp_X, axis=0)) + max_X)
+ val -= self.gamma * np.log(b)
+ G = exp_X / np.sum(exp_X, axis=0)
+ return val, G
+
+ def Omega(self, T):
+ return self.gamma * np.sum(T * np.log(T))
+
+
+class SquaredL2(Regularization):
+ """ Squared L2 regularization """
+
+ def delta_Omega(self, X):
+ max_X = np.maximum(X, 0)
+ val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma)
+ G = max_X / self.gamma
+ return val, G
+
+ def max_Omega(self, X, b):
+ G = projection_simplex(X / (b * self.gamma), axis=0)
+ val = np.sum(X * G, axis=0)
+ val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
+ return val, G
+
+ def Omega(self, T):
+ return 0.5 * self.gamma * np.sum(T ** 2)
+
+
+def dual_obj_grad(alpha, beta, a, b, C, regul):
+ """
+ Compute objective value and gradients of dual objective.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ beta: array, shape = len(b)
+ Current iterate of dual potentials.
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+
+ Returns
+ -------
+ obj: float
+ Objective value (higher is better).
+ grad_alpha: array, shape = len(a)
+ Gradient w.r.t. alpha.
+ grad_beta: array, shape = len(b)
+ Gradient w.r.t. beta.
+ """
+ obj = np.dot(alpha, a) + np.dot(beta, b)
+ grad_alpha = a.copy()
+ grad_beta = b.copy()
+
+ # X[:, j] = alpha + beta[j] - C[:, j]
+ X = alpha[:, np.newaxis] + beta - C
+
+ # val.shape = len(b)
+ # G.shape = len(a) x len(b)
+ val, G = regul.delta_Omega(X)
+
+ obj -= np.sum(val)
+ grad_alpha -= G.sum(axis=1)
+ grad_beta -= G.sum(axis=0)
+
+ return obj, grad_alpha, grad_beta
+
+
+def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
+ verbose=False):
+ """
+ Solve the "smoothed" dual objective.
+
+ Parameters
+ ----------
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+ method: str
+ Solver to be used (passed to `scipy.optimize.minimize`).
+ tol: float
+ Tolerance parameter.
+ max_iter: int
+ Maximum number of iterations.
+
+ Returns
+ -------
+ alpha: array, shape = len(a)
+ beta: array, shape = len(b)
+ Dual potentials.
+ """
+
+ def _func(params):
+ # Unpack alpha and beta.
+ alpha = params[:len(a)]
+ beta = params[len(a):]
+
+ obj, grad_alpha, grad_beta = dual_obj_grad(alpha, beta, a, b, C, regul)
+
+ # Pack grad_alpha and grad_beta.
+ grad = np.concatenate((grad_alpha, grad_beta))
+
+ # We need to maximize the dual.
+ return -obj, -grad
+
+ # Unfortunately, `minimize` only supports functions whose argument is a
+ # vector. So, we need to concatenate alpha and beta.
+ alpha_init = np.zeros(len(a))
+ beta_init = np.zeros(len(b))
+ params_init = np.concatenate((alpha_init, beta_init))
+
+ res = minimize(_func, params_init, method=method, jac=True,
+ tol=tol, options=dict(maxiter=max_iter, disp=verbose))
+
+ alpha = res.x[:len(a)]
+ beta = res.x[len(a):]
+
+ return alpha, beta, res
+
+
+def semi_dual_obj_grad(alpha, a, b, C, regul):
+ """
+ Compute objective value and gradient of semi-dual objective.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ Current iterate of semi-dual potentials.
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a max_Omega(X) method.
+
+ Returns
+ -------
+ obj: float
+ Objective value (higher is better).
+ grad: array, shape = len(a)
+ Gradient w.r.t. alpha.
+ """
+ obj = np.dot(alpha, a)
+ grad = a.copy()
+
+ # X[:, j] = alpha - C[:, j]
+ X = alpha[:, np.newaxis] - C
+
+ # val.shape = len(b)
+ # G.shape = len(a) x len(b)
+ val, G = regul.max_Omega(X, b)
+
+ obj -= np.dot(b, val)
+ grad -= np.dot(G, b)
+
+ return obj, grad
+
+
+def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
+ verbose=False):
+ """
+ Solve the "smoothed" semi-dual objective.
+
+ Parameters
+ ----------
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a max_Omega(X) method.
+ method: str
+ Solver to be used (passed to `scipy.optimize.minimize`).
+ tol: float
+ Tolerance parameter.
+ max_iter: int
+ Maximum number of iterations.
+
+ Returns
+ -------
+ alpha: array, shape = len(a)
+ Semi-dual potentials.
+ """
+
+ def _func(alpha):
+ obj, grad = semi_dual_obj_grad(alpha, a, b, C, regul)
+ # We need to maximize the semi-dual.
+ return -obj, -grad
+
+ alpha_init = np.zeros(len(a))
+
+ res = minimize(_func, alpha_init, method=method, jac=True,
+ tol=tol, options=dict(maxiter=max_iter, disp=verbose))
+
+ return res.x, res
+
+
+def get_plan_from_dual(alpha, beta, C, regul):
+ """
+ Retrieve optimal transportation plan from optimal dual potentials.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ beta: array, shape = len(b)
+ Optimal dual potentials.
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+
+ Returns
+ -------
+ T: array, shape = len(a) x len(b)
+ Optimal transportation plan.
+ """
+ X = alpha[:, np.newaxis] + beta - C
+ return regul.delta_Omega(X)[1]
+
+
+def get_plan_from_semi_dual(alpha, b, C, regul):
+ """
+ Retrieve optimal transportation plan from optimal semi-dual potentials.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ Optimal semi-dual potentials.
+ b: array, shape = len(b)
+ Second input histogram (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+
+ Returns
+ -------
+ T: array, shape = len(a) x len(b)
+ Optimal transportation plan.
+ """
+ X = alpha[:, np.newaxis] - C
+ return regul.max_Omega(X, b)[1] * b
+
+
+def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, verbose=False, log=False):
+ r"""
+ Solve the regularized OT problem in the dual and return the OT matrix
+
+ The function solves the smooth relaxed dual formulation (7) in [17]_ :
+
+ .. math::
+ \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j)
+
+ where :
+
+ - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega`
+ - a and b are source and target weights (sum to 1)
+
+ The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega`
+ (See [17]_ Proposition 1).
+ The optimization algorithm is using gradient decent (L-BFGS by default).
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ reg_type : str
+ Regularization type, can be the following (default ='l2'):
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
+ - 'l2' : Squared Euclidean regularization
+ method : str
+ Solver to use for scipy.optimize.minimize
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.sinhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ if reg_type.lower() in ['l2', 'squaredl2']:
+ regul = SquaredL2(gamma=reg)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+
+ # solve dual
+ alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
+ tol=stopThr, verbose=verbose)
+
+ # reconstruct transport matrix
+ G = get_plan_from_dual(alpha, beta, M, regul)
+
+ if log:
+ log = {'alpha': alpha, 'beta': beta, 'res': res}
+ return G, log
+ else:
+ return G
+
+
+def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, verbose=False, log=False):
+ r"""
+ Solve the regularized OT problem in the semi-dual and return the OT matrix
+
+ The function solves the smooth relaxed dual formulation (10) in [17]_ :
+
+ .. math::
+ \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b)
+
+ where :
+
+ .. math::
+ OT_\Omega^*(\alpha,b)=\sum_j b_j
+
+ - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17]
+ - a and b are source and target weights (sum to 1)
+
+ The OT matrix can is reconstructed using [17]_ Proposition 2.
+ The optimization algorithm is using gradient decent (L-BFGS by default).
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ reg_type : str
+ Regularization type, can be the following (default ='l2'):
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
+ - 'l2' : Squared Euclidean regularization
+ method : str
+ Solver to use for scipy.optimize.minimize
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.sinhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+ if reg_type.lower() in ['l2', 'squaredl2']:
+ regul = SquaredL2(gamma=reg)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+
+ # solve dual
+ alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax,
+ tol=stopThr, verbose=verbose)
+
+ # reconstruct transport matrix
+ G = get_plan_from_semi_dual(alpha, b, M, regul)
+
+ if log:
+ log = {'alpha': alpha, 'res': res}
+ return G, log
+ else:
+ return G
diff --git a/test/test_smooth.py b/test/test_smooth.py
new file mode 100644
index 0000000..2afa4f8
--- /dev/null
+++ b/test/test_smooth.py
@@ -0,0 +1,79 @@
+"""Tests for ot.smooth model """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+import pytest
+
+
+def test_smooth_ot_dual():
+
+ # get data
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ with pytest.raises(NotImplementedError):
+ Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none')
+
+ Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ # kl regyularisation
+ G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ np.testing.assert_allclose(G, G2, atol=1e-05)
+
+
+def test_smooth_ot_semi_dual():
+
+ # get data
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ with pytest.raises(NotImplementedError):
+ Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none')
+
+ Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ # kl regyularisation
+ G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ np.testing.assert_allclose(G, G2, atol=1e-05)