diff options
author | Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> | 2023-06-12 12:01:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-12 12:01:48 +0200 |
commit | 9076f02903ba2fb9ea9fe704764a755cad8dcd63 (patch) | |
tree | b7fda84880c5dabd1c441a1655741493e0683342 /ot/gromov/_gw.py | |
parent | f0dab2f684f4fc768fd50e0b70918e075dcdd0f3 (diff) |
[FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455)upstream/latest
* add entropic fgw + fgw bary + srgw + srfgw with tests
* add exemples for entropic srgw - srfgw solvers
* add PPA solvers for GW/FGW + complete previous commits
* update readme
* add tests
* add examples + tests + warning in entropic solvers + releases
* reduce testing runtimes for test_gromov
* fix conflicts
* optional marginals
* improve coverage
* gromov doc harmonization
* fix pep8
* complete optional marginal for entropic srfgw
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/gromov/_gw.py')
-rw-r--r-- | ot/gromov/_gw.py | 318 |
1 files changed, 163 insertions, 155 deletions
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index cdfa9a3..adf6b82 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -16,14 +16,14 @@ import numpy as np from ..utils import dist, UndefinedParameter, list_to_array from ..optim import cg, line_search_armijo, solve_1d_linesearch_quad -from ..utils import check_random_state +from ..utils import check_random_state, unif from ..backend import get_backend, NumpyBackend from ._utils import init_matrix, gwloss, gwggrad -from ._utils import update_square_loss, update_kl_loss +from ._utils import update_square_loss, update_kl_loss, update_feature_matrix -def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, +def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -31,7 +31,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log The function solves the following optimization problem: .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} @@ -60,11 +60,13 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : str + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss' symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. @@ -112,15 +114,24 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log distance between networks and stable network invariants. Information and Inference: A Journal of the IMA, 8(4), 757-787. """ - p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - if G0 is None: - nx = get_backend(p0, q0, C10, C20) + arr = [C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + if q is not None: + arr.append(list_to_array(q)) else: + q = unif(C2.shape[0], type_as=C2) + if G0 is not None: G0_ = G0 - nx = get_backend(p0, q0, C10, C20, G0_) - p = nx.to_numpy(p) - q = nx.to_numpy(q) + arr.append(G0) + + nx = get_backend(*arr) + p0, q0, C10, C20 = p, q, C1, C2 + + p = nx.to_numpy(p0) + q = nx.to_numpy(q0) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) if symmetric is None: @@ -168,7 +179,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, +def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Returns the Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -176,7 +187,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo The function solves the following optimization problem: .. math:: - GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} + \mathbf{GW} = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} @@ -209,10 +220,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) + p : array-like, shape (ns,), optional Distribution in the source space. - q : array-like, shape (nt,) + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional Distribution in the target space. + If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' symmetric : bool, optional @@ -266,6 +279,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo # simple get_backend as the full one will be handled in gromov_wasserstein nx = get_backend(C1, C2) + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) + if q is None: + q = unif(C2.shape[0], type_as=C2) + T, log_gw = gromov_wasserstein( C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0, max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) @@ -286,20 +305,20 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo return gw -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5, +def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5, armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`) .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + \mathbf{T}^T \mathbf{1} &= \mathbf{q} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 where : @@ -323,10 +342,12 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= Metric cost matrix representative of the structure in the source space C2 : array-like, shape (nt, nt) Metric cost matrix representative of the structure in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. loss_fun : str, optional Loss function used for the solver symmetric : bool, optional @@ -354,7 +375,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= Returns ------- - gamma : array-like, shape (`ns`, `nt`) + T : array-like, shape (`ns`, `nt`) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. @@ -372,16 +393,24 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= distance between networks and stable network invariants. Information and Inference: A Journal of the IMA, 8(4), 757-787. """ - p, q = list_to_array(p, q) - p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha - if G0 is None: - nx = get_backend(p0, q0, C10, C20, M0) + arr = [C1, C2, M] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + if q is not None: + arr.append(list_to_array(q)) else: + q = unif(C2.shape[0], type_as=C2) + if G0 is not None: G0_ = G0 - nx = get_backend(p0, q0, C10, C20, M0, G0_) + arr.append(G0) - p = nx.to_numpy(p) - q = nx.to_numpy(q) + nx = get_backend(*arr) + p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha + + p = nx.to_numpy(p0) + q = nx.to_numpy(q0) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) M = nx.to_numpy(M0) @@ -433,20 +462,20 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) -def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5, +def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5, armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`) .. math:: - \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} + \mathbf{GW} = \min_\mathbf{T} \quad (1 - \alpha) \langle \mathbf(T), \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf(T)\mathbf{1} &= \mathbf{p} - \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + \mathbf(T)^T \mathbf{1} &= \mathbf{q} - \mathbf{\gamma} &\geq 0 + \mathbf(T) &\geq 0 where : @@ -474,10 +503,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric Metric cost matrix representative of the structure in the source space. C2 : array-like, shape (nt, nt) Metric cost matrix representative of the structure in the target space. - p : array-like, shape (ns,) + p : array-like, shape (ns,), optional Distribution in the source space. - q : array-like, shape (nt,) + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional Distribution in the target space. + If let to its default value None, uniform distribution is taken. loss_fun : str, optional Loss function used for the solver. symmetric : bool, optional @@ -529,6 +560,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric """ nx = get_backend(C1, C2, M) + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) + if q is None: + q = unif(C2.shape[0], type_as=C2) + T, log_fgw = fused_gromov_wasserstein( M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True, max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) @@ -626,9 +663,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, return alpha, 1, cost_G -def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=False, - max_iter=1000, tol=1e-9, verbose=False, log=False, - init_C=None, random_state=None, **kwargs): +def gromov_barycenters( + N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss', symmetric=True, armijo=False, + max_iter=1000, tol=1e-9, warmstartT=False, verbose=False, log=False, + init_C=None, random_state=None, **kwargs): r""" Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` @@ -649,13 +687,16 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F Size of the targeted barycenter Cs : list of S array-like of shape (ns, ns) Metric cost matrices - ps : list of S array-like of shape (ns,) - Sample weights in the `S` spaces - p : array-like, shape (N,) - Weights in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights - loss_fun : callable + ps : list of S array-like of shape (ns,), optional + Sample weights in the `S` spaces. + If let to its default value None, uniform distributions are taken. + p : array-like, shape (N,), optional + Weights in the targeted barycenter. + If let to its default value None, uniform distribution is taken. + lambdas : list of float, optional + List of the `S` spaces' weights. + If let to its default value None, uniform weights are taken. + loss_fun : callable, optional tensor-matrix multiplication function based on specific loss function symmetric : bool, optional. Either structures are to be assumed symmetric or not. Default value is True. @@ -668,6 +709,9 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F Max number of iterations tol : float, optional Stop threshold on relative error (>0) + warmstartT: bool, optional + Either to perform warmstart of transport plans in the successive + fused gromov-wasserstein transport problems.s verbose : bool, optional Print information along iterations. log : bool, optional @@ -692,11 +736,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F """ Cs = list_to_array(*Cs) - ps = list_to_array(*ps) - p = list_to_array(p) - nx = get_backend(*Cs, *ps, p) + arr = [*Cs] + if ps is not None: + arr += list_to_array(*ps) + else: + ps = [unif(C.shape[0], type_as=C) for C in Cs] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(N, type_as=Cs[0]) + + nx = get_backend(*arr) S = len(Cs) + if lambdas is None: + lambdas = [1. / S] * S # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: @@ -714,13 +768,19 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F cpt = 0 err = 1 + if warmstartT: + T = [None] * S error = [] while (err > tol and cpt < max_iter): Cprev = C - T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] + if warmstartT: + T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s], + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] + else: + T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=None, + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) @@ -747,9 +807,11 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F return C -def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, - p=None, loss_fun='square_loss', armijo=False, symmetric=True, max_iter=100, tol=1e-9, - verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs): +def fgw_barycenters( + N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False, + fixed_features=False, p=None, loss_fun='square_loss', armijo=False, + symmetric=True, max_iter=100, tol=1e-9, warmstartT=False, verbose=False, + log=False, init_C=None, init_X=None, random_state=None, **kwargs): r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>` Parameters @@ -760,16 +822,21 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Features of all samples Cs : list of array-like, each element has shape (ns,ns) Structure matrices of all samples - ps : list of array-like, each element has shape (ns,) + ps : list of array-like, each element has shape (ns,), optional Masses of all samples. - lambdas : list of float - List of the `S` spaces' weights - alpha : float - Alpha parameter for the fgw distance + If let to its default value None, uniform distributions are taken. + lambdas : list of float, optional + List of the `S` spaces' weights. + If let to its default value None, uniform weights are taken. + alpha : float, optional + Alpha parameter for the fgw distance. fixed_structure : bool Whether to fix the structure of the barycenter during the updates fixed_features : bool Whether to fix the feature of the barycenter during the updates + p : array-like, shape (N,), optional + Weights in the targeted barycenter. + If let to its default value None, uniform distribution is taken. loss_fun : str Loss function used for the solver either 'square_loss' or 'kl_loss' symmetric : bool, optional @@ -779,6 +846,9 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Max number of iterations tol : float, optional Stop threshold on relative error (>0) + warmstartT: bool, optional + Either to perform warmstart of transport plans in the successive + fused gromov-wasserstein transport problems. verbose : bool, optional Print information along iterations. log : bool, optional @@ -814,15 +884,24 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ International Conference on Machine Learning (ICML). 2019. """ Cs = list_to_array(*Cs) - ps = list_to_array(*ps) Ys = list_to_array(*Ys) - p = list_to_array(p) - nx = get_backend(*Cs, *Ys, *ps) + arr = [*Cs, *Ys] + if ps is not None: + arr += list_to_array(*ps) + else: + ps = [unif(C.shape[0], type_as=C) for C in Cs] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(N, type_as=Cs[0]) + + nx = get_backend(*arr) S = len(Cs) + if lambdas is None: + lambdas = [1. / S] * S + d = Ys[0].shape[1] # dimension on the node features - if p is None: - p = nx.ones(N, type_as=Cs[0]) / N if fixed_structure: if init_C is None: @@ -877,13 +956,21 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: + T_temp = [t.T for t in T] if loss_fun == 'square_loss': - T_temp = [t.T for t in T] - C = update_structure_matrix(p, lambdas, T_temp, Cs) + C = update_square_loss(p, lambdas, T_temp, Cs) - T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] + elif loss_fun == 'kl_loss': + C = update_kl_loss(p, lambdas, T_temp, Cs) + if warmstartT: + T = [fused_gromov_wasserstein( + Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, + G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] + else: + T = [fused_gromov_wasserstein( + Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, + G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] # T is N,ns err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) err_structure = nx.norm(C - Cprev) @@ -910,82 +997,3 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ return X, C, log_ else: return X, C - - -def update_structure_matrix(p, lambdas, T, Cs): - r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. - - It is calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - Masses in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights. - T : list of S array-like of shape (ns, N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape (ns, ns) - Metric cost matrices. - - Returns - ------- - C : array-like, shape (`nt`, `nt`) - Updated :math:`\mathbf{C}` matrix. - """ - p = list_to_array(p) - T = list_to_array(*T) - Cs = list_to_array(*Cs) - nx = get_backend(*Cs, *T, p) - - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - return tmpsum / ppt - - -def update_feature_matrix(lambdas, Ys, Ts, p): - r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. - - - See "Solving the barycenter problem with Block Coordinate Descent (BCD)" - in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - masses in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights - Ts : list of S array-like, shape (ns,N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration - Ys : list of S array-like, shape (d,ns) - The features. - - Returns - ------- - X : array-like, shape (`d`, `N`) - - - .. _references-update-feature-matrix: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - p = list_to_array(p) - Ts = list_to_array(*Ts) - Ys = list_to_array(*Ys) - nx = get_backend(*Ys, *Ts, p) - - p = 1. / p - tmpsum = sum([ - lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] - for s in range(len(Ts)) - ]) - return tmpsum |