diff options
Diffstat (limited to 'ot/gromov/_gw.py')
-rw-r--r-- | ot/gromov/_gw.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index bc4719d..cdfa9a3 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -26,7 +26,7 @@ from ._utils import update_square_loss, update_kl_loss def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: @@ -39,6 +39,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} \mathbf{\gamma} &\geq 0 + Where : - :math:`\mathbf{C_1}`: Metric cost matrix in the source space @@ -68,7 +69,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. - Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). verbose : bool, optional Print information along iterations log : bool, optional @@ -170,7 +171,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" - Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: @@ -183,6 +184,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} \mathbf{\gamma} &\geq 0 + Where : - :math:`\mathbf{C_1}`: Metric cost matrix in the source space @@ -216,7 +218,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. - Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). verbose : bool, optional Print information along iterations log : bool, optional @@ -241,7 +243,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo gw_dist : float Gromov-Wasserstein distance log : dict - convergence information and Coupling marix + convergence information and Coupling matrix References ---------- @@ -310,6 +312,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= which can lead to copy overhead on GPU arrays. .. note:: All computations in the conjugate gradient solver are done with numpy to limit memory overhead. + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>` Parameters @@ -329,7 +332,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. - Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) armijo : bool, optional @@ -503,7 +506,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric Returns ------- fgw-distance : float - Fused gromov wasserstein distance for the given parameters. + Fused Gromov-Wasserstein distance for the given parameters. log : dict Log dictionary return only if log==True in parameters. |