summaryrefslogtreecommitdiff
path: root/ot/gromov/_utils.py
diff options
context:
space:
mode:
authorCédric Vincent-Cuaz <cedvincentcuaz@gmail.com>2023-06-12 12:01:48 +0200
committerGitHub <noreply@github.com>2023-06-12 12:01:48 +0200
commit9076f02903ba2fb9ea9fe704764a755cad8dcd63 (patch)
treeb7fda84880c5dabd1c441a1655741493e0683342 /ot/gromov/_utils.py
parentf0dab2f684f4fc768fd50e0b70918e075dcdd0f3 (diff)
[FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455)upstream/latest
* add entropic fgw + fgw bary + srgw + srfgw with tests * add exemples for entropic srgw - srfgw solvers * add PPA solvers for GW/FGW + complete previous commits * update readme * add tests * add examples + tests + warning in entropic solvers + releases * reduce testing runtimes for test_gromov * fix conflicts * optional marginals * improve coverage * gromov doc harmonization * fix pep8 * complete optional marginal for entropic srfgw --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/gromov/_utils.py')
-rw-r--r--ot/gromov/_utils.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py
index ef8cd88..0b8bb00 100644
--- a/ot/gromov/_utils.py
+++ b/ot/gromov/_utils.py
@@ -324,6 +324,49 @@ def update_kl_loss(p, lambdas, T, Cs):
return nx.exp(tmpsum / ppt)
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
+
+
+ See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
+ in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights
+ Ts : list of S array-like, shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+ Ys : list of S array-like, shape (d,ns)
+ The features.
+
+ Returns
+ -------
+ X : array-like, shape (`d`, `N`)
+
+
+ .. _references-update-feature-matrix:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ p = list_to_array(p)
+ Ts = list_to_array(*Ts)
+ Ys = list_to_array(*Ys)
+ nx = get_backend(*Ys, *Ts, p)
+
+ p = 1. / p
+ tmpsum = sum([
+ lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
+ for s in range(len(Ts))
+ ])
+ return tmpsum
+
+
def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation