diff options
author | Laetitia Chapel <laetitia.chapel@univ-ubs.fr> | 2020-04-15 15:32:56 +0200 |
---|---|---|
committer | Laetitia Chapel <laetitia.chapel@univ-ubs.fr> | 2020-04-15 15:32:56 +0200 |
commit | 8c724ad3579959e9d369c0b7fbaa22ea19ced614 (patch) | |
tree | 528d280492f33f71eac2d6522e0e3c05a4ae8568 /ot/partial.py | |
parent | fff2463aafd58343c8bc2ed7875622e16a8c1cee (diff) |
partial with tests
Diffstat (limited to 'ot/partial.py')
-rwxr-xr-x | ot/partial.py | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/ot/partial.py b/ot/partial.py index 746f337..3425acb 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -9,12 +9,11 @@ Partial OT import numpy as np -from ot.lp import emd +from .lp import emd def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, **kwargs): - r""" Solves the partial optimal transport problem for the quadratic cost and returns the OT plan @@ -136,7 +135,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['cost'] = np.sum(gamma*M) + log_emd['cost'] = np.sum(gamma * M) if log: return gamma, log_emd else: @@ -233,7 +232,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) - M_extended = np.ones((len(a_extended), len(b_extended))) * np.max(M) * 1e2 + M_extended = np.ones((len(a_extended), len(b_extended))) * 0 M_extended[-1, -1] = np.max(M) * 1e5 M_extended[:len(a), :len(b)] = M @@ -381,7 +380,7 @@ def gwloss_partial(C1, C2, T): Returns ------- - GW loss + GW loss """ g = gwgrad_partial(C1, C2, T) * 0.5 return np.sum(g * T) @@ -432,7 +431,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix thres : float, optional - quantile of the gradient matrix to populate the cost matrix when 0 + quantile of the gradient matrix to populate the cost matrix when 0 (default: 1) numItermax : int, optional Max number of iterations @@ -566,7 +565,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, where : - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are the sample weights - m is the amount of mass to be transported @@ -591,7 +590,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix thres : float, optional - quantile of the gradient matrix to populate the cost matrix when 0 + quantile of the gradient matrix to populate the cost matrix when 0 (default: 1) numItermax : int, optional Max number of iterations @@ -666,7 +665,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, where : - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are the sample weights - m is the amount of mass to be transported @@ -754,7 +753,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - np.multiply(K, m/np.sum(K), out=K) + np.multiply(K, m / np.sum(K), out=K) err, cpt = 1, 0 @@ -809,7 +808,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, - C2 is the metric cost matrix in the target space - p and q are the sample weights - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - m is the amount of mass to be transported @@ -944,7 +943,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, - C2 is the metric cost matrix in the target space - p and q are the sample weights - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term + - :math:`\Omega` is the entropic regularization term :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - m is the amount of mass to be transported |