From c0c959da8e62d57587ed36e8ba359ca095c5b423 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 24 Jul 2018 13:55:55 +0200 Subject: speedup einsum constraint violation --- ot/bregman.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index c8e69ce..26b7b53 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, np.exp(K, out=K) # print(np.min(K)) - tmp = np.empty(K.shape, dtype=M.dtype) tmp2 = np.empty(b.shape, dtype=M.dtype) Kp = (1 / a).reshape(-1, 1) * K @@ -379,11 +378,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ np.sum((v - vprev)**2) / np.sum((v)**2) else: - np.multiply(u.reshape(-1, 1), K, out=tmp) - np.multiply(tmp, v.reshape(1, -1), out=tmp) - np.sum(tmp, axis=0, out=tmp2) - tmp2 -= b - err = np.linalg.norm(tmp2)**2 + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + np.einsum('i,ij,j->j',u,K,v,out=tmp2) + err = np.linalg.norm(tmp2-b)**2 # violation of marginal if log: log['err'].append(err) -- cgit v1.2.3 From 66816cba7cd666706d054bddded1da6035e78c2a Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 24 Jul 2018 14:13:23 +0200 Subject: pep8 all the way --- ot/bregman.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 26b7b53..ab84bcf 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -378,9 +378,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ np.sum((v - vprev)**2) / np.sum((v)**2) else: - # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - np.einsum('i,ij,j->j',u,K,v,out=tmp2) - err = np.linalg.norm(tmp2-b)**2 # violation of marginal + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + np.einsum('i,ij,j->j', u, K, v, out=tmp2) + err = np.linalg.norm(tmp2 - b)**2 # violation of marginal if log: log['err'].append(err) -- cgit v1.2.3 From bbe411775b3d5abb5d6fb525262cccce3f73d345 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 24 Jul 2018 14:31:45 +0200 Subject: test eisum instead of dot --- ot/bregman.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index ab84bcf..d2ade46 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -358,9 +358,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, while (err > stopThr and cpt < numItermax): uprev = u vprev = v - KtransposeU = np.dot(K.T, u) + KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u) v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) + u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v) if (np.any(KtransposeU == 0) or np.any(np.isnan(u)) or np.any(np.isnan(v)) or -- cgit v1.2.3 From a04112c69a62182c061d4b65e71ebb43c866d3e1 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 24 Jul 2018 14:33:32 +0200 Subject: correction size --- ot/bregman.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index d2ade46..29ca9fd 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -358,9 +358,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, while (err > stopThr and cpt < numItermax): uprev = u vprev = v - KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v) + if nbb: + KtransposeU = np.einsum('ij,i,k->jk',K,u)#np.dot(K.T, u) + v = np.divide(b, KtransposeU) + u = 1. / np.einsum('ij,jk->ik',Kp,v)#np.dot(Kp, v) + else: + KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u) + v = np.divide(b, KtransposeU) + u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v) if (np.any(KtransposeU == 0) or np.any(np.isnan(u)) or np.any(np.isnan(v)) or -- cgit v1.2.3 From 603c0eee29db890b0092ea8c848473bf413e186f Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 24 Jul 2018 14:34:17 +0200 Subject: pb index --- ot/bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 29ca9fd..57cedb2 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -359,7 +359,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, uprev = u vprev = v if nbb: - KtransposeU = np.einsum('ij,i,k->jk',K,u)#np.dot(K.T, u) + KtransposeU = np.einsum('ij,ik->jk',K,u)#np.dot(K.T, u) v = np.divide(b, KtransposeU) u = 1. / np.einsum('ij,jk->ik',Kp,v)#np.dot(Kp, v) else: -- cgit v1.2.3 From 5e3392a029e675c7e19f8b1723fcfdb9aa9142aa Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 24 Jul 2018 14:35:58 +0200 Subject: cancel einsum --- ot/bregman.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 57cedb2..1873c46 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -358,14 +358,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, while (err > stopThr and cpt < numItermax): uprev = u vprev = v - if nbb: - KtransposeU = np.einsum('ij,ik->jk',K,u)#np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.einsum('ij,jk->ik',Kp,v)#np.dot(Kp, v) - else: - KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v) + + KtransposeU = np.dot(K.T, u) + v = np.divide(b, KtransposeU) + u = 1. / np.dot(Kp, v) + if (np.any(KtransposeU == 0) or np.any(np.isnan(u)) or np.any(np.isnan(v)) or -- cgit v1.2.3 From ace77962d2ae6407916ee7e4377f5c7ed0a8d8f2 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 24 Jul 2018 14:40:42 +0200 Subject: final makefile bench --- ot/bregman.py | 1 - 1 file changed, 1 deletion(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 1873c46..58e74de 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -362,7 +362,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, KtransposeU = np.dot(K.T, u) v = np.divide(b, KtransposeU) u = 1. / np.dot(Kp, v) - if (np.any(KtransposeU == 0) or np.any(np.isnan(u)) or np.any(np.isnan(v)) or -- cgit v1.2.3 From f4bfeb73da098384aa67599e7f729fb683a1bcc9 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 24 Jul 2018 15:54:56 +0200 Subject: ensum tets marginals sinkhorn --- ot/bregman.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 58e74de..c755f51 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -396,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log['v'] = v if nbb: # return only loss - res = np.zeros((nbb)) - for i in range(nbb): - res[i] = np.sum( - u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M) + res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: -- cgit v1.2.3 From d99abf078537acf6cf49480b9790a9c450889031 Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Fri, 7 Sep 2018 11:58:42 +0200 Subject: Wasserstein convolutional barycenter --- README.md | 4 +- data/duck.png | Bin 0 -> 5112 bytes data/heart.png | Bin 0 -> 5225 bytes data/redcross.png | Bin 0 -> 1683 bytes data/tooth.png | Bin 0 -> 4931 bytes examples/plot_convolutional_barycenter.py | 92 ++++++++++++++++++++++++++ ot/bregman.py | 106 ++++++++++++++++++++++++++++++ 7 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 data/duck.png create mode 100644 data/heart.png create mode 100644 data/redcross.png create mode 100644 data/tooth.png create mode 100644 examples/plot_convolutional_barycenter.py (limited to 'ot/bregman.py') diff --git a/README.md b/README.md index dded582..1105362 100644 --- a/README.md +++ b/README.md @@ -227,4 +227,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) -[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning \ No newline at end of file +[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning + +[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. \ No newline at end of file diff --git a/data/duck.png b/data/duck.png new file mode 100644 index 0000000..9181697 Binary files /dev/null and b/data/duck.png differ diff --git a/data/heart.png b/data/heart.png new file mode 100644 index 0000000..44a6385 Binary files /dev/null and b/data/heart.png differ diff --git a/data/redcross.png b/data/redcross.png new file mode 100644 index 0000000..8d0a6fa Binary files /dev/null and b/data/redcross.png differ diff --git a/data/tooth.png b/data/tooth.png new file mode 100644 index 0000000..cd92c9d Binary files /dev/null and b/data/tooth.png differ diff --git a/examples/plot_convolutional_barycenter.py b/examples/plot_convolutional_barycenter.py new file mode 100644 index 0000000..d231da9 --- /dev/null +++ b/examples/plot_convolutional_barycenter.py @@ -0,0 +1,92 @@ + +#%% +# -*- coding: utf-8 -*- +""" +============================================ +Convolutional Wasserstein Barycenter example +============================================ + +This example is designed to illustrate how the Convolutional Wasserstein Barycenter +function of POT works. +""" + +# Author: Nicolas Courty +# +# License: MIT License + + +import numpy as np +import pylab as pl +import ot + +############################################################################## +# Data preparation +# ---------------- +# +# The four distributions are constructed from 4 simple images + + +f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2] +f2 = 1 - pl.imread('../data/duck.png')[:, :, 2] +f3 = 1 - pl.imread('../data/heart.png')[:, :, 2] +f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2] + +A = [] +f1=f1/np.sum(f1) +f2=f2/np.sum(f2) +f3=f3/np.sum(f3) +f4=f4/np.sum(f4) +A.append(f1) +A.append(f2) +A.append(f3) +A.append(f4) +A=np.array(A) + +nb_images = 5 + +# those are the four corners coordinates that will be interpolated by bilinear +# interpolation +v1=np.array((1,0,0,0)) +v2=np.array((0,1,0,0)) +v3=np.array((0,0,1,0)) +v4=np.array((0,0,0,1)) + + +############################################################################## +# Barycenter computation and visualization +# ---------------------------------------- +# + +pl.figure(figsize=(10,10)) +pl.title('Convolutional Wasserstein Barycenters in POT') +cm='Blues' +# regularization parameter +reg=0.004 +for i in range(nb_images): + for j in range(nb_images): + pl.subplot(nb_images,nb_images,i*nb_images+j+1) + tx=float(i)/(nb_images-1) + ty=float(j)/(nb_images-1) + + # weights are constructed by bilinear interpolation + tmp1=(1-tx)*v1+tx*v2 + tmp2=(1-tx)*v3+tx*v4 + weights=(1-ty)*tmp1+ty*tmp2 + + if i==0 and j==0: + pl.imshow(f1,cmap=cm) + pl.axis('off') + elif i==0 and j==(nb_images-1): + pl.imshow(f3,cmap=cm) + pl.axis('off') + elif i==(nb_images-1) and j==0: + pl.imshow(f2,cmap=cm) + pl.axis('off') + elif i==(nb_images-1) and j==(nb_images-1): + pl.imshow(f4,cmap=cm) + pl.axis('off') + else: + # call to barycenter computation + pl.imshow(ot.convolutional_barycenter2d(A,reg,weights),cmap=cm) + pl.axis('off') +pl.show() \ No newline at end of file diff --git a/ot/bregman.py b/ot/bregman.py index c755f51..05f4d9d 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -918,6 +918,112 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, else: return geometricBar(weights, UKv) +def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e-9, verbose=False, log=False): + """Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - reg is the regularization strength scalar value + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_ + + Parameters + ---------- + A : np.ndarray (n,w,h) + n distributions (2D images) of size w x h + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each image on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (w,h) ndarray + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). + Convolutional wasserstein distances: Efficient optimal transportation on geometric domains + ACM Transactions on Graphics (TOG), 34(4), 66 + + + """ + + if weights is None: + weights = np.ones(A.shape[0]) / A.shape[0] + else: + assert(len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + b=np.zeros_like(A[0,:,:]) + U=np.ones_like(A) + KV=np.ones_like(A) + threshold = 1e-30 # in order to avoids numerical precision issues + + cpt = 0 + err=1 + + # build the convolution operator + t = np.linspace(0,1,A.shape[1]) + [Y,X] = np.meshgrid(t,t) + xi1 = np.exp(-(X-Y)**2/reg) + K = lambda x: np.dot(np.dot(xi1,x),xi1) + + while (err>stopThr and cpt Date: Fri, 7 Sep 2018 12:04:44 +0200 Subject: pep8 normalization --- examples/plot_convolutional_barycenter.py | 70 +++++++++++++++---------------- ot/bregman.py | 68 +++++++++++++++--------------- 2 files changed, 70 insertions(+), 68 deletions(-) (limited to 'ot/bregman.py') diff --git a/examples/plot_convolutional_barycenter.py b/examples/plot_convolutional_barycenter.py index d231da9..7ccdbe3 100644 --- a/examples/plot_convolutional_barycenter.py +++ b/examples/plot_convolutional_barycenter.py @@ -1,4 +1,4 @@ - + #%% # -*- coding: utf-8 -*- """ @@ -32,24 +32,24 @@ f3 = 1 - pl.imread('../data/heart.png')[:, :, 2] f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2] A = [] -f1=f1/np.sum(f1) -f2=f2/np.sum(f2) -f3=f3/np.sum(f3) -f4=f4/np.sum(f4) +f1 = f1 / np.sum(f1) +f2 = f2 / np.sum(f2) +f3 = f3 / np.sum(f3) +f4 = f4 / np.sum(f4) A.append(f1) A.append(f2) A.append(f3) A.append(f4) -A=np.array(A) +A = np.array(A) nb_images = 5 # those are the four corners coordinates that will be interpolated by bilinear # interpolation -v1=np.array((1,0,0,0)) -v2=np.array((0,1,0,0)) -v3=np.array((0,0,1,0)) -v4=np.array((0,0,0,1)) +v1 = np.array((1, 0, 0, 0)) +v2 = np.array((0, 1, 0, 0)) +v3 = np.array((0, 0, 1, 0)) +v4 = np.array((0, 0, 0, 1)) ############################################################################## @@ -57,36 +57,36 @@ v4=np.array((0,0,0,1)) # ---------------------------------------- # -pl.figure(figsize=(10,10)) +pl.figure(figsize=(10, 10)) pl.title('Convolutional Wasserstein Barycenters in POT') -cm='Blues' +cm = 'Blues' # regularization parameter -reg=0.004 +reg = 0.004 for i in range(nb_images): for j in range(nb_images): - pl.subplot(nb_images,nb_images,i*nb_images+j+1) - tx=float(i)/(nb_images-1) - ty=float(j)/(nb_images-1) - + pl.subplot(nb_images, nb_images, i * nb_images + j + 1) + tx = float(i) / (nb_images - 1) + ty = float(j) / (nb_images - 1) + # weights are constructed by bilinear interpolation - tmp1=(1-tx)*v1+tx*v2 - tmp2=(1-tx)*v3+tx*v4 - weights=(1-ty)*tmp1+ty*tmp2 - - if i==0 and j==0: - pl.imshow(f1,cmap=cm) - pl.axis('off') - elif i==0 and j==(nb_images-1): - pl.imshow(f3,cmap=cm) - pl.axis('off') - elif i==(nb_images-1) and j==0: - pl.imshow(f2,cmap=cm) - pl.axis('off') - elif i==(nb_images-1) and j==(nb_images-1): - pl.imshow(f4,cmap=cm) - pl.axis('off') + tmp1 = (1 - tx) * v1 + tx * v2 + tmp2 = (1 - tx) * v3 + tx * v4 + weights = (1 - ty) * tmp1 + ty * tmp2 + + if i == 0 and j == 0: + pl.imshow(f1, cmap=cm) + pl.axis('off') + elif i == 0 and j == (nb_images - 1): + pl.imshow(f3, cmap=cm) + pl.axis('off') + elif i == (nb_images - 1) and j == 0: + pl.imshow(f2, cmap=cm) + pl.axis('off') + elif i == (nb_images - 1) and j == (nb_images - 1): + pl.imshow(f4, cmap=cm) + pl.axis('off') else: # call to barycenter computation - pl.imshow(ot.convolutional_barycenter2d(A,reg,weights),cmap=cm) + pl.imshow(ot.convolutional_barycenter2d(A, reg, weights), cmap=cm) pl.axis('off') -pl.show() \ No newline at end of file +pl.show() diff --git a/ot/bregman.py b/ot/bregman.py index 05f4d9d..f844f03 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -918,7 +918,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, else: return geometricBar(weights, UKv) -def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e-9, verbose=False, log=False): + +def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, verbose=False, log=False): """Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -979,51 +980,52 @@ def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e if log: log = {'err': []} - b=np.zeros_like(A[0,:,:]) - U=np.ones_like(A) - KV=np.ones_like(A) - threshold = 1e-30 # in order to avoids numerical precision issues + b = np.zeros_like(A[0, :, :]) + U = np.ones_like(A) + KV = np.ones_like(A) + threshold = 1e-30 # in order to avoids numerical precision issues cpt = 0 - err=1 - - # build the convolution operator - t = np.linspace(0,1,A.shape[1]) - [Y,X] = np.meshgrid(t,t) - xi1 = np.exp(-(X-Y)**2/reg) - K = lambda x: np.dot(np.dot(xi1,x),xi1) - - while (err>stopThr and cpt stopThr and cpt < numItermax): + + bold = b + cpt = cpt + 1 + + b = np.zeros_like(A[0, :, :]) for r in range(A.shape[0]): - KV[r,:,:]=K(A[r,:,:]/np.maximum(threshold,K(U[r,:,:]))) - b += weights[r] * np.log(np.maximum(threshold, U[r,:,:]*KV[r,:,:])) + KV[r, :, :] = K(A[r, :, :] / np.maximum(threshold, K(U[r, :, :]))) + b += weights[r] * np.log(np.maximum(threshold, U[r, :, :] * KV[r, :, :])) b = np.exp(b) for r in range(A.shape[0]): - U[r,:,:]=b/np.maximum(threshold,KV[r,:,:]) - - if cpt%10==1: - err=np.sum(np.abs(bold-b)) + U[r, :, :] = b / np.maximum(threshold, KV[r, :, :]) + + if cpt % 10 == 1: + err = np.sum(np.abs(bold - b)) # log and verbose print if log: log['err'].append(err) if verbose: - if cpt%200 ==0: - print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) - print('{:5d}|{:8e}|'.format(cpt,err)) + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) if log: - log['niter']=cpt - log['U']=U - return b,log + log['niter'] = cpt + log['U'] = U + return b, log else: - return b - + return b + def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): -- cgit v1.2.3 From e8c6d2fc9c6b08bbed11628326711ab29c155bac Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Fri, 7 Sep 2018 12:34:14 +0200 Subject: pep8 fixed (contd) --- ot/bregman.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index f844f03..5327dbc 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -920,8 +920,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, verbose=False, log=False): - """Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images. + """Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. The function solves the following optimization problem: @@ -966,8 +966,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 ---------- .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). - Convolutional wasserstein distances: Efficient optimal transportation on geometric domains - ACM Transactions on Graphics (TOG), 34(4), 66 + Convolutional wasserstein distances: Efficient optimal transportation on geometric domains + ACM Transactions on Graphics (TOG), 34(4), 66 """ @@ -993,7 +993,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 [Y, X] = np.meshgrid(t, t) xi1 = np.exp(-(X - Y)**2 / reg) - def K(x): return np.dot(np.dot(xi1, x), xi1) + def K(x): + return np.dot(np.dot(xi1, x), xi1) while (err > stopThr and cpt < numItermax): -- cgit v1.2.3 From d19295b9cb29d21e09eeb28ac4b0e61990727023 Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Fri, 7 Sep 2018 14:41:00 +0200 Subject: stabThr and pep8 --- ot/bregman.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 5327dbc..748ac30 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -919,7 +919,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) -def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, verbose=False, log=False): +def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False): """Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -948,6 +948,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 Max number of iterations stopThr : float, optional Stop threshol on error (>0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue verbose : bool, optional Print information along iterations log : bool, optional @@ -983,7 +985,6 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 b = np.zeros_like(A[0, :, :]) U = np.ones_like(A) KV = np.ones_like(A) - threshold = 1e-30 # in order to avoids numerical precision issues cpt = 0 err = 1 @@ -993,7 +994,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 [Y, X] = np.meshgrid(t, t) xi1 = np.exp(-(X - Y)**2 / reg) - def K(x): + def K(x): return np.dot(np.dot(xi1, x), xi1) while (err > stopThr and cpt < numItermax): @@ -1003,11 +1004,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 b = np.zeros_like(A[0, :, :]) for r in range(A.shape[0]): - KV[r, :, :] = K(A[r, :, :] / np.maximum(threshold, K(U[r, :, :]))) - b += weights[r] * np.log(np.maximum(threshold, U[r, :, :] * KV[r, :, :])) + KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :]))) + b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :])) b = np.exp(b) for r in range(A.shape[0]): - U[r, :, :] = b / np.maximum(threshold, KV[r, :, :]) + U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :]) if cpt % 10 == 1: err = np.sum(np.abs(bold - b)) -- cgit v1.2.3 From dab572396be97fcf5439e4e20f887165b1ade62c Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Fri, 7 Sep 2018 14:49:20 +0200 Subject: whitetrail pep8 --- ot/bregman.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 748ac30..35e51f8 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -920,8 +920,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False): - """Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images. + """Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. The function solves the following optimization problem: @@ -949,7 +949,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 stopThr : float, optional Stop threshol on error (>0) stabThr : float, optional - Stabilization threshold to avoid numerical precision issue + Stabilization threshold to avoid numerical precision issue verbose : bool, optional Print information along iterations log : bool, optional @@ -967,9 +967,9 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). - Convolutional wasserstein distances: Efficient optimal transportation on geometric domains - ACM Transactions on Graphics (TOG), 34(4), 66 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). + Convolutional wasserstein distances: Efficient optimal transportation on geometric domains + ACM Transactions on Graphics (TOG), 34(4), 66 """ -- cgit v1.2.3