summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2020-04-16 15:52:00 +0200
committerLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2020-04-16 15:52:00 +0200
commitef7c11a5df3cf6c82864472f0cfa65d6b2036f2f (patch)
tree1b9cc1d6c281a1eca885212aa3857d48ded4d695 /ot
parent18b64556aaa477b5499dc05110c96d32b04147ff (diff)
partial with python 3.8
Diffstat (limited to 'ot')
-rwxr-xr-xot/partial.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/ot/partial.py b/ot/partial.py
index 8698d9d..726a590 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -232,7 +232,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
- M_extended = np.ones((len(a_extended), len(b_extended))) * 0
+ M_extended = np.zeros((len(a_extended), len(b_extended)))
M_extended[-1, -1] = np.max(M) * 1e5
M_extended[:len(a), :len(b)] = M
@@ -510,9 +510,9 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
Gprev = G0
M = gwgrad_partial(C1, C2, G0)
- M[M < eps] = np.quantile(M[M > eps], thres)
+ M[M < eps] = np.quantile(M, thres)
- M_emd = np.ones(dim_G_extended) * np.max(M) * 1e2
+ M_emd = np.zeros(dim_G_extended)
M_emd[:len(p), :len(q)] = M
M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
M_emd = np.asarray(M_emd, dtype=np.float64)
@@ -729,8 +729,8 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
M = np.asarray(M, dtype=np.float64)
dim_a, dim_b = M.shape
- dx = np.ones(dim_a)
- dy = np.ones(dim_b)
+ dx = np.ones(dim_a, dtype=np.float64)
+ dy = np.ones(dim_b, dtype=np.float64)
if len(a) == 0:
a = np.ones(dim_a, dtype=np.float64) / dim_a
@@ -738,7 +738,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
b = np.ones(dim_b, dtype=np.float64) / dim_b
if m is None:
- m = np.min((np.sum(a), np.sum(b)))
+ m = np.min((np.sum(a), np.sum(b))) * 1.0
if m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")