summaryrefslogtreecommitdiff
path: root/ot/gromov/_gw.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gromov/_gw.py')
-rw-r--r--ot/gromov/_gw.py318
1 files changed, 163 insertions, 155 deletions
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py
index cdfa9a3..adf6b82 100644
--- a/ot/gromov/_gw.py
+++ b/ot/gromov/_gw.py
@@ -16,14 +16,14 @@ import numpy as np
from ..utils import dist, UndefinedParameter, list_to_array
from ..optim import cg, line_search_armijo, solve_1d_linesearch_quad
-from ..utils import check_random_state
+from ..utils import check_random_state, unif
from ..backend import get_backend, NumpyBackend
from ._utils import init_matrix, gwloss, gwggrad
-from ._utils import update_square_loss, update_kl_loss
+from ._utils import update_square_loss, update_kl_loss, update_feature_matrix
-def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -31,7 +31,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log
The function solves the following optimization problem:
.. math::
- \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
@@ -60,11 +60,13 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : str
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
+ Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str, optional
loss function used for the solver either 'square_loss' or 'kl_loss'
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
@@ -112,15 +114,24 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log
distance between networks and stable network invariants.
Information and Inference: A Journal of the IMA, 8(4), 757-787.
"""
- p, q = list_to_array(p, q)
- p0, q0, C10, C20 = p, q, C1, C2
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20)
+ arr = [C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is not None:
+ arr.append(list_to_array(q))
else:
+ q = unif(C2.shape[0], type_as=C2)
+ if G0 is not None:
G0_ = G0
- nx = get_backend(p0, q0, C10, C20, G0_)
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
+ arr.append(G0)
+
+ nx = get_backend(*arr)
+ p0, q0, C10, C20 = p, q, C1, C2
+
+ p = nx.to_numpy(p0)
+ q = nx.to_numpy(q0)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
if symmetric is None:
@@ -168,7 +179,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log
return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Returns the Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -176,7 +187,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo
The function solves the following optimization problem:
.. math::
- GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ \mathbf{GW} = \min_\mathbf{T} \quad \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
@@ -209,10 +220,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
+ p : array-like, shape (ns,), optional
Distribution in the source space.
- q : array-like, shape (nt,)
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
symmetric : bool, optional
@@ -266,6 +279,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo
# simple get_backend as the full one will be handled in gromov_wasserstein
nx = get_backend(C1, C2)
+ # init marginals if set as None
+ if p is None:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is None:
+ q = unif(C2.shape[0], type_as=C2)
+
T, log_gw = gromov_wasserstein(
C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0,
max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
@@ -286,20 +305,20 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo
return gw
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
.. math::
- \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
where :
@@ -323,10 +342,12 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
Metric cost matrix representative of the structure in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix representative of the structure in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
+ Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
Loss function used for the solver
symmetric : bool, optional
@@ -354,7 +375,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
Returns
-------
- gamma : array-like, shape (`ns`, `nt`)
+ T : array-like, shape (`ns`, `nt`)
Optimal transportation matrix for the given parameters.
log : dict
Log dictionary return only if log==True in parameters.
@@ -372,16 +393,24 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
distance between networks and stable network invariants.
Information and Inference: A Journal of the IMA, 8(4), 757-787.
"""
- p, q = list_to_array(p, q)
- p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20, M0)
+ arr = [C1, C2, M]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is not None:
+ arr.append(list_to_array(q))
else:
+ q = unif(C2.shape[0], type_as=C2)
+ if G0 is not None:
G0_ = G0
- nx = get_backend(p0, q0, C10, C20, M0, G0_)
+ arr.append(G0)
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
+ nx = get_backend(*arr)
+ p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha
+
+ p = nx.to_numpy(p0)
+ q = nx.to_numpy(q0)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
M = nx.to_numpy(M0)
@@ -433,20 +462,20 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
-def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
.. math::
- \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ \mathbf{GW} = \min_\mathbf{T} \quad (1 - \alpha) \langle \mathbf(T), \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf(T)\mathbf{1} &= \mathbf{p}
- \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+ \mathbf(T)^T \mathbf{1} &= \mathbf{q}
- \mathbf{\gamma} &\geq 0
+ \mathbf(T) &\geq 0
where :
@@ -474,10 +503,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric
Metric cost matrix representative of the structure in the source space.
C2 : array-like, shape (nt, nt)
Metric cost matrix representative of the structure in the target space.
- p : array-like, shape (ns,)
+ p : array-like, shape (ns,), optional
Distribution in the source space.
- q : array-like, shape (nt,)
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
Loss function used for the solver.
symmetric : bool, optional
@@ -529,6 +560,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric
"""
nx = get_backend(C1, C2, M)
+ # init marginals if set as None
+ if p is None:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is None:
+ q = unif(C2.shape[0], type_as=C2)
+
T, log_fgw = fused_gromov_wasserstein(
M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True,
max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
@@ -626,9 +663,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
return alpha, 1, cost_G
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=False,
- max_iter=1000, tol=1e-9, verbose=False, log=False,
- init_C=None, random_state=None, **kwargs):
+def gromov_barycenters(
+ N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss', symmetric=True, armijo=False,
+ max_iter=1000, tol=1e-9, warmstartT=False, verbose=False, log=False,
+ init_C=None, random_state=None, **kwargs):
r"""
Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
@@ -649,13 +687,16 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
Size of the targeted barycenter
Cs : list of S array-like of shape (ns, ns)
Metric cost matrices
- ps : list of S array-like of shape (ns,)
- Sample weights in the `S` spaces
- p : array-like, shape (N,)
- Weights in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights
- loss_fun : callable
+ ps : list of S array-like of shape (ns,), optional
+ Sample weights in the `S` spaces.
+ If let to its default value None, uniform distributions are taken.
+ p : array-like, shape (N,), optional
+ Weights in the targeted barycenter.
+ If let to its default value None, uniform distribution is taken.
+ lambdas : list of float, optional
+ List of the `S` spaces' weights.
+ If let to its default value None, uniform weights are taken.
+ loss_fun : callable, optional
tensor-matrix multiplication function based on specific loss function
symmetric : bool, optional.
Either structures are to be assumed symmetric or not. Default value is True.
@@ -668,6 +709,9 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
Max number of iterations
tol : float, optional
Stop threshold on relative error (>0)
+ warmstartT: bool, optional
+ Either to perform warmstart of transport plans in the successive
+ fused gromov-wasserstein transport problems.s
verbose : bool, optional
Print information along iterations.
log : bool, optional
@@ -692,11 +736,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
"""
Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
- p = list_to_array(p)
- nx = get_backend(*Cs, *ps, p)
+ arr = [*Cs]
+ if ps is not None:
+ arr += list_to_array(*ps)
+ else:
+ ps = [unif(C.shape[0], type_as=C) for C in Cs]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(N, type_as=Cs[0])
+
+ nx = get_backend(*arr)
S = len(Cs)
+ if lambdas is None:
+ lambdas = [1. / S] * S
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
@@ -714,13 +768,19 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
cpt = 0
err = 1
+ if warmstartT:
+ T = [None] * S
error = []
while (err > tol and cpt < max_iter):
Cprev = C
- T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo,
- max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
+ if warmstartT:
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s],
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
+ else:
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=None,
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -747,9 +807,11 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
return C
-def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
- p=None, loss_fun='square_loss', armijo=False, symmetric=True, max_iter=100, tol=1e-9,
- verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs):
+def fgw_barycenters(
+ N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False,
+ fixed_features=False, p=None, loss_fun='square_loss', armijo=False,
+ symmetric=True, max_iter=100, tol=1e-9, warmstartT=False, verbose=False,
+ log=False, init_C=None, init_X=None, random_state=None, **kwargs):
r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
Parameters
@@ -760,16 +822,21 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Features of all samples
Cs : list of array-like, each element has shape (ns,ns)
Structure matrices of all samples
- ps : list of array-like, each element has shape (ns,)
+ ps : list of array-like, each element has shape (ns,), optional
Masses of all samples.
- lambdas : list of float
- List of the `S` spaces' weights
- alpha : float
- Alpha parameter for the fgw distance
+ If let to its default value None, uniform distributions are taken.
+ lambdas : list of float, optional
+ List of the `S` spaces' weights.
+ If let to its default value None, uniform weights are taken.
+ alpha : float, optional
+ Alpha parameter for the fgw distance.
fixed_structure : bool
Whether to fix the structure of the barycenter during the updates
fixed_features : bool
Whether to fix the feature of the barycenter during the updates
+ p : array-like, shape (N,), optional
+ Weights in the targeted barycenter.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
Loss function used for the solver either 'square_loss' or 'kl_loss'
symmetric : bool, optional
@@ -779,6 +846,9 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Max number of iterations
tol : float, optional
Stop threshold on relative error (>0)
+ warmstartT: bool, optional
+ Either to perform warmstart of transport plans in the successive
+ fused gromov-wasserstein transport problems.
verbose : bool, optional
Print information along iterations.
log : bool, optional
@@ -814,15 +884,24 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
International Conference on Machine Learning (ICML). 2019.
"""
Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
Ys = list_to_array(*Ys)
- p = list_to_array(p)
- nx = get_backend(*Cs, *Ys, *ps)
+ arr = [*Cs, *Ys]
+ if ps is not None:
+ arr += list_to_array(*ps)
+ else:
+ ps = [unif(C.shape[0], type_as=C) for C in Cs]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(N, type_as=Cs[0])
+
+ nx = get_backend(*arr)
S = len(Cs)
+ if lambdas is None:
+ lambdas = [1. / S] * S
+
d = Ys[0].shape[1] # dimension on the node features
- if p is None:
- p = nx.ones(N, type_as=Cs[0]) / N
if fixed_structure:
if init_C is None:
@@ -877,13 +956,21 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
if not fixed_structure:
+ T_temp = [t.T for t in T]
if loss_fun == 'square_loss':
- T_temp = [t.T for t in T]
- C = update_structure_matrix(p, lambdas, T_temp, Cs)
+ C = update_square_loss(p, lambdas, T_temp, Cs)
- T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
- max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T_temp, Cs)
+ if warmstartT:
+ T = [fused_gromov_wasserstein(
+ Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
+ G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
+ else:
+ T = [fused_gromov_wasserstein(
+ Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
+ G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
# T is N,ns
err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
err_structure = nx.norm(C - Cprev)
@@ -910,82 +997,3 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
return X, C, log_
else:
return X, C
-
-
-def update_structure_matrix(p, lambdas, T, Cs):
- r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
-
- It is calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- Masses in the targeted barycenter.
- lambdas : list of float
- List of the `S` spaces' weights.
- T : list of S array-like of shape (ns, N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
- Cs : list of S array-like, shape (ns, ns)
- Metric cost matrices.
-
- Returns
- -------
- C : array-like, shape (`nt`, `nt`)
- Updated :math:`\mathbf{C}` matrix.
- """
- p = list_to_array(p)
- T = list_to_array(*T)
- Cs = list_to_array(*Cs)
- nx = get_backend(*Cs, *T, p)
-
- tmpsum = sum([
- lambdas[s] * nx.dot(
- nx.dot(T[s].T, Cs[s]),
- T[s]
- ) for s in range(len(T))
- ])
- ppt = nx.outer(p, p)
- return tmpsum / ppt
-
-
-def update_feature_matrix(lambdas, Ys, Ts, p):
- r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
-
-
- See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
- in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- masses in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights
- Ts : list of S array-like, shape (ns,N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
- Ys : list of S array-like, shape (d,ns)
- The features.
-
- Returns
- -------
- X : array-like, shape (`d`, `N`)
-
-
- .. _references-update-feature-matrix:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary RĂ©mi, Tavenard Romain and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- p = list_to_array(p)
- Ts = list_to_array(*Ts)
- Ys = list_to_array(*Ys)
- nx = get_backend(*Ys, *Ts, p)
-
- p = 1. / p
- tmpsum = sum([
- lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
- for s in range(len(Ts))
- ])
- return tmpsum