summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile11
-rw-r--r--README.md10
-rw-r--r--data/duck.pngbin0 -> 5112 bytes
-rw-r--r--data/heart.pngbin0 -> 5225 bytes
-rw-r--r--data/redcross.pngbin0 -> 1683 bytes
-rw-r--r--data/tooth.pngbin0 -> 4931 bytes
-rw-r--r--examples/plot_convolutional_barycenter.py92
-rw-r--r--ot/bregman.py125
-rw-r--r--ot/da.py284
-rw-r--r--ot/stochastic.py179
-rw-r--r--test/test_bregman.py24
-rw-r--r--test/test_da.py63
-rw-r--r--test/test_stochastic.py60
13 files changed, 337 insertions, 511 deletions
diff --git a/Makefile b/Makefile
index 1abc6e9..84a644b 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,7 @@
PYTHON=python3
+branch := $(shell git symbolic-ref --short -q HEAD)
help :
@echo "The following make targets are available:"
@@ -57,6 +58,16 @@ rdoc :
notebook :
ipython notebook --matplotlib=inline --notebook-dir=notebooks/
+bench :
+ @git stash >/dev/null 2>&1
+ @echo 'Branch master'
+ @git checkout master >/dev/null 2>&1
+ python3 $(script)
+ @echo 'Branch $(branch)'
+ @git checkout $(branch) >/dev/null 2>&1
+ python3 $(script)
+ @git stash apply >/dev/null 2>&1
+
autopep8 :
autopep8 -ir test ot examples --jobs -1
diff --git a/README.md b/README.md
index d2f0fea..e56ae40 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,6 @@ It provides the following solvers:
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
-* Non regularized free support Wasserstein barycenters [20].
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
@@ -26,6 +25,7 @@ It provides the following solvers:
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
* Gromov-Wasserstein distances and barycenters ([13] and regularized [12])
* Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
+* Non regularized free support Wasserstein barycenters [20].
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
@@ -165,7 +165,7 @@ The contributors to this library are:
* [Stanislas Chambon](https://slasnista.github.io/)
* [Antoine Rolet](https://arolet.github.io/)
* Erwan Vautier (Gromov-Wasserstein)
-* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic optimization)
+* [Kilian Fatras](https://kilianfatras.github.io/)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -224,8 +224,10 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
-[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).
+[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016).
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)
-[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning \ No newline at end of file
+[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
+
+[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
diff --git a/data/duck.png b/data/duck.png
new file mode 100644
index 0000000..9181697
--- /dev/null
+++ b/data/duck.png
Binary files differ
diff --git a/data/heart.png b/data/heart.png
new file mode 100644
index 0000000..44a6385
--- /dev/null
+++ b/data/heart.png
Binary files differ
diff --git a/data/redcross.png b/data/redcross.png
new file mode 100644
index 0000000..8d0a6fa
--- /dev/null
+++ b/data/redcross.png
Binary files differ
diff --git a/data/tooth.png b/data/tooth.png
new file mode 100644
index 0000000..cd92c9d
--- /dev/null
+++ b/data/tooth.png
Binary files differ
diff --git a/examples/plot_convolutional_barycenter.py b/examples/plot_convolutional_barycenter.py
new file mode 100644
index 0000000..e74db04
--- /dev/null
+++ b/examples/plot_convolutional_barycenter.py
@@ -0,0 +1,92 @@
+
+#%%
+# -*- coding: utf-8 -*-
+"""
+============================================
+Convolutional Wasserstein Barycenter example
+============================================
+
+This example is designed to illustrate how the Convolutional Wasserstein Barycenter
+function of POT works.
+"""
+
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import pylab as pl
+import ot
+
+##############################################################################
+# Data preparation
+# ----------------
+#
+# The four distributions are constructed from 4 simple images
+
+
+f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
+f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
+f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
+f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
+
+A = []
+f1 = f1 / np.sum(f1)
+f2 = f2 / np.sum(f2)
+f3 = f3 / np.sum(f3)
+f4 = f4 / np.sum(f4)
+A.append(f1)
+A.append(f2)
+A.append(f3)
+A.append(f4)
+A = np.array(A)
+
+nb_images = 5
+
+# those are the four corners coordinates that will be interpolated by bilinear
+# interpolation
+v1 = np.array((1, 0, 0, 0))
+v2 = np.array((0, 1, 0, 0))
+v3 = np.array((0, 0, 1, 0))
+v4 = np.array((0, 0, 0, 1))
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+pl.figure(figsize=(10, 10))
+pl.title('Convolutional Wasserstein Barycenters in POT')
+cm = 'Blues'
+# regularization parameter
+reg = 0.004
+for i in range(nb_images):
+ for j in range(nb_images):
+ pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
+ tx = float(i) / (nb_images - 1)
+ ty = float(j) / (nb_images - 1)
+
+ # weights are constructed by bilinear interpolation
+ tmp1 = (1 - tx) * v1 + tx * v2
+ tmp2 = (1 - tx) * v3 + tx * v4
+ weights = (1 - ty) * tmp1 + ty * tmp2
+
+ if i == 0 and j == 0:
+ pl.imshow(f1, cmap=cm)
+ pl.axis('off')
+ elif i == 0 and j == (nb_images - 1):
+ pl.imshow(f3, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == 0:
+ pl.imshow(f2, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == (nb_images - 1):
+ pl.imshow(f4, cmap=cm)
+ pl.axis('off')
+ else:
+ # call to barycenter computation
+ pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
+ pl.axis('off')
+pl.show()
diff --git a/ot/bregman.py b/ot/bregman.py
index c8e69ce..35e51f8 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
np.exp(K, out=K)
# print(np.min(K))
- tmp = np.empty(K.shape, dtype=M.dtype)
tmp2 = np.empty(b.shape, dtype=M.dtype)
Kp = (1 / a).reshape(-1, 1) * K
@@ -359,6 +358,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
+
KtransposeU = np.dot(K.T, u)
v = np.divide(b, KtransposeU)
u = 1. / np.dot(Kp, v)
@@ -379,11 +379,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
np.sum((v - vprev)**2) / np.sum((v)**2)
else:
- np.multiply(u.reshape(-1, 1), K, out=tmp)
- np.multiply(tmp, v.reshape(1, -1), out=tmp)
- np.sum(tmp, axis=0, out=tmp2)
- tmp2 -= b
- err = np.linalg.norm(tmp2)**2
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ np.einsum('i,ij,j->j', u, K, v, out=tmp2)
+ err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
if log:
log['err'].append(err)
@@ -398,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['v'] = v
if nbb: # return only loss
- res = np.zeros((nbb))
- for i in range(nbb):
- res[i] = np.sum(
- u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M)
+ res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -924,6 +919,116 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
+ """Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
+ - reg is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
+
+ Parameters
+ ----------
+ A : np.ndarray (n,w,h)
+ n distributions (2D images) of size w x h
+ reg : float
+ Regularization term >0
+ weights : np.ndarray (n,)
+ Weights of each image on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (w,h) ndarray
+ 2D Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
+ Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
+ ACM Transactions on Graphics (TOG), 34(4), 66
+
+
+ """
+
+ if weights is None:
+ weights = np.ones(A.shape[0]) / A.shape[0]
+ else:
+ assert(len(weights) == A.shape[0])
+
+ if log:
+ log = {'err': []}
+
+ b = np.zeros_like(A[0, :, :])
+ U = np.ones_like(A)
+ KV = np.ones_like(A)
+
+ cpt = 0
+ err = 1
+
+ # build the convolution operator
+ t = np.linspace(0, 1, A.shape[1])
+ [Y, X] = np.meshgrid(t, t)
+ xi1 = np.exp(-(X - Y)**2 / reg)
+
+ def K(x):
+ return np.dot(np.dot(xi1, x), xi1)
+
+ while (err > stopThr and cpt < numItermax):
+
+ bold = b
+ cpt = cpt + 1
+
+ b = np.zeros_like(A[0, :, :])
+ for r in range(A.shape[0]):
+ KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
+ b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
+ b = np.exp(b)
+ for r in range(A.shape[0]):
+ U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
+
+ if cpt % 10 == 1:
+ err = np.sum(np.abs(bold - b))
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ if log:
+ log['niter'] = cpt
+ log['U'] = U
+ return b, log
+ else:
+ return b
+
+
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
stopThr=1e-3, verbose=False, log=False):
"""
diff --git a/ot/da.py b/ot/da.py
index 48b418f..bc09e3c 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -15,7 +15,7 @@ import scipy.linalg as linalg
from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
-from .utils import check_params, deprecated, BaseEstimator
+from .utils import check_params, BaseEstimator
from .optim import cg
from .optim import gcg
@@ -740,288 +740,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b
-@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
- "removed in 0.5"
- "\n\tfor standard transport use class EMDTransport instead.")
-class OTDA(object):
-
- """Class for domain adaptation with optimal transport as proposed in [5]
-
-
- References
- ----------
-
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
- "Optimal Transport for Domain Adaptation," in IEEE Transactions on
- Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
-
- """
-
- def __init__(self, metric='sqeuclidean', norm=None):
- """ Class initialization"""
- self.xs = 0
- self.xt = 0
- self.G = 0
- self.metric = metric
- self.norm = norm
- self.computed = False
-
- def fit(self, xs, xt, ws=None, wt=None, max_iter=100000):
- """Fit domain adaptation between samples is xs and xt
- (with optional weights)"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = emd(ws, wt, self.M, max_iter)
- self.computed = True
-
- def interp(self, direction=1):
- """Barycentric interpolation for the source (1) or target (-1) samples
-
- This Barycentric interpolation solves for each source (resp target)
- sample xs (resp xt) the following optimization problem:
-
- .. math::
- arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)
-
- where k is the index of the sample in xs
-
- For the moment only squared euclidean distance is provided but more
- metric could be used in the future.
-
- """
- if direction > 0: # >0 then source to target
- G = self.G
- w = self.ws.reshape((self.xs.shape[0], 1))
- x = self.xt
- else:
- G = self.G.T
- w = self.wt.reshape((self.xt.shape[0], 1))
- x = self.xs
-
- if self.computed:
- if self.metric == 'sqeuclidean':
- return np.dot(G / w, x) # weighted mean
- else:
- print(
- "Warning, metric not handled yet, using weighted average")
- return np.dot(G / w, x) # weighted mean
- return None
- else:
- print("Warning, model not fitted yet, returning None")
- return None
-
- def predict(self, x, direction=1):
- """ Out of sample mapping using the formulation from [6]
-
- For each sample x to map, it finds the nearest source sample xs and
- map the samle x to the position xst+(x-xs) wher xst is the barycentric
- interpolation of source sample xs.
-
- References
- ----------
-
- .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
- Regularized discrete optimal transport. SIAM Journal on Imaging
- Sciences, 7(3), 1853-1882.
-
- """
- if direction > 0: # >0 then source to target
- xf = self.xt
- x0 = self.xs
- else:
- xf = self.xs
- x0 = self.xt
-
- D0 = dist(x, x0) # dist netween new samples an source
- idx = np.argmin(D0, 1) # closest one
- xf = self.interp(direction) # interp the source samples
- # aply the delta to the interpolation
- return xf[idx, :] + x - x0[idx, :]
-
-
-@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class SinkhornTransport instead.")
-class OTDA_sinkhorn(OTDA):
-
- """Class for domain adaptation with optimal transport with entropic
- regularization
-
-
- """
-
- def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs):
- """Fit regularized domain adaptation between samples is xs and xt
- (with optional weights)"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
- self.computed = True
-
-
-@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class SinkhornLpl1Transport instead.")
-class OTDA_lpl1(OTDA):
-
- """Class for domain adaptation with optimal transport with entropic and
- group regularization"""
-
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
- """Fit regularized domain adaptation between samples is xs and xt
- (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
- parameters"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
- self.computed = True
-
-
-@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class SinkhornL1l2Transport instead.")
-class OTDA_l1l2(OTDA):
-
- """Class for domain adaptation with optimal transport with entropic
- and group lasso regularization"""
-
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
- """Fit regularized domain adaptation between samples is xs and xt
- (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
- parameters"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
- self.computed = True
-
-
-@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class MappingTransport instead.")
-class OTDA_mapping_linear(OTDA):
-
- """Class for optimal transport with joint linear mapping estimation as in
- [8]
- """
-
- def __init__(self):
- """ Class initialization"""
-
- self.xs = 0
- self.xt = 0
- self.G = 0
- self.L = 0
- self.bias = False
- self.computed = False
- self.metric = 'sqeuclidean'
-
- def fit(self, xs, xt, mu=1, eta=1, bias=False, **kwargs):
- """ Fit domain adaptation between samples is xs and xt (with optional
- weights)"""
- self.xs = xs
- self.xt = xt
- self.bias = bias
-
- self.ws = unif(xs.shape[0])
- self.wt = unif(xt.shape[0])
-
- self.G, self.L = joint_OT_mapping_linear(
- xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
- self.computed = True
-
- def mapping(self):
- return lambda x: self.predict(x)
-
- def predict(self, x):
- """ Out of sample mapping estimated during the call to fit"""
- if self.computed:
- if self.bias:
- x = np.hstack((x, np.ones((x.shape[0], 1))))
- return x.dot(self.L) # aply the delta to the interpolation
- else:
- print("Warning, model not fitted yet, returning None")
- return None
-
-
-@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class MappingTransport instead.")
-class OTDA_mapping_kernel(OTDA_mapping_linear):
-
- """Class for optimal transport with joint nonlinear mapping
- estimation as in [8]"""
-
- def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian',
- sigma=1, **kwargs):
- """ Fit domain adaptation between samples is xs and xt """
- self.xs = xs
- self.xt = xt
- self.bias = bias
-
- self.ws = unif(xs.shape[0])
- self.wt = unif(xt.shape[0])
- self.kernel = kerneltype
- self.sigma = sigma
- self.kwargs = kwargs
-
- self.G, self.L = joint_OT_mapping_kernel(
- xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
- self.computed = True
-
- def predict(self, x):
- """ Out of sample mapping estimated during the call to fit"""
-
- if self.computed:
- K = kernel(
- x, self.xs, method=self.kernel, sigma=self.sigma,
- **self.kwargs)
- if self.bias:
- K = np.hstack((K, np.ones((x.shape[0], 1))))
- return K.dot(self.L)
- else:
- print("Warning, model not fitted yet, returning None")
- return None
-
-
def distribution_estimation_uniform(X):
"""estimates a uniform distribution from an array of samples X
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 5e8206e..ec53015 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -435,8 +435,8 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
##############################################################################
-def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
- batch_beta):
+def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
+ batch_beta):
'''
Computes the partial gradient of F_\W_varepsilon
@@ -444,104 +444,31 @@ def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
..math:
\forall i in batch_alpha,
- grad_alpha_i = 1 * batch_size -
- sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
+ grad_alpha_i = alpha_i * batch_size/len(beta) -
+ sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
+ * a_i * b_j
- where :
- - M is the (ns,nt) metric cost matrix
- - alpha, beta are dual variables in R^ixR^J
- - reg is the regularization term
- - batch_alpha and batch_beta are list of index
-
- The algorithm used for solving the dual problem is the SGD algorithm
- as proposed in [19]_ [alg.1]
-
- Parameters
- ----------
-
- reg : float number,
- Regularization term > 0
- M : np.ndarray(ns, nt),
- cost matrix
- alpha : np.ndarray(ns,)
- dual variable
- beta : np.ndarray(nt,)
- dual variable
- batch_size : int number
- size of the batch
- batch_alpha : np.ndarray(bs,)
- batch of index of alpha
- batch_beta : np.ndarray(bs,)
- batch of index of beta
-
- Returns
- -------
-
- grad : np.ndarray(ns,)
- partial grad F in alpha
-
- Examples
- --------
-
- >>> n_source = 7
- >>> n_target = 4
- >>> reg = 1
- >>> numItermax = 20000
- >>> lr = 0.1
- >>> batch_size = 3
- >>> log = True
- >>> a = ot.utils.unif(n_source)
- >>> b = ot.utils.unif(n_target)
- >>> rng = np.random.RandomState(0)
- >>> X_source = rng.randn(n_source, 2)
- >>> Y_target = rng.randn(n_target, 2)
- >>> M = ot.dist(X_source, Y_target)
- >>> sgd_dual_pi, log = stochastic.solve_dual_entropic(a, b, M, reg,
- batch_size,
- numItermax, lr, log)
- >>> print(log['alpha'], log['beta'])
- >>> print(sgd_dual_pi)
-
- References
- ----------
-
- [Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
- '''
-
- grad_alpha = np.zeros(batch_size)
- grad_alpha[:] = batch_size
- for j in batch_beta:
- grad_alpha -= np.exp((alpha[batch_alpha] + beta[j] -
- M[batch_alpha, j]) / reg)
- return grad_alpha
-
-
-def batch_grad_dual_beta(M, reg, alpha, beta, batch_size, batch_alpha,
- batch_beta):
- '''
- Computes the partial gradient of F_\W_varepsilon
-
- Compute the partial gradient of the dual problem:
-
- ..math:
- \forall j in batch_beta,
- grad_beta_j = 1 * batch_size -
+ \forall j in batch_alpha,
+ grad_beta_j = beta_j * batch_size/len(alpha) -
sum_{i in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
-
+ * a_i * b_j
where :
- M is the (ns,nt) metric cost matrix
- alpha, beta are dual variables in R^ixR^J
- reg is the regularization term
- - batch_alpha and batch_beta are list of index
+ - batch_alpha and batch_beta are lists of index
+ - a and b are source and target weights (sum to 1)
+
The algorithm used for solving the dual problem is the SGD algorithm
as proposed in [19]_ [alg.1]
Parameters
----------
-
+ a : np.ndarray(ns,),
+ source measure
+ b : np.ndarray(nt,),
+ target measure
M : np.ndarray(ns, nt),
cost matrix
reg : float number,
@@ -561,7 +488,7 @@ def batch_grad_dual_beta(M, reg, alpha, beta, batch_size, batch_alpha,
-------
grad : np.ndarray(ns,)
- partial grad F in beta
+ partial grad F
Examples
--------
@@ -591,19 +518,22 @@ def batch_grad_dual_beta(M, reg, alpha, beta, batch_size, batch_alpha,
[Seguy et al., 2018] :
International Conference on Learning Representation (2018),
arXiv preprint arxiv:1711.02283.
-
'''
- grad_beta = np.zeros(batch_size)
- grad_beta[:] = batch_size
- for i in batch_alpha:
- grad_beta -= np.exp((alpha[i] +
- beta[batch_beta] - M[i, batch_beta]) / reg)
- return grad_beta
+ G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] -
+ M[batch_alpha, :][:, batch_beta]) / reg) *
+ a[batch_alpha, None] * b[None, batch_beta])
+ grad_beta = np.zeros(np.shape(M)[1])
+ grad_alpha = np.zeros(np.shape(M)[0])
+ grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0] +
+ G.sum(0))
+ grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta) /
+ np.shape(M)[1] + G.sum(1))
+
+ return grad_alpha, grad_beta
-def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
- alternate=True):
+def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
'''
Compute the sgd algorithm to solve the regularized discrete measures
optimal transport dual problem
@@ -623,7 +553,10 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
Parameters
----------
-
+ a : np.ndarray(ns,),
+ source measure
+ b : np.ndarray(nt,),
+ target measure
M : np.ndarray(ns, nt),
cost matrix
reg : float number,
@@ -634,8 +567,6 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
number of iteration
lr : float number
learning rate
- alternate : bool, optional
- alternating algorithm
Returns
-------
@@ -662,8 +593,8 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
>>> Y_target = rng.randn(n_target, 2)
>>> M = ot.dist(X_source, Y_target)
>>> sgd_dual_pi, log = stochastic.solve_dual_entropic(a, b, M, reg,
- batch_size,
- numItermax, lr, log)
+ batch_size,
+ numItermax, lr, log)
>>> print(log['alpha'], log['beta'])
>>> print(sgd_dual_pi)
@@ -677,35 +608,17 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
n_source = np.shape(M)[0]
n_target = np.shape(M)[1]
- cur_alpha = np.random.randn(n_source)
- cur_beta = np.random.randn(n_target)
- if alternate:
- for cur_iter in range(numItermax):
- k = np.sqrt(cur_iter + 1)
- batch_alpha = np.random.choice(n_source, batch_size, replace=False)
- batch_beta = np.random.choice(n_target, batch_size, replace=False)
- grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta,
- batch_size, batch_alpha,
- batch_beta)
- cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha
- grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta,
- batch_size, batch_alpha,
- batch_beta)
- cur_beta[batch_beta] += (lr / k) * grad_F_beta
-
- else:
- for cur_iter in range(numItermax):
- k = np.sqrt(cur_iter + 1)
- batch_alpha = np.random.choice(n_source, batch_size, replace=False)
- batch_beta = np.random.choice(n_target, batch_size, replace=False)
- grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta,
- batch_size, batch_alpha,
- batch_beta)
- grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta,
- batch_size, batch_alpha,
- batch_beta)
- cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha
- cur_beta[batch_beta] += (lr / k) * grad_F_beta
+ cur_alpha = np.zeros(n_source)
+ cur_beta = np.zeros(n_target)
+ for cur_iter in range(numItermax):
+ k = np.sqrt(cur_iter + 1)
+ batch_alpha = np.random.choice(n_source, batch_size, replace=False)
+ batch_beta = np.random.choice(n_target, batch_size, replace=False)
+ update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
+ cur_beta, batch_size,
+ batch_alpha, batch_beta)
+ cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha]
+ cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta]
return cur_alpha, cur_beta
@@ -787,7 +700,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
arXiv preprint arxiv:1711.02283.
'''
- opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, batch_size,
+ opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
numItermax, lr)
pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
a[:, None] * b[None, :])
diff --git a/test/test_bregman.py b/test/test_bregman.py
index c8e9179..01ec655 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -105,6 +105,30 @@ def test_bary():
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+def test_wassersteinbary():
+
+ size = 100 # size of a square image
+ a1 = np.random.randn(size, size)
+ a1 += a1.min()
+ a1 = a1 / np.sum(a1)
+ a2 = np.random.randn(size, size)
+ a2 += a2.min()
+ a2 = a2 / np.sum(a2)
+ # creating matrix A containing all distributions
+ A = np.zeros((2, 100, 100))
+ A[0, :, :] = a1
+ A[1, :, :] = a2
+
+ # wasserstein
+ reg = 1e-3
+ bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+
def test_unmix():
n_bins = 50 # nb bins
diff --git a/test/test_da.py b/test/test_da.py
index 97e23da..f7f3a9d 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -484,66 +484,3 @@ def test_linear_mapping_class():
Cst = np.cov(Xst.T)
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-
-
-def test_otda():
-
- n_samples = 150 # nb samples
- np.random.seed(0)
-
- xs, ys = ot.datasets.make_data_classif('3gauss', n_samples)
- xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples)
-
- a, b = ot.unif(n_samples), ot.unif(n_samples)
-
- # LP problem
- da_emd = ot.da.OTDA() # init class
- da_emd.fit(xs, xt) # fit distributions
- da_emd.interp() # interpolation of source samples
- da_emd.predict(xs) # interpolation of source samples
-
- np.testing.assert_allclose(a, np.sum(da_emd.G, 1))
- np.testing.assert_allclose(b, np.sum(da_emd.G, 0))
-
- # sinkhorn regularization
- lambd = 1e-1
- da_entrop = ot.da.OTDA_sinkhorn()
- da_entrop.fit(xs, xt, reg=lambd)
- da_entrop.interp()
- da_entrop.predict(xs)
-
- np.testing.assert_allclose(
- a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
- np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
-
- # non-convex Group lasso regularization
- reg = 1e-1
- eta = 1e0
- da_lpl1 = ot.da.OTDA_lpl1()
- da_lpl1.fit(xs, ys, xt, reg=reg, eta=eta)
- da_lpl1.interp()
- da_lpl1.predict(xs)
-
- np.testing.assert_allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3)
- np.testing.assert_allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3)
-
- # True Group lasso regularization
- reg = 1e-1
- eta = 2e0
- da_l1l2 = ot.da.OTDA_l1l2()
- da_l1l2.fit(xs, ys, xt, reg=reg, eta=eta, numItermax=20, verbose=True)
- da_l1l2.interp()
- da_l1l2.predict(xs)
-
- np.testing.assert_allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3)
- np.testing.assert_allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3)
-
- # linear mapping
- da_emd = ot.da.OTDA_mapping_linear() # init class
- da_emd.fit(xs, xt, numItermax=10) # fit distributions
- da_emd.predict(xs) # interpolation of source samples
-
- # nonlinear mapping
- da_emd = ot.da.OTDA_mapping_kernel() # init class
- da_emd.fit(xs, xt, numItermax=10) # fit distributions
- da_emd.predict(xs) # interpolation of source samples
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index f315c88..f0f3fc8 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -32,7 +32,7 @@ def test_stochastic_sag():
# test sag
n = 15
reg = 1
- numItermax = 300000
+ numItermax = 30000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -62,7 +62,7 @@ def test_stochastic_asgd():
# test asgd
n = 15
reg = 1
- numItermax = 300000
+ numItermax = 100000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -92,12 +92,11 @@ def test_sag_asgd_sinkhorn():
# test all algorithms
n = 15
reg = 1
- nb_iter = 300000
+ nb_iter = 100000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
u = ot.utils.unif(n)
- zero = np.zeros(n)
M = ot.dist(x, x)
G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
@@ -108,13 +107,13 @@ def test_sag_asgd_sinkhorn():
# check constratints
np.testing.assert_allclose(
- zero, (G_sag - G_sinkhorn).sum(1), atol=1e-03) # cf convergence sag
+ G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(
- zero, (G_sag - G_sinkhorn).sum(0), atol=1e-03) # cf convergence sag
+ G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
- zero, (G_asgd - G_sinkhorn).sum(1), atol=1e-03) # cf convergence asgd
+ G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(
- zero, (G_asgd - G_sinkhorn).sum(0), atol=1e-03) # cf convergence asgd
+ G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
np.testing.assert_allclose(
@@ -137,8 +136,8 @@ def test_stochastic_dual_sgd():
# test sgd
n = 10
reg = 1
- numItermax = 300000
- batch_size = 8
+ numItermax = 15000
+ batch_size = 10
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -151,9 +150,9 @@ def test_stochastic_dual_sgd():
# check constratints
np.testing.assert_allclose(
- u, G.sum(1), atol=1e-02) # cf convergence sgd
+ u, G.sum(1), atol=1e-03) # cf convergence sgd
np.testing.assert_allclose(
- u, G.sum(0), atol=1e-02) # cf convergence sgd
+ u, G.sum(0), atol=1e-03) # cf convergence sgd
#############################################################################
@@ -168,13 +167,13 @@ def test_dual_sgd_sinkhorn():
# test all dual algorithms
n = 10
reg = 1
- nb_iter = 300000
- batch_size = 8
+ nb_iter = 15000
+ batch_size = 10
rng = np.random.RandomState(0)
+# Test uniform
x = rng.randn(n, 2)
u = ot.utils.unif(n)
- zero = np.zeros(n)
M = ot.dist(x, x)
G_sgd = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
@@ -184,8 +183,33 @@ def test_dual_sgd_sinkhorn():
# check constratints
np.testing.assert_allclose(
- zero, (G_sgd - G_sinkhorn).sum(1), atol=1e-02) # cf convergence sgd
+ G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(
- zero, (G_sgd - G_sinkhorn).sum(0), atol=1e-02) # cf convergence sgd
+ G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
- G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd
+ G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+
+# Test gaussian
+ n = 30
+ reg = 1
+ batch_size = 30
+
+ a = ot.datasets.make_1D_gauss(n, 15, 5) # m= mean, s= std
+ b = ot.datasets.make_1D_gauss(n, 15, 5)
+ X_source = np.arange(n, dtype=np.float64)
+ Y_target = np.arange(n, dtype=np.float64)
+ M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1)))
+ M /= M.max()
+
+ G_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size,
+ numItermax=nb_iter)
+
+ G_sinkhorn = ot.sinkhorn(a, b, M, reg)
+
+ # check constratints
+ np.testing.assert_allclose(
+ G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd