diff options
author | Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> | 2023-06-12 12:01:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-12 12:01:48 +0200 |
commit | 9076f02903ba2fb9ea9fe704764a755cad8dcd63 (patch) | |
tree | b7fda84880c5dabd1c441a1655741493e0683342 /ot/gromov/_utils.py | |
parent | f0dab2f684f4fc768fd50e0b70918e075dcdd0f3 (diff) |
[FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455)upstream/latest
* add entropic fgw + fgw bary + srgw + srfgw with tests
* add exemples for entropic srgw - srfgw solvers
* add PPA solvers for GW/FGW + complete previous commits
* update readme
* add tests
* add examples + tests + warning in entropic solvers + releases
* reduce testing runtimes for test_gromov
* fix conflicts
* optional marginals
* improve coverage
* gromov doc harmonization
* fix pep8
* complete optional marginal for entropic srfgw
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/gromov/_utils.py')
-rw-r--r-- | ot/gromov/_utils.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index ef8cd88..0b8bb00 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -324,6 +324,49 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) +def update_feature_matrix(lambdas, Ys, Ts, p): + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. + + + See "Solving the barycenter problem with Block Coordinate Descent (BCD)" + in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + masses in the targeted barycenter + lambdas : list of float + List of the `S` spaces' weights + Ts : list of S array-like, shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + Ys : list of S array-like, shape (d,ns) + The features. + + Returns + ------- + X : array-like, shape (`d`, `N`) + + + .. _references-update-feature-matrix: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + p = list_to_array(p) + Ts = list_to_array(*Ts) + Ys = list_to_array(*Ys) + nx = get_backend(*Ys, *Ts, p) + + p = 1. / p + tmpsum = sum([ + lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] + for s in range(len(Ts)) + ]) + return tmpsum + + def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation |