From 8c724ad3579959e9d369c0b7fbaa22ea19ced614 Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Wed, 15 Apr 2020 15:32:56 +0200 Subject: partial with tests --- examples/plot_partial_wass_and_gromov.py | 18 +++---- ot/__init__.py | 81 -------------------------------- ot/partial.py | 23 +++++---- ot/unbalanced.py | 66 ++++++++------------------ 4 files changed, 39 insertions(+), 149 deletions(-) delete mode 100644 ot/__init__.py diff --git a/examples/plot_partial_wass_and_gromov.py b/examples/plot_partial_wass_and_gromov.py index 2ddeb68..30b3fc0 100755 --- a/examples/plot_partial_wass_and_gromov.py +++ b/examples/plot_partial_wass_and_gromov.py @@ -33,9 +33,9 @@ mu = np.array([0, 0]) cov = np.array([[1, 0], [0, 2]]) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) -xs = np.append(xs, (np.random.rand(n_noise, 2)+1)*4).reshape((-1, 2)) +xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2)) xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) -xt = np.append(xt, (np.random.rand(n_noise, 2)+1)*-3).reshape((-1, 2)) +xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2)) M = sp.spatial.distance.cdist(xs, xt) @@ -62,7 +62,7 @@ w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5, log=True) print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist'])) -print('Entropic partial Wasserstein distance (m = 0.5): ' + \ +print('Entropic partial Wasserstein distance (m = 0.5): ' + str(log['partial_w_dist'])) pl.figure(1, (10, 5)) @@ -98,10 +98,10 @@ cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) -xs = np.concatenate((xs, ((np.random.rand(n_noise, 2)+1)*4)), axis=0) +xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0) P = sp.linalg.sqrtm(cov_t) xt = np.random.randn(n_samples, 3).dot(P) + mu_t -xt = np.concatenate((xt, ((np.random.rand(n_noise, 3)+1)*10)), axis=0) +xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0) fig = pl.figure() ax1 = fig.add_subplot(121) @@ -128,7 +128,7 @@ res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, m=m, log=True) print('Partial Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist'])) -print('Entropic partial Wasserstein distance (m = 1): ' + \ +print('Entropic partial Wasserstein distance (m = 1): ' + str(log['partial_gw_dist'])) pl.figure(1, (10, 5)) @@ -142,14 +142,14 @@ pl.title('Entropic partial Wasserstein') pl.show() print('-----m = 2/3') -m = 2/3 +m = 2 / 3 res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, m=m, log=True) -print('Partial Wasserstein distance (m = 2/3): ' + \ +print('Partial Wasserstein distance (m = 2/3): ' + str(log0['partial_gw_dist'])) -print('Entropic partial Wasserstein distance (m = 2/3): ' + \ +print('Entropic partial Wasserstein distance (m = 2/3): ' + str(log['partial_gw_dist'])) pl.figure(1, (10, 5)) diff --git a/ot/__init__.py b/ot/__init__.py deleted file mode 100644 index 89c7936..0000000 --- a/ot/__init__.py +++ /dev/null @@ -1,81 +0,0 @@ -""" - -This is the main module of the POT toolbox. It provides easy access to -a number of sub-modules and functions described below. - -.. note:: - - - Here is a list of the submodules and short description of what they contain. - - - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems. - - :any:`ot.bregman` contains OT solvers for the entropic OT problems using - Bregman projections. - - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems. - - :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT - problems. - - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov - Wasserstein problems. - - :any:`ot.optim` contains generic solvers OT based optimization problems - - :any:`ot.da` contains classes and function related to Monge mapping - estimation and Domain Adaptation (DA). - - :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers - - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein - Discriminant Analysis. - - :any:`ot.utils` contains utility functions such as distance computation and - timing. - - :any:`ot.datasets` contains toy dataset generation functions. - - :any:`ot.plot` contains visualization functions - - :any:`ot.stochastic` contains stochastic solvers for regularized OT. - - :any:`ot.unbalanced` contains solvers for regularized unbalanced OT. - -.. warning:: - The list of automatically imported sub-modules is as follows: - :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` - :py:mod:`ot.utils`, :py:mod:`ot.datasets`, - :py:mod:`ot.gromov`, :py:mod:`ot.smooth` - :py:mod:`ot.stochastic` - - The following sub-modules are not imported due to additional dependencies: - - - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. - - :any:`ot.plot` : depends on :code:`matplotlib` - -""" - -# Author: Remi Flamary -# Nicolas Courty -# -# License: MIT License - - -# All submodules and packages -from . import lp -from . import bregman -from . import optim -from . import utils -from . import datasets -from . import da -from . import gromov -from . import smooth -from . import stochastic -from . import unbalanced - -# OT functions -from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d -from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 -from .da import sinkhorn_lpl1_mm - -# utils functions -from .utils import dist, unif, tic, toc, toq - -__version__ = "0.6.0" - -__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', - 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', - 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', - 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2'] diff --git a/ot/partial.py b/ot/partial.py index 746f337..3425acb 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -9,12 +9,11 @@ Partial OT import numpy as np -from ot.lp import emd +from .lp import emd def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, **kwargs): - r""" Solves the partial optimal transport problem for the quadratic cost and returns the OT plan @@ -136,7 +135,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['cost'] = np.sum(gamma*M) + log_emd['cost'] = np.sum(gamma * M) if log: return gamma, log_emd else: @@ -233,7 +232,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) - M_extended = np.ones((len(a_extended), len(b_extended))) * np.max(M) * 1e2 + M_extended = np.ones((len(a_extended), len(b_extended))) * 0 M_extended[-1, -1] = np.max(M) * 1e5 M_extended[:len(a), :len(b)] = M @@ -381,7 +380,7 @@ def gwloss_partial(C1, C2, T): Returns ------- - GW loss + GW loss """ g = gwgrad_partial(C1, C2, T) * 0.5 return np.sum(g * T) @@ -432,7 +431,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix thres : float, optional - quantile of the gradient matrix to populate the cost matrix when 0 + quantile of the gradient matrix to populate the cost matrix when 0 (default: 1) numItermax : int, optional Max number of iterations @@ -566,7 +565,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, where : - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are the sample weights - m is the amount of mass to be transported @@ -591,7 +590,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix thres : float, optional - quantile of the gradient matrix to populate the cost matrix when 0 + quantile of the gradient matrix to populate the cost matrix when 0 (default: 1) numItermax : int, optional Max number of iterations @@ -666,7 +665,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, where : - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are the sample weights - m is the amount of mass to be transported @@ -754,7 +753,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - np.multiply(K, m/np.sum(K), out=K) + np.multiply(K, m / np.sum(K), out=K) err, cpt = 1, 0 @@ -809,7 +808,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, - C2 is the metric cost matrix in the target space - p and q are the sample weights - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - m is the amount of mass to be transported @@ -944,7 +943,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, - C2 is the metric cost matrix in the target space - p and q are the sample weights - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - m is the amount of mass to be transported diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 66a8830..23f6607 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -14,7 +14,7 @@ from scipy.special import logsumexp # from .utils import unif, dist -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem @@ -120,20 +120,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numI """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -261,8 +261,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', else: raise ValueError('Unknown method %s.' % method) -# TODO: update the doc -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, + +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -349,7 +349,6 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, """ a = np.asarray(a, dtype=np.float64) - print(a) b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) @@ -377,39 +376,24 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, else: u = np.ones(dim_a) / dim_a v = np.ones(dim_b) / dim_b - u = np.ones(dim_a) - v = np.ones(dim_b) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) - np.true_divide(M, -reg, out=K) + np.divide(M, -reg, out=K) np.exp(K, out=K) - - if div == "KL": - fi = reg_m / (reg_m + reg) - elif div == "TV": - fi = reg_m / reg + + fi = reg_m / (reg_m + reg) err = 1. - - dx = np.ones(dim_a) / dim_a - dy = np.ones(dim_b) / dim_b - z = 1 for i in range(numItermax): uprev = u vprev = v - Kv = z*K.dot(v*dy) - u = scaling_iter_prox(Kv, a, fi, div) - #u = (a / Kv) ** fi - Ktu = z*K.T.dot(u*dx) - v = scaling_iter_prox(Ktu, b, fi, div) - #v = (b / Ktu) ** fi - #print(v*dy) - z = np.dot((u*dx).T, np.dot(K,v*dy))/0.35 - print(z) - + Kv = K.dot(v) + u = (a / Kv) ** fi + Ktu = K.T.dot(u) + v = (b / Ktu) ** fi if (np.any(Ktu == 0.) or np.any(np.isnan(u)) or np.any(np.isnan(v)) @@ -450,12 +434,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, if log: return u[:, None] * K * v[None, :], log else: - return z*u[:, None] * K * v[None, :] + return u[:, None] * K * v[None, :] + -# TODO: update the doc -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False, **kwargs): +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, + **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -580,10 +564,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, np.divide(M, -reg, out=K) np.exp(K, out=K) - if div == "KL": - fi = reg_m / (reg_m + reg) - elif div == "TV": - fi = reg_m / reg + fi = reg_m / (reg_m + reg) cpt = 0 err = 1. @@ -669,15 +650,6 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, else: return ot_matrix -def scaling_iter_prox(s, p, fi, div): - if div == "KL": - return (p / s) ** fi - elif div == "TV": - return np.minimum(s*np.exp(fi), np.maximum(s*np.exp(-fi), p)) / s - else: - raise ValueError("Unknown divergence '%s'." % div) - - def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax=1000, stopThr=1e-6, -- cgit v1.2.3