summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorKilian <kilian.fatras@gmail.com>2018-08-29 14:22:40 -0700
committerGitHub <noreply@github.com>2018-08-29 14:22:40 -0700
commit15f4b29a91fda1dbd221e6e0a3443431d3d69257 (patch)
tree82c33ee8b09112b6a67ed614e370156e4144628f /ot
parent63b34bf012076eb89ed112122fdaa65667464ae7 (diff)
parent5180023fc49d15ad83faccc5674d5966fe9a0385 (diff)
Merge branch 'master' into stochastic_OT
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py18
-rw-r--r--ot/da.py284
-rw-r--r--ot/lp/__init__.py95
-rw-r--r--ot/lp/cvx.py3
-rw-r--r--ot/smooth.py600
-rw-r--r--ot/utils.py31
6 files changed, 739 insertions, 292 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index b017c1a..c755f51 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -344,8 +344,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# print(reg)
- K = np.exp(-M / reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
# print(np.min(K))
+ tmp2 = np.empty(b.shape, dtype=M.dtype)
Kp = (1 / a).reshape(-1, 1) * K
cpt = 0
@@ -353,6 +358,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
+
KtransposeU = np.dot(K.T, u)
v = np.divide(b, KtransposeU)
u = 1. / np.dot(Kp, v)
@@ -373,8 +379,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
np.sum((v - vprev)**2) / np.sum((v)**2)
else:
- transp = u.reshape(-1, 1) * (K * v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ np.einsum('i,ij,j->j', u, K, v, out=tmp2)
+ err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
if log:
log['err'].append(err)
@@ -389,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['v'] = v
if nbb: # return only loss
- res = np.zeros((nbb))
- for i in range(nbb):
- res[i] = np.sum(
- u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M)
+ res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
diff --git a/ot/da.py b/ot/da.py
index 48b418f..bc09e3c 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -15,7 +15,7 @@ import scipy.linalg as linalg
from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
-from .utils import check_params, deprecated, BaseEstimator
+from .utils import check_params, BaseEstimator
from .optim import cg
from .optim import gcg
@@ -740,288 +740,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b
-@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
- "removed in 0.5"
- "\n\tfor standard transport use class EMDTransport instead.")
-class OTDA(object):
-
- """Class for domain adaptation with optimal transport as proposed in [5]
-
-
- References
- ----------
-
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
- "Optimal Transport for Domain Adaptation," in IEEE Transactions on
- Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
-
- """
-
- def __init__(self, metric='sqeuclidean', norm=None):
- """ Class initialization"""
- self.xs = 0
- self.xt = 0
- self.G = 0
- self.metric = metric
- self.norm = norm
- self.computed = False
-
- def fit(self, xs, xt, ws=None, wt=None, max_iter=100000):
- """Fit domain adaptation between samples is xs and xt
- (with optional weights)"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = emd(ws, wt, self.M, max_iter)
- self.computed = True
-
- def interp(self, direction=1):
- """Barycentric interpolation for the source (1) or target (-1) samples
-
- This Barycentric interpolation solves for each source (resp target)
- sample xs (resp xt) the following optimization problem:
-
- .. math::
- arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)
-
- where k is the index of the sample in xs
-
- For the moment only squared euclidean distance is provided but more
- metric could be used in the future.
-
- """
- if direction > 0: # >0 then source to target
- G = self.G
- w = self.ws.reshape((self.xs.shape[0], 1))
- x = self.xt
- else:
- G = self.G.T
- w = self.wt.reshape((self.xt.shape[0], 1))
- x = self.xs
-
- if self.computed:
- if self.metric == 'sqeuclidean':
- return np.dot(G / w, x) # weighted mean
- else:
- print(
- "Warning, metric not handled yet, using weighted average")
- return np.dot(G / w, x) # weighted mean
- return None
- else:
- print("Warning, model not fitted yet, returning None")
- return None
-
- def predict(self, x, direction=1):
- """ Out of sample mapping using the formulation from [6]
-
- For each sample x to map, it finds the nearest source sample xs and
- map the samle x to the position xst+(x-xs) wher xst is the barycentric
- interpolation of source sample xs.
-
- References
- ----------
-
- .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
- Regularized discrete optimal transport. SIAM Journal on Imaging
- Sciences, 7(3), 1853-1882.
-
- """
- if direction > 0: # >0 then source to target
- xf = self.xt
- x0 = self.xs
- else:
- xf = self.xs
- x0 = self.xt
-
- D0 = dist(x, x0) # dist netween new samples an source
- idx = np.argmin(D0, 1) # closest one
- xf = self.interp(direction) # interp the source samples
- # aply the delta to the interpolation
- return xf[idx, :] + x - x0[idx, :]
-
-
-@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class SinkhornTransport instead.")
-class OTDA_sinkhorn(OTDA):
-
- """Class for domain adaptation with optimal transport with entropic
- regularization
-
-
- """
-
- def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs):
- """Fit regularized domain adaptation between samples is xs and xt
- (with optional weights)"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
- self.computed = True
-
-
-@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class SinkhornLpl1Transport instead.")
-class OTDA_lpl1(OTDA):
-
- """Class for domain adaptation with optimal transport with entropic and
- group regularization"""
-
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
- """Fit regularized domain adaptation between samples is xs and xt
- (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
- parameters"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
- self.computed = True
-
-
-@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class SinkhornL1l2Transport instead.")
-class OTDA_l1l2(OTDA):
-
- """Class for domain adaptation with optimal transport with entropic
- and group lasso regularization"""
-
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
- """Fit regularized domain adaptation between samples is xs and xt
- (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
- parameters"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
- self.computed = True
-
-
-@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class MappingTransport instead.")
-class OTDA_mapping_linear(OTDA):
-
- """Class for optimal transport with joint linear mapping estimation as in
- [8]
- """
-
- def __init__(self):
- """ Class initialization"""
-
- self.xs = 0
- self.xt = 0
- self.G = 0
- self.L = 0
- self.bias = False
- self.computed = False
- self.metric = 'sqeuclidean'
-
- def fit(self, xs, xt, mu=1, eta=1, bias=False, **kwargs):
- """ Fit domain adaptation between samples is xs and xt (with optional
- weights)"""
- self.xs = xs
- self.xt = xt
- self.bias = bias
-
- self.ws = unif(xs.shape[0])
- self.wt = unif(xt.shape[0])
-
- self.G, self.L = joint_OT_mapping_linear(
- xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
- self.computed = True
-
- def mapping(self):
- return lambda x: self.predict(x)
-
- def predict(self, x):
- """ Out of sample mapping estimated during the call to fit"""
- if self.computed:
- if self.bias:
- x = np.hstack((x, np.ones((x.shape[0], 1))))
- return x.dot(self.L) # aply the delta to the interpolation
- else:
- print("Warning, model not fitted yet, returning None")
- return None
-
-
-@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class MappingTransport instead.")
-class OTDA_mapping_kernel(OTDA_mapping_linear):
-
- """Class for optimal transport with joint nonlinear mapping
- estimation as in [8]"""
-
- def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian',
- sigma=1, **kwargs):
- """ Fit domain adaptation between samples is xs and xt """
- self.xs = xs
- self.xt = xt
- self.bias = bias
-
- self.ws = unif(xs.shape[0])
- self.wt = unif(xt.shape[0])
- self.kernel = kerneltype
- self.sigma = sigma
- self.kwargs = kwargs
-
- self.G, self.L = joint_OT_mapping_kernel(
- xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
- self.computed = True
-
- def predict(self, x):
- """ Out of sample mapping estimated during the call to fit"""
-
- if self.computed:
- K = kernel(
- x, self.xs, method=self.kernel, sigma=self.sigma,
- **self.kwargs)
- if self.bias:
- K = np.hstack((K, np.ones((x.shape[0], 1))))
- return K.dot(self.L)
- else:
- print("Warning, model not fitted yet, returning None")
- return None
-
-
def distribution_estimation_uniform(X):
"""estimates a uniform distribution from an array of samples X
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 5dda82a..02cbd8c 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -17,6 +17,9 @@ from .import cvx
from .emd_wrap import emd_c, check_result
from ..utils import parmap
from .cvx import barycenter
+from ..utils import dist
+
+__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
def emd(a, b, M, numItermax=100000, log=False):
@@ -214,3 +217,95 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
res = parmap(f, [b[:, i] for i in range(nb)], processes)
return res
+
+
+
+def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
+ """
+ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
+
+ The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
+ This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
+
+ Parameters
+ ----------
+ measures_locations : list of (k_i,d) np.ndarray
+ The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
+ measures_weights : list of (k_i,) np.ndarray
+ Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
+
+ X_init : (k,d) np.ndarray
+ Initialization of the support locations (on k atoms) of the barycenter
+ b : (k,) np.ndarray
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ weights : (k,) np.ndarray
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) np.ndarray
+ Support locations (on k atoms) of the barycenter
+
+ References
+ ----------
+
+ .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+
+ iter_count = 0
+
+ N = len(measures_locations)
+ k = X_init.shape[0]
+ d = X_init.shape[1]
+ if b is None:
+ b = np.ones((k,))/k
+ if weights is None:
+ weights = np.ones((N,)) / N
+
+ X = X_init
+
+ log_dict = {}
+ displacement_square_norms = []
+
+ displacement_square_norm = stopThr + 1.
+
+ while ( displacement_square_norm > stopThr and iter_count < numItermax ):
+
+ T_sum = np.zeros((k, d))
+
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
+
+ M_i = dist(X, measure_locations_i)
+ T_i = emd(b, measure_weights_i, M_i)
+ T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
+
+ displacement_square_norm = np.sum(np.square(T_sum-X))
+ if log:
+ displacement_square_norms.append(displacement_square_norm)
+
+ X = T_sum
+
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
+
+ iter_count += 1
+
+ if log:
+ log_dict['displacement_square_norms'] = displacement_square_norms
+ return X, log_dict
+ else:
+ return X \ No newline at end of file
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index c8c75bc..8e763be 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -11,6 +11,7 @@ import numpy as np
import scipy as sp
import scipy.sparse as sps
+
try:
import cvxopt
from cvxopt import solvers, matrix, spmatrix
@@ -26,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
- """Compute the entropic regularized wasserstein barycenter of distributions A
+ """Compute the Wasserstein barycenter of distributions A
The function solves the following optimization problem [16]:
diff --git a/ot/smooth.py b/ot/smooth.py
new file mode 100644
index 0000000..5a8e4b5
--- /dev/null
+++ b/ot/smooth.py
@@ -0,0 +1,600 @@
+#Copyright (c) 2018, Mathieu Blondel
+#All rights reserved.
+#
+#Redistribution and use in source and binary forms, with or without
+#modification, are permitted provided that the following conditions are met:
+#
+#1. Redistributions of source code must retain the above copyright notice, this
+#list of conditions and the following disclaimer.
+#
+#2. Redistributions in binary form must reproduce the above copyright notice,
+#this list of conditions and the following disclaimer in the documentation and/or
+#other materials provided with the distribution.
+#
+#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+#ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+#WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+#IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
+#INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
+#NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
+#OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+#LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+#OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+#THE POSSIBILITY OF SUCH DAMAGE.
+
+# Author: Mathieu Blondel
+# Remi Flamary <remi.flamary@unice.fr>
+
+"""
+Implementation of
+Smooth and Sparse Optimal Transport.
+Mathieu Blondel, Vivien Seguy, Antoine Rolet.
+In Proc. of AISTATS 2018.
+https://arxiv.org/abs/1710.06276
+
+[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal
+Transport. Proceedings of the Twenty-First International Conference on
+Artificial Intelligence and Statistics (AISTATS).
+
+Original code from https://github.com/mblondel/smooth-ot/
+
+"""
+
+import numpy as np
+from scipy.optimize import minimize
+
+
+def projection_simplex(V, z=1, axis=None):
+ """ Projection of x onto the simplex, scaled by z
+
+ P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
+ z: float or array
+ If array, len(z) must be compatible with V
+ axis: None or int
+ - axis=None: project V by P(V.ravel(); z)
+ - axis=1: project each V[i] by P(V[i]; z[i])
+ - axis=0: project each V[:, j] by P(V[:, j]; z[j])
+ """
+ if axis == 1:
+ n_features = V.shape[1]
+ U = np.sort(V, axis=1)[:, ::-1]
+ z = np.ones(len(V)) * z
+ cssv = np.cumsum(U, axis=1) - z[:, np.newaxis]
+ ind = np.arange(n_features) + 1
+ cond = U - cssv / ind > 0
+ rho = np.count_nonzero(cond, axis=1)
+ theta = cssv[np.arange(len(V)), rho - 1] / rho
+ return np.maximum(V - theta[:, np.newaxis], 0)
+
+ elif axis == 0:
+ return projection_simplex(V.T, z, axis=1).T
+
+ else:
+ V = V.ravel().reshape(1, -1)
+ return projection_simplex(V, z, axis=1).ravel()
+
+
+class Regularization(object):
+ """Base class for Regularization objects
+
+ Notes
+ -----
+ This class is not intended for direct use but as aparent for true
+ regularizatiojn implementation.
+ """
+
+ def __init__(self, gamma=1.0):
+ """
+
+ Parameters
+ ----------
+ gamma: float
+ Regularization parameter.
+ We recover unregularized OT when gamma -> 0.
+
+ """
+ self.gamma = gamma
+
+ def delta_Omega(X):
+ """
+ Compute delta_Omega(X[:, j]) for each X[:, j].
+ delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).
+
+ Parameters
+ ----------
+ X: array, shape = len(a) x len(b)
+ Input array.
+
+ Returns
+ -------
+ v: array, len(b)
+ Values: v[j] = delta_Omega(X[:, j])
+ G: array, len(a) x len(b)
+ Gradients: G[:, j] = nabla delta_Omega(X[:, j])
+ """
+ raise NotImplementedError
+
+ def max_Omega(X, b):
+ """
+ Compute max_Omega_j(X[:, j]) for each X[:, j].
+ max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j].
+
+ Parameters
+ ----------
+ X: array, shape = len(a) x len(b)
+ Input array.
+
+ Returns
+ -------
+ v: array, len(b)
+ Values: v[j] = max_Omega_j(X[:, j])
+ G: array, len(a) x len(b)
+ Gradients: G[:, j] = nabla max_Omega_j(X[:, j])
+ """
+ raise NotImplementedError
+
+ def Omega(T):
+ """
+ Compute regularization term.
+
+ Parameters
+ ----------
+ T: array, shape = len(a) x len(b)
+ Input array.
+
+ Returns
+ -------
+ value: float
+ Regularization term.
+ """
+ raise NotImplementedError
+
+
+class NegEntropy(Regularization):
+ """ NegEntropy regularization """
+
+ def delta_Omega(self, X):
+ G = np.exp(X / self.gamma - 1)
+ val = self.gamma * np.sum(G, axis=0)
+ return val, G
+
+ def max_Omega(self, X, b):
+ max_X = np.max(X, axis=0) / self.gamma
+ exp_X = np.exp(X / self.gamma - max_X)
+ val = self.gamma * (np.log(np.sum(exp_X, axis=0)) + max_X)
+ val -= self.gamma * np.log(b)
+ G = exp_X / np.sum(exp_X, axis=0)
+ return val, G
+
+ def Omega(self, T):
+ return self.gamma * np.sum(T * np.log(T))
+
+
+class SquaredL2(Regularization):
+ """ Squared L2 regularization """
+
+ def delta_Omega(self, X):
+ max_X = np.maximum(X, 0)
+ val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma)
+ G = max_X / self.gamma
+ return val, G
+
+ def max_Omega(self, X, b):
+ G = projection_simplex(X / (b * self.gamma), axis=0)
+ val = np.sum(X * G, axis=0)
+ val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
+ return val, G
+
+ def Omega(self, T):
+ return 0.5 * self.gamma * np.sum(T ** 2)
+
+
+def dual_obj_grad(alpha, beta, a, b, C, regul):
+ """
+ Compute objective value and gradients of dual objective.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ beta: array, shape = len(b)
+ Current iterate of dual potentials.
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+
+ Returns
+ -------
+ obj: float
+ Objective value (higher is better).
+ grad_alpha: array, shape = len(a)
+ Gradient w.r.t. alpha.
+ grad_beta: array, shape = len(b)
+ Gradient w.r.t. beta.
+ """
+ obj = np.dot(alpha, a) + np.dot(beta, b)
+ grad_alpha = a.copy()
+ grad_beta = b.copy()
+
+ # X[:, j] = alpha + beta[j] - C[:, j]
+ X = alpha[:, np.newaxis] + beta - C
+
+ # val.shape = len(b)
+ # G.shape = len(a) x len(b)
+ val, G = regul.delta_Omega(X)
+
+ obj -= np.sum(val)
+ grad_alpha -= G.sum(axis=1)
+ grad_beta -= G.sum(axis=0)
+
+ return obj, grad_alpha, grad_beta
+
+
+def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
+ verbose=False):
+ """
+ Solve the "smoothed" dual objective.
+
+ Parameters
+ ----------
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+ method: str
+ Solver to be used (passed to `scipy.optimize.minimize`).
+ tol: float
+ Tolerance parameter.
+ max_iter: int
+ Maximum number of iterations.
+
+ Returns
+ -------
+ alpha: array, shape = len(a)
+ beta: array, shape = len(b)
+ Dual potentials.
+ """
+
+ def _func(params):
+ # Unpack alpha and beta.
+ alpha = params[:len(a)]
+ beta = params[len(a):]
+
+ obj, grad_alpha, grad_beta = dual_obj_grad(alpha, beta, a, b, C, regul)
+
+ # Pack grad_alpha and grad_beta.
+ grad = np.concatenate((grad_alpha, grad_beta))
+
+ # We need to maximize the dual.
+ return -obj, -grad
+
+ # Unfortunately, `minimize` only supports functions whose argument is a
+ # vector. So, we need to concatenate alpha and beta.
+ alpha_init = np.zeros(len(a))
+ beta_init = np.zeros(len(b))
+ params_init = np.concatenate((alpha_init, beta_init))
+
+ res = minimize(_func, params_init, method=method, jac=True,
+ tol=tol, options=dict(maxiter=max_iter, disp=verbose))
+
+ alpha = res.x[:len(a)]
+ beta = res.x[len(a):]
+
+ return alpha, beta, res
+
+
+def semi_dual_obj_grad(alpha, a, b, C, regul):
+ """
+ Compute objective value and gradient of semi-dual objective.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ Current iterate of semi-dual potentials.
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a max_Omega(X) method.
+
+ Returns
+ -------
+ obj: float
+ Objective value (higher is better).
+ grad: array, shape = len(a)
+ Gradient w.r.t. alpha.
+ """
+ obj = np.dot(alpha, a)
+ grad = a.copy()
+
+ # X[:, j] = alpha - C[:, j]
+ X = alpha[:, np.newaxis] - C
+
+ # val.shape = len(b)
+ # G.shape = len(a) x len(b)
+ val, G = regul.max_Omega(X, b)
+
+ obj -= np.dot(b, val)
+ grad -= np.dot(G, b)
+
+ return obj, grad
+
+
+def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
+ verbose=False):
+ """
+ Solve the "smoothed" semi-dual objective.
+
+ Parameters
+ ----------
+ a: array, shape = len(a)
+ b: array, shape = len(b)
+ Input histograms (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a max_Omega(X) method.
+ method: str
+ Solver to be used (passed to `scipy.optimize.minimize`).
+ tol: float
+ Tolerance parameter.
+ max_iter: int
+ Maximum number of iterations.
+
+ Returns
+ -------
+ alpha: array, shape = len(a)
+ Semi-dual potentials.
+ """
+
+ def _func(alpha):
+ obj, grad = semi_dual_obj_grad(alpha, a, b, C, regul)
+ # We need to maximize the semi-dual.
+ return -obj, -grad
+
+ alpha_init = np.zeros(len(a))
+
+ res = minimize(_func, alpha_init, method=method, jac=True,
+ tol=tol, options=dict(maxiter=max_iter, disp=verbose))
+
+ return res.x, res
+
+
+def get_plan_from_dual(alpha, beta, C, regul):
+ """
+ Retrieve optimal transportation plan from optimal dual potentials.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ beta: array, shape = len(b)
+ Optimal dual potentials.
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+
+ Returns
+ -------
+ T: array, shape = len(a) x len(b)
+ Optimal transportation plan.
+ """
+ X = alpha[:, np.newaxis] + beta - C
+ return regul.delta_Omega(X)[1]
+
+
+def get_plan_from_semi_dual(alpha, b, C, regul):
+ """
+ Retrieve optimal transportation plan from optimal semi-dual potentials.
+
+ Parameters
+ ----------
+ alpha: array, shape = len(a)
+ Optimal semi-dual potentials.
+ b: array, shape = len(b)
+ Second input histogram (should be non-negative and sum to 1).
+ C: array, shape = len(a) x len(b)
+ Ground cost matrix.
+ regul: Regularization object
+ Should implement a delta_Omega(X) method.
+
+ Returns
+ -------
+ T: array, shape = len(a) x len(b)
+ Optimal transportation plan.
+ """
+ X = alpha[:, np.newaxis] - C
+ return regul.max_Omega(X, b)[1] * b
+
+
+def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, verbose=False, log=False):
+ r"""
+ Solve the regularized OT problem in the dual and return the OT matrix
+
+ The function solves the smooth relaxed dual formulation (7) in [17]_ :
+
+ .. math::
+ \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j)
+
+ where :
+
+ - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega`
+ - a and b are source and target weights (sum to 1)
+
+ The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega`
+ (See [17]_ Proposition 1).
+ The optimization algorithm is using gradient decent (L-BFGS by default).
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ reg_type : str
+ Regularization type, can be the following (default ='l2'):
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
+ - 'l2' : Squared Euclidean regularization
+ method : str
+ Solver to use for scipy.optimize.minimize
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.sinhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ if reg_type.lower() in ['l2', 'squaredl2']:
+ regul = SquaredL2(gamma=reg)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+
+ # solve dual
+ alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
+ tol=stopThr, verbose=verbose)
+
+ # reconstruct transport matrix
+ G = get_plan_from_dual(alpha, beta, M, regul)
+
+ if log:
+ log = {'alpha': alpha, 'beta': beta, 'res': res}
+ return G, log
+ else:
+ return G
+
+
+def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, verbose=False, log=False):
+ r"""
+ Solve the regularized OT problem in the semi-dual and return the OT matrix
+
+ The function solves the smooth relaxed dual formulation (10) in [17]_ :
+
+ .. math::
+ \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b)
+
+ where :
+
+ .. math::
+ OT_\Omega^*(\alpha,b)=\sum_j b_j
+
+ - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17]
+ - a and b are source and target weights (sum to 1)
+
+ The OT matrix can is reconstructed using [17]_ Proposition 2.
+ The optimization algorithm is using gradient decent (L-BFGS by default).
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ reg_type : str
+ Regularization type, can be the following (default ='l2'):
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
+ - 'l2' : Squared Euclidean regularization
+ method : str
+ Solver to use for scipy.optimize.minimize
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.sinhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+ if reg_type.lower() in ['l2', 'squaredl2']:
+ regul = SquaredL2(gamma=reg)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+
+ # solve dual
+ alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax,
+ tol=stopThr, verbose=verbose)
+
+ # reconstruct transport matrix
+ G = get_plan_from_semi_dual(alpha, b, M, regul)
+
+ if log:
+ log = {'alpha': alpha, 'res': res}
+ return G, log
+ else:
+ return G
diff --git a/ot/utils.py b/ot/utils.py
index 7dac283..bb21b38 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -77,6 +77,34 @@ def clean_zeros(a, b, M):
return a2, b2, M2
+def euclidean_distances(X, Y, squared=False):
+ """
+ Considering the rows of X (and Y=X) as vectors, compute the
+ distance matrix between each pair of vectors.
+ Parameters
+ ----------
+ X : {array-like}, shape (n_samples_1, n_features)
+ Y : {array-like}, shape (n_samples_2, n_features)
+ squared : boolean, optional
+ Return squared Euclidean distances.
+ Returns
+ -------
+ distances : {array}, shape (n_samples_1, n_samples_2)
+ """
+ XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
+ YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
+ distances = np.dot(X, Y.T)
+ distances *= -2
+ distances += XX
+ distances += YY
+ np.maximum(distances, 0, out=distances)
+ if X is Y:
+ # Ensure that distances between vectors and themselves are set to 0.0.
+ # This may not be the case due to floating point rounding errors.
+ distances.flat[::distances.shape[0] + 1] = 0.0
+ return distances if squared else np.sqrt(distances, out=distances)
+
+
def dist(x1, x2=None, metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
@@ -104,7 +132,8 @@ def dist(x1, x2=None, metric='sqeuclidean'):
"""
if x2 is None:
x2 = x1
-
+ if metric == "sqeuclidean":
+ return euclidean_distances(x1, x2, squared=True)
return cdist(x1, x2, metric=metric)