summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2020-04-15 15:32:56 +0200
committerLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2020-04-15 15:32:56 +0200
commit8c724ad3579959e9d369c0b7fbaa22ea19ced614 (patch)
tree528d280492f33f71eac2d6522e0e3c05a4ae8568
parentfff2463aafd58343c8bc2ed7875622e16a8c1cee (diff)
partial with tests
-rwxr-xr-xexamples/plot_partial_wass_and_gromov.py18
-rw-r--r--ot/__init__.py81
-rwxr-xr-xot/partial.py23
-rw-r--r--ot/unbalanced.py66
4 files changed, 39 insertions, 149 deletions
diff --git a/examples/plot_partial_wass_and_gromov.py b/examples/plot_partial_wass_and_gromov.py
index 2ddeb68..30b3fc0 100755
--- a/examples/plot_partial_wass_and_gromov.py
+++ b/examples/plot_partial_wass_and_gromov.py
@@ -33,9 +33,9 @@ mu = np.array([0, 0])
cov = np.array([[1, 0], [0, 2]])
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
-xs = np.append(xs, (np.random.rand(n_noise, 2)+1)*4).reshape((-1, 2))
+xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
-xt = np.append(xt, (np.random.rand(n_noise, 2)+1)*-3).reshape((-1, 2))
+xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
M = sp.spatial.distance.cdist(xs, xt)
@@ -62,7 +62,7 @@ w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5,
log=True)
print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist']))
-print('Entropic partial Wasserstein distance (m = 0.5): ' + \
+print('Entropic partial Wasserstein distance (m = 0.5): ' +
str(log['partial_w_dist']))
pl.figure(1, (10, 5))
@@ -98,10 +98,10 @@ cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
-xs = np.concatenate((xs, ((np.random.rand(n_noise, 2)+1)*4)), axis=0)
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
-xt = np.concatenate((xt, ((np.random.rand(n_noise, 3)+1)*10)), axis=0)
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
fig = pl.figure()
ax1 = fig.add_subplot(121)
@@ -128,7 +128,7 @@ res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
m=m, log=True)
print('Partial Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
-print('Entropic partial Wasserstein distance (m = 1): ' + \
+print('Entropic partial Wasserstein distance (m = 1): ' +
str(log['partial_gw_dist']))
pl.figure(1, (10, 5))
@@ -142,14 +142,14 @@ pl.title('Entropic partial Wasserstein')
pl.show()
print('-----m = 2/3')
-m = 2/3
+m = 2 / 3
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
m=m, log=True)
-print('Partial Wasserstein distance (m = 2/3): ' + \
+print('Partial Wasserstein distance (m = 2/3): ' +
str(log0['partial_gw_dist']))
-print('Entropic partial Wasserstein distance (m = 2/3): ' + \
+print('Entropic partial Wasserstein distance (m = 2/3): ' +
str(log['partial_gw_dist']))
pl.figure(1, (10, 5))
diff --git a/ot/__init__.py b/ot/__init__.py
deleted file mode 100644
index 89c7936..0000000
--- a/ot/__init__.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""
-
-This is the main module of the POT toolbox. It provides easy access to
-a number of sub-modules and functions described below.
-
-.. note::
-
-
- Here is a list of the submodules and short description of what they contain.
-
- - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
- - :any:`ot.bregman` contains OT solvers for the entropic OT problems using
- Bregman projections.
- - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
- - :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT
- problems.
- - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
- Wasserstein problems.
- - :any:`ot.optim` contains generic solvers OT based optimization problems
- - :any:`ot.da` contains classes and function related to Monge mapping
- estimation and Domain Adaptation (DA).
- - :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers
- - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
- Discriminant Analysis.
- - :any:`ot.utils` contains utility functions such as distance computation and
- timing.
- - :any:`ot.datasets` contains toy dataset generation functions.
- - :any:`ot.plot` contains visualization functions
- - :any:`ot.stochastic` contains stochastic solvers for regularized OT.
- - :any:`ot.unbalanced` contains solvers for regularized unbalanced OT.
-
-.. warning::
- The list of automatically imported sub-modules is as follows:
- :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
- :py:mod:`ot.utils`, :py:mod:`ot.datasets`,
- :py:mod:`ot.gromov`, :py:mod:`ot.smooth`
- :py:mod:`ot.stochastic`
-
- The following sub-modules are not imported due to additional dependencies:
-
- - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
- - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU.
- - :any:`ot.plot` : depends on :code:`matplotlib`
-
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Nicolas Courty <ncourty@irisa.fr>
-#
-# License: MIT License
-
-
-# All submodules and packages
-from . import lp
-from . import bregman
-from . import optim
-from . import utils
-from . import datasets
-from . import da
-from . import gromov
-from . import smooth
-from . import stochastic
-from . import unbalanced
-
-# OT functions
-from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
-from .bregman import sinkhorn, sinkhorn2, barycenter
-from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
-from .da import sinkhorn_lpl1_mm
-
-# utils functions
-from .utils import dist, unif, tic, toc, toq
-
-__version__ = "0.6.0"
-
-__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets',
- 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d',
- 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
- 'sinkhorn_unbalanced', 'barycenter_unbalanced',
- 'sinkhorn_unbalanced2']
diff --git a/ot/partial.py b/ot/partial.py
index 746f337..3425acb 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -9,12 +9,11 @@ Partial OT
import numpy as np
-from ot.lp import emd
+from .lp import emd
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
**kwargs):
-
r"""
Solves the partial optimal transport problem for the quadratic cost
and returns the OT plan
@@ -136,7 +135,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['cost'] = np.sum(gamma*M)
+ log_emd['cost'] = np.sum(gamma * M)
if log:
return gamma, log_emd
else:
@@ -233,7 +232,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
- M_extended = np.ones((len(a_extended), len(b_extended))) * np.max(M) * 1e2
+ M_extended = np.ones((len(a_extended), len(b_extended))) * 0
M_extended[-1, -1] = np.max(M) * 1e5
M_extended[:len(a), :len(b)] = M
@@ -381,7 +380,7 @@ def gwloss_partial(C1, C2, T):
Returns
-------
- GW loss
+ GW loss
"""
g = gwgrad_partial(C1, C2, T) * 0.5
return np.sum(g * T)
@@ -432,7 +431,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
thres : float, optional
- quantile of the gradient matrix to populate the cost matrix when 0
+ quantile of the gradient matrix to populate the cost matrix when 0
(default: 1)
numItermax : int, optional
Max number of iterations
@@ -566,7 +565,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
where :
- M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term
+ - :math:`\Omega` is the entropic regularization term
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are the sample weights
- m is the amount of mass to be transported
@@ -591,7 +590,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
thres : float, optional
- quantile of the gradient matrix to populate the cost matrix when 0
+ quantile of the gradient matrix to populate the cost matrix when 0
(default: 1)
numItermax : int, optional
Max number of iterations
@@ -666,7 +665,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
where :
- M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term
+ - :math:`\Omega` is the entropic regularization term
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are the sample weights
- m is the amount of mass to be transported
@@ -754,7 +753,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- np.multiply(K, m/np.sum(K), out=K)
+ np.multiply(K, m / np.sum(K), out=K)
err, cpt = 1, 0
@@ -809,7 +808,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
- C2 is the metric cost matrix in the target space
- p and q are the sample weights
- L : quadratic loss function
- - :math:`\Omega` is the entropic regularization term
+ - :math:`\Omega` is the entropic regularization term
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- m is the amount of mass to be transported
@@ -944,7 +943,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
- C2 is the metric cost matrix in the target space
- p and q are the sample weights
- L : quadratic loss function
- - :math:`\Omega` is the entropic regularization term
+ - :math:`\Omega` is the entropic regularization term
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- m is the amount of mass to be transported
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 66a8830..23f6607 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -14,7 +14,7 @@ from scipy.special import logsumexp
# from .utils import unif, dist
-def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numItermax=1000,
+def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the unbalanced entropic regularization optimal transport problem
@@ -120,20 +120,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numI
"""
if method.lower() == 'sinkhorn':
- return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div,
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div,
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
- return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg,
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
@@ -261,8 +261,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
else:
raise ValueError('Unknown method %s.' % method)
-# TODO: update the doc
-def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000,
+
+def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
@@ -349,7 +349,6 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000,
"""
a = np.asarray(a, dtype=np.float64)
- print(a)
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
@@ -377,39 +376,24 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000,
else:
u = np.ones(dim_a) / dim_a
v = np.ones(dim_b) / dim_b
- u = np.ones(dim_a)
- v = np.ones(dim_b)
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
- np.true_divide(M, -reg, out=K)
+ np.divide(M, -reg, out=K)
np.exp(K, out=K)
-
- if div == "KL":
- fi = reg_m / (reg_m + reg)
- elif div == "TV":
- fi = reg_m / reg
+
+ fi = reg_m / (reg_m + reg)
err = 1.
-
- dx = np.ones(dim_a) / dim_a
- dy = np.ones(dim_b) / dim_b
- z = 1
for i in range(numItermax):
uprev = u
vprev = v
- Kv = z*K.dot(v*dy)
- u = scaling_iter_prox(Kv, a, fi, div)
- #u = (a / Kv) ** fi
- Ktu = z*K.T.dot(u*dx)
- v = scaling_iter_prox(Ktu, b, fi, div)
- #v = (b / Ktu) ** fi
- #print(v*dy)
- z = np.dot((u*dx).T, np.dot(K,v*dy))/0.35
- print(z)
-
+ Kv = K.dot(v)
+ u = (a / Kv) ** fi
+ Ktu = K.T.dot(u)
+ v = (b / Ktu) ** fi
if (np.any(Ktu == 0.)
or np.any(np.isnan(u)) or np.any(np.isnan(v))
@@ -450,12 +434,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000,
if log:
return u[:, None] * K * v[None, :], log
else:
- return z*u[:, None] * K * v[None, :]
+ return u[:, None] * K * v[None, :]
+
-# TODO: update the doc
-def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5,
- numItermax=1000, stopThr=1e-6,
- verbose=False, log=False, **kwargs):
+def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False,
+ **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport
problem and return the loss
@@ -580,10 +564,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5,
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- if div == "KL":
- fi = reg_m / (reg_m + reg)
- elif div == "TV":
- fi = reg_m / reg
+ fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
@@ -669,15 +650,6 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5,
else:
return ot_matrix
-def scaling_iter_prox(s, p, fi, div):
- if div == "KL":
- return (p / s) ** fi
- elif div == "TV":
- return np.minimum(s*np.exp(fi), np.maximum(s*np.exp(-fi), p)) / s
- else:
- raise ValueError("Unknown divergence '%s'." % div)
-
-
def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
numItermax=1000, stopThr=1e-6,