diff options
author | Laetitia Chapel <laetitia.chapel@univ-ubs.fr> | 2020-04-15 15:32:56 +0200 |
---|---|---|
committer | Laetitia Chapel <laetitia.chapel@univ-ubs.fr> | 2020-04-15 15:32:56 +0200 |
commit | 8c724ad3579959e9d369c0b7fbaa22ea19ced614 (patch) | |
tree | 528d280492f33f71eac2d6522e0e3c05a4ae8568 | |
parent | fff2463aafd58343c8bc2ed7875622e16a8c1cee (diff) |
partial with tests
-rwxr-xr-x | examples/plot_partial_wass_and_gromov.py | 18 | ||||
-rw-r--r-- | ot/__init__.py | 81 | ||||
-rwxr-xr-x | ot/partial.py | 23 | ||||
-rw-r--r-- | ot/unbalanced.py | 66 |
4 files changed, 39 insertions, 149 deletions
diff --git a/examples/plot_partial_wass_and_gromov.py b/examples/plot_partial_wass_and_gromov.py index 2ddeb68..30b3fc0 100755 --- a/examples/plot_partial_wass_and_gromov.py +++ b/examples/plot_partial_wass_and_gromov.py @@ -33,9 +33,9 @@ mu = np.array([0, 0]) cov = np.array([[1, 0], [0, 2]])
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
-xs = np.append(xs, (np.random.rand(n_noise, 2)+1)*4).reshape((-1, 2))
+xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
-xt = np.append(xt, (np.random.rand(n_noise, 2)+1)*-3).reshape((-1, 2))
+xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
M = sp.spatial.distance.cdist(xs, xt)
@@ -62,7 +62,7 @@ w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5, log=True)
print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist']))
-print('Entropic partial Wasserstein distance (m = 0.5): ' + \
+print('Entropic partial Wasserstein distance (m = 0.5): ' +
str(log['partial_w_dist']))
pl.figure(1, (10, 5))
@@ -98,10 +98,10 @@ cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
-xs = np.concatenate((xs, ((np.random.rand(n_noise, 2)+1)*4)), axis=0)
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
-xt = np.concatenate((xt, ((np.random.rand(n_noise, 3)+1)*10)), axis=0)
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
fig = pl.figure()
ax1 = fig.add_subplot(121)
@@ -128,7 +128,7 @@ res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, m=m, log=True)
print('Partial Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
-print('Entropic partial Wasserstein distance (m = 1): ' + \
+print('Entropic partial Wasserstein distance (m = 1): ' +
str(log['partial_gw_dist']))
pl.figure(1, (10, 5))
@@ -142,14 +142,14 @@ pl.title('Entropic partial Wasserstein') pl.show()
print('-----m = 2/3')
-m = 2/3
+m = 2 / 3
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
m=m, log=True)
-print('Partial Wasserstein distance (m = 2/3): ' + \
+print('Partial Wasserstein distance (m = 2/3): ' +
str(log0['partial_gw_dist']))
-print('Entropic partial Wasserstein distance (m = 2/3): ' + \
+print('Entropic partial Wasserstein distance (m = 2/3): ' +
str(log['partial_gw_dist']))
pl.figure(1, (10, 5))
diff --git a/ot/__init__.py b/ot/__init__.py deleted file mode 100644 index 89c7936..0000000 --- a/ot/__init__.py +++ /dev/null @@ -1,81 +0,0 @@ -""" - -This is the main module of the POT toolbox. It provides easy access to -a number of sub-modules and functions described below. - -.. note:: - - - Here is a list of the submodules and short description of what they contain. - - - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems. - - :any:`ot.bregman` contains OT solvers for the entropic OT problems using - Bregman projections. - - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems. - - :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT - problems. - - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov - Wasserstein problems. - - :any:`ot.optim` contains generic solvers OT based optimization problems - - :any:`ot.da` contains classes and function related to Monge mapping - estimation and Domain Adaptation (DA). - - :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers - - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein - Discriminant Analysis. - - :any:`ot.utils` contains utility functions such as distance computation and - timing. - - :any:`ot.datasets` contains toy dataset generation functions. - - :any:`ot.plot` contains visualization functions - - :any:`ot.stochastic` contains stochastic solvers for regularized OT. - - :any:`ot.unbalanced` contains solvers for regularized unbalanced OT. - -.. warning:: - The list of automatically imported sub-modules is as follows: - :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` - :py:mod:`ot.utils`, :py:mod:`ot.datasets`, - :py:mod:`ot.gromov`, :py:mod:`ot.smooth` - :py:mod:`ot.stochastic` - - The following sub-modules are not imported due to additional dependencies: - - - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. - - :any:`ot.plot` : depends on :code:`matplotlib` - -""" - -# Author: Remi Flamary <remi.flamary@unice.fr> -# Nicolas Courty <ncourty@irisa.fr> -# -# License: MIT License - - -# All submodules and packages -from . import lp -from . import bregman -from . import optim -from . import utils -from . import datasets -from . import da -from . import gromov -from . import smooth -from . import stochastic -from . import unbalanced - -# OT functions -from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d -from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 -from .da import sinkhorn_lpl1_mm - -# utils functions -from .utils import dist, unif, tic, toc, toq - -__version__ = "0.6.0" - -__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', - 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', - 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', - 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2'] diff --git a/ot/partial.py b/ot/partial.py index 746f337..3425acb 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -9,12 +9,11 @@ Partial OT import numpy as np -from ot.lp import emd +from .lp import emd def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, **kwargs): - r""" Solves the partial optimal transport problem for the quadratic cost and returns the OT plan @@ -136,7 +135,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['cost'] = np.sum(gamma*M) + log_emd['cost'] = np.sum(gamma * M) if log: return gamma, log_emd else: @@ -233,7 +232,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) - M_extended = np.ones((len(a_extended), len(b_extended))) * np.max(M) * 1e2 + M_extended = np.ones((len(a_extended), len(b_extended))) * 0 M_extended[-1, -1] = np.max(M) * 1e5 M_extended[:len(a), :len(b)] = M @@ -381,7 +380,7 @@ def gwloss_partial(C1, C2, T): Returns ------- - GW loss + GW loss """ g = gwgrad_partial(C1, C2, T) * 0.5 return np.sum(g * T) @@ -432,7 +431,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix thres : float, optional - quantile of the gradient matrix to populate the cost matrix when 0 + quantile of the gradient matrix to populate the cost matrix when 0 (default: 1) numItermax : int, optional Max number of iterations @@ -566,7 +565,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, where : - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are the sample weights - m is the amount of mass to be transported @@ -591,7 +590,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix thres : float, optional - quantile of the gradient matrix to populate the cost matrix when 0 + quantile of the gradient matrix to populate the cost matrix when 0 (default: 1) numItermax : int, optional Max number of iterations @@ -666,7 +665,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, where : - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are the sample weights - m is the amount of mass to be transported @@ -754,7 +753,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - np.multiply(K, m/np.sum(K), out=K) + np.multiply(K, m / np.sum(K), out=K) err, cpt = 1, 0 @@ -809,7 +808,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, - C2 is the metric cost matrix in the target space - p and q are the sample weights - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - m is the amount of mass to be transported @@ -944,7 +943,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, - C2 is the metric cost matrix in the target space - p and q are the sample weights - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - m is the amount of mass to be transported diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 66a8830..23f6607 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -14,7 +14,7 @@ from scipy.special import logsumexp # from .utils import unif, dist -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem @@ -120,20 +120,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numI """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -261,8 +261,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', else: raise ValueError('Unknown method %s.' % method) -# TODO: update the doc -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, + +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -349,7 +349,6 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, """ a = np.asarray(a, dtype=np.float64) - print(a) b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) @@ -377,39 +376,24 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, else: u = np.ones(dim_a) / dim_a v = np.ones(dim_b) / dim_b - u = np.ones(dim_a) - v = np.ones(dim_b) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) - np.true_divide(M, -reg, out=K) + np.divide(M, -reg, out=K) np.exp(K, out=K) - - if div == "KL": - fi = reg_m / (reg_m + reg) - elif div == "TV": - fi = reg_m / reg + + fi = reg_m / (reg_m + reg) err = 1. - - dx = np.ones(dim_a) / dim_a - dy = np.ones(dim_b) / dim_b - z = 1 for i in range(numItermax): uprev = u vprev = v - Kv = z*K.dot(v*dy) - u = scaling_iter_prox(Kv, a, fi, div) - #u = (a / Kv) ** fi - Ktu = z*K.T.dot(u*dx) - v = scaling_iter_prox(Ktu, b, fi, div) - #v = (b / Ktu) ** fi - #print(v*dy) - z = np.dot((u*dx).T, np.dot(K,v*dy))/0.35 - print(z) - + Kv = K.dot(v) + u = (a / Kv) ** fi + Ktu = K.T.dot(u) + v = (b / Ktu) ** fi if (np.any(Ktu == 0.) or np.any(np.isnan(u)) or np.any(np.isnan(v)) @@ -450,12 +434,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, if log: return u[:, None] * K * v[None, :], log else: - return z*u[:, None] * K * v[None, :] + return u[:, None] * K * v[None, :] + -# TODO: update the doc -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False, **kwargs): +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, + **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -580,10 +564,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, np.divide(M, -reg, out=K) np.exp(K, out=K) - if div == "KL": - fi = reg_m / (reg_m + reg) - elif div == "TV": - fi = reg_m / reg + fi = reg_m / (reg_m + reg) cpt = 0 err = 1. @@ -669,15 +650,6 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, else: return ot_matrix -def scaling_iter_prox(s, p, fi, div): - if div == "KL": - return (p / s) ** fi - elif div == "TV": - return np.minimum(s*np.exp(fi), np.maximum(s*np.exp(-fi), p)) / s - else: - raise ValueError("Unknown divergence '%s'." % div) - - def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax=1000, stopThr=1e-6, |