summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoreloitanguy <69361683+eloitanguy@users.noreply.github.com>2022-05-06 08:43:21 +0200
committerGitHub <noreply@github.com>2022-05-06 08:43:21 +0200
commitccc076e0fc535b2c734214c0ac1936e9e2cbeb62 (patch)
treeb5a20af6fabcaefa0de4bc27afd9049bd15612f6
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
[WIP] Generalized Wasserstein Barycenters (#372)
* GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo
-rw-r--r--CONTRIBUTORS.md1
-rw-r--r--README.md6
-rw-r--r--RELEASES.md6
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py6
-rw-r--r--examples/barycenters/plot_generalized_free_support_barycenter.py152
-rw-r--r--ot/__init__.py2
-rw-r--r--ot/lp/__init__.py145
-rw-r--r--test/test_ot.py40
-rwxr-xr-xtest/test_partial.py1
9 files changed, 334 insertions, 25 deletions
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index ab64fba..0909b14 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -37,6 +37,7 @@ The contributors to this library are:
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
+* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
## Acknowledgments
diff --git a/README.md b/README.md
index e2b33d9..12340d5 100644
--- a/README.md
+++ b/README.md
@@ -288,4 +288,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.
-[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) \ No newline at end of file
+[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors)
+
+[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021.
+
+[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file
diff --git a/RELEASES.md b/RELEASES.md
index be2192e..3461832 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -1,5 +1,11 @@
# Releases
+## 0.8.3dev
+
+#### New features
+
+- Added Generalized Wasserstein Barycenter solver + example (PR #372)
+
## 0.8.2
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index cf5d64d..59e0042 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -103,7 +103,7 @@ ax = pl.axis()
# %%
# Animate trajectories of the gradient flow along iteration
-# -------------------------------------------------------
+# ---------------------------------------------------------
pl.figure(3, (8, 4))
@@ -122,7 +122,7 @@ ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100,
# %%
# Compute the Sliced Wasserstein Barycenter
-#
+# -----------------------------------------
x1_torch = torch.tensor(x1).to(device=device)
x3_torch = torch.tensor(x3).to(device=device)
xbinit = np.random.randn(500, 2) * 10 + 16
@@ -169,7 +169,7 @@ ax = pl.axis()
# %%
# Animate trajectories of the barycenter along gradient descent
-# -------------------------------------------------------
+# -------------------------------------------------------------
pl.figure(5, (8, 4))
diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py
new file mode 100644
index 0000000..9af1953
--- /dev/null
+++ b/examples/barycenters/plot_generalized_free_support_barycenter.py
@@ -0,0 +1,152 @@
+# -*- coding: utf-8 -*-
+"""
+=======================================
+Generalized Wasserstein Barycenter Demo
+=======================================
+
+This example illustrates the computation of Generalized Wasserstein Barycenter
+as proposed in [42].
+
+
+[42] Delon, J., Gozlan, N., and Saint-Dizier, A..
+Generalized Wasserstein barycenters between probability measures living on different subspaces.
+arXiv preprint arXiv:2105.09755, 2021.
+
+"""
+
+# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.pylab as pl
+import ot
+import matplotlib.animation as animation
+
+########################
+# Generate and plot data
+# ----------------------
+
+# Input measures
+sub_sample_factor = 8
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
+I3 = pl.imread('../../data/heart.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
+
+sz = I1.shape[0]
+UU, VV = np.meshgrid(np.arange(sz), np.arange(sz))
+
+# Input measure locations in their respective 2D spaces
+X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]]
+
+# Input measure weights
+a_list = [ot.unif(x.shape[0]) for x in X_list]
+
+# Projections 3D -> 2D
+P1 = np.array([[1, 0, 0], [0, 1, 0]])
+P2 = np.array([[0, 1, 0], [0, 0, 1]])
+P3 = np.array([[1, 0, 0], [0, 0, 1]])
+P_list = [P1, P2, P3]
+
+# Barycenter weights
+weights = np.array([1 / 3, 1 / 3, 1 / 3])
+
+# Number of barycenter points to compute
+n_samples_bary = 150
+
+# Send the input measures into 3D space for visualisation
+X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)]
+
+# Plot the input data
+fig = plt.figure(figsize=(3, 3))
+axis = fig.add_subplot(1, 1, 1, projection="3d")
+for Xi in X_visu:
+ axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+axis.view_init(azim=45)
+axis.set_xticks([])
+axis.set_yticks([])
+axis.set_zticks([])
+plt.show()
+
+#################################
+# Barycenter computation and plot
+# -------------------------------
+
+Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary)
+fig = plt.figure(figsize=(3, 3))
+
+axis = fig.add_subplot(1, 1, 1, projection="3d")
+for Xi in X_visu:
+ axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+axis.view_init(azim=45)
+axis.set_xticks([])
+axis.set_yticks([])
+axis.set_zticks([])
+plt.show()
+
+
+#############################
+# Plotting projection matches
+# ---------------------------
+
+fig = plt.figure(figsize=(9, 3))
+
+ax = fig.add_subplot(1, 3, 1, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=0, azim=0)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+ax = fig.add_subplot(1, 3, 2, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=0, azim=90)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+ax = fig.add_subplot(1, 3, 3, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=90, azim=0)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+plt.tight_layout()
+plt.show()
+
+##############################################
+# Rotation animation
+# --------------------------------------------
+
+fig = plt.figure(figsize=(7, 7))
+ax = fig.add_subplot(1, 1, 1, projection="3d")
+
+
+def _init():
+ for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ ax.view_init(elev=0, azim=0)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_zticks([])
+ return fig,
+
+
+def _update_plot(i):
+ ax.view_init(elev=i, azim=4 * i)
+ return fig,
+
+
+ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=90, interval=50, blit=True, repeat_delay=2000)
diff --git a/ot/__init__.py b/ot/__init__.py
index 86ed94e..15d8351 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -51,7 +51,7 @@ from .factored import factored_optimal_transport
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "0.8.2"
+__version__ = "0.8.3dev"
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 390c32d..572781d 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -26,10 +26,8 @@ from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
-
-
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d']
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter']
def check_number_threads(numThreads):
@@ -517,8 +515,8 @@ def emd2(a, b, M, processes=1,
log['warning'] = result_code_string
log['result_code'] = result_code
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
- (a0, b0, M0), (log['u'] - nx.mean(log['u']),
- log['v'] - nx.mean(log['v']), G))
+ (a0, b0, M0), (log['u'] - nx.mean(log['u']),
+ log['v'] - nx.mean(log['v']), G))
return [cost, log]
else:
def f(b):
@@ -572,18 +570,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
where :
- :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
- - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
- - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
+ - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+ - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
- :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
- This problem is considered in :ref:`[1] <references-free-support-barycenter>` (Algorithm 2).
+ This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
There are two differences with the following codes:
- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
- :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
implementation of the fixed-point algorithm of
- :ref:`[2] <references-free-support-barycenter>` proposed in the continuous setting.
+ :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
Parameters
----------
@@ -623,13 +621,13 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
.. _references-free-support-barycenter:
References
----------
- .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+ .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
- .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+ .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
"""
- nx = get_backend(*measures_locations,*measures_weights,X_init)
+ nx = get_backend(*measures_locations, *measures_weights, X_init)
iter_count = 0
@@ -637,9 +635,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
- b = nx.ones((k,),type_as=X_init) / k
+ b = nx.ones((k,), type_as=X_init) / k
if weights is None:
- weights = nx.ones((N,),type_as=X_init) / N
+ weights = nx.ones((N,), type_as=X_init) / N
X = X_init
@@ -650,15 +648,14 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
while (displacement_square_norm > stopThr and iter_count < numItermax):
- T_sum = nx.zeros((k, d),type_as=X_init)
-
+ T_sum = nx.zeros((k, d), type_as=X_init)
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
- T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i)
+ T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)
- displacement_square_norm = nx.sum((T_sum - X)**2)
+ displacement_square_norm = nx.sum((T_sum - X) ** 2)
if log:
displacement_square_norms.append(displacement_square_norm)
@@ -675,3 +672,111 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
else:
return X
+
+def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None,
+ numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0):
+ r"""
+ Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with
+ a fixed amount of points of uniform weights) whose respective projections fit the input measures.
+ More formally:
+
+ .. math::
+ \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma)
+
+ where :
+
+ - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d`
+ - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter
+ - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}`
+ - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex)
+ - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations
+ - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex)
+ - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}`
+
+ As show by :ref:`[42] <references-generalized-free-support-barycenter>`,
+ this problem can be re-written as a Wasserstein Barycenter problem,
+ which we solve using the free support method :ref:`[20] <references-generalized-free-support-barycenter>`
+ (Algorithm 2).
+
+ Parameters
+ ----------
+ X_list : list of p (k_i,d_i) array-like
+ Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ a_list : list of p (k_i,) array-like
+ Measure weights: each element is a vector (k_i) on the simplex
+ P_list : list of p (d_i,d) array-like
+ Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}`
+ n_samples_bary : int
+ Number of barycenter points
+ Y_init : (n_samples_bary,d) array-like
+ Initialization of the support locations (on `k` atoms) of the barycenter
+ b : (n_samples_bary,) array-like
+ Initialization of the weights of the barycenter measure (on the simplex)
+ weights : (p,) array-like
+ Initialization of the coefficients of the barycenter (on the simplex)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
+ eps: Stability coefficient for the change of variable matrix inversion
+ If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix
+ inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense)
+
+
+ Returns
+ -------
+ Y : (n_samples_bary,d) array-like
+ Support locations (on n_samples_bary atoms) of the barycenter
+
+
+ .. _references-generalized-free-support-barycenter:
+ References
+ ----------
+ .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021.
+
+ """
+ nx = get_backend(*X_list, *a_list, *P_list)
+ d = P_list[0].shape[1]
+ p = len(X_list)
+
+ if weights is None:
+ weights = nx.ones(p, type_as=X_list[0]) / p
+
+ # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB)
+ A = eps * nx.eye(d, type_as=X_list[0]) # if eps nonzero: will force the invertibility of A
+ for (P_i, lambda_i) in zip(P_list, weights):
+ A = A + lambda_i * P_i.T @ P_i
+ B = nx.inv(nx.sqrtm(A))
+
+ Z_list = [x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)] # change of variables -> (WB) problem on Z
+
+ if Y_init is None:
+ Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0])
+
+ if b is None:
+ b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised
+
+ out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads)
+
+ if log: # unpack
+ Y, log_dict = out
+ else:
+ Y = out
+ log_dict = None
+ Y = Y @ B.T # return to the Generalised WB formulation
+
+ if log:
+ return Y, log_dict
+ else:
+ return Y
diff --git a/test/test_ot.py b/test/test_ot.py
index bf832f6..ba3ef6a 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -320,6 +320,46 @@ def test_free_support_barycenter_backends(nx):
np.testing.assert_allclose(X, nx.to_numpy(X2))
+def test_generalised_free_support_barycenter():
+ np.random.seed(42) # random inits
+ X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0
+ a = [np.array([1.]), np.array([1.])]
+
+ P = [np.eye(2), np.eye(2)]
+
+ Y_init = np.array([-12., 7.]).reshape((1, 2))
+
+ # obvious barycenter location between two 2D diracs
+ Y_true = np.array([0., .0]).reshape((1, 2))
+
+ # test without log and no init
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+ # test with log and init
+ Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+
+def test_generalised_free_support_barycenter_backends(nx):
+ np.random.seed(42)
+ X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ a = [np.array([1.]), np.array([1.])]
+ P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ Y_init = np.array([-12.]).reshape((1, 1))
+
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init)
+
+ X2 = nx.from_numpy(*X)
+ a2 = nx.from_numpy(*a)
+ P2 = nx.from_numpy(*P)
+ Y_init2 = nx.from_numpy(Y_init)
+
+ Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2)
+
+ np.testing.assert_allclose(Y, nx.to_numpy(Y2))
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]
diff --git a/test/test_partial.py b/test/test_partial.py
index 97c611b..e07377b 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -137,6 +137,7 @@ def test_partial_wasserstein():
def test_partial_gromov_wasserstein():
+ np.random.seed(42)
n_samples = 20 # nb samples
n_noise = 10 # nb of samples (noise)