From 3fff90eb437dce30fd83012f4c0e24f3fca041b2 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 14 Jan 2022 17:47:27 +0100 Subject: [WIP] Set dev version and add minigallery to quick start guide (#334) * change version and add minigallery in quickstart guide * remove ot.gpu from documentation because it is deprecated and bacckends should be used * start 0.8.2dev and description in releases.md * typo for gallery sinkhorn2 * test better doc update for files in .githib folder --- ot/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'ot/__init__.py') diff --git a/ot/__init__.py b/ot/__init__.py index f55819d..1ea7403 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -50,7 +50,7 @@ from .gromov import (gromov_wasserstein, gromov_wasserstein2, # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.8.1.0" +__version__ = "0.8.2dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', -- cgit v1.2.3 From a5e0f0d40d5046a6639924347ef97e2ac80ad0c9 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 2 Feb 2022 11:53:12 +0100 Subject: [MRG] Add weak OT solver (#341) * add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation --- README.md | 3 + RELEASES.md | 8 ++- docs/source/all.rst | 1 + examples/others/plot_WeakOT_VS_OT.py | 98 +++++++++++++++++++++++++++ examples/plot_OT_2D_samples.py | 5 +- ot/__init__.py | 5 +- ot/gromov.py | 16 +++++ ot/lp/__init__.py | 9 ++- ot/lp/cvx.py | 1 - ot/utils.py | 12 +++- ot/weak.py | 124 +++++++++++++++++++++++++++++++++++ test/test_bregman.py | 13 ++-- test/test_ot.py | 2 +- test/test_utils.py | 18 ++++- test/test_weak.py | 54 +++++++++++++++ 15 files changed, 343 insertions(+), 26 deletions(-) create mode 100644 examples/others/plot_WeakOT_VS_OT.py create mode 100644 ot/weak.py create mode 100644 test/test_weak.py (limited to 'ot/__init__.py') diff --git a/README.md b/README.md index 17fbe81..a7627df 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ POT provides the following generic OT solvers (links to examples): * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. +* Weak OT solver between empirical distributions [39] * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] @@ -301,3 +302,5 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020 [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. + +[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 94c853b..4d05582 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,10 +5,12 @@ #### New features -- Better list of related examples in quick start guide with `minigallery` (PR #334) +- Better list of related examples in quick start guide with `minigallery` (PR #334). - Add optional log-domain Sinkhorn implementation in WDA to support smaller values - of the regularization parameter (PR #336) -- Backend implementation for `ot.lp.free_support_barycenter` (PR #340) + of the regularization parameter (PR #336). +- Backend implementation for `ot.lp.free_support_barycenter` (PR #340). +- Add weak OT solver + example (PR #341). + #### Closed issues diff --git a/docs/source/all.rst b/docs/source/all.rst index 7f85a91..76d2ff5 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -28,6 +28,7 @@ API and modules unbalanced partial sliced + weak .. autosummary:: :toctree: ../modules/generated/ diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py new file mode 100644 index 0000000..a29c875 --- /dev/null +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Weak Optimal Transport VS exact Optimal Transport +==================================================== + +Illustration of 2D optimal transport between distributions that are weighted +sum of diracs. The OT matrix is plotted with the samples. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +############################################################################## +# Generate data an plot it +# ------------------------ + +#%% parameters and data generation + +n = 50 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +pl.figure(2) +pl.imshow(M, interpolation='nearest') +pl.title('Cost matrix M') + + +############################################################################## +# Compute Weak OT and exact OT solutions +# -------------------------------------- + +#%% EMD + +G0 = ot.emd(a, b, M) + +#%% Weak OT + +Gweak = ot.weak_optimal_transport(xs, xt, a, b) + + +############################################################################## +# Plot weak OT and exact OT solutions +# -------------------------------------- + +pl.figure(3, (8, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(G0, interpolation='nearest') +pl.title('OT matrix') + +pl.subplot(1, 2, 2) +pl.imshow(Gweak, interpolation='nearest') +pl.title('Weak OT matrix') + +pl.figure(4, (8, 5)) + +pl.subplot(1, 2, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('OT matrix with samples') + +pl.subplot(1, 2, 2) +ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Weak OT matrix with samples') diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index af1bc12..c3a7cd8 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -42,7 +42,6 @@ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples # loss matrix M = ot.dist(xs, xt) -M /= M.max() ############################################################################## # Plot data @@ -87,7 +86,7 @@ pl.title('OT matrix with samples') #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Gs = ot.sinkhorn(a, b, M, lambd) @@ -112,7 +111,7 @@ pl.show() #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd) diff --git a/ot/__init__.py b/ot/__init__.py index 1ea7403..7253318 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import unbalanced from . import partial from . import backend from . import regpath +from . import weak # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -46,7 +47,7 @@ from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) - +from .weak import weak_optimal_transport # utils functions from .utils import dist, unif, tic, toc, toq @@ -59,5 +60,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', - 'max_sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/gromov.py b/ot/gromov.py index 6544260..b7e7949 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F - :math:`\mathbf{q}`: distribution in the target space - `L`: loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters @@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2ff7c1f..d9b6fa9 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,6 +26,8 @@ from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend + + __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -220,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] `. @@ -358,7 +361,8 @@ def emd2(a, b, M, processes=1, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] `. @@ -622,3 +626,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X + diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 869d450..fbf3c0e 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -11,7 +11,6 @@ import numpy as np import scipy as sp import scipy.sparse as sps - try: import cvxopt from cvxopt import solvers, matrix, spmatrix diff --git a/ot/utils.py b/ot/utils.py index e6c93c8..725ca00 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -116,7 +116,7 @@ def proj_simplex(v, z=1): return w -def unif(n): +def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). @@ -124,13 +124,19 @@ def unif(n): ---------- n : int number of bins in the histogram + type_as : array_like + array of the same type of the expected output (numpy/pytorch/jax) Returns ------- - h : np.array (`n`,) + h : array_like (`n`,) histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ - return np.ones((n,)) / n + if type_as is None: + return np.ones((n,)) / n + else: + nx = get_backend(type_as) + return nx.ones((n,)) / n def clean_zeros(a, b, M): diff --git a/ot/weak.py b/ot/weak.py new file mode 100644 index 0000000..f7d5b23 --- /dev/null +++ b/ot/weak.py @@ -0,0 +1,124 @@ +""" +Weak optimal ransport solvers +""" + +# Author: Remi Flamary +# +# License: MIT License + +from .backend import get_backend +from .optim import cg +import numpy as np + +__all__ = ['weak_optimal_transport'] + + +def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs): + r"""Solves the weak optimal transport problem between two empirical distributions + + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2 + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`X_a` :math:`X_b` are the sample matrices. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed + in :ref:`[39] `. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status + + + .. _references-weak: + References + ---------- + .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). + Kantorovich duality for general transport costs and applications. + Journal of Functional Analysis, 273(11), 3327-3405. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + nx = get_backend(Xa, Xb) + + Xa2 = nx.to_numpy(Xa) + Xb2 = nx.to_numpy(Xb) + + if a is None: + a2 = np.ones((Xa.shape[0])) / Xa.shape[0] + else: + a2 = nx.to_numpy(a) + if b is None: + b2 = np.ones((Xb.shape[0])) / Xb.shape[0] + else: + b2 = nx.to_numpy(b) + + # init uniform + if G0 is None: + T0 = a2[:, None] * b2[None, :] + else: + T0 = nx.to_numpy(G0) + + # weak OT loss + def f(T): + return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1)) + + # weak OT gradient + def df(T): + return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T) + + # solve with conditional gradient and return solution + if log: + res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs) + log['u'] = nx.from_numpy(log['u'], type_as=Xa) + log['v'] = nx.from_numpy(log['v'], type_as=Xb) + return nx.from_numpy(res, type_as=Xa), log + else: + return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6e90aa4..1419f9b 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -60,7 +60,7 @@ def test_convergence_warning(method): ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) -def test_not_impemented_method(): +def test_not_implemented_method(): # test sinkhorn w = 10 n = w ** 2 @@ -635,7 +635,7 @@ def test_wasserstein_bary_2d(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -667,7 +667,7 @@ def test_wasserstein_bary_2d_debiased(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -940,14 +940,11 @@ def test_screenkhorn(nx): bb = nx.from_numpy(b) M_nx = nx.from_numpy(M, type_as=ab) - # np sinkhorn - G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals - np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) diff --git a/test/test_ot.py b/test/test_ot.py index e8e2d97..3e2d845 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -232,7 +232,7 @@ def test_emd2_multi(): # Gaussian distributions a = gauss(n, m=20, s=5) # m= mean, s= std - ls = np.arange(20, 500, 20) + ls = np.arange(20, 500, 100) nb = len(ls) b = np.zeros((n, nb)) for i in range(nb): diff --git a/test/test_utils.py b/test/test_utils.py index 8b23c22..5ad167b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -62,12 +62,12 @@ def test_tic_toc(): import time ot.tic() - time.sleep(0.5) + time.sleep(0.1) t = ot.toc() t2 = ot.toq() # test timing - np.testing.assert_allclose(0.5, t, rtol=1e-1, atol=1e-1) + np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1) # test toc vs toq np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1) @@ -94,10 +94,22 @@ def test_unif(): np.testing.assert_allclose(1, np.sum(u)) -def test_dist(): +def test_unif_backend(nx): n = 100 + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + u = ot.unif(n, type_as=tp) + + np.testing.assert_allclose(1, np.sum(nx.to_numpy(u)), atol=1e-6) + + +def test_dist(): + + n = 10 + rng = np.random.RandomState(0) x = rng.randn(n, 2) diff --git a/test/test_weak.py b/test/test_weak.py new file mode 100644 index 0000000..c4c3278 --- /dev/null +++ b/test/test_weak.py @@ -0,0 +1,54 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary +# +# License: MIT License + +import ot +import numpy as np + + +def test_weak_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + # chaeck that identity is recovered + G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n) + + # check G is identity + np.testing.assert_allclose(G, np.eye(n) / n) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + +def test_weak_ot_bakends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G = ot.weak_optimal_transport(xs, xt, u, u) + + xs2 = nx.from_numpy(xs) + xt2 = nx.from_numpy(xt) + u2 = nx.from_numpy(u) + + G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2) + + np.testing.assert_allclose(nx.to_numpy(G2), G) -- cgit v1.2.3 From 50c0f17d00e3492c4d56a356af30cf00d6d07913 Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Fri, 11 Feb 2022 10:53:38 +0100 Subject: [MRG] GW dictionary learning (#319) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add fgw dictionary learning feature * add fgw dictionary learning feature * plot gromov wasserstein dictionary learning * Update __init__.py * fix pep8 errors exact E501 line too long * fix last pep8 issues * add unitary tests for (F)GW dictionary learning without using autodifferentiable functions * correct tests for (F)GW dictionary learning without using autodiff * correct tests for (F)GW dictionary learning without using autodiff * fix docs and notations * answer to review: improve tests, docs, examples + make node weights optional * fix pep8 and examples * improve docs + tests + thumbnail * make example faster * improve ex * update README.md * make GDL tests faster Co-authored-by: Rémi Flamary --- README.md | 2 + RELEASES.md | 2 +- .../plot_gromov_wasserstein_dictionary_learning.py | 357 +++++++ ot/__init__.py | 4 - ot/gromov.py | 1074 +++++++++++++++++++- test/test_gromov.py | 554 +++++++++- 6 files changed, 1954 insertions(+), 39 deletions(-) create mode 100755 examples/gromov/plot_gromov_wasserstein_dictionary_learning.py (limited to 'ot/__init__.py') diff --git a/README.md b/README.md index a7627df..c6bfd9c 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. POT provides the following Machine Learning related solvers: @@ -198,6 +199,7 @@ The contributors to this library are * [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/RELEASES.md b/RELEASES.md index 4d05582..925920a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,7 +10,7 @@ of the regularization parameter (PR #336). - Backend implementation for `ot.lp.free_support_barycenter` (PR #340). - Add weak OT solver + example (PR #341). - +- Add (F)GW linear dictionary learning solvers + example (PR #319) #### Closed issues diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py new file mode 100755 index 0000000..1fdc3b9 --- /dev/null +++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- + +r""" +================================= +(Fused) Gromov-Wasserstein Linear Dictionary Learning +================================= + +In this exemple, we illustrate how to learn a Gromov-Wasserstein dictionary on +a dataset of structured data such as graphs, denoted +:math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights. +Given a dictionary :math:`\mathbf{C_{dict}}` composed of D structures of a fixed +size nt, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` +is modeled as a convex combination :math:`\mathbf{w_s} \in \Sigma_D` of these +dictionary atoms as :math:`\sum_d w_{s,d} \mathbf{C_{dict}[d]}`. + + +First, we consider a dataset composed of graphs generated by Stochastic Block models +with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters +varying in :math:`\{ 1, 2, 3\}`. We learn a dictionary of 3 atoms, by minimizing +the Gromov-Wasserstein distance from all samples to its model in the dictionary +with respect to the dictionary atoms. + +Second, we illustrate the extension of this dictionary learning framework to +structured data endowed with node features by using the Fused Gromov-Wasserstein +distance. Starting from the aforementioned dataset of unattributed graphs, we +add discrete labels uniformly depending on the number of clusters. Then we learn +and visualize attributed graph atoms where each sample is modeled as a joint convex +combination between atom structures and features. + + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph +Dictionary Learning, International Conference on Machine Learning (ICML), 2021. + +""" +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +from sklearn.manifold import MDS +from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning +import ot +import networkx +from networkx.generators.community import stochastic_block_model as sbm +# %% +# ============================================================================= +# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. +# ============================================================================= + +np.random.seed(42) + +N = 60 # number of graphs in the dataset +# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability. +clusters = [1, 2, 3] +Nc = N // len(clusters) # number of graphs by cluster +nlabels = len(clusters) +dataset = [] +labels = [] + +p_inter = 0.1 +p_intra = 0.9 +for n_cluster in clusters: + for i in range(Nc): + n_nodes = int(np.random.uniform(low=30, high=50)) + + if n_cluster > 1: + P = p_inter * np.ones((n_cluster, n_cluster)) + np.fill_diagonal(P, p_intra) + else: + P = p_intra * np.eye(1) + sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32) + G = sbm(sizes, P, seed=i, directed=False) + C = networkx.to_numpy_array(G) + dataset.append(C) + labels.append(n_cluster) + + +# Visualize samples + +def plot_graph(x, C, binary=True, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if binary: + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + else: # connection intensity proportional to C[i,j] + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k') + + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color='C0', s=50.) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Estimate the gromov-wasserstein dictionary from the dataset +# ============================================================================= + + +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] + +D = 3 # 3 atoms in the dictionary +nt = 6 # of 6 nodes each + +q = ot.unif(nt) +reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s} + +Cdict_GW, log = gromov_wasserstein_dictionary_learning( + Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16, + learning_rate=0.1, reg=reg, projection='nonnegative_symmetric', + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, + use_log=True, use_adam_optimizer=True, verbose=True +) +# visualize loss evolution over epochs +pl.figure(2, (4, 3)) +pl.clf() +pl.title('loss evolution by epoch', fontsize=14) +pl.plot(log['loss_epochs']) +pl.xlabel('epochs', fontsize=12) +pl.ylabel('loss', fontsize=12) +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the estimated dictionary atoms +# ============================================================================= + + +# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white) + +pl.figure(3, (12, 8)) +pl.clf() +for idx_atom, atom in enumerate(Cdict_GW): + scaled_atom = (atom - atom.min()) / (atom.max() - atom.min()) + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + pl.subplot(2, D, idx_atom + 1) + pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14) + plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.) + pl.axis("off") + pl.subplot(2, D, D + idx_atom + 1) + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation='nearest') + pl.colorbar() + pl.axis("off") +pl.tight_layout() +pl.show() +#%% +# ============================================================================= +# Visualization of the embedding space +# ============================================================================= + +unmixings = [] +reconstruction_errors = [] +for C in dataset: + p = ot.unif(C.shape[0]) + unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing( + C, Cdict_GW, p=p, q=q, reg=reg, + tol_outer=10**(-5), tol_inner=10**(-5), + max_iter_outer=30, max_iter_inner=300 + ) + unmixings.append(unmixing) + reconstruction_errors.append(reconstruction_error) +unmixings = np.array(unmixings) +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) + + +# Compute the 2D representation of the unmixing living in the 2-simplex of probability +unmixings2D = np.zeros(shape=(N, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(4, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() +# %% +# ============================================================================= +# Endow the dataset with node features +# ============================================================================= + +# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters +# 1 cluster --> 0 as nodes feature +# 2 clusters --> 1 as nodes feature +# 3 clusters --> 2 as nodes feature +# features are one-hot encoded following these assignments +dataset_features = [] +for i in range(len(dataset)): + n = dataset[i].shape[0] + F = np.zeros((n, 3)) + if i < Nc: # graph with 1 cluster + F[:, 0] = 1. + elif i < 2 * Nc: # graph with 2 clusters + F[:, 1] = 1. + else: # graph with 3 clusters + F[:, 2] = 1. + dataset_features.append(F) + +pl.figure(5, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + F = dataset_features[(c - 1) * Nc] + colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])] + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color=colors, s=50) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() +# %% +# ============================================================================= +# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs +# ============================================================================= +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] +D = 3 # 6 atoms instead of 3 +nt = 6 +q = ot.unif(nt) +reg = 0.001 +alpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein + + +Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning( + Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha, + epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, + projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True +) +# visualize loss evolution +pl.figure(6, (4, 3)) +pl.clf() +pl.title('loss evolution by epoch', fontsize=14) +pl.plot(log['loss_epochs']) +pl.xlabel('epochs', fontsize=12) +pl.ylabel('loss', fontsize=12) +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the estimated dictionary atoms +# ============================================================================= + +pl.figure(7, (12, 8)) +pl.clf() +max_features = Ydict_FGW.max() +min_features = Ydict_FGW.min() + +for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)): + scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min()) + #scaled_F = 2 * (Fatom - min_features) / (max_features - min_features) + colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])] + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + pl.subplot(2, D, idx_atom + 1) + pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14) + plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100) + pl.axis("off") + pl.subplot(2, D, D + idx_atom + 1) + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation='nearest') + pl.colorbar() + pl.axis("off") +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the embedding space +# ============================================================================= + +unmixings = [] +reconstruction_errors = [] +for i in range(len(dataset)): + C = dataset[i] + Y = dataset_features[i] + p = ot.unif(C.shape[0]) + unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing( + C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha, + reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300 + ) + unmixings.append(unmixing) + reconstruction_errors.append(reconstruction_error) +unmixings = np.array(unmixings) +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) + +# Visualize unmixings in the 2-simplex of probability +unmixings2D = np.zeros(shape=(N, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(8, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) + +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 7253318..bda7a35 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,5 +1,4 @@ """ - .. warning:: The list of automatically imported sub-modules is as follows: :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` @@ -7,13 +6,10 @@ :py:mod:`ot.gromov`, :py:mod:`ot.smooth` :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath` , :py:mod:`ot.unbalanced`. - The following sub-modules are not imported due to additional dependencies: - - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. - :any:`ot.plot` : depends on :code:`matplotlib` - """ # Author: Remi Flamary diff --git a/ot/gromov.py b/ot/gromov.py index b7e7949..f5a1f91 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -7,6 +7,7 @@ Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers # Nicolas Courty # Rémi Flamary # Titouan Vayer +# Cédric Vincent-Cuaz # # License: MIT License @@ -17,7 +18,7 @@ from .bregman import sinkhorn from .utils import dist, UndefinedParameter, list_to_array from .optim import cg from .lp import emd_1d, emd -from .utils import check_random_state +from .utils import check_random_state, unif from .backend import get_backend @@ -320,7 +321,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) -def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -365,6 +366,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -389,18 +393,26 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F """ p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - nx = get_backend(p0, q0, C10, C20) - + if G0 is None: + nx = get_backend(p0, q0, C10, C20) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - G0 = p[:, None] * q[None, :] + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) def f(G): return gwloss(constC, hC1, hC2, G) @@ -418,7 +430,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): r""" Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -467,6 +479,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. Returns ------- @@ -491,9 +506,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= """ p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - nx = get_backend(p0, q0, C10, C20) + if G0 is None: + nx = get_backend(p0, q0, C10, C20) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) @@ -502,7 +520,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - G0 = p[:, None] * q[None, :] + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) def f(G): return gwloss(constC, hC1, hC2, G) @@ -533,7 +557,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= return gw -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): r""" Computes the FGW transport between two graphs (see :ref:`[24] `) @@ -578,6 +602,9 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. log : bool, optional record log if True **kwargs : dict @@ -600,20 +627,28 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, (ICML). 2019. """ p, q = list_to_array(p, q) - p0, q0, C10, C20, M0 = p, q, C1, C2, M - nx = get_backend(p0, q0, C10, C20, M0) + if G0 is None: + nx = get_backend(p0, q0, C10, C20, M0) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, M0, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) M = nx.to_numpy(M0) + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - G0 = p[:, None] * q[None, :] - def f(G): return gwloss(constC, hC1, hC2, G) @@ -622,19 +657,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) - log['fgw_dist'] = fgw_dist log['u'] = nx.from_numpy(log['u'], type_as=C10) log['v'] = nx.from_numpy(log['v'], type_as=C10) return nx.from_numpy(res, type_as=C10), log - else: return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) -def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): +def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): r""" Computes the FGW distance between two graphs see (see :ref:`[24] `) @@ -683,6 +715,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. log : bool, optional Record log if True. **kwargs : dict @@ -711,7 +746,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 p, q = list_to_array(p, q) p0, q0, C10, C20, M0 = p, q, C1, C2, M - nx = get_backend(p0, q0, C10, C20, M0) + if G0 is None: + nx = get_backend(p0, q0, C10, C20, M0) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, M0, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) @@ -721,7 +760,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - G0 = p[:, None] * q[None, :] + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) def f(G): return gwloss(constC, hC1, hC2, G) @@ -1796,3 +1841,988 @@ def update_feature_matrix(lambdas, Ys, Ts, p): for s in range(len(Ts)) ]) return tmpsum + + +def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - reg is the regularization coefficient. + + The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38] + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns, ns). + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate: float, optional + Learning rate used for the stochastic gradient descent. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + projection: str , optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epochs. + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + # Handle backend of non-optional arguments + Cs0 = Cs + nx = get_backend(*Cs0) + Cs = [nx.to_numpy(C) for C in Cs0] + dataset_size = len(Cs) + # Handle backend of optional arguments + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0 + if use_adam_optimizer: + adam_moments = _initialize_adam_optimizer(Cdict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + Cdict_best_state = Cdict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + # batch sampling + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing( + Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner, + max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Cdict *= 2 / batch_size + if use_adam_optimizer: + Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments) + else: + Cdict -= learning_rate * grad_Cdict + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + if verbose: + print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), log + + +def _initialize_adam_optimizer(variable): + + # Initialization for our numpy implementation of adam optimizer + atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor + atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor + atoms_adam_count = 1 + + return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count} + + +def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09): + + adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad + adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2) + unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count']) + unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count']) + variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps) + adam_moments['count'] += 1 + + return variable, adam_moments + + +def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): + r""" + Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`. + + .. math:: + \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + T: array-like (ns, nt) + Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})` + current_loss: float + reconstruction error + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + C0, Cdict0 = C, Cdict + nx = get_backend(C0, Cdict0) + C = nx.to_numpy(C0) + Cdict = nx.to_numpy(Cdict0) + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + + w = unif(D) # Initialize uniformly the unmixing w + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + + const_q = q[:, None] * q[None, :] + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs) + current_loss = log['gw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing( + C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T, + starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs + ) + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2 + + + Such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38] + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + w: array-like, shape (D,) + Linear unmixing of the input structure onto the dictionary + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + T: array-like, shape (ns,nt) + fixed transport plan between the input structure and its representation in the dictionary. + p : array-like, shape (ns,) + Distribution in the source space. + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + + Returns + ------- + w: ndarray (D,) + optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + + while (convergence_criterion > tol) and (count < max_iter): + + previous_loss = current_loss + # 1) Compute gradient at current point w + grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + grad_w -= 2 * reg * w + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: # not that the loss can be negative if reg >0 + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + count += 1 + + return w, Cembedded, current_loss + + +def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that: + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`. + """ + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + a = trace_diffx - trace_diffw + b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + + if a > 0: + gamma = min(1, max(0, - b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff + + +def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1., + Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2 + + + Such that :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + + The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38] + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns,ns). + Ys : list of S array-like, shape (ns, d) + List of feature matrix of variable size (ns,d) with d fixed. + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + alpha : float + Trade-off parameter of Fused Gromov-Wasserstein + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate_C: float, optional + Learning rate used for the stochastic gradient descent on Cdict. Default is 1. + learning_rate_Y: float, optional + Learning rate used for the stochastic gradient descent on Ydict. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary structures Cdict. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + Ydict_init: list of D array-like with shape (nt, d), optional + Used to initialize the dictionary features Ydict. + If set to None, the dictionary features will be initialized randomly. + Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features. + projection: str, optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + Ydict_best_state : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epoches. + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + Cs0, Ys0 = Cs, Ys + nx = get_backend(*Cs0, *Ys0) + Cs = [nx.to_numpy(C) for C in Cs0] + Ys = [nx.to_numpy(Y) for Y in Ys0] + + d = Ys[0].shape[-1] + dataset_size = len(Cs) + + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + if Ydict_init is None: + # Initialize randomly features of dictionary atoms based on samples distribution by feature component + dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys]) + Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d)) + else: + Ydict = nx.to_numpy(Ydict_init).copy() + assert Ydict.shape == (D, nt, d) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_adam_optimizer: + adam_moments_C = _initialize_adam_optimizer(Cdict) + adam_moments_Y = _initialize_adam_optimizer(Ydict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + + # Batch iterations + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ys_embedded = np.zeros((batch_size, nt, d)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing( + Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q, + tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + grad_Ydict = np.zeros_like(Ydict) + + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx]) + grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :] + grad_Cdict *= 2 / batch_size + grad_Ydict *= 2 / batch_size + + if use_adam_optimizer: + Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C) + Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y) + else: + Cdict -= learning_rate_C * grad_Cdict + Ydict -= learning_rate_Y * grad_Ydict + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + if verbose: + print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log + + +def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): + r""" + Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` + + .. math:: + \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Ydict : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + Yembedded: array-like, shape (nt,d) + embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`. + T: array-like (ns,nt) + Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`. + current_loss: float + reconstruction error + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict + nx = get_backend(C0, Y0, Cdict0, Ydict0) + C = nx.to_numpy(C0) + Y = nx.to_numpy(Y0) + Cdict = nx.to_numpy(Cdict0) + Ydict = nx.to_numpy(Ydict0) + + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + d = Y.shape[-1] + w = unif(D) # Initialize with uniform weights + ns = C.shape[-1] + nt = Cdict.shape[-1] + + # modeling (C,Y) + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + Yembedded = np.sum(w[:, None, None] * Ydict, axis=0) + + # constants depending on q + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T) + M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features + T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True) + current_loss = log['fgw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, + T, p, q, const_q, diag_q, current_loss, alpha, reg, + tol=tol_inner, max_iter=max_iter_inner, **kwargs) + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2 + + Such that : + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}` + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7. + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt,nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt,d) + Embedded features of (C,Y) onto the dictionary + w: array-like, shape (n_D,) + Linear unmixing of (C,Y) onto (Cdict,Ydict) + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution. + diag_q: array-like, shape (nt,nt) + diagonal matrix with values of q on the diagonal. + T: array-like, shape (ns,nt) + fixed transport plan between (C,Y) and its model + p : array-like, shape (ns,) + Distribution in the source space (C,Y). + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + w: ndarray (D,) + linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + ones_ns_d = np.ones(Y.shape) + + while (convergence_criterion > tol) and (count < max_iter): + previous_loss = current_loss + + # 1) Compute gradient at current point w + # structure + grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + # feature + grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2)) + grad_w -= reg * w + grad_w *= 2 + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + Yembedded += gamma * Yembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + count += 1 + + return w, Cembedded, Yembedded, current_loss + + +def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that : + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Y: arrat-like, shape (ns,d) + Feature matrix of the input space + Cdict : list of D array-like, shape (nt, nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt, d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt, nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt, d) + Embedded features of (C,Y) onto the dictionary + T: array-like, shape (ns, nt) + Fixed transport plan between (C,Y) and its current model. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + ones_ns_d: array-like, shape (ns, d) + :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + Yembedded_diff: numpy array, shape (nt, nt) + Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + """ + # polynomial coefficients from quadratic objective (with respect to w) on structures + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_gw = trace_diffx - trace_diffw + b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + + # polynomial coefficient from quadratic objective (with respect to w) on features + Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0) + Yembedded_diff = Yembedded_x - Yembedded + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T) + b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T))) + + a = alpha * a_gw + (1 - alpha) * a_w + b = alpha * b_gw + (1 - alpha) * b_w + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + if a > 0: + gamma = min(1, max(0, -b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff, Yembedded_diff diff --git a/test/test_gromov.py b/test/test_gromov.py index 4b995d5..329f99c 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -3,6 +3,7 @@ # Author: Erwan Vautier # Nicolas Courty # Titouan Vayer +# Cédric Vincent-Cuaz # # License: MIT License @@ -26,6 +27,7 @@ def test_gromov(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -37,9 +39,10 @@ def test_gromov(nx): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0) - G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -56,9 +59,9 @@ def test_gromov(nx): gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) gwb = nx.to_numpy(gwb) - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False) gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) ) G = log['T'] @@ -91,6 +94,7 @@ def test_gromov_dtype_device(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -105,9 +109,10 @@ def test_gromov_dtype_device(nx): C2b = nx.from_numpy(C2, type_as=tp) pb = nx.from_numpy(p, type_as=tp) qb = nx.from_numpy(q, type_as=tp) + G0b = nx.from_numpy(G0, type_as=tp) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -123,6 +128,7 @@ def test_gromov_device_tf(): xt = xs[::-1].copy() p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) C1 /= C1.max() @@ -134,8 +140,9 @@ def test_gromov_device_tf(): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + G0b = nx.from_numpy(G0) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -145,6 +152,7 @@ def test_gromov_device_tf(): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0b) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) nx.assert_same_dtype_device(C1b, Gb) @@ -554,6 +562,7 @@ def test_fgw(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -569,9 +578,10 @@ def test_fgw(nx): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0) - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True) Gb = nx.to_numpy(Gb) # check constraints @@ -586,8 +596,8 @@ def test_fgw(nx): np.testing.assert_allclose( Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True) fgwb = nx.to_numpy(fgwb) G = log['T'] @@ -698,3 +708,523 @@ def test_fgw_barycenter(nx): Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + +def test_gromov_wasserstein_linear_unmixing(nx): + n = 10 + + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + p = ot.unif(n) + + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + Cdictb = nx.from_numpy(Cdict) + pb = nx.from_numpy(p) + tol = 10**(-5) + # Tests without regularization + reg = 0. + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + + reg = 0.001 + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 10 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_means = [C.mean() for C in Cs] + np.random.seed(0) + Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + Csb = [nx.from_numpy(C) for C in Cs] + psb = [nx.from_numpy(p) for p in ps] + qb = nx.from_numpy(q) + Cdict_initb = nx.from_numpy(Cdict_init) + + # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization + # > Compute initial reconstruction of samples on this random dictionary without backend + use_adam_optimizer = True + verbose = False + tol = 10**(-5) + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_init, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + initial_total_reconstruction += reconstruction + + # > Learn the dictionary using this init + Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, + epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary without backend + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction += reconstruction + + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + + Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b += reconstruction + + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization. + np.random.seed(0) + Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # Test: Same after going through backend + np.random.seed(0) + Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # and testing other optimization settings untested until now. + # We pass previously estimated dictionaries to speed up the process. + use_adam_optimizer = False + verbose = True + use_log = True + + np.random.seed(0) + Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # Test: Same after going through backend + np.random.seed(0) + Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis2 += reconstruction + + np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05) + + +def test_fused_gromov_wasserstein_linear_unmixing(nx): + + n = 10 + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + Ydict = np.stack([F, F]) + p = ot.unif(n) + + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + Fb = nx.from_numpy(F) + Cdictb = nx.from_numpy(Cdict) + Ydictb = nx.from_numpy(Ydict) + pb = nx.from_numpy(p) + # Tests without regularization + reg = 0. + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + reg = 0.001 + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_fused_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 10 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + Ys = [F.copy() for _ in range(n_samples)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_structure_means = [C.mean() for C in Cs] + np.random.seed(0) + Cdict_init = np.random.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys]) + Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2)) + + Csb = [nx.from_numpy(C) for C in Cs] + Ysb = [nx.from_numpy(Y) for Y in Ys] + psb = [nx.from_numpy(p) for p in ps] + qb = nx.from_numpy(q) + Cdict_initb = nx.from_numpy(Cdict_init) + Ydict_initb = nx.from_numpy(Ydict_init) + + # Test: Compute initial reconstruction of samples on this random dictionary + alpha = 0.5 + use_adam_optimizer = True + verbose = False + tol = 1e-05 + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q, + alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + initial_total_reconstruction += reconstruction + + # > Learn a dictionary using this given initialization and check that the reconstruction loss + # on the learned dictionary is lower than the one using its initialization. + Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction += reconstruction + # Compare both + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + + Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb, + epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b += reconstruction + + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03) + + # Test: Perform similar experiment without providing the initial dictionary being an optional input + np.random.seed(0) + Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # > Same after going through backend + np.random.seed(0) + Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis += reconstruction + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + + # Test: without using adam optimizer, with log and verbose set to True + use_adam_optimizer = False + verbose = True + use_log = True + + # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init. + np.random.seed(0) + Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # > Same after going through backend + np.random.seed(0) + Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis2 += reconstruction + + # > Compare results with/without backend + np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) -- cgit v1.2.3 From 82452e0f5f6dae05c7a1cc384e7a1fb62ae7e0d5 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 24 Mar 2022 14:13:25 +0100 Subject: [MRG] Add factored coupling (#358) * add gfactored ot * pep8 and add doc * add exmaple for factotred OT * final number of PR * correct test on backends * remove useless loss * better tests --- README.md | 4 +- RELEASES.md | 1 + docs/source/all.rst | 1 + examples/others/plot_factored_coupling.py | 86 ++++++++++++++++++ ot/__init__.py | 5 ++ ot/factored.py | 145 ++++++++++++++++++++++++++++++ ot/plot.py | 7 +- test/test_factored.py | 56 ++++++++++++ 8 files changed, 303 insertions(+), 2 deletions(-) create mode 100644 examples/others/plot_factored_coupling.py create mode 100644 ot/factored.py create mode 100644 test/test_factored.py (limited to 'ot/__init__.py') diff --git a/README.md b/README.md index c6bfd9c..ec5d221 100644 --- a/README.md +++ b/README.md @@ -305,4 +305,6 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020 [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. -[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. \ No newline at end of file +[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. + +[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 86b401a..c2bd0d1 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features +- Implementation of factored OT with emd and sinkhorn (PR #358). - A brand new logo for POT (PR #357) - Better list of related examples in quick start guide with `minigallery` (PR #334). - Add optional log-domain Sinkhorn implementation in WDA to support smaller values diff --git a/docs/source/all.rst b/docs/source/all.rst index 76d2ff5..3f7d029 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -29,6 +29,7 @@ API and modules partial sliced weak + factored .. autosummary:: :toctree: ../modules/generated/ diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py new file mode 100644 index 0000000..b5b1c9f --- /dev/null +++ b/examples/others/plot_factored_coupling.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +========================================== +Optimal transport with factored couplings +========================================== + +Illustration of the factored coupling OT between 2D empirical distributions + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +# %% +# Generate data an plot it +# ------------------------ + +# parameters and data generation + +np.random.seed(42) + +n = 100 # nb samples + +xs = np.random.rand(n, 2) - .5 + +xs = xs + np.sign(xs) + +xt = np.random.rand(n, 2) - .5 + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + + +# %% +# Compute Factore OT and exact OT solutions +# -------------------------------------- + +#%% EMD +M = ot.dist(xs, xt) +G0 = ot.emd(a, b, M) + +#%% factored OT OT + +Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4) + + +# %% +# Plot factored OT and exact OT solutions +# -------------------------------------- + +pl.figure(2, (14, 4)) + +pl.subplot(1, 3, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Exact OT with samples') + +pl.subplot(1, 3, 2) +ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5) +ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples') +pl.title('Factored OT with template samples') + +pl.subplot(1, 3, 3) +ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Factored OT low rank OT plan') diff --git a/ot/__init__.py b/ot/__init__.py index bda7a35..c5e1967 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -33,6 +33,7 @@ from . import partial from . import backend from . import regpath from . import weak +from . import factored # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -44,6 +45,9 @@ from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport +from .factored import factored_optimal_transport + + # utils functions from .utils import dist, unif, tic, toc, toq @@ -57,4 +61,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', + 'factored_optimal_transport', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/factored.py b/ot/factored.py new file mode 100644 index 0000000..abc2445 --- /dev/null +++ b/ot/factored.py @@ -0,0 +1,145 @@ +""" +Factored OT solvers (low rank, cost or OT plan) +""" + +# Author: Remi Flamary +# +# License: MIT License + +from .backend import get_backend +from .utils import dist +from .lp import emd +from .bregman import sinkhorn + +__all__ = ['factored_optimal_transport'] + + +def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs): + r"""Solves factored OT problem and return OT plans and intermediate distribution + + This function solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where : + + - :math:`\mu_a` and :math:`\mu_b` are empirical distributions. + - :math:`\mu` is an empirical distribution with r samples + + And returns the two OT plans between + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed in + :ref:`[39] `. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + Ga: array-like, shape (ns, r) + Optimal transportation matrix between source and the intermediate + distribution + Gb: array-like, shape (r, nt) + Optimal transportation matrix between the intermediate and target + distribution + X: array-like, shape (r, d) + Support of the intermediate distribution + log: dict, optional + If input log is true, a dictionary containing the cost and dual + variables and exit status + + + .. _references-factored: + References + ---------- + .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, + G., & Weed, J. (2019, April). Statistical optimal transport via factored + couplings. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2454-2465). PMLR. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General + regularized OT + """ + + nx = get_backend(Xa, Xb) + + n_a = Xa.shape[0] + n_b = Xb.shape[0] + d = Xa.shape[1] + + if a is None: + a = nx.ones((n_a), type_as=Xa) / n_a + if b is None: + b = nx.ones((n_b), type_as=Xb) / n_b + + if X0 is None: + X = nx.randn(r, d, type_as=Xa) + else: + X = X0 + + w = nx.ones(r, type_as=Xa) / r + + def solve_ot(X1, X2, w1, w2): + M = dist(X1, X2) + if reg > 0: + G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs) + log['cost'] = nx.sum(G * M) + return G, log + else: + return emd(w1, w2, M, log=True, **kwargs) + + norm_delta = [] + + # solve the barycenter + for i in range(numItermax): + + old_X = X + + # solve OT with template + Ga, loga = solve_ot(Xa, X, a, w) + Gb, logb = solve_ot(X, Xb, w, b) + + X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r + + delta = nx.norm(X - old_X) + if delta < stopThr: + break + if log: + norm_delta.append(delta) + + if log: + log_dic = {'delta_iter': norm_delta, + 'ua': loga['u'], + 'va': loga['v'], + 'ub': logb['u'], + 'vb': logb['v'], + 'costa': loga['cost'], + 'costb': logb['cost'], + } + return Ga, Gb, X, log_dic + + return Ga, Gb, X diff --git a/ot/plot.py b/ot/plot.py index 2208c90..8ade2eb 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): if ('color' not in kwargs) and ('c' not in kwargs): kwargs['color'] = 'k' mx = G.max() + if 'alpha' in kwargs: + scale = kwargs['alpha'] + del kwargs['alpha'] + else: + scale = 1 for i in range(xs.shape[0]): for j in range(xt.shape[0]): if G[i, j] / mx > thr: pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], - alpha=G[i, j] / mx, **kwargs) + alpha=G[i, j] / mx * scale, **kwargs) diff --git a/test/test_factored.py b/test/test_factored.py new file mode 100644 index 0000000..fd2fd01 --- /dev/null +++ b/test/test_factored.py @@ -0,0 +1,56 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary +# +# License: MIT License + +import ot +import numpy as np + + +def test_factored_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True) + + # check constraints + np.testing.assert_allclose(u, Ga.sum(1)) + np.testing.assert_allclose(u, Gb.sum(0)) + + Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True) + + # check constraints + np.testing.assert_allclose(u, Ga.sum(1)) + np.testing.assert_allclose(u, Gb.sum(0)) + + +def test_factored_ot_backends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + xs2 = nx.from_numpy(xs) + xt2 = nx.from_numpy(xt) + u2 = nx.from_numpy(u) + + Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10) + + # check constraints + np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) + np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) + + Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2) + + # check constraints + np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) + np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) -- cgit v1.2.3 From eccb1386eea52b94b82456d126bd20cbe3198e05 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 21 Apr 2022 16:34:01 +0200 Subject: [MRG] Release 8.2 (#365) * release text and number * add examples in release fil build wheels * switch gallery to release * add much needed contributors file * debug circleci * une line of logos * working logo * back to stable sphinx galery --- .circleci/config.yml | 5 ++- CONTRIBUTORS.md | 52 ++++++++++++++++++++++++++++ README.md | 35 ++++--------------- RELEASES.md | 57 ++++++++++++++++++++++++++----- docs/source/_static/images/logo_3ia.jpg | Bin 0 -> 25029 bytes docs/source/_static/images/logo_anr.jpg | Bin 0 -> 23493 bytes docs/source/_static/images/logo_cnrs.jpg | Bin 0 -> 6918 bytes docs/source/contributors.rst | 6 ++++ docs/source/index.rst | 1 + docs/source/releases.rst | 2 +- examples/others/plot_logo.py | 10 +++--- ot/__init__.py | 2 +- 12 files changed, 122 insertions(+), 48 deletions(-) create mode 100644 CONTRIBUTORS.md create mode 100644 docs/source/_static/images/logo_3ia.jpg create mode 100644 docs/source/_static/images/logo_anr.jpg create mode 100644 docs/source/_static/images/logo_cnrs.jpg create mode 100644 docs/source/contributors.rst (limited to 'ot/__init__.py') diff --git a/.circleci/config.yml b/.circleci/config.yml index 77ab45c..7e15a65 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -48,9 +48,8 @@ jobs: python -m pip install --user -e . python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements.txt python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt - python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler - - + python -m pip install --user --upgrade --progress-bar off ipython sphinx-gallery memory_profiler + # python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler - save_cache: key: pip-cache paths: diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 0000000..ab64fba --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,52 @@ + + +## Creators and Maintainers + +This toolbox has been created and is maintained by: + +* [Rémi Flamary](http://remi.flamary.com/) +* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/) + +## Contributors + +The contributors to this library are: + +* [Rémi Flamary](http://remi.flamary.com/) (EMD wrapper, Pytorch backend, DA + classes, conditional gradients, WDA, weak OT, linear OT mapping, documentation) +* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/) (Original sinkhorn, + Wasserstein barycenters and convolutional barycenters, 1D wasserstein) +* [Alexandre Gramfort](http://alexandre.gramfort.net/) (CI, documentation) +* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/) (Partial OT, + Unbalanced OT non-regularized) +* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation) +* [Léo Gautheron](https://github.com/aje) (Initial GPU implementation) +* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1) (DA classes) +* [Stanislas Chambon](https://slasnista.github.io/) (DA classes) +* [Antoine Rolet](https://arolet.github.io/) (EMD solver debug) +* Erwan Vautier (Gromov-Wasserstein) +* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers, + empirical sinkhorn) +* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) (Greenkhorn) +* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein, Fused-Gromov-Wasserstein) +* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters) +* [Romain Tavenard](https://rtavenar.github.io/) (1D Wasserstein) +* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) +* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) +* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) +* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) +* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) +* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) + +## Acknowledgments + +This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): + +* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab) +* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT) +* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD) +* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda) + +POT has benefited from the financing or manpower from the following partners: + +ANRCNRS3IA \ No newline at end of file diff --git a/README.md b/README.md index 1b50aeb..e2b33d9 100644 --- a/README.md +++ b/README.md @@ -180,35 +180,12 @@ This toolbox has been created and is maintained by * [Rémi Flamary](http://remi.flamary.com/) * [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/) -The contributors to this library are - -* [Alexandre Gramfort](http://alexandre.gramfort.net/) (CI, documentation) -* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/) (Partial OT) -* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation) -* [Léo Gautheron](https://github.com/aje) (Initial GPU implementation) -* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1) (DA classes) -* [Stanislas Chambon](https://slasnista.github.io/) (DA classes) -* [Antoine Rolet](https://arolet.github.io/) (EMD solver debug) -* Erwan Vautier (Gromov-Wasserstein) -* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers) -* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) -* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein) -* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters) -* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) -* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) -* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) -* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) -* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) -* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) -* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) -* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) - -This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): - -* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab) -* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT) -* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD) -* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda) +The numerous contributors to this library are listed [here](CONTRIBUTORS.md). + +POT has benefited from the financing or manpower from the following partners: + +ANRCNRS3IA + ## Contributions and code of conduct diff --git a/RELEASES.md b/RELEASES.md index 33d1ab6..be2192e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,22 +1,61 @@ # Releases -## 0.8.2dev Development +## 0.8.2 + +This releases introduces several new notable features. The less important +but most exiting one being that we now have a logo for the toolbox (color +and dark background) : + +![](https://pythonot.github.io/master/_images/logo.svg)![](https://pythonot.github.io/master/_static/logo_dark.svg) + +This logo is generated using with matplotlib and using the solution of an OT +problem provided by POT (with `ot.emd`). Generating the logo can be done with a +simple python script also provided in the [documentation gallery](https://pythonot.github.io/auto_examples/others/plot_logo.html#sphx-glr-auto-examples-others-plot-logo-py). + +New OT solvers include [Weak +OT](https://pythonot.github.io/gen_modules/ot.weak.html#ot.weak.weak_optimal_transport) + and [OT with factored +coupling](https://pythonot.github.io/gen_modules/ot.factored.html#ot.factored.factored_optimal_transport) +that can be used on large datasets. The [Majorization Minimization](https://pythonot.github.io/gen_modules/ot.unbalanced.html?highlight=mm_#ot.unbalanced.mm_unbalanced) solvers for +non-regularized Unbalanced OT are now also available. We also now provide an +implementation of [GW and FGW unmixing](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.gromov_wasserstein_linear_unmixing) and [dictionary learning](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.gromov_wasserstein_dictionary_learning). It is now +possible to use autodiff to solve entropic an quadratic regularized OT in the +dual for full or stochastic optimization thanks to the new functions to compute +the dual loss for [entropic](https://pythonot.github.io/gen_modules/ot.stochastic.html#ot.stochastic.loss_dual_entropic) and [quadratic](https://pythonot.github.io/gen_modules/ot.stochastic.html#ot.stochastic.loss_dual_quadratic) regularized OT and reconstruct the [OT +plan](https://pythonot.github.io/gen_modules/ot.stochastic.html#ot.stochastic.plan_dual_entropic) on part or all of the data. They can be used for instance to solve OT +problems with stochastic gradient or for estimating the [dual potentials as +neural networks](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html#sphx-glr-auto-examples-backends-plot-stoch-continuous-ot-pytorch-py). + +On the backend front, we now have backend compatible functions and classes in +the domain adaptation [`ot.da`](https://pythonot.github.io/gen_modules/ot.da.html#module-ot.da) and unbalanced OT [`ot.unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html) modules. This +means that the DA classes can be used on tensors from all compatible backends. +The [free support Wasserstein barycenter](https://pythonot.github.io/gen_modules/ot.lp.html?highlight=free%20support#ot.lp.free_support_barycenter) solver is now also backend compatible. + +Finally we have worked on the documentation to provide an update of existing +examples in the gallery and and several new examples including [GW dictionary +learning](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html#sphx-glr-auto-examples-gromov-plot-gromov-wasserstein-dictionary-learning-py) +[weak Optimal +Transport](https://pythonot.github.io/auto_examples/others/plot_WeakOT_VS_OT.html#sphx-glr-auto-examples-others-plot-weakot-vs-ot-py), +[NN based dual potentials +estimation](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html#sphx-glr-auto-examples-backends-plot-stoch-continuous-ot-pytorch-py) +and [Factored coupling OT](https://pythonot.github.io/auto_examples/others/plot_factored_coupling.html#sphx-glr-auto-examples-others-plot-factored-coupling-py). +. #### New features - Remove deprecated `ot.gpu` submodule (PR #361) -- Update examples in the gallery (PR #359). +- Update examples in the gallery (PR #359) - Add stochastic loss and OT plan computation for regularized OT and - backend examples(PR #360). -- Implementation of factored OT with emd and sinkhorn (PR #358). + backend examples(PR #360) +- Implementation of factored OT with emd and sinkhorn (PR #358) - A brand new logo for POT (PR #357) -- Better list of related examples in quick start guide with `minigallery` (PR #334). +- Better list of related examples in quick start guide with `minigallery` (PR #334) - Add optional log-domain Sinkhorn implementation in WDA to support smaller values - of the regularization parameter (PR #336). -- Backend implementation for `ot.lp.free_support_barycenter` (PR #340). -- Add weak OT solver + example (PR #341). -- Add backend support for Domain Adaptation and Unbalanced solvers (PR #343). + of the regularization parameter (PR #336) +- Backend implementation for `ot.lp.free_support_barycenter` (PR #340) +- Add weak OT solver + example (PR #341) +- Add backend support for Domain Adaptation and Unbalanced solvers (PR #343) - Add (F)GW linear dictionary learning solvers + example (PR #319) - Add links to related PR and Issues in the doc release page (PR #350) - Add new minimization-maximization algorithms for solving exact Unbalanced OT + example (PR #362) diff --git a/docs/source/_static/images/logo_3ia.jpg b/docs/source/_static/images/logo_3ia.jpg new file mode 100644 index 0000000..ecc56b2 Binary files /dev/null and b/docs/source/_static/images/logo_3ia.jpg differ diff --git a/docs/source/_static/images/logo_anr.jpg b/docs/source/_static/images/logo_anr.jpg new file mode 100644 index 0000000..dcef212 Binary files /dev/null and b/docs/source/_static/images/logo_anr.jpg differ diff --git a/docs/source/_static/images/logo_cnrs.jpg b/docs/source/_static/images/logo_cnrs.jpg new file mode 100644 index 0000000..902cf6f Binary files /dev/null and b/docs/source/_static/images/logo_cnrs.jpg differ diff --git a/docs/source/contributors.rst b/docs/source/contributors.rst new file mode 100644 index 0000000..f0acea6 --- /dev/null +++ b/docs/source/contributors.rst @@ -0,0 +1,6 @@ +Contributors +============ + +.. include:: ../../CONTRIBUTORS.md + :parser: myst_parser.sphinx_ + :start-line: 2 diff --git a/docs/source/index.rst b/docs/source/index.rst index 7ff7d22..3d53ef4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,7 @@ Contents auto_examples/index releases .github/CONTRIBUTING + contributors .github/CODE_OF_CONDUCT diff --git a/docs/source/releases.rst b/docs/source/releases.rst index 8250a4d..b2c7a44 100644 --- a/docs/source/releases.rst +++ b/docs/source/releases.rst @@ -3,4 +3,4 @@ Releases .. include:: ../../RELEASES.md :parser: myst_parser.sphinx_ - :start-line: 3 + :start-line: 2 diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py index 9414371..bb4f640 100644 --- a/examples/others/plot_logo.py +++ b/examples/others/plot_logo.py @@ -18,7 +18,7 @@ matplotlib and ploting teh solution of the EMD solver from POT. # sphinx_gallery_thumbnail_number = 1 -# %% +# %% Load modules import numpy as np import matplotlib.pyplot as pl import ot @@ -36,21 +36,21 @@ p2 = np.array([[1.5, 6], [2, 4], [2, 5], [1.5, 3], [0.5, 2], [.5, 1], ]) o1 = np.array([[0, 6.], [-1, 5], [-1.5, 4], [-1.5, 3], [-1, 2], [0, 1], ]) o2 = np.array([[1, 6.], [2, 5], [2.5, 4], [2.5, 3], [2, 2], [1, 1], ]) -# scaling and translation for letter O +# Scaling and translation for letter O o1[:, 0] += 6.4 o2[:, 0] += 6.4 o1[:, 0] *= 0.6 o2[:, 0] *= 0.6 -# letter T +# Letter T t1 = np.array([[-1, 6.], [-1, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) t2 = np.array([[1.5, 6.], [1.5, 5], [0.5, 4], [0.5, 3], [0.5, 2], [0.5, 1], ]) -# translatin the T +# Translating the T t1[:, 0] += 7.1 t2[:, 0] += 7.1 -# Cocatenate all letters +# Concatenate all letters x1 = np.concatenate((p1, o1, t1), axis=0) x2 = np.concatenate((p2, o2, t2), axis=0) diff --git a/ot/__init__.py b/ot/__init__.py index c5e1967..86ed94e 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -51,7 +51,7 @@ from .factored import factored_optimal_transport # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.8.2dev" +__version__ = "0.8.2" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', -- cgit v1.2.3