summaryrefslogtreecommitdiff
path: root/ot/partial.py
diff options
context:
space:
mode:
authorLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2020-04-15 15:32:56 +0200
committerLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2020-04-15 15:32:56 +0200
commit8c724ad3579959e9d369c0b7fbaa22ea19ced614 (patch)
tree528d280492f33f71eac2d6522e0e3c05a4ae8568 /ot/partial.py
parentfff2463aafd58343c8bc2ed7875622e16a8c1cee (diff)
partial with tests
Diffstat (limited to 'ot/partial.py')
-rwxr-xr-xot/partial.py23
1 files changed, 11 insertions, 12 deletions
diff --git a/ot/partial.py b/ot/partial.py
index 746f337..3425acb 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -9,12 +9,11 @@ Partial OT
import numpy as np
-from ot.lp import emd
+from .lp import emd
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
**kwargs):
-
r"""
Solves the partial optimal transport problem for the quadratic cost
and returns the OT plan
@@ -136,7 +135,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['cost'] = np.sum(gamma*M)
+ log_emd['cost'] = np.sum(gamma * M)
if log:
return gamma, log_emd
else:
@@ -233,7 +232,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
- M_extended = np.ones((len(a_extended), len(b_extended))) * np.max(M) * 1e2
+ M_extended = np.ones((len(a_extended), len(b_extended))) * 0
M_extended[-1, -1] = np.max(M) * 1e5
M_extended[:len(a), :len(b)] = M
@@ -381,7 +380,7 @@ def gwloss_partial(C1, C2, T):
Returns
-------
- GW loss
+ GW loss
"""
g = gwgrad_partial(C1, C2, T) * 0.5
return np.sum(g * T)
@@ -432,7 +431,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
thres : float, optional
- quantile of the gradient matrix to populate the cost matrix when 0
+ quantile of the gradient matrix to populate the cost matrix when 0
(default: 1)
numItermax : int, optional
Max number of iterations
@@ -566,7 +565,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
where :
- M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term
+ - :math:`\Omega` is the entropic regularization term
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are the sample weights
- m is the amount of mass to be transported
@@ -591,7 +590,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
thres : float, optional
- quantile of the gradient matrix to populate the cost matrix when 0
+ quantile of the gradient matrix to populate the cost matrix when 0
(default: 1)
numItermax : int, optional
Max number of iterations
@@ -666,7 +665,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
where :
- M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term
+ - :math:`\Omega` is the entropic regularization term
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are the sample weights
- m is the amount of mass to be transported
@@ -754,7 +753,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- np.multiply(K, m/np.sum(K), out=K)
+ np.multiply(K, m / np.sum(K), out=K)
err, cpt = 1, 0
@@ -809,7 +808,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
- C2 is the metric cost matrix in the target space
- p and q are the sample weights
- L : quadratic loss function
- - :math:`\Omega` is the entropic regularization term
+ - :math:`\Omega` is the entropic regularization term
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- m is the amount of mass to be transported
@@ -944,7 +943,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
- C2 is the metric cost matrix in the target space
- p and q are the sample weights
- L : quadratic loss function
- - :math:`\Omega` is the entropic regularization term
+ - :math:`\Omega` is the entropic regularization term
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- m is the amount of mass to be transported